擴充套件Spark ML來構建你自己的模型和變換器型別



 作者 | Holden Karau 

儘管Spark ML管道提供了各種各樣的演算法,你仍可能想要額外的功能,並且不脫離管道模型。在Spark Mllib中,這算不上什麼問題,你可以通過RDD的變換來實現你自己的演算法,並繼續下去。對於Spark ML 管道來說,同樣的方法是可行的,但是我們會失去一些管道所具備的優良特性,包括自動執行元演算法的能力,例如交叉驗證的引數搜尋。在本文中,你會從標準的wordcount例子入手(在大資料方面,你是不可能真正躲開wordcount例子的),瞭解到如何擴充套件Spark ML 管道模型。

為了將你自己的演算法加入Spark管道中來,你需要實現Estimator或者是Transformer,它們都實現了PipelineStage介面。對於那些不需要訓練的演算法,你可以實現Transformer介面,而對於那些需要訓練的演算法,你需要實現Estimator介面,它們都定義在org.apache.spark.ml下(都實現了基類 PipelineStage)。要說明的是,訓練並不是只限於複雜的機器學習模型,即使是最大最小值區間縮放器也需要訓練來確定範圍。如果你的演算法需要訓練,它們必須以Estimator來構建而不是Transformer。



class HardCodedWordCountStage(override val uid: String) extends Transformer {

def this() = this(Identifiable.randomUID(“hardcodedwordcount”))

def copy(extra: ParamMap): HardCodedWordCountStage = {



一個管道stage的起始以及拷貝代理如下:transformSchema 必須基於任何引數和一個輸入模式產生你的管道stage的期望輸出。考慮到已有欄位可能會被使用到,大部分管道stage只增加新的欄位,很少的一些會去掉之前的一些欄位。這有時候會導致輸出的結果包含比下游所需的資料多,反而會降低效能。如果發現你的管道中有這樣的問題,那麼你可以建立你自己的stage來去掉不需要的欄位。

除了產生輸出模式之外,transformSchema 方法還應該驗證輸入模式是否適合於該stage(例如,輸入列是否是期望的型別)。


override def transformSchema(schema: StructType): StructType = {

// Check that the input type is a string

val idx = schema.fieldIndex(“happy_pandas”)

val field = schema.fields(idx)

if (field.dataType != StringType) {

throw new Exception(s”Input type ${field.dataType} did not match input type StringType”)


// Add the return field

schema.add(StructField(“happy_panda_counts”, IntegerType, false))



def transform(df: Dataset[_]): DataFrame = {

val wordcount = udf { in: String => in.split(” “).size }







class ConfigurableWordCount(override val uid: String) extends Transformer {

final val inputCol= new Param[String](this, “inputCol”, “The input column”)

final val outputCol = new Param[String](this, “outputCol”, “The output column”)

; def setInputCol(value: String): this.type = set(inputCol, value)

def setOutputCol(value: String): this.type = set(outputCol, value)

def this() = this(Identifiable.randomUID(“configurablewordcount”))

def copy(extra: ParamMap): HardCodedWordCountStage = {



override def transformSchema(schema: StructType): StructType = {

// Check that the input type is a string

val idx = schema.fieldIndex($(inputCol))

val field = schema.fields(idx)

if (field.dataType != StringType) {

throw new Exception(s”Input type ${field.dataType} did not match input type StringType”)


// Add the return field

schema.add(StructField($(outputCol), IntegerType, false))


def transform(df: Dataset[_]): DataFrame = {

val wordcount = udf { in: String => in.split(” “).size }

df.select(col(“*”), wordcount(df.col($(inputCol))).as($(outputCol)))



不需要訓練的演算法可以通過Estimator介面來實現,儘管對於許多演算法而言, org.apache.spark.ml.Predictor 或者 org.apache.spark.ml.classificationClassifier 這些幫助類更容易實現。Estimator 和 Transformer介面的主要區別是,它不再直接在輸入上進行變換操作,而是會首先在一個train 方法裡面進行一個步驟——訓練。一個字串索引器是你可以實現的最簡單的estimator之一。儘管在Spark中可以直接使用了,它仍然是用於說明如何使用estimator介面的非常好的例子。

trait SimpleIndexerParams extends Params {

final val inputCol= new Param[String](this, “inputCol”, “The input column”)

final val outputCol = new Param[String](this, “outputCol”, “The output column”)


class SimpleIndexer(override val uid: String) extends Estimator[SimpleIndexerModel] with SimpleIndexerParams {

def setInputCol(value: String) = set(inputCol, value)

def setOutputCol(value: String) = set(outputCol, value)

def this() = this(Identifiable.randomUID(“simpleindexer”))

override def copy(extra: ParamMap): SimpleIndexer = {



override def transformSchema(schema: StructType): StructType = {

// Check that the input type is a string

val idx = schema.fieldIndex($(inputCol))

val field = schema.fields(idx)

if (field.dataType != StringType) {

throw new Exception(s”Input type ${field.dataType} did not match input type StringType”)


// Add the return field

schema.add(StructField($(outputCol), IntegerType, false))


override def fit(dataset: Dataset[_]): SimpleIndexerModel = {

import dataset.sparkSession.implicits._

val words = dataset.select(dataset($(inputCol)).as[String]).distinct


new SimpleIndexerModel(uid, words)

; }


class SimpleIndexerModel(

override val uid: String, words: Array[String]) extends Model[SimpleIndexerModel] with SimpleIndexerParams {

override def copy(extra: ParamMap): SimpleIndexerModel = {



private val labelToIndex: Map[String, Double] = words.zipWithIndex.

map{case (x, y) => (x, y.toDouble)}.toMap

override def transformSchema(schema: StructType): StructType = {

// Check that the input type is a string

val idx = schema.fieldIndex($(inputCol))

val field = schema.fields(idx)

if (field.dataType != StringType) {

throw new Exception(s”Input type ${field.dataType} did not match input type StringType”)


// Add the return field

schema.add(StructField($(outputCol), IntegerType, false))


override def transform(dataset: Dataset[_]): DataFrame = {

val indexer = udf { label: String => labelToIndex(label) }






Predictor 介面增加了兩個最常用的引數(輸入和輸出列)作為標記列、特徵列和預測列——並且自動地幫我們處理模式的變換。

Classifier 介面基本上如出一轍,除了它還增加了一個rawPredictionColumn ,並且提供了工具來檢測類別的數量(getNumClasses方法)以及將輸入的 DataFrame 轉化為一個LabeledPoints的RDD(使其更容易來封裝傳統的Mllib分類演算法)。


// Simple Bernouli Naive Bayes classifier – no sanity checks for brevity

// Example only – not for production use.

class SimpleNaiveBayes(val uid: String)

extends Classifier[Vector, SimpleNaiveBayes, SimpleNaiveBayesModel] {

def this() = this(Identifiable.randomUID(“simple-naive-bayes”))

override def train(ds: Dataset[_]): SimpleNaiveBayesModel = {

import ds.sparkSession.implicits._


// Note: you can use getNumClasses and extractLabeledPoints to get an RDD instead

// Using the RDD approach is common when integrating with legacy machine learning code

// or iterative algorithms which can create large query plans.

// Here we use Datasets since neither of those apply.

// Compute the number of documents

val numDocs = ds.count

// Get the number of classes.

// Note this estimator assumes they start at 0 and go to numClasses

val numClasses = getNumClasses(ds)

// Get the number of features by peaking at the first row

val numFeatures: Integer = ds.select(col($(featuresCol))).head


// Determine the number of records for each class

val groupedByLabel = ds.select(col($(labelCol)).as[Double]).groupByKey(x => x)

val classCounts = groupedByLabel.agg(count(“*”).as[Long])


// Select the labels and features so we can more easily map over them.

// Note: we do this as a DataFrame using the untyped API because the Vector

// UDT is no longer public.

val df = ds.select(col($(labelCol)).cast(DoubleType), col($(featuresCol)))

// Figure out the non-zero frequency of each feature for each label and

// output label index pairs using a case clas to make it easier to work with.

val labelCounts: Dataset[LabeledToken] = df.flatMap {

case Row(label: Double, features: Vector) =>

features.toArray.zip(Stream from 1)

.filter{vIdx => vIdx._2 == 1.0}

.map{case (v, idx) => LabeledToken(label, idx)}


// Use the typed Dataset aggregation API to count the number of non-zero

// features for each label-feature index.

val aggregatedCounts: Array[((Double, Integer), Long)] = labelCounts

.groupByKey(x => (x.label, x.index))


val theta = Array.fill(numClasses)(new Array[Double](numFeatures))

// Compute the denominator for the general prioirs

val piLogDenom = math.log(numDocs + numClasses)

// Compute the priors for each class

val pi = classCounts.map{case(_, cc) =>

math.log(cc.toDouble) – piLogDenom }.toArray

// For each label/feature update the probabilities

aggregatedCounts.foreach{case ((label, featureIndex), count) =>

// log of number of documents for this label + 2.0 (smoothing)

val thetaLogDenom = math.log(

classCounts.get(label).map(_.toDouble).getOrElse(0.0) + 2.0)

theta(label.toInt)(featureIndex) = math.log(count + 1.0) – thetaLogDenom


// Unpersist now that we are done computing everything


// Construct a model

new SimpleNaiveBayesModel(uid, numClasses, numFeatures, Vectors.dense(pi),

new DenseMatrix(numClasses, theta(0).length, theta.flatten, true))


override def copy(extra: ParamMap) = {




// Simplified Naive Bayes Model

case class SimpleNaiveBayesModel(

override val uid: String,

override val numClasses: Int,

override val numFeatures: Int,

val pi: Vector,

val theta: DenseMatrix) extends

ClassificationModel[Vector, SimpleNaiveBayesModel] {

override def copy(extra: ParamMap) = {



// We have to do some tricks here because we are using Spark’s

// Vector/DenseMatrix calculations – but for your own model don’t feel

// limited to Spark’s native ones.

val negThetaArray = theta.values.map(v => math.log(1.0 – math.exp(v)))

val negTheta = new DenseMatrix(numClasses, numFeatures, negThetaArray, true)

val thetaMinusNegThetaArray = theta.values.zip(negThetaArray)

.map{case (v, nv) => v – nv}

val thetaMinusNegTheta = new DenseMatrix(

numClasses, numFeatures, thetaMinusNegThetaArray, true)

val onesVec = Vectors.dense(Array.fill(theta.numCols)(1.0))

val negThetaSum: Array[Double] = negTheta.multiply(onesVec).toArray

// Here is the prediciton functionality you need to implement – for ClassificationModels

// transform automatically wraps this – but if you might benefit from broadcasting your model or

// other optimizations you can also override transform.

def predictRaw(features: Vector): Vector = {

// Toy implementation – use BLAS or similar instead

// the summing of the three vectors but the functionality isn’t exposed.


.map{case (x, y) => x + y}.zip(negThetaSum).map{case (x, y) => x + y}





