PySpark原始碼解析,教你用Python呼叫高效Scala介面,搞定大規模資料分析

機器之心發表於2019-12-24

PySpark原始碼解析,教你用Python呼叫高效Scala介面,搞定大規模資料分析

眾所周知,Spark 框架主要是由 Scala 語言實現,同時也包含少量 Java 程式碼。Spark 面向使用者的程式設計介面,也是 Scala。然而,在資料科學領域,Python 一直佔據比較重要的地位,仍然有大量的資料工程師在使用各類 Python 資料處理和科學計算的庫,例如 numpy、Pandas、scikit-learn 等。同時,Python 語言的入門門檻也顯著低於 Scala。

為此,Spark 推出了 PySpark,在 Spark 框架上提供一套 Python 的介面,方便廣大資料科學家使用。本文主要從原始碼實現層面解析 PySpark 的實現原理,包括以下幾個方面:

  • PySpark 的多程式架構;

  • Python 端呼叫 Java、Scala 介面;

  • Python Driver 端 RDD、SQL 介面;

  • Executor 端程式間通訊和序列化;

  • Pandas UDF;

  • 總結。

PySpark專案地址:https://github.com/apache/spark/tree/master/python

1、PySpark 的多程式架構

PySpark 採用了 Python、JVM 程式分離的多程式架構,在 Driver、Executor 端均會同時有 Python、JVM 兩個程式。當通過 spark-submit 提交一個 PySpark 的 Python 指令碼時,Driver 端會直接執行這個 Python 指令碼,並從 Python 中啟動 JVM;而在 Python 中呼叫的 RDD 或者 DataFrame 的操作,會通過 Py4j 呼叫到 Java 的介面。

在 Executor 端恰好是反過來,首先由 Driver 啟動了 JVM 的 Executor 程式,然後在 JVM 中去啟動 Python 的子程式,用以執行 Python 的 UDF,這其中是使用了 socket 來做程式間通訊。總體的架構圖如下所示:

PySpark原始碼解析,教你用Python呼叫高效Scala介面,搞定大規模資料分析

2、Python Driver 如何呼叫 Java 的介面

上面提到,通過 spark-submit 提交 PySpark 作業後,Driver 端首先是執行使用者提交的 Python 指令碼,然而 Spark 提供的大多數 API 都是 Scala 或者 Java 的,那麼就需要能夠在 Python 中去呼叫 Java 介面。這裡 PySpark 使用了 Py4j 這個開源庫。當建立 Python 端的 SparkContext 物件時,實際會啟動 JVM,並建立一個 Scala 端的 SparkContext 物件。程式碼實現在 python/pyspark/context.py:

def _ensure_initialized(cls, instance=None, gateway=None, conf=None):
    """
    Checks whether a SparkContext is initialized or not.
    Throws error if a SparkContext is already running.
    """
    with SparkContext._lock:
        if not SparkContext._gateway:
            SparkContext._gateway = gateway or launch_gateway(conf)
            SparkContext._jvm = SparkContext._gateway.jvm

在 launch_gateway (python/pyspark/java_gateway.py) 中,首先啟動 JVM 程式:

SPARK_HOME = _find_spark_home()
# Launch the Py4j gateway using Spark's run command so that we pick up the
# proper classpath and settings from spark-env.sh
on_windows = platform.system() == "Windows"
script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit"
command = [os.path.join(SPARK_HOME, script)]

然後建立 JavaGateway 並 import 一些關鍵的 class:

gateway = JavaGateway(
        gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret,
                                             auto_convert=True))
# Import the classes used by PySpark
java_import(gateway.jvm, "org.apache.spark.SparkConf")
java_import(gateway.jvm, "org.apache.spark.api.java.*")
java_import(gateway.jvm, "org.apache.spark.api.python.*")
java_import(gateway.jvm, "org.apache.spark.ml.python.*")
java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
# TODO(davies): move into sql
java_import(gateway.jvm, "org.apache.spark.sql.*")
java_import(gateway.jvm, "org.apache.spark.sql.api.python.*")
java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
java_import(gateway.jvm, "scala.Tuple2")

拿到 JavaGateway 物件,即可以通過它的 jvm 屬性,去呼叫 Java 的類了,例如:

gateway = JavaGateway()

gateway = JavaGateway()
jvm = gateway.jvm
l = jvm.java.util.ArrayList()

然後會繼續建立 JVM 中的 SparkContext 物件:

def _initialize_context(self, jconf):
    """
    Initialize SparkContext in function to allow subclass specific initialization
    """
    return self._jvm.JavaSparkContext(jconf)

# Create the Java SparkContext through Py4J
self._jsc = jsc or self._initialize_context(self._conf._jconf)

3、Python Driver 端的 RDD、SQL 介面

在 PySpark 中,繼續初始化一些 Python 和 JVM 的環境後,Python 端的 SparkContext 物件就建立好了,它實際是對 JVM 端介面的一層封裝。和 Scala API 類似,SparkContext 物件也提供了各類建立 RDD 的介面,和 Scala API 基本一一對應,我們來看一些例子。

