pyspark與py4j執行緒模型簡析

post200發表於2021-09-09

事由

上週工作中遇到一個bug,現象是一個spark streaming的job會不定期地hang住,不退出也不繼續執行。這個job經是用pyspark寫的,以kafka為資料來源,會在每個batch結束時將統計結果寫入mysql。經過排查,我們在driver程式中發現有有若干執行緒都出於Sl狀態(睡眠狀態),進而使用gdb除錯發現了一處死鎖。

這是MySQLdb庫舊版本中的一處bug,在此不再贅述,有興趣的可以看。不過這倒是提起了我對另外一件事的興趣,就是driver程式——嚴格的說應該是driver程式的python子程式——中的這些執行緒是從哪來的?當然,這些執行緒的存在很容易理解,我們開啟了spark.streaming.concurrentJobs引數,有多個batch可以同時執行,每個執行緒對應一個batch。但翻遍pyspark的python程式碼,都沒有找到有相關執行緒啟動的地方,於是簡單調研了一下pyspark到底是怎麼工作的,做個記錄。

本文概括

  1. Py4J的執行緒模型

  2. pyspark基本原理(driver端)

  3. CPython中的deque的執行緒安全

涉及軟體版本

  • spark: 2.1.0

  • py4j: 0.10.4

Py4J

spark是由scala語言編寫的,pyspark並沒有像豆瓣開源的用python復刻了spark,而只是提供了一層可以與原生JVM通訊的python API,就是python與JVM之間的這座橋樑。這個庫分為Java和Python兩部分,基本原理是:

  1. Java部分,透過py4j.GatewayServer監聽一個tcp socket(記做server_socket)

  2. Python部分,所有對JVM中物件的訪問或者方法的呼叫,都是透過py4j.JavaGateway向上面這個socket完成的。

  3. 另外,Python部分在建立JavaGateway物件時,可以選擇同時建立一個CallbackServer,它會在Python這冊監聽一個tcp socket(記做callback_socket),用來給Java回撥Python程式碼提供一條渠道。

  4. Py4J提供了一套文字協議用來在tcp socket間傳遞命令。

pyspark driver工作流程

  1. 首先,一個spark job被提交後,如果被判定這是一個python的job,spark driver會找到相應的入口,即org.apache.spark.deploy.PythonRunnermain函式,這個函式中會啟動GatewayServer

    // Launch a Py4J gateway server for the process to connect to; this will let it see our
    // Java system properties and such
    val gatewayServer = new py4j.GatewayServer(null, 0)    val thread = new Thread(new Runnable() {      override def run(): Unit = Utils.logUncaughtExceptions {
        gatewayServer.start()
      }
    })
    thread.setName("py4j-gateway-init")
    thread.setDaemon(true)
    thread.start()
  1. 然後,會建立一個Python子程式來執行我們提交上來的python入口檔案,並把剛才GatewayServer監聽的那個埠寫入到子程式的環境變數中去(這樣Python才知道要透過那個埠訪問JVM)

    // Launch Python process
    val builder = new ProcessBuilder((Seq(pythonExec, formattedPythonFile) ++ otherArgs).asJava)    val env = builder.environment()
    env.put("PYTHONPATH", pythonPath)    // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
    env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string
    env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)    // pass conf spark.pyspark.python to python process, the only way to pass info to
    // python process is through environment variable.
    sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _))
    builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
  1. Python子程式這邊,我們是透過pyspark提供的python API編寫的這個程式,在建立SparkContext(python)時,會初始化_gateway變數(JavaGateway物件)和_jvm變數(JVMView物件)

    @classmethod
    def _ensure_initialized(cls, instance=None, gateway=None, conf=None):
        """
        Checks whether a SparkContext is initialized or not.
        Throws error if a SparkContext is already running.
        """
        with SparkContext._lock:            if not SparkContext._gateway:
                SparkContext._gateway = gateway or launch_gateway(conf)
                SparkContext._jvm = SparkContext._gateway.jvm            if instance:                if (SparkContext._active_spark_context and
                        SparkContext._active_spark_context != instance):
                    currentMaster = SparkContext._active_spark_context.master
                    currentAppName = SparkContext._active_spark_context.appName
                    callsite = SparkContext._active_spark_context._callsite                    # Raise error if there is already a running Spark context
                    raise ValueError(                        "Cannot run multiple SparkContexts at once; "
                        "existing SparkContext(app=%s, master=%s)"
                        " created by %s at %s:%s "
                        % (currentAppName, currentMaster,
                            callsite.function, callsite.file, callsite.linenum))                else:
                    SparkContext._active_spark_context = instance

