赞
踩
UDF :输入一行,返回一个结果 ;一对一;比如定义一个函数,功能是输入一个生日,返回一个对应的年龄
UDTF:输入一行,返回多行(hive);一对多;sparkSQL中没有UDTF,spark中用flatMap即可实现该功能
UDAF:输入多行,返回一行;aggregate(聚合),count,sum这些是spark自带的聚合函数,但是复杂的业务,要自己定义
object UdfTest { def main(args: Array[String]): Unit = { val spark = SparkSession.builder() .appName("Test") .master("local[*]") //设置本地模式运行 .getOrCreate() //使用自定义udf函数 val test = udf((str:String)=>{str.replace("1","2")}) //测试数据,创建一个Dataset val tmp_da = List(Data(Array("1测试数据"), Array("1测试数据"))) .toDF("1test1","1test2") val testData = tmp_da.withColumn("test3",test(col("1test2"))) import spark.implicits._ val result: DataFrame = testData.select($"test3") //展示结果 result.show() spark.stop() } }
1.继承import org.apache.spark.sql.expressions.Aggregator,定义泛型
IN:输入的数据类型
BUF:缓冲区的数据类型
OUT:输出的数据类型
2.重写方法
3.注册自定义聚合函数
val udaf_test = new 对象 extends UserDefinedAggregateFunction
4. 直接可以使用udaf_test进行使用
val testData = tmp_da.agg(udaf_test(col("1test2")) as "")
在Spark中,自定义聚合函数要继承UserDefinedAggregateFunction这个抽象类,重写里面的方法。
先来看一下这个类的源码:
abstract class UserDefinedAggregateFunction extends Serializable { /** * A `StructType` represents data types of input arguments of this aggregate function. * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments * with type of `DoubleType` and `LongType`, the returned `StructType` will look like * * ``` * new StructType() * .add("doubleInput", DoubleType) * .add("longInput", LongType) * ``` * * The name of a field of this `StructType` is only used to identify the corresponding * input argument. Users can choose names to identify the input arguments. * * @since 1.5.0 */ def inputSchema: StructType /** * A `StructType` represents data types of values in the aggregation buffer. * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values * (i.e. two intermediate values) with type of `DoubleType` and `LongType`, * the returned `StructType` will look like * * ``` * new StructType() * .add("doubleInput", DoubleType) * .add("longInput", LongType) * ``` * * The name of a field of this `StructType` is only used to identify the corresponding * buffer value. Users can choose names to identify the input arguments. * * @since 1.5.0 */ def bufferSchema: StructType /** * The `DataType` of the returned value of this [[UserDefinedAggregateFunction]]. * * @since 1.5.0 */ def dataType: DataType /** * Returns true iff this function is deterministic, i.e. given the same input, * always return the same output. * * @since 1.5.0 */ def deterministic: Boolean /** * Initializes the given aggregation buffer, i.e. the zero value of the aggregation buffer. * * The contract should be that applying the merge function on two initial buffers should just * return the initial buffer itself, i.e. * `merge(initialBuffer, initialBuffer)` should equal `initialBuffer`. * * @since 1.5.0 */ def initialize(buffer: MutableAggregationBuffer): Unit /** * Updates the given aggregation buffer `buffer` with new input data from `input`. * * This is called once per input row. * * @since 1.5.0 */ def update(buffer: MutableAggregationBuffer, input: Row): Unit /** * Merges two aggregation buffers and stores the updated buffer values back to `buffer1`. * * This is called when we merge two partially aggregated data together. * * @since 1.5.0 */ def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit /** * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given * aggregation buffer. * * @since 1.5.0 */ def evaluate(buffer: Row): Any
可以看出,继承这个类之后,要重写里面的八个方法。
每个方法代表的含义是:
inputSchema:输入数据的类型
bufferSchema:产生中间结果的数据类型
dataType:最终返回的结果类型
deterministic:确保一致性(输入什么类型的数据就返回什么类型的数据),一般用true
initialize:指定初始值
update:每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区中的运算)
merge:全局聚合(将每个分区的结果进行聚合)
evaluate:计算最终的结果
import org.apache.spark.rdd.RDD import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.{DataFrame, Encoder, Encoders, SparkSession, functions} import org.apache.spark.sql.functions._ object SchedulerDemo { // 计算输出结果的数据来源 case class Buff(var total: Long, var count: Long) def main(args: Array[String]): Unit = { val spark = SparkSession.builder() .appName("Test") .master("local[*]") //设置本地模式运行 .getOrCreate() import spark.implicits._ //测试数据,创建一个Dataset val rdd: RDD[(String, Int)] = spark.sparkContext.makeRDD(List(("zhangsan", 20), ("lisi", 30), ("wangwu", 40))) val df: DataFrame = rdd.toDF("username", "age") df.createOrReplaceTempView("test") /** * 自定义聚合函数类:计算年龄的平均值 * 1.继承import org.apache.spark.sql.expressions.Aggregator,定义泛型 * IN:输入的数据类型 * BUF:缓冲区的数据类型 * OUT:输出的数据类型 * 2.重写方法 */ //使用自定义聚合函数 val demo = new Aggregator[Long,Buff,Long] { // 初始值 override def zero: Buff = Buff(0L, 0L) // 缓冲区数据计算 override def reduce(buff: Buff, in: Long): Buff = { buff.total += in buff.count += 1 buff } // 合并缓冲区 override def merge(buff1: Buff, buff2: Buff): Buff = { buff1.total += buff2.total buff1.count += buff2.count buff1 } // 输出值计算 override def finish(buff: Buff): Long = { buff.total / buff.count } // 缓冲区编码设置 Encoders.product是进行scala元组和case类转换的编码器 override def bufferEncoder: Encoder[Buff] = Encoders.product // 输出编码 override def outputEncoder: Encoder[Long] = Encoders.scalaLong } spark.udf.register("MyAgeAvg",functions.udaf(demo)) val result = spark.sql("select MyAgeAvg(age) from test") //展示结果 result.show() spark.stop() } }
import org.apache.spark.rdd.RDD import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.{DataFrame, Encoder, Encoders, SparkSession, functions} import org.apache.spark.sql.functions._ object SchedulerDemo { // 计算输出结果的数据来源 case class Buff(var total: Long, var count: Long) def main(args: Array[String]): Unit = { val spark = SparkSession.builder() .appName("Test") .master("local[*]") //设置本地模式运行 .getOrCreate() import spark.implicits._ //测试数据,创建一个Dataset val rdd: RDD[(String, Int)] = spark.sparkContext.makeRDD(List(("zhangsan", 20), ("lisi", 30), ("wangwu", 40))) val df: DataFrame = rdd.toDF("username", "age") /** * 自定义聚合函数类:计算年龄的平均值 * 1.继承import org.apache.spark.sql.expressions.Aggregator,定义泛型 * IN:输入的数据类型 * BUF:缓冲区的数据类型 * OUT:输出的数据类型 * 2.重写方法 */ //使用自定义聚合函数 val demo = new Aggregator[Long,Buff,Long] { // 初始值 override def zero: Buff = Buff(0L, 0L) // 缓冲区数据计算 override def reduce(buff: Buff, in: Long): Buff = { buff.total += in buff.count += 1 buff } // 合并缓冲区 override def merge(buff1: Buff, buff2: Buff): Buff = { buff1.total += buff2.total buff1.count += buff2.count buff1 } // 输出值计算 override def finish(buff: Buff): Long = { buff.total / buff.count } // 缓冲区编码设置 Encoders.product是进行scala元组和case类转换的编码器 override def bufferEncoder: Encoder[Buff] = Encoders.product // 输出编码 override def outputEncoder: Encoder[Long] = Encoders.scalaLong } val agg_1 = demo.toColumn.name("agg_column") //展示结果 val result = df.select($"age").as[Long].select(agg_1) result.show() spark.stop() } }
spark版本3.0之后 使用 Aggregator的方式为强类型,而3.0之前继承 UserDefinedAggregateFunction 属于弱类型的方式实现
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。