当前位置:   article > 正文

Spark SQL自定义函数UDF、UDAF聚合函数以及开窗函数的使用_scala中 spark.udf.register的作用

scala中 spark.udf.register的作用

一、UDF的使用

1、Spark SQL自定义函数就是可以通过scala写一个类,然后在SparkSession上注册一个函数并对应这个类,然后在SQL语句中就可以使用该函数了,首先定义UDF函数,那么创建一个SqlUdf类,并且继承UDF1或UDF2等等,UDF后边的数字表示了当调用函数时会传入进来有几个参数,最后一个R则表示返回的数据类型,如下图所示:

2、这里选择继承UDF2,如下代码所示:

  1. package com.udf
  2. import org.apache.spark.sql.api.java.UDF2
  3. class SqlUDF extends UDF2[String,Integer,String] {
  4. override def call(t1: String, t2: Integer): String = {
  5. t1+"_udf_test_"+t2
  6. }
  7. }

3、然后在SparkSession生成的对象上通过sparkSession.udf.register进行注册,如下代码所示:

  1. val conf=new SparkConf().setAppName("AppUdf").setMaster("local")
  2. val sparkSession=SparkSession.builder().config(conf).getOrCreate()
  3. //指定函数名为:splicing_t1_t2 此函数名只有通过udf.register注册过之后才能够被使用,第二个参数是继承与UDF的类
  4. //第三个参数是返回类型
  5. sparkSession.udf.register("splicing_t1_t2",new SqlUDF,DataTypes.StringType)

4、生成模拟数据,并注册一个临时表,如下代码所示:

  1. var rows=Seq[Row]()
  2. val random=new Random()
  3. for(i <- 0 until 10){
  4. val name="name"+i
  5. val age=random.nextInt(30)%15+15
  6. val row=Row(name,age)
  7. rows +:=row
  8. }
  9. val rowsRDD=sparkSession.sparkContext.parallelize(rows)
  10. val schema=DataTypes.createStructType(Array[StructField](
  11. DataTypes.createStructField("name",DataTypes.StringType,true),
  12. DataTypes.createStructField("age",DataTypes.IntegerType,true))
  13. )
  14. val df=sparkSession.createDataFrame(rowsRDD,schema)
  15. df.createOrReplaceTempView("person")
  16. df.show()

输出 结果如下图所示:

5、在sql语句中使用自定义函数splicing_t1_t2,然后将函数的返回结果定义一个别名name_age,如下代码所示:

  1. val sql="SELECT name,age,splicing_t1_t2(name,age) name_age FROM person"
  2. sparkSession.sql(sql).show()

输出结果如下:

6、由此可以看到在自定义的UDF类中,想如何操作都可以了,完整代码如下;

  1. package com.udf
  2. import org.apache.spark.SparkConf
  3. import org.apache.spark.sql.{Row, SparkSession}
  4. import org.apache.spark.sql.types.{DataTypes, StructField}
  5. import scala.util.Random
  6. object AppUdf {
  7. def main(args:Array[String]):Unit={
  8. val conf=new SparkConf().setAppName("AppUdf").setMaster("local")
  9. val sparkSession=SparkSession.builder().config(conf).getOrCreate()
  10. //指定函数名为:splicing_t1_t2 此函数名只有通过udf.register注册过之后才能够被使用,第二个参数是继承与UDF的类
  11. //第三个参数是返回类型
  12. sparkSession.udf.register("splicing_t1_t2",new SqlUDF,DataTypes.StringType)
  13. var rows=Seq[Row]()
  14. val random=new Random()
  15. for(i <- 0 until 10){
  16. val name="name"+i
  17. val age=random.nextInt(30)%15+15
  18. val row=Row(name,age)
  19. rows +:=row
  20. }
  21. val rowsRDD=sparkSession.sparkContext.parallelize(rows)
  22. val schema=DataTypes.createStructType(Array[StructField](
  23. DataTypes.createStructField("name",DataTypes.StringType,true),
  24. DataTypes.createStructField("age",DataTypes.IntegerType,true))
  25. )
  26. val df=sparkSession.createDataFrame(rowsRDD,schema)
  27. df.createOrReplaceTempView("person")
  28. val sql="SELECT name,age,splicing_t1_t2(name,age) name_age FROM person"
  29. sparkSession.sql(sql).show()
  30. sparkSession.close()
  31. }
  32. }