其中launch_gateway函式可見pyspark/java_gateway.py

  1. 上面初始化的這個_jvm物件值得一說,在pyspark中很多對JVM的呼叫其實都是透過它來進行的,比如很多python種對應的spark物件都有一個_jsc變數,它是JVM中的SparkContext物件在Python中的影子,它是這麼初始化的

    def _initialize_context(self, jconf):
        """
        Initialize SparkContext in function to allow subclass specific initialization
        """
        return self._jvm.JavaSparkContext(jconf)

這裡_jvm為什麼能直接呼叫JavaSparkContext這個JVM環境中的建構函式呢?我們看JVMView中的__getattr__方法:

    def __getattr__(self, name):
        if name == UserHelpAutoCompletion.KEY:            return UserHelpAutoCompletion()

        answer = self._gateway_client.send_command(
            proto.REFLECTION_COMMAND_NAME +
            proto.REFL_GET_UNKNOWN_SUB_COMMAND_NAME + name + "n" + self._id +            "n" + proto.END_COMMAND_PART)        if answer == proto.SUCCESS_PACKAGE:            return JavaPackage(name, self._gateway_client, jvm_id=self._id)        elif answer.startswith(proto.SUCCESS_CLASS):            return JavaClass(
                answer[proto.CLASS_FQN_START:], self._gateway_client)        else:            raise Py4JError("{0} does not exist in the JVM".format(name))

self._gateway_client.send_command其實就是向server_socket傳送訪問物件請求的命令了,最後根據響應值生成不同型別的影子物件,針對我們這裡的JavaSparkContext,就是一個JavaClass物件。這個系列的型別還包括了JavaMemberJavaPackage等等,他們也透過__getattr__來實現Java物件屬性訪問以及方法的呼叫。

  1. 我們剛才介紹Py4j時說過Python端在建立JavaGateway時,可以選擇同時建立一個CallbackClient,預設情況下,一個普通的pyspark job是不會啟動回撥服務的,因為用不著,所有的互動都是Python --> JVM這種模式的。那什麼時候需要呢?streaming job就需要(具體流程我們稍後介紹),這就(終於!)引出了我們今天主要討論的Py4J執行緒模型的問題。

Py4J執行緒模型

我們已經知道了Python與JVM雙方向的通訊分別是透過server_socketcallack_socket來完成的,這兩個socket的處理模型都是多執行緒模型,即,每收到一個連線就啟動一個執行緒來處理。我們只看Python --> JVM這條通路的情況,另外一邊是一樣的

Server端(Java)

    protected void processSocket(Socket socket) {        try {            this.lock.lock();            if(!this.isShutdown) {
                socket.setSoTimeout(this.readTimeout);
                Py4JServerConnection gatewayConnection = this.createConnection(this.gateway, socket);                this.connections.add(gatewayConnection);                this.fireConnectionStarted(gatewayConnection);
            }
        } catch (Exception var6) {            this.fireConnectionError(var6);
        } finally {            this.lock.unlock();
        }
    }

繼續看createConnection:

    protected Py4JServerConnection createConnection(Gateway gateway, Socket socket) throws IOException {
        GatewayConnection connection = new GatewayConnection(gateway, socket, this.customCommands, this.listeners);
        connection.startConnection();        return connection;
    }

