SparkSQL中的UDF、UDAF、UDTF實現

jim8973發表於2020-11-08

分類

根據輸入輸出之間的關係來分類:

UDF —— 輸入一行,輸出一行
UDAF —— 輸入多行,輸出一行
UDTF —— 輸入一行,輸出多行

UDF函式

1、資料

大狗	三國,水滸,紅樓
二狗	金瓶梅
二條	西遊,唐詩宋詞

2、需求:求出每個人的愛好個數
3、實現

def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder
      .master("local")
      .appName(this.getClass.getSimpleName)
      .getOrCreate()
    import spark.implicits._
    val df = spark.sparkContext.textFile("D:\\ssc\\likes.txt")
        .map(_.split("\t"))
        .map(x=>Likes(x(0),x(1))).toDF()
    df.createOrReplaceTempView("t_team")
    val teamsLengthUTF = spark.udf.register("teams_length",(input:String)=>{
      input.split(",").length
    })
    println("--------------SQL方式----------------")
    spark.sql("select name,teams,teams_length(teams) as teams_length from t_team").show(false)
    println("--------------API方式----------------")
    df.select($"name",$"teams",teamsLengthUTF($"teams").as("teams_length")).show(false)
    spark.stop()
  }
  case class Likes(name:String,teams:String)
}

UDAF

如果自定義UDAF需要繼承UserDefinedAggregateFunction

SparkSql自帶的聚和函式:

 /**
    * 多進一出 udaf
    * @param spark
    */
  def udafWithSum(spark:SparkSession): Unit = {
    val rows = new util.ArrayList[Row]()
    rows.add(Row("Luck",30,"M"))
    rows.add(Row("Jack",60,"M"))
    rows.add(Row("Jim",19,"F"))
    rows.add(Row("Lily",20,"F"))

    val schema = StructType(
      List(
        StructField("name",StringType,false),
        StructField("age",IntegerType,false),
        StructField("sex",StringType,false)
      )
    )

    val df = spark.createDataFrame(rows,schema)

    df.createOrReplaceTempView("user")

    spark.sql("select sex,sum(age) from user group by sex").show(false)
  }

自定義UDAF
需求:男性和女性各自的平均年齡

  //avg = sum/參與計算的個數
  object JimAvgUDAF extends UserDefinedAggregateFunction {
    /**
      * 輸入型別
      * @return
      */
    override def inputSchema: StructType = {
      StructType(
        StructField("nums",DoubleType,true) :: Nil
      )
    }

    /**
      * 定義聚合過程中所處理的資料型別
      * @return
      */
    override def bufferSchema: StructType = {
      StructType(
        StructField("buffer1",DoubleType,true) ::     //年齡總和
          StructField("buffer2",LongType,true) :: Nil //參與計算的人數
      )
    }

    /**
      * 輸出資料型別
      * @return
      */
    override def dataType: DataType = DoubleType

    /**
      * 規定一致性
      * @return
      */
    override def deterministic: Boolean = true

    /**
      * 初始化資料
      * @param buffer
      */
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      buffer.update(0,0.0)
      buffer.update(1,0L)
    }

    /**
      * 分割槽內聚和
      * @param buffer
      * @param input
      */
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
      buffer.update(0,buffer.getDouble(0) + input.getDouble(0))
      buffer.update(1,buffer.getLong(1) + 1)
    }

    /**
      * 全域性聚合
      * @param buffer1
      * @param buffer2
      */
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
      buffer1.update(0,buffer1.getDouble(0) + buffer2.getDouble(0))
      buffer1.update(1,buffer1.getLong(1) + buffer2.getLong(1))
    }

    /**
      * 最終結果
      * @param buffer
      * @return
      */
    override def evaluate(buffer: Row): Any = {
      buffer.getDouble(0) / buffer.getLong(1)
    }
  }

呼叫

/**
    * select sex,jim_avg(age) from user group by sex
    * @param spark
    */
  def udafAvgWithSex(spark:SparkSession): Unit = {
    val rows = new util.ArrayList[Row]()
    rows.add(Row("Luck",30,"M"))
    rows.add(Row("Jack",60,"M"))
    rows.add(Row("Jim",19,"F"))
    rows.add(Row("Lily",20,"F"))

    val schema = StructType(
      List(
        StructField("name",StringType,false),
        StructField("age",IntegerType,false),
        StructField("sex",StringType,false)
      )
    )

    val df = spark.createDataFrame(rows,schema)

    df.createOrReplaceTempView("user")

    spark.udf.register("avg_udaf",JimAvgUDAF)

    spark.sql("select sex,avg_udaf(age) as avg_age from user group by sex").show(false)
  }

UDTF

一進多出

def udtfExplode(spark:SparkSession): Unit = {
    val rows = new util.ArrayList[Row]()
    rows.add(Row("Luck","Java,JavaScript,Scala"))
    rows.add(Row("Jack","History,English,Math"))

    val schema = StructType(
      List(
        StructField("teacher",StringType,false),
        StructField("courses",StringType,false)
      )
    )

    val df = spark.createDataFrame(rows,schema)

//    implicit val encoder = org.apache.spark.sql.Encoders.kryo[(String,String)]

    import spark.implicits._

    val courseDs = df.flatMap(row => {
      val list = new ListBuffer[Course]()
      val courses = row.getString(1).split(",")
      for (course <- courses) {
        list.append(Course(row.getString(0), course))
      }
      list
    })

    courseDs.printSchema()

    courseDs.show(false)
  }

  case class Course(teacher:String,course: String)

相關文章