二、无类型的用户自定于聚合函数:UserDefinedAggregateFunction

1、它是一个接口,需要实现的方法有:

  1. class AvgAge extends UserDefinedAggregateFunction {
  2. //设置输入数据的类型,指定输入数据的字段与类型,它与在生成表时创建字段时的方法相同
  3. override def inputSchema: StructType = ???
  4. //指定缓冲数据的字段与类型
  5. override def bufferSchema: StructType = ???
  6. //指定数据的返回类型
  7. override def dataType: DataType = ???
  8. //指定是否是确定性,对输入数据进行一致性检验,是一个布尔值,当为true时,表示对于同样的输入会得到同样的输出
  9. override def deterministic: Boolean = ???
  10. //initialize用户初始化缓存数据
  11. override def initialize(buffer: MutableAggregationBuffer): Unit = ???
  12. //当有新的输入数据时,update就会更新缓存变量
  13. override def update(buffer: MutableAggregationBuffer, input: Row): Unit = ???
  14. //将更新的缓存变量进行合并,有可能每个缓存变量的值都不在一个节点上,最终是要将所有节点的值进行合并才行
  15. override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = ???
  16. //一个计算方法,用于计算我们的最终结果
  17. override def evaluate(buffer: Row): Any = ???
  18. }

这是一个计算平均年龄的自定义聚合函数,实现代码如下所示:

  1. package com.udf
  2. import java.math.BigDecimal
  3. import org.apache.spark.sql.Row
  4. import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
  5. import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}
  6. /**
  7. * 用于计算平均年龄的聚合函数
  8. */
  9. class AvgAge extends UserDefinedAggregateFunction {
  10. /**
  11. * 设置输入数据的类型,指定输入数据的字段与类型,它与在生成表时创建字段时的方法相同
  12. * 比如计算平均年龄,输入的是age这一列的数据,注意此处的age名称可以随意命名
  13. * @return
  14. */
  15. override def inputSchema: StructType = DataTypes.createStructType(Array[StructField](DataTypes.createStructField("age",DataTypes.IntegerType,true)))
  16. /**
  17. * 指定缓冲数据的字段与类型,相当于中间变量
  18. * 由于要计算平均值,首先要计算出总和与个数才能计算平均值,因此需要进来一个值就要累加并计数才能计算出平均值
  19. * 所以要定义两个变量作为累加和以及计数的变量
  20. * @return
  21. */
  22. override def bufferSchema: StructType = DataTypes.createStructType(Array[StructField](
  23. DataTypes.createStructField("sum",DataTypes.DoubleType,true),
  24. DataTypes.createStructField("count",DataTypes.IntegerType,true)
  25. ))
  26. //指定数据的返回类型,由于平均值是double类型,因此定义DoubleType
  27. override def dataType: DataType = DataTypes.DoubleType
  28. /**
  29. * 设置该函数是否为幂等函数
  30. * 幂等函数:即只要输入的数据相同,结果一定相同
  31. * true表示是幂等函数,false表示不是
  32. * @return
  33. */
  34. override def deterministic: Boolean = true
  35. /**
  36. * initialize用于初始化缓存变量的值,也就是初始化bufferSchema函数中定义的两个变量的值sum,count
  37. * 其中buffer(0)就表示sum值,buffer(1)就表示count的值,如果还有第3个,则使用buffer(3)表示
  38. * @param buffer
  39. */
  40. override def initialize(buffer: MutableAggregationBuffer): Unit = {
  41. buffer.update(0,0.0) //或使用buffer(0)=0.0
  42. buffer.update(1,0) //或使用buffer(1)=0
  43. }
  44. /**
  45. * 当有一行数据进来时就会调用update一次,有多少行就会调用多少次,input就表示在调用自定义函数中有多少个参数,最终会将
  46. * 这些参数生成一个Row对象,在使用时可以通过input.getString或inpu.getLong等方式获得对应的值
  47. * 缓冲中的变量sum,count使用buffer(0)或buffer.getDouble(0)的方式获取到
  48. * @param buffer
  49. * @param input
  50. */
  51. override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
  52. val sum=buffer.getDouble(0)
  53. val count=buffer.getInt(1)
  54. buffer.update(0,sum+input.getInt(0).toDouble)
  55. buffer.update(1,count+1)
  56. }
  57. /**
  58. * 将更新的缓存变量进行合并,有可能每个缓存变量的值都不在一个节点上,最终是要将所有节点的值进行合并才行
  59. * 其中buffer1是本节点上的缓存变量,而buffer2是从其他节点上过来的缓存变量然后转换为一个Row对象,然后将buffer2
  60. * 中的数据合并到buffer1中去即可
  61. * @param buffer1
  62. * @param buffer2
  63. */
  64. override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
  65. val sum1=buffer1.getDouble(0)
  66. val count1=buffer1.getInt(1)
  67. val sum2=buffer2.getDouble(0)
  68. val count2=buffer2.getInt(1)
  69. buffer1.update(0,sum1+sum2)
  70. buffer1.update(1,count1+count2)
  71. }
  72. /**
  73. * 一个计算方法,用于计算我们的最终结果,也就相当于返回值
  74. * @param buffer
  75. * @return
  76. */
  77. override def evaluate(buffer: Row): Any = {
  78. val bd = new BigDecimal(buffer.getDouble(0)/buffer.getInt(1).toDouble)
  79. bd.setScale(2, BigDecimal.ROUND_HALF_UP).doubleValue//保留两位小数
  80. }
  81. }