其中connection.startConnection其實就是建立了一個新執行緒,來負責處理這個連線。

Client端(Python)

我們來看GatewayClient中的send_command方法:

    def send_command(self, command, retry=True, binary=False):
        """Sends a command to the JVM. This method is not intended to be
           called directly by Py4J users. It is usually called by
           :class:`JavaMember` instances.

        :param command: the `string` command to send to the JVM. The command
         must follow the Py4J protocol.

        :param retry: if `True`, the GatewayClient tries to resend a message
         if it fails.

        :param binary: if `True`, we won't wait for a Py4J-protocol response
         from the other end; we'll just return the raw connection to the
         caller. The caller becomes the owner of the connection, and is
         responsible for closing the connection (or returning it this
         `GatewayClient` pool using `_give_back_connection`).

        :rtype: the `string` answer received from the JVM (The answer follows
         the Py4J protocol). The guarded `GatewayConnection` is also returned
         if `binary` is `True`.
        """
        connection = self._get_connection()        try:
            response = connection.send_command(command)            if binary:                return response, self._create_connection_guard(connection)            else:
                self._give_back_connection(connection)        except Py4JNetworkError as pne:            if connection:
                reset = False
                if isinstance(pne.cause, socket.timeout):
                    reset = True
                connection.close(reset)            if self._should_retry(retry, connection, pne):
                logging.info("Exception while sending command.", exc_info=True)
                response = self.send_command(command, binary=binary)            else:
                logging.exception(                    "Exception while sending command.")
                response = proto.ERROR        return response

這裡這個self._get_connection是這麼實現的

    def _get_connection(self):        if not self.is_connected:
            raise Py4JNetworkError("Gateway is not connected.")        try:
            connection = self.deque.pop()
        except IndexError:
            connection = self._create_connection()        return connection

這裡使用了一個deque(也就是Python標準庫中的collections.deque)來維護一個連線池,如果有空閒的連線,就可以直接使用,如果沒有,就新建一個連線。現在問題來了,如果deque不是執行緒安全的,那麼這段程式碼在多執行緒環境就會有問題。那麼deque是不是執行緒安全的呢?

deque的執行緒安全

當然是了,Py4J當然不會犯這樣的低階錯誤,我們看標準庫的文件:

Deques support thread-safe, memory efficient appends and pops from either    side of the deque with approximately the same O(1) performance in either direction.

是執行緒安全的,不過措辭有點模糊,沒有明確指出哪些方法是執行緒安全的,不過可以明確的是至少append的pop都是。之所以去查一下,是因為我也有點含糊,因為Python標準庫還有另外一個Queue.Queue,在多執行緒程式設計中經常使用,肯定是執行緒安全的,於是很容易誤以為deque不是執行緒安全的,所以我們才要一個新的Queue。這個問題,推薦閱讀stackoverflow上——他的回答不是被採納的最高票,不過我認為他的回答比高票更有說服力

  1. 高票答案一直強調說deque是執行緒安全的這個事實是個意外,是CPython中存在GIL造成的,其他Python直譯器就不一定遵守。關於這一點我是不認同的,deque在CPython中的實現確實依賴的GIL才變成了執行緒安全的,但deque的雙端append的pop是執行緒安全的這件事是白紙黑字寫在Python文件中的,其他虛擬機器的實現必須遵守,否則就不能稱之為合格的Python實現。

  2. 那為什麼還要有一個內部顯式用了鎖來做執行緒同步的Queue.Queue呢?Jonathan給出的回答是Queueputget可以是blocking的,而deque不行,這樣一來,當你需要在多個執行緒中進行通訊時(比如最簡單的一個Producer - Consumer模式的實現),Queue往往是最佳選擇。

關於deque是否是執行緒安全這個問題,我將調研的結果寫在了這個知乎問題的答案下,就不在贅述了,這篇文章已經太長了。

