Spark UDAF實現舉例 -- average pooling

野路子程式設計師發表於2020-12-31

1.UDAF定義

spark中的UDF(UserDefinedFunction)大家都不會陌生, UDF其實就是將一個普通的函式, 包裝為可以按 操作DataFrame中指定Columns的函式.

例如, 對某一列的所有元素進行+1操作, 它對應mapreduce操作中的map操作. 這種操作有的主要特點是:

  • 行與行之間的操作是獨立的, 可以非常方便的平行計算
  • 每一行的操作完成後, map的任務就完成了, 直接將結果返回就行, 它是一種”無狀態的“

但是UDAF(UserDefinedAggregateFunction)則不同, 由於存在聚合(Aggregate)操作, 它對應mapreduce操作中的reduce操作. SparkSQL中有很多現成的聚合函式, 常用的sum, count, avg等等都是. 這種操作的主要特點是:

  • 每一輪reduce之間可以是並行, 但是多輪reduce的執行是序列的, 下一輪依靠前一輪的結果, 它是一種“有狀態的”, 需要記錄中間的計算結果

分析上圖, 96 => (96, 1)這一步是一個map操作, 給每個樣本新增一個1, 表示它的數量. 它們之間的計算是獨立的, 也不影響資料的行數. 然後(96, 1)和(54, 1)求和, 得到(150, 2), 它是一輪reduce的其中一箇中間結果, 等三個中間結果都結束了, 才能繼續後續的reduce, 得到最終的reduce結果(303, 6), 因此完整的reduce需要記錄並不斷更新中間結果.

2.向量平均(average pooling)

向量平均是個很常用的操作, 比如我們現在有1000個64維的向量, 想要求這1000個點的中心點. 通常來說我們不會用64列float column去儲存一個向量, 因此無法使用原生的avg函式.

下面介紹如何自定義一個avgvector函式, 去處理array[float] column的平均值計算問題. 通過這個例子學會如何在spark下實現自定義的聚合函式

2.1 average的並行化

average演算法非常簡單, 求個和, 然後除以樣本個數就好了. 它的並行化也很好理解

  • reduce的過程只進行sum的累積和樣本數num的累積, 在最後一步將sum/num

因此我們的在reduce的過程中, 需要時刻記錄當前task處理的樣本的個數, 和它們的和.

由於這樣的原因, 不像UDF只需要定義一個函式就可以, UDAF通常需要定義一個類, 用來儲存中間結果

2.2 程式碼實現

// 從基類UserDefinedAggregateFunction繼承
class VectorMean64 extends UserDefinedAggregateFunction {
  // 定義輸入的格式
  // 這個函式將會處理的那一列的資料型別, 因為是64維的向量, 因此是Array[Float]
  override def inputSchema: org.apache.spark.sql.types.StructType =
    StructType(StructField("vector", ArrayType(FloatType)) :: Nil)

  // 這個就是上面提到的狀態
  // 在reduce過程中, 需要記錄的中間結果. vector_count即為已經統計的向量個數, 而vector_sum即為已經統計的向量的和
  override def bufferSchema: StructType =
    StructType(
      StructField("vector_count", IntegerType) ::
        StructField("vector_sum", ArrayType(FloatType)) :: Nil)

  // 最終的輸出格式
  // 既然是求平均, 最後當然還是一個向量, 依然是Array[Float]
  override def dataType: DataType = ArrayType(FloatType)

  override def deterministic: Boolean = true

  // 初始化
  // buffer的格式即為bufferSchema, 因此buffer(0)就是向量個數, 初始化當然是0, buffer(1)為向量和, 初始化為零向量
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0
    buffer(1) = Array.fill[Float](64)(0).toSeq
  }

  // 定義reduce的更新操作: 如何根據一行新資料, 更新一個聚合buffer的中間結果
  // 一行資料是一個向量, 因此需要將count+1, 然後sum+新向量
  // addTwoEmb為向量相加的基本實現
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getInt(0) + 1

    val inputVector = input.getAs[Seq[Float]](0)
    buffer(1) = addTwoEmb(buffer.getAs[Seq[Float]](1), inputVector)
  }

  // 定義reduce的merge操作: 兩個buffer結果合併到其中一個bufer上
  // 兩個buffer各自統計的樣本個數相加; 兩個buffer各自的sum也相加
  // 注意: 為什麼buffer1和buffer2的資料型別不一樣?一個是MutableAggregationBuffer, 一個是Row
  // 因為: 在將所有中間task的結果進行reduce的過程中, 兩兩合併時是將一個結果合到另外一個上面, 因此一個是mutable的, 它們兩者的schema其實是一樣的, 都對應bufferSchema
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getInt(0) + buffer2.getInt(0)
    buffer1(1) = addTwoEmb(buffer1.getAs[Seq[Float]](1), buffer2.getAs[Seq[Float]](1))
  }

  // 最終的結果, 依賴最終的buffer中的資料計算的到, 就是將sum/count
  override def evaluate(buffer: Row): Any = {
    val result = buffer.getAs[Seq[Float]](1).toArray
    val count = buffer.getInt(0)
    for (i <- result.indices) {
      result(i) /= (count + 1)
    }
    result.toSeq
  }

	// 向量相加
  private def addTwoEmb(emb1: Seq[Float], emb2: Seq[Float]): Seq[Float] = {
    val result = Array.fill[Float](emb1.length)(0)
    for (i <- emb1.indices) {
      result(i) = emb1(i) + emb2(i)
    }
    result.toSeq
  }

解釋可以參考上面的程式碼註釋. 核心就是定義四個模組:

  • 中間結果的格式 - bufferSchema
  • 將一行資料更新到中間結果buffer中 - update
  • 將兩個中間結果buffer合併 - merge
  • 從最後的buffer計算需要的結果 - evaluate

2.3 使用

// 註冊一下, 使其可以在Spark SQL中使用
spark.udf.register("avgVector64", new VectorMean64)
spark.sql("""
|select group_id, avgVector64(embedding) as avg_embedding
|from embedding_table_name
|group by group_id
""".stripMargin)

// 當然不註冊也可以用, 只是不能在SQL中用, 可以直接用來操作DataFrame
val avgVector64 = new VectorMean64
val df = spark.sql("select group_id, embedding from embedding_table_name")
df.groupBy("group_id").agg(avgVector64(col("embedding")))

參考

https://docs.databricks.com/spark/latest/spark-sql/udaf-scala.html

相關文章