2、注册该类,并指定到一个自定义函数中,如下图所示:

3、在表中加一列字段id,通过GROUP BY进行分组计算,如

4、在sql语句中使用group_age_avg,如下图所示:

输出结果如下图所示:

5、完整代码如下:

  1. package com.udf
  2. import org.apache.spark.SparkConf
  3. import org.apache.spark.sql.{Row, SparkSession}
  4. import org.apache.spark.sql.types.{DataTypes, StructField}
  5. import scala.util.Random
  6. object AppUdf {
  7. def main(args:Array[String]):Unit={
  8. val conf=new SparkConf().setAppName("AppUdf").setMaster("local")
  9. val sparkSession=SparkSession.builder().config(conf).getOrCreate()
  10. //指定函数名为:splicing_t1_t2 此函数名只有通过udf.register注册过之后才能够被使用,第二个参数是继承与UDF的类
  11. //第三个参数是返回类型
  12. sparkSession.udf.register("splicing_t1_t2",new SqlUDF,DataTypes.StringType)
  13. //UDAF不用设置返回类型,因此使用两个参数即可
  14. sparkSession.udf.register("group_age_avg",new AvgAge)
  15. var rows=Seq[Row]()
  16. val random=new Random()
  17. for(i <- 0 until 10){
  18. val name="name"+i
  19. val age=random.nextInt(30)%15+15
  20. val row=Row(random.nextInt(2),name,age)
  21. rows +:=row
  22. }
  23. val rowsRDD=sparkSession.sparkContext.parallelize(rows)
  24. val schema=DataTypes.createStructType(Array[StructField](
  25. DataTypes.createStructField("id",DataTypes.IntegerType,true),
  26. DataTypes.createStructField("name",DataTypes.StringType,true),
  27. DataTypes.createStructField("age",DataTypes.IntegerType,true))
  28. )
  29. val df=sparkSession.createDataFrame(rowsRDD,schema)
  30. df.createOrReplaceTempView("person")
  31. df.show()
  32. val sql="SELECT id, group_age_avg(age) avg_age FROM person GROUP BY id"
  33. sparkSession.sql(sql).show()
  34. sparkSession.close()
  35. }
  36. }

三、类型安全的用户自定于聚合函数:Aggregator