關於Py4J執行緒模型的問題,還可以參考。

pyspark streaming與CallbackServer

剛才提到,如果是streaming的job,GatewayServer在初始化時會同時建立一個CallbackServer,提供JVM --> Python這條通路。

    @classmethod
    def _ensure_initialized(cls):
        SparkContext._ensure_initialized()
        gw = SparkContext._gateway

        java_import(gw.jvm, "org.apache.spark.streaming.*")
        java_import(gw.jvm, "org.apache.spark.streaming.api.java.*")
        java_import(gw.jvm, "org.apache.spark.streaming.api.python.*")        # start callback server
        # getattr will fallback to JVM, so we cannot test by hasattr()
        if "_callback_server" not in gw.__dict__ or gw._callback_server is None:
            gw.callback_server_parameters.eager_load = True
            gw.callback_server_parameters.daemonize = True
            gw.callback_server_parameters.daemonize_connections = True
            gw.callback_server_parameters.port = 0
            gw.start_callback_server(gw.callback_server_parameters)
            cbport = gw._callback_server.server_socket.getsockname()[1]
            gw._callback_server.port = cbport            # gateway with real port
            gw._python_proxy_port = gw._callback_server.port            # get the GatewayServer object in JVM by ID
            jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client)            # update the port of CallbackClient with real port
            jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port)        # register serializer for TransformFunction
        # it happens before creating SparkContext when loading from checkpointing
        cls._transformerSerializer = TransformFunctionSerializer(
            SparkContext._active_spark_context, CloudPickleSerializer(), gw)

為什麼需要這樣呢?一個streaming job通常需要呼叫foreachRDD,並提供一個函式,這個函式會在每個batch被回撥:

    def foreachRDD(self, func):
        """
        Apply a function to each RDD in this DStream.
        """
        if func.__code__.co_argcount == 1:
            old_func = func
            func = lambda t, rdd: old_func(rdd)
        jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer)
        api = self._ssc._jvm.PythonDStream
        api.callForeachRDD(self._jdstream, jfunc)

這裡,Python函式func被封裝成了一個TransformFunction物件,在scala端spark也定義了同樣介面一個trait:

/**
 * Interface for Python callback function which is used to transform RDDs
 */private[python] trait PythonTransformFunction {  def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]]  /**
   * Get the failure, if any, in the last call to `call`.
   *
   * @return the failure message if there was a failure, or `null` if there was no failure.
   */
  def getLastFailure: String}

這樣是Py4J提供的機制,這樣就可以讓JVM透過這個影子介面回撥Python中的物件了,下面就是scala中的callForeachRDD函式,它把PythonTransformFunction又封裝了一層成為scala中的TransformFunction, 但不管如何封裝,最後都會呼叫PythonTransformFunction介面中的call方法完成對Python的回撥。

  /**
   * helper function for DStream.foreachRDD(),
   * cannot be `foreachRDD`, it will confusing py4j
   */
  def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonTransformFunction) {
    val func = new TransformFunction((pfunc))
    jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time))
  }

所以,終於要回答這個問題了,我們一開始看到的driver中的多個執行緒是怎麼來的

  1. python呼叫foreachRDD提供一個TranformFunction給scala端

  2. scala端呼叫自己的foreachRDD進行正常的spark streaming作業

  3. 由於我們開啟了spark.streaming.concurrentJobs,多個batch可以同時執行,這在scala端是透過執行緒池來進行的,每個batch都需要回撥Python中的TranformFunction,而按照我們之前介紹的Py4J執行緒模型,多個併發的回撥會發現沒有可用的socket連線而生成新的,而在CallbackServer(Python)這端,每個新連線都會建立一個新執行緒來處理。這樣就出現了driver的Python程式中出現多個執行緒的現象。



作者:Garfieldog
連結:


來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/2249/viewspace-2819315/,如需轉載,請註明出處,否則將追究法律責任。

相關文章