当前位置:   article > 正文

Spark sql 自定义函数(UDF、UDTF、UDAF)Spark版本 3.0之前 AND 3.0之后_sparksql中没有原生自定义 udtf 函数的方法,用什么方式可以完成 udtf 拆分数据的

sparksql中没有原生自定义 udtf 函数的方法,用什么方式可以完成 udtf 拆分数据的

Spark sql 自定义函数(UDF、UDTF、UDAF)

概念

UDF :输入一行,返回一个结果 ;一对一;比如定义一个函数,功能是输入一个生日,返回一个对应的年龄
UDTF:输入一行,返回多行(hive);一对多;sparkSQL中没有UDTF,spark中用flatMap即可实现该功能
UDAF:输入多行,返回一行;aggregate(聚合),count,sum这些是spark自带的聚合函数,但是复杂的业务,要自己定义
  • 1
  • 2
  • 3

案例(Spark < 3.0)

UDF

object UdfTest {

  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder()
      .appName("Test")
      .master("local[*]")       //设置本地模式运行
      .getOrCreate()

    //使用自定义udf函数
    val test = udf((str:String)=>{str.replace("1","2")})
    
	//测试数据,创建一个Dataset
      val tmp_da = List(Data(Array("1测试数据"), Array("1测试数据")))
          .toDF("1test1","1test2")
    val testData = tmp_da.withColumn("test3",test(col("1test2")))
    import spark.implicits._
    val result: DataFrame = testData.select($"test3")
    //展示结果
    result.show()
    spark.stop()
  }
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

UDAF

1.继承import org.apache.spark.sql.expressions.Aggregator,定义泛型
IN:输入的数据类型
BUF:缓冲区的数据类型
OUT:输出的数据类型
2.重写方法
3.注册自定义聚合函数
val udaf_test = new 对象 extends UserDefinedAggregateFunction
4. 直接可以使用udaf_test进行使用 
val testData = tmp_da.agg(udaf_test(col("1test2")) as "")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

在Spark中,自定义聚合函数要继承UserDefinedAggregateFunction这个抽象类,重写里面的方法。
先来看一下这个类的源码:

abstract class UserDefinedAggregateFunction extends Serializable {

  /**
   * A `StructType` represents data types of input arguments of this aggregate function.
   * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments
   * with type of `DoubleType` and `LongType`, the returned `StructType` will look like
   *
   * ```
   *   new StructType()
   *    .add("doubleInput", DoubleType)
   *    .add("longInput", LongType)
   * ```
   *
   * The name of a field of this `StructType` is only used to identify the corresponding
   * input argument. Users can choose names to identify the input arguments.
   *
   * @since 1.5.0
   */
  def inputSchema: StructType

  /**
   * A `StructType` represents data types of values in the aggregation buffer.
   * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values
   * (i.e. two intermediate values) with type of `DoubleType` and `LongType`,
   * the returned `StructType` will look like
   *
   * ```
   *   new StructType()
   *    .add("doubleInput", DoubleType)
   *    .add("longInput", LongType)
   * ```
   *
   * The name of a field of this `StructType` is only used to identify the corresponding
   * buffer value. Users can choose names to identify the input arguments.
   *
   * @since 1.5.0
   */
  def bufferSchema: StructType

  /**
   * The `DataType` of the returned value of this [[UserDefinedAggregateFunction]].
   *
   * @since 1.5.0
   */
  def dataType: DataType

  /**
   * Returns true iff this function is deterministic, i.e. given the same input,
   * always return the same output.
   *
   * @since 1.5.0
   */
  def deterministic: Boolean

  /**
   * Initializes the given aggregation buffer, i.e. the zero value of the aggregation buffer.
   *
   * The contract should be that applying the merge function on two initial buffers should just
   * return the initial buffer itself, i.e.
   * `merge(initialBuffer, initialBuffer)` should equal `initialBuffer`.
   *
   * @since 1.5.0
   */
  def initialize(buffer: MutableAggregationBuffer): Unit

  /**
   * Updates the given aggregation buffer `buffer` with new input data from `input`.
   *
   * This is called once per input row.
   *
   * @since 1.5.0
   */
  def update(buffer: MutableAggregationBuffer, input: Row): Unit

  /**
   * Merges two aggregation buffers and stores the updated buffer values back to `buffer1`.
   *
   * This is called when we merge two partially aggregated data together.
   *
   * @since 1.5.0
   */
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit

  /**
   * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given
   * aggregation buffer.
   *
   * @since 1.5.0
   */
  def evaluate(buffer: Row): Any
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90

可以看出,继承这个类之后,要重写里面的八个方法。
每个方法代表的含义是:

inputSchema:输入数据的类型
bufferSchema:产生中间结果的数据类型
dataType:最终返回的结果类型
deterministic:确保一致性(输入什么类型的数据就返回什么类型的数据),一般用true
initialize:指定初始值
update:每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区中的运算)
merge:全局聚合(将每个分区的结果进行聚合)
evaluate:计算最终的结果
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

案例(Spark>=3.0)

UDAF(自定义聚合函数)(注册函数)

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{DataFrame, Encoder, Encoders, SparkSession, functions}
import org.apache.spark.sql.functions._

object SchedulerDemo {
  // 计算输出结果的数据来源
  case class Buff(var total: Long, var count: Long)

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .appName("Test")
      .master("local[*]") //设置本地模式运行
      .getOrCreate()
    import spark.implicits._

    //测试数据,创建一个Dataset
    val rdd: RDD[(String, Int)] = spark.sparkContext.makeRDD(List(("zhangsan", 20), ("lisi", 30), ("wangwu", 40)))
    val df: DataFrame = rdd.toDF("username", "age")
    df.createOrReplaceTempView("test")
    /**
     * 自定义聚合函数类:计算年龄的平均值
     * 1.继承import org.apache.spark.sql.expressions.Aggregator,定义泛型
     * IN:输入的数据类型
     * BUF:缓冲区的数据类型
     * OUT:输出的数据类型
     * 2.重写方法
     */
    //使用自定义聚合函数
    val demo = new Aggregator[Long,Buff,Long] {
      // 初始值
      override def zero: Buff =  Buff(0L, 0L)
      // 缓冲区数据计算
      override def reduce(buff: Buff, in: Long): Buff = {
        buff.total += in
        buff.count += 1
        buff
      }
      // 合并缓冲区
      override def merge(buff1: Buff, buff2: Buff): Buff = {
        buff1.total += buff2.total
        buff1.count += buff2.count
        buff1
      }
      // 输出值计算
      override def finish(buff: Buff): Long = {
        buff.total / buff.count
      }
      // 缓冲区编码设置  Encoders.product是进行scala元组和case类转换的编码器
      override def bufferEncoder: Encoder[Buff] = Encoders.product
      // 输出编码
      override def outputEncoder: Encoder[Long] = Encoders.scalaLong
    }
    spark.udf.register("MyAgeAvg",functions.udaf(demo))
    val result = spark.sql("select MyAgeAvg(age) from test")
    //展示结果
    result.show()
    spark.stop()
  }
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61

UDAF(自定义聚合函数)(DSL)

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{DataFrame, Encoder, Encoders, SparkSession, functions}
import org.apache.spark.sql.functions._

object SchedulerDemo {
  // 计算输出结果的数据来源
  case class Buff(var total: Long, var count: Long)

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .appName("Test")
      .master("local[*]") //设置本地模式运行
      .getOrCreate()
    import spark.implicits._
    //测试数据,创建一个Dataset
    val rdd: RDD[(String, Int)] = spark.sparkContext.makeRDD(List(("zhangsan", 20), ("lisi", 30), ("wangwu", 40)))
    val df: DataFrame = rdd.toDF("username", "age")
    /**
     * 自定义聚合函数类:计算年龄的平均值
     * 1.继承import org.apache.spark.sql.expressions.Aggregator,定义泛型
     * IN:输入的数据类型
     * BUF:缓冲区的数据类型
     * OUT:输出的数据类型
     * 2.重写方法
     */
    //使用自定义聚合函数
    val demo = new Aggregator[Long,Buff,Long] {
      // 初始值
      override def zero: Buff =  Buff(0L, 0L)
      // 缓冲区数据计算
      override def reduce(buff: Buff, in: Long): Buff = {
        buff.total += in
        buff.count += 1
        buff
      }
      // 合并缓冲区
      override def merge(buff1: Buff, buff2: Buff): Buff = {
        buff1.total += buff2.total
        buff1.count += buff2.count
        buff1
      }
      // 输出值计算
      override def finish(buff: Buff): Long = {
        buff.total / buff.count
      }
      // 缓冲区编码设置  Encoders.product是进行scala元组和case类转换的编码器
      override def bufferEncoder: Encoder[Buff] = Encoders.product
      // 输出编码
      override def outputEncoder: Encoder[Long] = Encoders.scalaLong
    }
    val agg_1 = demo.toColumn.name("agg_column")
    //展示结果
    val result = df.select($"age").as[Long].select(agg_1)
    result.show()
    spark.stop()
  }
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59

总结

spark版本3.0之后 使用 Aggregator的方式为强类型,而3.0之前继承 UserDefinedAggregateFunction 属于弱类型的方式实现
  • 1
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/603749
推荐阅读
相关标签
  

闽ICP备14008679号