当前位置:   article > 正文

Spark_udf_udaf_java aggregator

java aggregator

弱类型
package com.atguigu.sparksql

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType}
//求年龄平均值
//1.需要两个值 累加的年龄和 出现的次数

//user defind aggregate function
object Spark_UDAF {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName(“wc”).setMaster(“local[*]”)
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
//创建RDD
//引入隐式转换
import spark.implicits._
//创建聚合函数
val udaf = new MyAgeFunction
//注册聚合函数
spark.udf.register(“avgAge”,udaf)
val frame: DataFrame = spark.read.json(“in/person.json”)
//将df转换为表
frame.createOrReplaceTempView(“stu”)
spark.sql(“select avgAge(age) from stu”).show()
}
}
//继承UserDefinedAggregateFunction
//实现方法
class MyAgeFunction extends UserDefinedAggregateFunction{
//输入的结构是什么样的 传入的是年龄 直接new 增加年龄
override def inputSchema: StructType = {
//输入的字段叫age,longtype类型
new StructType().add(“age”,LongType)
}
//Buffer 计算时的数据结构 缓冲区的数据结构 输入数据的结构
override def bufferSchema: StructType = {
new StructType().add(“sum”,LongType).add(“count”,LongType)
}
//out 函数返回的数据类型 sum/count
override def dataType: DataType = {
DoubleType
}
//函数是否稳定 给什么值就返回什么
override def deterministic: Boolean = true
//函数计算之前缓冲区的初始化 既sum 和count
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//buffer(0)=0L
//buffer(1)=0L //既刚开始把变量count和sum放进缓冲区时的值
buffer.update(0,0L)
buffer.update(1,0L)
}
//更新数据 计算的每一条数据更新缓冲区 节点内更新
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//将缓冲区的数据与传入的数据相加
//count每次加1
buffer(0)=buffer.getLong(0)+input.getLong(0)
buffer(1)=buffer.getLong(1)+1
}
//将多个节点的缓冲区合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//sum
buffer1(0)=buffer1.getLong(0)+buffer2.getLong(0)
//count
buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
}
//计算 sum/count
override def evaluate(buffer: Row): Any = {
buffer.getLong(0).toDouble/buffer.getLong(1)
}
}

package com.atguigu.sparksql

import org.apache.spark.SparkConf
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType}
//求年龄平均值
//1.需要两个值 累加的年龄和 出现的次数

//user defind aggregate function
object Spark_UDAF2 {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName(“wc”).setMaster(“local[*]”)
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
//创建RDD
//引入隐式转换
import spark.implicits._
//创建聚合函数
val udaf = new MyAgeFunctionOb
//将聚合函数转换为查询列
val avgCol: TypedColumn[UserBean, Double] = udaf.toColumn.name(“avgAge”)

val frame: DataFrame = spark.read.json("in/person.json")
val userDS: Dataset[UserBean] = frame.as[UserBean]
userDS.select(avgCol).show()
spark.stop()
  • 1
  • 2
  • 3
  • 4

}
}
case class UserBean(name:String,age:BigInt)
//处理逻辑类
case class AvgBuffer(var sum:BigInt,var count:Int)
//继承Aggregator
//强类型
// 自定义聚合函数类 继承Aggregator[-IN, BUF, OUT]
class MyAgeFunctionOb extends Aggregator[UserBean,AvgBuffer,Double]{
//初始化 缓冲区(age0,conut0)
override def zero: AvgBuffer = {
AvgBuffer(0,0)
}
//根据输入数据更新缓冲区 返回缓冲区 把输入的对象和冲缓区做操作
override def reduce(b: AvgBuffer, a: UserBean): AvgBuffer = {
b.sum=b.sum+a.age
b.count=b.count+1
b
}
//缓冲区的合并操作
override def merge(b1: AvgBuffer, b2: AvgBuffer): AvgBuffer = {
b1.sum = b1.sum+b2.sum
b1.count=b1.count+b2.count
b1
}
//完成计算
override def finish(reduction: AvgBuffer): Double = {
reduction.sum.toDouble/reduction.count
}
//自定义类型的转码 编码解码问题 固定写法 不用管
override def bufferEncoder: Encoder[AvgBuffer] = Encoders.product

override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/603723
推荐阅读
相关标签
  

闽ICP备14008679号