赞
踩
val session = SparkSession
.builder()
.config(new SparkConf())
.appName("test01") // 如果在conf中设置了,就不需要在此设置
.master("local") // 如果在conf中设置了,就不需要在此设置
.enableHiveSupport() //开启这个选项时 spark sql on hive 才支持DDL,没开启,spark只有catalog
.config("hive.metastore.uris", "thrift://192.168.7.11:9083")
.getOrCreate()
val sc: SparkContext = session.sparkContext
sc.setLogLevel("ERROR")
# sparkContext读文件 sc.textFile(bigFile,minPartitions = 2) -> hadoopFile(path,TextInputFormat,keyClass,valueClass,minPartitions) -> HadoopRDD sc.wholeTextFiles(bigFile) -> WholeTextFileRDD -继承了 NewHadoopRDD sc.newAPIHadoopFile(bigFile) -> //sc.textFile(bigFile,minPartitions = 2).take(10).foreach(println) sc.hadoopFile(bigFile, classOf[TextInputFormat], classOf[LongWritable], classOf[Text], 2).map(_._2.toString).setName(bigFile).take(10).foreach(println) # 写入HDFS 删除HDFS文件 val hadoopConf = sparkContext.hadoopConfiguration val hdfs = org.apache.hadoop.fs.FileSystem.get(hadoopConf) if(hdfs.exists(path)){ //为防止误删,禁止递归删除 hdfs.delete(path,false) } sc.textFile(bigFile).map(t => ("filename.txt",t)).saveAsHadoopFile( // (键,值) "/data/hdfsPath", // 文件目录路径 classOf[String], // 键的类型 classOf[String], // 值的类型 classOf[PairRDDMultipleOTextOutputFormat], //重写以下方法 generateFileNameForKeyValue() 文件名, generateActualKey() 是否有键值在内容中 //classOf[SnappyCodec] )
package com.chauncy.spark_sql.file_read_write import com.chauncy.spark_sql.InitSparkSession import com.chauncy.spark_sql.project.PairRDDMultipleOTextOutputFormat import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.sql.SaveMode /** * @author Chauncy Jin * Tool: * @date 2021/12/15 */ object MyTextReadWrite { def deleteFile(hdfs: FileSystem, path: String): AnyVal = { val pathFile: Path = new Path(path) if (hdfs.exists(pathFile)) { //为防止误删,禁止递归删除 hdfs.delete(pathFile, false) } } def main(args: Array[String]): Unit = { val sparkSession = InitSparkSession.sparkSession val sc = sparkSession.sparkContext val conf: Configuration = sc.hadoopConfiguration val hdfs: FileSystem = org.apache.hadoop.fs.FileSystem.get(conf) sc.setLogLevel("Error") val bigFile = "file:///Users/jinxingguang/java_project/bigdata3/spark-demo/data/pvuvdata" val jsonPath = "file:///Users/jinxingguang/java_project/bigdata3/spark-demo/data/test.json" /** * sparkContext读写文件 */ val outputPath = "/data/hdfsPath/text" deleteFile(hdfs, outputPath) sc.textFile(bigFile).map(t => ("filename.txt", t)).saveAsHadoopFile( outputPath, classOf[String], classOf[String], classOf[PairRDDMultipleOTextOutputFormat], //重写以下方法 generateFileNameForKeyValue() 文件名, generateActualKey() 是否有键值在内容中 //classOf[SnappyCodec] 加入codec就有问题 ) /** * sparkSession文件的读写 */ val dataFrame = sparkSession.read.json(jsonPath) //dataFrame.show() val jsonOutputPath = "/data/hdfsPath/json" val orcOutputPath = "/data/hdfsPath/orc" deleteFile(hdfs, jsonOutputPath) deleteFile(hdfs, orcOutputPath) dataFrame.write.mode(SaveMode.Append).json(jsonOutputPath) dataFrame.write.mode(SaveMode.Append).orc(orcOutputPath) //println(new String("hello\n".getBytes, "GBK")) // 读取Windows格式的文件 } }
session.catalog.listDatabases().show() // 查看hive库
session.catalog.listTables().show() // 查看表
// 将json文件映射为表
session.read.json("path.json").createTempView("user_json") // 将文件映射为user_json 表
val data = session.sql("select * from user_json") // 查询表
data.show() // 显示查询结果
data.printSchema() // 打印表头
session.catalog.cacheTable("user_json") // 将表缓存起来
session.catalog.clearCache() // 释放缓冲
// {"name":"zhangsan","age":20}...
val spark = SparkSession.builder().config(new SparkConf()).master("local").appName("hello").getOrCreate()
val frame = spark.read.json("file:///Users/jinxingguang/java_project/bigdata3/spark-demo/data/json")
frame.show()
println(frame.count())
import spark.implicits._
frame.filter($"age" >=20).show()
frame.createTempView("user") // 通过sparkSession向catalog注册表
frame.write.saveAsTable("hive_user") // 在hive中将数据保存成user表,不是临时向catalog注册的表
package com.chauncy import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} /** * @author Chauncy Jin * Tool: SparkSession 快速生成dataFrame * @date 2022/1/20 */ object DataSetTest { def main(args: Array[String]): Unit = { val spark = SparkSession.builder() .appName("test") .master("local") .config("spark.sql.shuffle.partitions", "10") .config("hive.metastore.uris", "thrift://192.168.7.11:9083") .enableHiveSupport() .getOrCreate() val sc = spark.sparkContext sc.setLogLevel("Error") //读取数据并生成DataSet实例 // 1087718415194492928&229071822400790528&升级&0&3&1&null&913683594017214464&2019-01-22 22:27:34.0&0&null&null&null import spark.implicits._ val das = spark.read.textFile("/project/jxc.db/jxc/ods/ODS_AGENT_UPGRADE_LOG/*") val rddSet: Dataset[(String, String, String, String)] = das.filter(_.split("&").length > 8) .map((line: String) => { val ste = line.split("&") (ste(1), ste(3), ste(4), ste(8)) }) val frame: DataFrame = rddSet.toDF("uid", "ago_lev", "after_lev", "date") frame.printSchema() } }
package com.chauncy import org.apache.spark.sql.types.{DataTypes, StructField, StructType} import org.apache.spark.sql.{Row, SparkSession} /** * @author Chauncy Jin * Tool: spark读取文件生成dataFrame 保存数据到hive * @date 2022/1/20 */ object DataTest { def main(args: Array[String]): Unit = { val spark = SparkSession.builder() .master("local") .appName("test") .config("spark.sql.shuffle.partitions", "10") .config("hive.metastore.uris", "thrift://192.168.7.11:9083") .enableHiveSupport() .getOrCreate() val sc = spark.sparkContext sc.setLogLevel("Error") // 数据+元数据 == DataFrame 类似表 // 第一种方式 row类型的rdd + structType // 1. 数据 RDD[ROW] 一行一行的数据 //读取数据并生成DataFrame实例 // 1087718415194492928&229071822400790528&升级&0&3&1&null&913683594017214464&2019-01-22 22:27:34.0&0&null&null&null val das = spark.sparkContext.textFile("/project/jxc.db/jxc/ods/ODS_AGENT_UPGRADE_LOG/*") val rowRDD = das.filter(_.split("&").length > 8) .map((line: String) => { val ste = line.split("&") Row.apply(ste(1), ste(3), ste(4), ste(8)) }) // 2. 元数据 : StructType val fields = Array[StructField]( StructField.apply("uid", DataTypes.StringType, nullable = true), StructField.apply("ago_lev", DataTypes.StringType, nullable = true), StructField.apply("after_lev", DataTypes.StringType, nullable = true), StructField.apply("date", DataTypes.StringType, nullable = true) ) val schema = StructType.apply(fields) // 表的定义 // 3. 创建DataFrame val dataFrame = spark.createDataFrame(rowRDD, schema) dataFrame.show(10,truncate = true) // 慎用,数据大会爆 使用num约束 dataFrame.printSchema() // 打印表头 dataFrame.createTempView("temp_change_log") // 通过session 向catalog中注册表名 spark.sql("select * from temp_change_log limit 10").show() spark.sql("use jxc ") spark.sql( """ |CREATE EXTERNAL TABLE IF NOT EXISTS `ods_change_log_chauncy` ( | `mall_user_id` string, | `ago_lead_id` string, | `after_lead_id` string, | `create_date` string |) |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\001' |LOCATION 'hdfs://mycluster/project/jxc.db/jxc/ods/ods_change_log_chauncy'; |""".stripMargin) spark.sql( """ |insert overwrite table jxc.ods_change_log_chauncy select * from temp_change_log |""".stripMargin) } }
person.txt chauncy 18 0 lisa 22 1 yiyun 99 1 // 创建DataFrame动态封装 val rdd = sc.textFile("file:///Users/jinxingguang/java_project/bigdata-chauncy/spark-demo/data/person.txt") // 每一列的类型约定 val userSchema = Array( "name string", "age int", "sex int" ) // 转换类型 def toDataType(col: (String, Int)) = { userSchema(col._2).split("[ ]")(1) match { case "int" => col._1.toInt case "date" => java.sql.Date.valueOf(col._1) case _ => col._1.toString } } // 1 row rdd // rdd.map(_.split(" ")).map(line => Row.apply(line(0),line(1).toInt)) // 写死了 val rddRow: RDD[Row] = rdd.map(_.split(" ")) .map(x => x.zipWithIndex) // [(chauncy,0), (18,1)] .map(x => x.map(toDataType(_))) .map(line => Row.fromSeq(line)) // row 表示了很多的列,每个列要标识出准确的类型 // 2 structType // 函数,获取每一列的类型 def getDataType(v: String) = { v match { case "int" => DataTypes.IntegerType // 24 case "binary" => DataTypes.BinaryType case "boolean" => DataTypes.BooleanType // true false case "byte" => DataTypes.ByteType case "date" => DataTypes.DateType // 2022-12-31 case "long" => DataTypes.LongType case "timestamp" => DataTypes.TimestampType case _ => DataTypes.StringType // 字符串 } } // 列的属性 val fields: Array[StructField] = userSchema.map(_.split(" ")).map(x => StructField.apply(x(0), getDataType(x(1)))) val schema: StructType = StructType.apply(fields) // schema = schema01等价 val schema01: StructType = StructType.fromDDL("name string,age int,sex int") val dataFrame = session.createDataFrame(rddRow, schema01) dataFrame.show() dataFrame.printSchema() // 通过session向catalog注册 dataFrame.createTempView("user") session.sql("select * from user").show()
// Bean类型实例 class Person extends Serializable { @BeanProperty var name: String = "" @BeanProperty var age: Int = 0 @BeanProperty var sex: Int = 0 } ------- //第二种方式: bean类型的rdd + javabean //第二种方式: bean类型的rdd + javabean val rdd = sc.textFile("file:///Users/jinxingguang/java_project/bigdata-chauncy/spark-demo/data/person.txt") val person = new Person // 放到外部需要 extends Serializable // 1,mr,spark pipeline iter 一次内存飞过一条数据::-> 这一条记录完成读取/计算/序列化 // 2,分布式计算,计算逻辑由 Driver 序列化,发送给其他jvm的Executor中执行 val rddBean: RDD[Person] = rdd.map(_.split(" ")).map(arr => { // val person = new Person person.setName(arr(0)) person.setAge(arr(1).toInt) person.setSex(arr(2).toInt) person }) val dataFrame = session.createDataFrame(rddBean, classOf[Person]) dataFrame.show() dataFrame.printSchema() // 通过session向catalog注册 dataFrame.createTempView("user") session.sql("select * from user").show()
case class User(name:String,age:BigInt) extends Serializable
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().config(new SparkConf()).master("local").appName("hello").getOrCreate()
import spark.implicits._
// {"name":"zhangsan","age":20}...
val data = spark.read.json("file:///Users/jinxingguang/java_project/bigdata3/spark-demo/data/json").as[User]
data.show()
data.filter($"age" >=20).show()
}
/* 文本文件要转*结构化*数据再进行计算 文本文件 -> 中间态数据: ETL过程 文件格式,分区和分桶 分区可以让计算加载的数据减少,分桶可以让计算过程中的shuffle移动量减少 */ // Spark 的DataSet 既可以按collection类似于rdd的方式操作,也可以按SQL的方式操作 val rddData: Dataset[String] = session.read.textFile("file:///Users/jinxingguang/java_project/bigdata-chauncy/spark-demo/data/person.txt") import session.implicits._ val person: Dataset[(String, Int)] = rddData.map( line => { val strs = line.split(" ") (strs(0), strs(1).toInt) } )/*(Encoders.tuple(Encoders.STRING,Encoders.scalaInt)) //自己写编码器,也可以直接导入隐式类*/ // 附加表的列描述 val cPerson = person.toDF("name", "age") cPerson.show() cPerson.printSchema()
package com.chauncy import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, SparkSession} import org.elasticsearch.spark.sql.EsSparkSQL /** * @author Chauncy Jin * Tool: * add jar hdfs://emr-header-1.cluster-246415:9000/jars/es-7.6.2/elasticsearch-hadoop-hive-7.6.2.jar; * @date 2022/1/20 */ object HiveToEs { def main(args: Array[String]): Unit = { val sparkConf = new SparkConf().setAppName("ES_Hive").setMaster("local") sparkConf.set("es.nodes", "172.20.5.11,172.20.5.12,172.20.5.13") //sparkConf.set("es.nodes","172.20.5.11") sparkConf.set("es.port", "9200") sparkConf.set("es.index.auto.create", "true") sparkConf.set("es.write.operation", "index") sparkConf.set("spark.es.batch.size.entries", "10000") sparkConf.set("spark.es.batch.write.refresh", "false") sparkConf.set("spark.es.scroll.size", "10000") sparkConf.set("spark.es.input.use.sliced.partitions", "false") sparkConf.set("hive.metastore.uris", "thrift://172.20.1.232:9083") val sparkSession: SparkSession = SparkSession.builder().config(sparkConf).enableHiveSupport().getOrCreate() sparkSession.sparkContext.setLogLevel("ERROR") //val url: String = "jdbc:mysql://rm-uf638jr947ng36h26co.mysql.rds.aliyuncs.com:3306/data_bi?useSSL=false&useUnicode=true&characterEncoding=UTF-8&rewriteBatchedStatements=true" //val table: String = "temp_zz" //val properties: Properties = new Properties() //properties.put("user", "dev_client_001") //properties.put("password", "meifute@123") //properties.put("driver", "com.mysql.jdbc.Driver") //properties.setProperty("batchsize", "10000") //properties.setProperty("fetchsize", "10000") //val course: DataFrame = sparkSession.read.jdbc(url, table, properties) sparkSession.sql("use data_prod ") val course: DataFrame = sparkSession.sql( """ |select |a.user_id as user_id, |cast(a.metric_value as string) as metric_value, |a.load_date as load_date, |a.dt as dt, |a.metric_name as metric_name |from ads_fact_metric_agent a |""".stripMargin) //course.show() EsSparkSQL.saveToEs(course, "ads_fact_metric_agent") sparkSession.stop() /* Warning:scalac: While parsing annotations in /Users/jinxingguang/.m2/repository/org/apache/spark/spark-core_2.12/3.1.2/spark-core_2.12-3.1.2.jar(org/apache/spark/rdd/RDDOperationScope.class), could not find NON_ABSENT in enum object JsonInclude$Include. This is likely due to an implementation restriction: an annotation argument cannot refer to a member of the annotated class (SI-7014). */ } }
val conf = new SparkConf().setMaster("local").setAppName("sql hive") val session = SparkSession .builder() .config(conf) // .enableHiveSupport() .getOrCreate() val sc = session.sparkContext sc.setLogLevel("Error") import session.implicits._ val dataDF: DataFrame = List( "hello world", "hello world", "hello msb", "hello world", "hello world", "hello spark", "hello world", "hello spark" ).toDF("line") // 列名为line dataDF.createTempView("ooxx") // 注册到catalog val df: DataFrame = session.sql("select * from ooxx") df.show() df.printSchema() // 计算词频 使用SQL的方式 // session.sql(" select word,count(1) from (select explode(split(line,' ')) word from ooxx) as tt group by word").show() // 计算词频 使用api的方式 df 相当于 from table val subTab = dataDF.selectExpr("explode(split(line,' ')) word") val dataset: RelationalGroupedDataset = subTab.groupBy("word") val res = dataset.count() // 将结果保存到parquet文本 res.write.mode(SaveMode.Append).parquet("file:///Users/jinxingguang/java_project/bigdata-chauncy/spark-demo/data/out/ooxx") // 读取parquet文本 val frame: DataFrame = session.read.parquet("file:///Users/jinxingguang/java_project/bigdata-chauncy/spark-demo/data/out/ooxx") frame.show() frame.printSchema() /* 基于文件的行式: session.read.parquet() session.read.textFile() session.read.json() session.read.csv() 读取任何格式的数据源都要转换成DF res.write.parquet() res.write.orc() res.write.text() */
package com.chauncy.spark_sql.file_read_write import java.util.Properties import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession} /** * @author Chauncy Jin * Tool: * @date 2021/10/13 */ object MySparkMysql { def main(args: Array[String]): Unit = { // val conf = new SparkConf().setMaster("local").setAppName("mysql") val session = SparkSession .builder() .master("local") .appName("mysql") .config(new SparkConf()) // 可有可无 .config("spark.sql.shuffle.partitions", "1") // 默认会有100并行度的参数 .config("hive.metastore.uris", "thrift://192.168.7.11:9083") .enableHiveSupport() .getOrCreate() val sc = session.sparkContext // sc.setLogLevel("ERROR") sc.setLogLevel("INFO") val properties = new Properties() properties.setProperty("user", "用户") properties.setProperty("password", "密码") properties.setProperty("driver", "com.mysql.jdbc.Driver") properties.setProperty("batchsize", "10000") // 批处理的大小 properties.setProperty("fetchsize", "10000") // 一次拿多少数据 val url_ex = s"jdbc:mysql://ip:port/tablename?useSSL=false&useUnicode=true&characterEncoding=UTF-8&rewriteBatchedStatements=true" // rewriteBatchedStatements=true 打开批处理开关 sparkSession.sql("use jxc") sparkSession.sql("select * from ODS_AGENT_UPGRADE_LOG") .write.mode(SaveMode.Overwrite).jdbc(url_ex, "chauncy_agent_upgrade_log", properties) /* */ // 没有介入hive时,数据源都是DS/DF val jdbcDF: DataFrame = session.read.jdbc(properties.getProperty("url"), "student", properties) jdbcDF.show(10, truncate = true) jdbcDF.createTempView("student_spark") session.sql("select * from student_spark").show() // 写数据到mysql import org.apache.spark.sql.functions._ //导入函数,可以使用 udf、col 、lit方法 jdbcDF.withColumn("status", lit(1)) // 加入一列 .write.mode(SaveMode.Overwrite) .jdbc(properties.getProperty("url"), "student_copy", properties) // jdbcDF.write.jdbc(properties.getProperty("url"),"student_copy",properties) // 数据移动 /** * 连表查询 */ val usersDF: DataFrame = session.read.jdbc(properties.get("url").toString, "student", properties) val scoreDF: DataFrame = session.read.jdbc(properties.get("url").toString, "score", properties) usersDF.createTempView("userstab") scoreDF.createTempView("scoretab") val resDF: DataFrame = session.sql( """ |SELECT | userstab.s_id, | userstab.s_name, | scoretab.s_score |FROM | userstab | JOIN scoretab ON userstab.s_id = scoretab.s_id |""".stripMargin) resDF.show() resDF.printSchema() // 默认并行度是100 // 21/10/13 07:47:05 INFO DAGScheduler: Submitting 100 missing tasks from ResultStage 11 } }
val ss: SparkSession = SparkSession .builder() .master("local") .appName("standalone hive") .config("spark.sql.shuffle.partitions", 1) .config("spark.sql.warehouse.dir", "file:///Users/jinxingguang/java_project/bigdata-chauncy/spark/warehouse") .enableHiveSupport() //开启hive支持 ? 自己会启动hive的metastore .getOrCreate() val sc: SparkContext = ss.sparkContext // sc.setLogLevel("ERROR") // ss.sql("create table xxx(name string,age int)") // ss.sql("insert into xxx values ('zhangsan',18),('lisi',22)") ss.sql("select * from xxx").show() ss.catalog.listTables().show() // 有数据库的概念 ss.sql("create database chauncy_db") ss.sql("use chauncy_db") ss.sql("create table meifute(name string,age int)") ss.catalog.listTables().show()
val ss: SparkSession = SparkSession .builder() .appName("cluster on hive") .master("local") .config("hive.metastore.uris", "thrift://node01:9083") .enableHiveSupport() .getOrCreate() val sc: SparkContext = ss.sparkContext sc.setLogLevel("ERROR") ss.sql("create database IF NOT EXISTS spark_hive ") ss.sql("use spark_hive") ss.catalog.listTables().show() // 报错了,删除掉hive中的hbase表 // Class org.apache.hadoop.hive.hbase.HBaseSerDe not found // 读取hive表, sparkSession.table(表名) 产生dataframe sparkSession.table("TO_YCAK_MAC_LOC_D") import ss.implicits._ val df01: DataFrame = List( "zhangsan", "lisi" ).toDF("name") df01.createTempView("ooxx") // 通过session向catalog注册表 // SQL创建表,并插入数据 ss.sql("create table IF NOT EXISTS hddffs ( id int,age int)") //DDL // 需要 core-site.xml 和 hdfs-site.xml 的hadoop配置 ss.sql("insert into hddffs values (4,3),(8,4),(9,3)") // DML 数据是通过spark自己和hdfs进行访问 ss.sql("show tables").show() // 临时表,没有保存到hive中 df01.write.saveAsTable("oxox") // 在hive中将数据保存成oxox表,不是临时向catalog注册的表 ss.sql("show tables").show() // 临时表,没有保存到hive中
只需要metastore的地址配置就可以了
cat > /opt/bigdata/spark-2.3.4/conf/hive-site.xml <<-EOF
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<?xml-stylesheet type="text/xsl" href="configuration.xsl"?>
<configuration>
<property>
<name>hive.metastore.uris</name>
<value>thrift://node01:9083</value>
<description>metastore地址</description>
</property>
</configuration>
EOF
启动spark-shell
cd /opt/bigdata/spark-2.3.4/bin
./spark-shell --master yarn
scala> spark.sql("show tables").show
启动spark-sql
cd /opt/bigdata/spark-2.3.4/bin
./spark-sql --master yarn
查看网页 http://node03:8088/cluster 会出现SparkSQL
可以直接执行SQL,跟hive中共享,两边都可以操作
spark-sql> show tables;
# 对外暴露JDBC服务,接受SQL执行
cd /opt/bigdata/spark-2.3.4/sbin
./start-thriftserver.sh --master yarn
查看网页 http://node03:8088/cluster **多了一个 Thrift JDBC/ODBC Server**
# 使用spark的beeline连接
/opt/bigdata/spark-2.3.4/bin/beeline -u jdbc:hive2://node01:10000/default -n god
打印 Connected to: Spark SQL (version 2.3.4)
/usr/lib/spark-current/bin/beeline -u jdbc:hive2://localhost:10000 -n hadoop
show tables;
org/apache/spark/sql/functions.scala scala对应的源码
package com.chauncy.spark_dataframe_dataset import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Row, SparkSession} object MySpark_Sql_functions { def main(args: Array[String]): Unit = { val ss: SparkSession = SparkSession.builder() .master("local") .appName("ProduceClientLog") .config("hive.metastore.uris", "thrift://192.168.7.11:9083") // hive metastore 的地址 //.config(new SparkConf()) .enableHiveSupport() .getOrCreate() ss.sparkContext.setLogLevel("ERROR") import ss.implicits._ // 将List转成DataFrame val dataDF: DataFrame = List( ("A", 1, 67), ("D", 1, 87), ("B", 1, 54), ("D", 2, 24), ("C", 3, 64), ("R", 2, 54), ("E", 1, 74) ).toDF("name", "class", "score") dataDF.createTempView("users") // 分组,排序统计 // ss.sql("select name,sum(score) " + // " from users " + // "group by name" + // " order by name").show() ss.sql("select * from users order by name asc,score desc").show() //udf 普通的自定义函数 ss.udf.register("ooxx", (x: Int) => { x * 10 }) ss.sql("select *,ooxx(score) mut_10 from users ").show() // 自定义聚合函数 // class MyAggFun extends UserDefinedAggregateFunction ss.udf.register("myagg", new MyAvgFun) ss.sql("select name," + " myagg(score) " + " from users " + " group by name").show() } class MyAvgFun extends UserDefinedAggregateFunction { // 输入列的类型 override def inputSchema: StructType = { // myagg(score) StructType.apply(Array(StructField.apply("score", IntegerType, false))) } override def bufferSchema: StructType = { StructType.apply( Array( StructField.apply("sum", IntegerType, false), StructField.apply("count", IntegerType, false) ) ) } override def dataType: DataType = DoubleType // 是否幂等 override def deterministic: Boolean = true // 是否初始化 override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = 0 buffer(1) = 0 } // 来一条数据更新一次 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { // 组内,一条记录调用一次 buffer(0) = buffer.getInt(0) + input.getInt(0) // sum buffer(1) = buffer.getInt(1) + 1 } // 溢写怎么计算 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1(0) = buffer1.getInt(0) + buffer2.getInt(0) buffer1(1) = buffer1.getInt(1) + buffer2.getInt(1) } // 最后的结果 override def evaluate(buffer: Row): Double = { buffer.getInt(0) /buffer.getInt(1) } } }
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。