def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None,
                     valueConverter=None, conf=None, batchSize=0):
    jconf = self._dictToJavaMap(conf)
    jrdd = self._jvm.PythonRDD.newAPIHadoopFile(self._jsc, path, inputFormatClass, keyClass,
                                                valueClass, keyConverter, valueConverter,
                                                jconf, batchSize)
    return RDD(jrdd, self)

可以看到,這裡 Python 端基本就是直接呼叫了 Java/Scala 介面。而 PythonRDD (core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala),則是一個 Scala 中封裝的伴生物件,提供了常用的 RDD IO 相關的介面。另外一些介面會通過 self._jsc 物件去建立 RDD。其中 self._jsc 就是 JVM 中的 SparkContext 物件。拿到 RDD 物件之後,可以像 Scala、Java API 一樣,對 RDD 進行各類操作,這些大部分都封裝在 python/pyspark/rdd.py 中。

這裡的程式碼中出現了 jrdd 這樣一個物件,這實際上是 Scala 為提供 Java 互操作的 RDD 的一個封裝,用來提供 Java 的 RDD 介面,具體實現在 core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala 中。可以看到每個 Python 的 RDD 物件需要用一個 JavaRDD 物件去建立。

對於 DataFrame 介面,Python 層也同樣提供了 SparkSession、DataFrame 物件,它們也都是對 Java 層介面的封裝,這裡不一一贅述。

4、Executor 端程式間通訊和序列化

對於 Spark 內建的運算元,在 Python 中呼叫 RDD、DataFrame 的介面後,從上文可以看出會通過 JVM 去呼叫到 Scala 的介面,最後執行和直接使用 Scala 並無區別。而對於需要使用 UDF 的情形,在 Executor 端就需要啟動一個 Python worker 子程式,然後執行 UDF 的邏輯。那麼 Spark 是怎樣判斷需要啟動子程式的呢?

在 Spark 編譯使用者的 DAG 的時候,Catalyst Optimizer 會建立 BatchEvalPython 或者 ArrowEvalPython 這樣的 Logical Operator,隨後會被轉換成 PythonEvals 這個 Physical Operator。在 PythonEvals(sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala)中:

object PythonEvals extends Strategy {
  override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case ArrowEvalPython(udfs, output, child, evalType) =>
      ArrowEvalPythonExec(udfs, output, planLater(child), evalType) :: Nil
    case BatchEvalPython(udfs, output, child) =>
      BatchEvalPythonExec(udfs, output, planLater(child)) :: Nil
    case _ =>
      Nil
  }
}

建立了 ArrowEvalPythonExec 或者 BatchEvalPythonExec,而這二者內部會建立 ArrowPythonRunner、PythonUDFRunner 等類的物件例項,並呼叫了它們的 compute 方法。由於它們都繼承了 BasePythonRunner,基類的 compute 方法中會去啟動 Python 子程式:

def compute(
      inputIterator: Iterator[IN],
      partitionIndex: Int,
      context: TaskContext): Iterator[OUT] = {
  // ......

  val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap)
  // Start a thread to feed the process input from our parent's iterator
  val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context)
  writerThread.start()
  val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))

  val stdoutIterator = newReaderIterator(
    stream, writerThread, startTime, env, worker, releasedOrClosed, context)
  new InterruptibleIterator(context, stdoutIterator)

這裡 env.createPythonWorker 會通過 PythonWorkerFactory(core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala)去啟動 Python 程式。Executor 端啟動 Python 子程式後,會建立一個 socket 與 Python 建立連線。所有 RDD 的資料都要序列化後,通過 socket 傳送,而結果資料需要同樣的方式序列化傳回 JVM。

對於直接使用 RDD 的計算,或者沒有開啟 spark.sql.execution.arrow.enabled 的 DataFrame,是將輸入資料按行傳送給 Python,可想而知,這樣效率極低。

在 Spark 2.2 後提供了基於 Arrow 的序列化、反序列化的機制(從 3.0 起是預設開啟),從 JVM 傳送資料到 Python 程式的程式碼在 sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala。這個類主要是重寫了 newWriterThread 這個方法,使用了 ArrowWriter 向 socket 傳送資料:

val arrowWriter = ArrowWriter.create(root)
val writer = new ArrowStreamWriter(root, null, dataOut)
writer.start()

