当前位置:   article > 正文

spark UDAF 自定义聚合函数 UserDefinedAggregateFunction 带条件的去重操作_自定义聚合函数继承 userdefinedaggregatefunction sparksql 字段

自定义聚合函数继承 userdefinedaggregatefunction sparksql 字段去重数量

需求:按餐品分组,并求出无优惠金额的订单数。

  1. package cd.custom.jde.job.udf
  2. import org.apache.spark.sql.Row
  3. import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
  4. import org.apache.spark.sql.types._
  5. /**
  6. * create by roy 2020-02-12
  7. * 去重订单,并判断是否是折扣
  8. */
  9. class CountDistinctAndIf extends UserDefinedAggregateFunction {
  10. override def inputSchema: StructType = {
  11. new StructType().add("orderid", StringType, nullable = true)
  12. .add("price", DoubleType, nullable = true)
  13. }
  14. override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
  15. // println("update==>>>", buffer,input,input.getDouble(1) <= 0) //=1,说是折扣的
  16. if (input.getDouble(1) <= 0) {
  17. //取出新加入的行,并加入缓存区
  18. buffer(0) = (buffer.getSeq[String](0).toSet + input.getString(0)).toSeq
  19. }
  20. }
  21. override def bufferSchema: StructType = {
  22. new StructType().add("items", ArrayType(StringType, true), nullable = true)
  23. // .add("price", DoubleType, nullable = true)
  24. }
  25. //合并数据
  26. override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
  27. // println("merge==>", buffer2)
  28. // if (buffer2 != null && buffer2.size >= 2 && buffer2.get(1) != null && buffer2.get(0) != null && buffer2.getDouble(1) > 0) {
  29. buffer1(0) = (buffer1.getSeq[String](0).toSet ++ buffer2.getSeq[String](0).toSet).toSeq
  30. }
  31. override def initialize(buffer: MutableAggregationBuffer): Unit = {
  32. buffer(0) = Seq[String]()
  33. }
  34. override def deterministic: Boolean = true
  35. override def evaluate(buffer: Row): Any = {
  36. buffer.getSeq[String](0).length
  37. }
  38. override def dataType: DataType = IntegerType
  39. }

实例应用:

  1. package spark.udf
  2. import cd.custom.jde.job.udf.CountDistinctAndIf
  3. import org.apache.log4j.{Level, Logger}
  4. import org.apache.spark.sql.types.{DoubleType, StringType, StructType}
  5. import org.apache.spark.sql.{Row, SparkSession}
  6. object MyOrderTest {
  7. Logger.getRootLogger.setLevel(Level.WARN)
  8. def main(args: Array[String]): Unit = {
  9. val data = Seq(
  10. Row("a", "a100", 0.0, "300"),
  11. Row("a", "a100", 7.0, "300"),
  12. Row("a", "a101", 6.0, "300"),
  13. Row("a", "a101", 5.0, "301"),
  14. Row("a", "a100", 0.0, "300")
  15. )
  16. val schme = new StructType()
  17. .add("storeid", StringType)
  18. .add("orderid", StringType)
  19. .add("yhPrice", DoubleType)
  20. .add("pid", StringType)
  21. val spark = SparkSession.builder().master("local[*]").getOrCreate()
  22. val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schme)
  23. df.show()
  24. df.createOrReplaceTempView("tab_tmp")
  25. val cCountDistinct2 = new CountDistinctAndIf
  26. spark.sqlContext.udf.register("cCountDistinct2", cCountDistinct2)
  27. spark.sql(
  28. """
  29. |select pid,count(1) pid_num,
  30. |sum(if(yhPrice<=0,1,0)) as zk_all_order_num,
  31. |cCountDistinct2(orderid,yhPrice) as zk_order_num
  32. |from tab_tmp group by pid
  33. """.stripMargin).show()
  34. }
  35. /* +---+-------+----------------+------------+
  36. |pid|pid_num|zk_all_order_num|zk_order_num|
  37. +---+-------+----------------+------------+
  38. |300| 4| 2| 1|
  39. |301| 1| 0| 0|
  40. +---+-------+----------------+------------+*/
  41. }

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/603708
推荐阅读
相关标签
  

闽ICP备14008679号