赞
踩
需求:按餐品分组,并求出无优惠金额的订单数。
- package cd.custom.jde.job.udf
-
- import org.apache.spark.sql.Row
- import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
- import org.apache.spark.sql.types._
-
- /**
- * create by roy 2020-02-12
- * 去重订单,并判断是否是折扣
- */
- class CountDistinctAndIf extends UserDefinedAggregateFunction {
-
- override def inputSchema: StructType = {
- new StructType().add("orderid", StringType, nullable = true)
- .add("price", DoubleType, nullable = true)
- }
-
- override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
- // println("update==>>>", buffer,input,input.getDouble(1) <= 0) //=1,说是折扣的
- if (input.getDouble(1) <= 0) {
- //取出新加入的行,并加入缓存区
- buffer(0) = (buffer.getSeq[String](0).toSet + input.getString(0)).toSeq
- }
- }
-
- override def bufferSchema: StructType = {
- new StructType().add("items", ArrayType(StringType, true), nullable = true)
- // .add("price", DoubleType, nullable = true)
- }
-
- //合并数据
- override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
- // println("merge==>", buffer2)
- // if (buffer2 != null && buffer2.size >= 2 && buffer2.get(1) != null && buffer2.get(0) != null && buffer2.getDouble(1) > 0) {
- buffer1(0) = (buffer1.getSeq[String](0).toSet ++ buffer2.getSeq[String](0).toSet).toSeq
- }
-
- override def initialize(buffer: MutableAggregationBuffer): Unit = {
- buffer(0) = Seq[String]()
- }
-
- override def deterministic: Boolean = true
-
- override def evaluate(buffer: Row): Any = {
- buffer.getSeq[String](0).length
- }
-
- override def dataType: DataType = IntegerType
- }
实例应用:
- package spark.udf
-
- import cd.custom.jde.job.udf.CountDistinctAndIf
- import org.apache.log4j.{Level, Logger}
- import org.apache.spark.sql.types.{DoubleType, StringType, StructType}
- import org.apache.spark.sql.{Row, SparkSession}
-
- object MyOrderTest {
-
- Logger.getRootLogger.setLevel(Level.WARN)
-
- def main(args: Array[String]): Unit = {
-
- val data = Seq(
- Row("a", "a100", 0.0, "300"),
- Row("a", "a100", 7.0, "300"),
- Row("a", "a101", 6.0, "300"),
- Row("a", "a101", 5.0, "301"),
- Row("a", "a100", 0.0, "300")
- )
- val schme = new StructType()
- .add("storeid", StringType)
- .add("orderid", StringType)
- .add("yhPrice", DoubleType)
- .add("pid", StringType)
- val spark = SparkSession.builder().master("local[*]").getOrCreate()
- val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schme)
- df.show()
- df.createOrReplaceTempView("tab_tmp")
-
- val cCountDistinct2 = new CountDistinctAndIf
- spark.sqlContext.udf.register("cCountDistinct2", cCountDistinct2)
- spark.sql(
- """
- |select pid,count(1) pid_num,
- |sum(if(yhPrice<=0,1,0)) as zk_all_order_num,
- |cCountDistinct2(orderid,yhPrice) as zk_order_num
- |from tab_tmp group by pid
- """.stripMargin).show()
-
- }
-
- /* +---+-------+----------------+------------+
- |pid|pid_num|zk_all_order_num|zk_order_num|
- +---+-------+----------------+------------+
- |300| 4| 2| 1|
- |301| 1| 0| 0|
- +---+-------+----------------+------------+*/
- }
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。