while (inputIterator.hasNext) {
val nextBatch = inputIterator.next()

while (nextBatch.hasNext) {
    arrowWriter.write(nextBatch.next())
}

arrowWriter.finish()
writer.writeBatch()
arrowWriter.reset()

可以看到,每次取出一個 batch,填充給 ArrowWriter,實際資料會儲存在 root 物件中,然後由 ArrowStreamWriter 將 root 物件中的整個 batch 的資料寫入到 socket 的 DataOutputStream 中去。ArrowStreamWriter 會呼叫 writeBatch 方法去序列化訊息並寫資料,程式碼參考 ArrowWriter.java#L131。

protected ArrowBlock writeRecordBatch(ArrowRecordBatch batch) throws IOException {
  ArrowBlock block = MessageSerializer.serialize(out, batch, option);
  LOGGER.debug("RecordBatch at {}, metadata: {}, body: {}",
      block.getOffset(), block.getMetadataLength(), block.getBodyLength());
  return block;
}

在 MessageSerializer 中,使用了 flatbuffer 來序列化資料。flatbuffer 是一種比較高效的序列化協議,它的主要優點是反序列化的時候,不需要解碼,可以直接通過裸 buffer 來讀取欄位,可以認為反序列化的開銷為零。我們來看看 Python 程式收到訊息後是如何反序列化的。

Python 子程式實際上是執行了 worker.py 的 main 函式 (python/pyspark/worker.py):

if __name__ == '__main__':
    # Read information about how to connect back to the JVM from the environment.
    java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
    auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
    (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
    main(sock_file, sock_file)

這裡會去向 JVM 建立連線,並從 socket 中讀取指令和資料。對於如何進行序列化、反序列化,是通過 UDF 的型別來區分:

eval_type = read_int(infile)
if eval_type == PythonEvalType.NON_UDF:
    func, profiler, deserializer, serializer = read_command(pickleSer, infile)
else:
    func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type)

在 read_udfs 中,如果是 PANDAS 類的 UDF,會建立 ArrowStreamPandasUDFSerializer,其餘的 UDF 型別建立 BatchedSerializer。我們來看看 ArrowStreamPandasUDFSerializer(python/pyspark/serializers.py):

def dump_stream(self, iterator, stream):
    import pyarrow as pa
    writer = None
    try:
        for batch in iterator:
            if writer is None:
                writer = pa.RecordBatchStreamWriter(stream, batch.schema)
            writer.write_batch(batch)
    finally:
        if writer is not None:
            writer.close()

def load_stream(self, stream):
    import pyarrow as pa
    reader = pa.ipc.open_stream(stream)
    for batch in reader:
        yield batch

可以看到,這裡雙向的序列化、反序列化,都是呼叫了 PyArrow 的 ipc 的方法,和前面看到的 Scala 端是正好對應的,也是按 batch 來讀寫資料。對於 Pandas 的 UDF,讀到一個 batch 後,會將 Arrow 的 batch 轉換成 Pandas Series。

def arrow_to_pandas(self, arrow_column):
    from pyspark.sql.types import _check_series_localize_timestamps

    # If the given column is a date type column, creates a series of datetime.date directly
    # instead of creating datetime64[ns] as intermediate data to avoid overflow caused by
    # datetime64[ns] type handling.
    s = arrow_column.to_pandas(date_as_object=True)

    s = _check_series_localize_timestamps(s, self._timezone)
    return s

def load_stream(self, stream):
    """
    Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
    """
    batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
    import pyarrow as pa
    for batch in batches:
        yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]

5、Pandas UDF

前面我們已經看到,PySpark 提供了基於 Arrow 的程式間通訊來提高效率,那麼對於使用者在 Python 層的 UDF,是不是也能直接使用到這種高效的記憶體格式呢?答案是肯定的,這就是 PySpark 推出的 Pandas UDF。區別於以往以行為單位的 UDF,Pandas UDF 是以一個 Pandas Series 為單位,batch 的大小可以由 spark.sql.execution.arrow.maxRecordsPerBatch 這個引數來控制。這是一個來自官方文件的示例:

def multiply_func(a, b):
    return a * b

multiply = pandas_udf(multiply_func, returnType=LongType())

df.select(multiply(col("x"), col("x"))).show()

上文已經解析過,PySpark 會將 DataFrame 以 Arrow 的方式傳遞給 Python 程式,Python 中會轉換為 Pandas Series,傳遞給使用者的 UDF。在 Pandas UDF 中,可以使用 Pandas 的 API 來完成計算,在易用性和效能上都得到了很大的提升。

6、總結

PySpark 為使用者提供了 Python 層對 RDD、DataFrame 的操作介面,同時也支援了 UDF,通過 Arrow、Pandas 向量化的執行,對提升大規模資料處理的吞吐是非常重要的,一方面可以讓資料以向量的形式進行計算,提升 cache 命中率,降低函式呼叫的開銷,另一方面對於一些 IO 的操作,也可以降低網路延遲對效能的影響。

然而 PySpark 仍然存在著一些不足,主要有:

  • 程式間通訊消耗額外的 CPU 資源;

  • 程式設計介面仍然需要理解 Spark 的分散式計算原理;

  • Pandas UDF 對返回值有一定的限制,返回多列資料不太方便。

Databricks 提出了新的 Koalas 介面來使得使用者可以以接近單機版 Pandas 的形式來編寫分散式的 Spark 計算作業,對資料科學家會更加友好。而 Vectorized Execution 的推進,有望在 Spark 內部一切資料都是用 Arrow 的格式來存放,對跨語言支援將會更加友好。同時也能看到,在這裡仍然有很大的效能、易用性的優化空間,這也是我們平臺近期的主要發力方向之一。

陳緒,匯量科技(Mobvista)高階演算法科學家,負責匯量科技大規模資料智慧計算引擎和平臺的研發工作。在此之前陳緒是阿里巴巴高階技術專家,負責阿里集團大規模機器學習平臺的研發。

相關文章