1、它是一个接口,需要继承与Aggregator,而Aggregator有3个参数,分别是IN,BUF,OUT,IN表示输入的值是什么,可以是一个自定类对象包含多个值,也可以是单个值,BUF就是需要用来缓存值使用的,如果需要缓存多个值也需要定义一个对象,而返回值也可以是一个对象返回多个值,需要实现的方法有:

  1. package com.udf
  2. import org.apache.spark.sql.Encoder
  3. import org.apache.spark.sql.expressions.Aggregator
  4. case class DataBuf(var sum:Double,var count:Int)
  5. object AvgAgeAggregator extends Aggregator[Int,DataBuf,Double]{
  6. /**
  7. * 相当于UserDefinedAggregateFunction中的initialize函数,用于初始化DataBuf对象的值,此DataBuf是自定义类型的
  8. * @return
  9. */
  10. override def zero: DataBuf = ???
  11. /**
  12. * reduce函数相当于UserDefinedAggregateFunction中的update函数,当有新的数据a时,更新中间数据b
  13. * @param b
  14. * @param a
  15. * @return
  16. */
  17. override def reduce(b: DataBuf, a: Int): DataBuf = ???
  18. /**
  19. * merge函数相当于UserDefinedAggregateFunction中的merge函数,对两个值进行 合并,
  20. * 因为有可能每个缓存变量的值都不在一个节点上,最终是要将所有节点的值进行合并才行,将b2中的值合并到b1中
  21. * @param b1
  22. * @param b2
  23. * @return
  24. */
  25. override def merge(b1: DataBuf, b2: DataBuf): DataBuf = ???
  26. /**
  27. * finish相当于UserDefinedAggregateFunction中的evaluate,是一个计算方法,用于计算我们的最终结果,也就相当于返回值
  28. * 返回值可以是一个对象
  29. * @param reduction
  30. * @return
  31. */
  32. override def finish(reduction: DataBuf): Double = ???
  33. /**
  34. * 缓冲数据编码方式
  35. * @return
  36. */
  37. override def bufferEncoder: Encoder[DataBuf] = ???
  38. /**
  39. * 最终数据输出编码方式
  40. * @return
  41. */
  42. override def outputEncoder: Encoder[Double] = ???
  43. }

2、具体实现如下代码所示:

  1. package com.udf
  2. import java.math.BigDecimal
  3. import org.apache.spark.sql.{Encoder, Encoders}
  4. import org.apache.spark.sql.expressions.Aggregator
  5. case class DataBuf(var sum:Double,var count:Int)
  6. object AvgAgeAggregator extends Aggregator[Int,DataBuf,Double]{
  7. /**
  8. * 相当于UserDefinedAggregateFunction中的initialize函数,用于初始化DataBuf对象的值,此DataBuf是自定义类型的
  9. * @return
  10. */
  11. override def zero: DataBuf = DataBuf(0.0,0)
  12. /**
  13. * reduce函数相当于UserDefinedAggregateFunction中的update函数,当有新的数据a时,更新中间数据b
  14. * @param b
  15. * @param a
  16. * @return
  17. */
  18. override def reduce(b: DataBuf, a: Int): DataBuf = {
  19. b.count+=1
  20. b.sum+=a.toDouble
  21. b
  22. }
  23. /**
  24. * merge函数相当于UserDefinedAggregateFunction中的merge函数,对两个值进行 合并,
  25. * 因为有可能每个缓存变量的值都不在一个节点上,最终是要将所有节点的值进行合并才行,将b2中的值合并到b1中
  26. * @param b1
  27. * @param b2
  28. * @return
  29. */
  30. override def merge(b1: DataBuf, b2: DataBuf): DataBuf = {
  31. b1.sum+=b2.sum
  32. b1.count+=b2.count
  33. b1
  34. }
  35. /**
  36. * finish相当于UserDefinedAggregateFunction中的evaluate,是一个计算方法,用于计算我们的最终结果,也就相当于返回值
  37. * 返回值可以是一个对象
  38. * @param reduction
  39. * @return
  40. */
  41. override def finish(reduction: DataBuf): Double = {
  42. val bd = new BigDecimal(reduction.sum/reduction.count.toDouble)
  43. bd.setScale(2, BigDecimal.ROUND_HALF_UP).doubleValue//保留两位小数
  44. }
  45. /**
  46. * 缓冲数据编码方式,如果Encoder中指定的类型时对象,则设置为product,如果是具体的类型,则需设置为具体的类型
  47. * @return
  48. */
  49. override def bufferEncoder: Encoder[DataBuf] = Encoders.product
  50. /**
  51. * 最终数据输出编码方式,如果Encoder中指定的类型,则设置为具体的类型,比如Double则设置为scalaDouble
  52. * @return
  53. */
  54. override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
  55. }

3、而使用此聚合函数就不能通过注册函数来使用了,需要通过Dataset对象的select来使用,如下图所示:

执行结果如下图所示:

因此无类型的用户自定于聚合函数:UserDefinedAggregateFunction和类型安全的用户自定于聚合函数:Aggregator之间的区别是

(1)UserDefinedAggregateFunction不能够带类型而Aggregator是可以带类型的。

(2)使用方法不同UserDefinedAggregateFunction通过注册可以在DataFram的sql语句中使用,而Aggregator必须是在Dataset上使用。

四、开窗函数的使用

1、在Spark 1.5.x版本以后,在Spark SQL和DataFrame中引入了开窗函数,其中比较常用的开窗函数就是row_number该函数的作用是根据表中字段进行分组,然后根据表中的字段排序;其实就是根据其排序顺序,给组中的每条记录添加一个序号;且每组的序号都是从1开始,可利用它的这个特性进行分组取top-n。它是放在select子句中的,其格式为:

ROW_NUMBER() OVER (PARTITION BY area ORDER BY click_count DESC) rank 

首先可以,在SELECT查询时,使用row_number()函数,其次row_number()函数后面先跟上OVER关键字,然后括号中,是PARTITION BY,也就是说根据哪个字段进行分组,其次是可以用ORDER BY进行组内排序, 然后row_number()就可以给每个组内的行,一个组内行号,然后rank就是每一组的行号

2、使用方法的sql语句为:

SELECT id,name,age,row_number() OVER (PARTITION BY id ORDER BY age) rank FROM person ORDER BY id desc,rank desc

意思是在sql语句中加一个rank字段,该字段记录了以id为分组,在组内按照age升序排序,并记录行号,最后先按照id降序排序,如果id相同则按照rank降序排序

3、代码如下:

  1. package com.udf
  2. import org.apache.spark.SparkConf
  3. import org.apache.spark.sql.{Row, SparkSession}
  4. import org.apache.spark.sql.types.{DataTypes, StructField}
  5. import scala.util.Random
  6. object AppUdf {
  7. def main(args:Array[String]):Unit={
  8. val conf=new SparkConf().setAppName("AppUdf").setMaster("local")
  9. val sparkSession=SparkSession.builder().config(conf).getOrCreate()
  10. //指定函数名为:splicing_t1_t2 此函数名只有通过udf.register注册过之后才能够被使用,第二个参数是继承与UDF的类
  11. //第三个参数是返回类型
  12. sparkSession.udf.register("splicing_t1_t2",new SqlUDF,DataTypes.StringType)
  13. //UDAF不用设置返回类型,因此使用两个参数即可
  14. sparkSession.udf.register("group_age_avg",new AvgAge)
  15. var rows=Seq[Row]()
  16. val random=new Random()
  17. for(i <- 0 until 10){
  18. val name="name"+i
  19. val age=random.nextInt(30)%15+15
  20. val row=Row(random.nextInt(2),name,age)
  21. rows +:=row
  22. }
  23. val rowsRDD=sparkSession.sparkContext.parallelize(rows)
  24. val schema=DataTypes.createStructType(Array[StructField](
  25. DataTypes.createStructField("id",DataTypes.IntegerType,true),
  26. DataTypes.createStructField("name",DataTypes.StringType,true),
  27. DataTypes.createStructField("age",DataTypes.IntegerType,true))
  28. )
  29. val df=sparkSession.createDataFrame(rowsRDD,schema)
  30. df.createOrReplaceTempView("person")
  31. df.show()
  32. val sql="SELECT id,name,age,row_number() OVER (PARTITION BY id ORDER BY age) rank FROM person ORDER BY id desc,rank desc"
  33. sparkSession.sql(sql).show()
  34. sparkSession.close()
  35. }
  36. }

输出结果如下:

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/603694
推荐阅读
相关标签
  

闽ICP备14008679号