当前位置:   article > 正文

TF-IDF算法详解_tfidf算法

tfidf算法

TF-IDF算法详解

此算法多用于情感语义分析,提取每条评论中的权重词用来分析,分类!
TF:(Term Frequency,缩写为TF)也就是词频.
IDF:(Inverse Document Frequency) 逆文档频率
下面就是具体的公式:

1.计算词频TF

在这里插入图片描述
考虑到文章有长短之分,为了方便不同文章的比较,进行"词频"标准化.在这里插入图片描述
再或者
在这里插入图片描述

2.计算逆文档频率

需要一个语料库(corpus),用来模拟语言的使用环境
在这里插入图片描述
如果一个词越常见,那么他的分母越大,逆文档频率就越小接近0,分母之所以要加1,是为了避免分母为0(即所有文档都不包含该词)log标示对得到的值去对数.

3.计算TF-IDF

在这里插入图片描述
可以看到,TF-IDF与一个词在文档中的出现次数成正比,与该词在整个语言环境中的出现的次数成反比,计算出每个文档的每个词的TF-IDF值,然后降序排列,去排在最前面的几个词.

代码举例

有四个文档,1,2,3,4,后面的doc为每个文档的关键词
docid,doc
1,a a a a a a x x y
2,b b b x y
3,c c x y
4,d x
需求:利用TF-IDF求出每个关键词的权重值

import cn.doitedu.commons.util.SparkUtil
import org.apache.log4j.{Level, Logger}

/**
  * 利用SQL的方式求出每个关键词的权重值
  */
object TFIDF_SQL {

  def main(args: Array[String]): Unit = {

    Logger.getLogger("org").setLevel(Level.WARN)

    val spark = SparkUtil.getSparkSession(this.getClass.getSimpleName)
    import spark.implicits._

    val df = spark.read.option("header","true").csv("userprofile/data/demo/tfidf")

    // 将定义好的udf注册到sql解析引擎
    import cn.doitedu.ml.util.TF_IDF_Util._
    spark.udf.register("doc2tf",doc2tf)

    // 1. 将原始数据,加工成tf值 词向量
    val tf_df = df.selectExpr("docid","doc2tf(doc,26) as tfarr")
    println("tf特征向量结果: ----------")
    tf_df.show(100,false)

    // 2.利用tf_df,计算出每个词所出现的文档数
    // 将tf_df做一个变换: tf数组中的非零值全部替换成1,以便于后续的计数操作
    spark.udf.register("arr2one",arr2One)
    val tf_df_one = tf_df.selectExpr("docid","arr2one(tfarr) as flag")

    tf_df_one.show(10,false)

    // 接下来将tf_df_one这个dataframe中的所有行的flag数组,对应位置的元素累加到一起==》 该位置的词所出现过的文档数
    import cn.doitedu.ml.util.ArraySumUDAF
    spark.udf.register("arr_sum",ArraySumUDAF)
    val docCntDF = tf_df_one.selectExpr("arr_sum(flag) as doc_cnt")
    docCntDF.show(10,false)

    // 接着,将上面的结果:每个词所出现的文档数 ==》 IDF:  lg(文档总数/(1+词文档数))
    val docTotal: Long = df.count() // 文档总数
    spark.udf.register("cnt2idf",docCntArr2Idf)

    //val idfDF = docCntDF.selectExpr("cnt2idf(doc_cnt,"+docTotal+")")
    val idfDF = docCntDF.selectExpr("cnt2idf(doc_cnt,"+docTotal+")")  // cnt2idf(doc_cnt,4)

    println("idf特征向量结果: ----------")
    idfDF.show(10,false)
    // TODO 将idfDF这个表  和  tfDF表,综合相乘得到  tfidf表
    spark.close()
  }
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52

执行结果

tf特征向量结果: ----------
+-----+----------------------------------------------------------------------------------------------------------------------------------+
|docid|tfarr                                                                                                                             |
+-----+----------------------------------------------------------------------------------------------------------------------------------+
|1    |[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 1.0, 0.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]|
|2    |[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0]|
|3    |[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0]|
|4    |[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]|
+-----+----------------------------------------------------------------------------------------------------------------------------------+

+-----+----------------------------------------------------------------------------------------------------------------------------------+
|docid|flag                                                                                                                              |
+-----+----------------------------------------------------------------------------------------------------------------------------------+
|1    |[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]|
|2    |[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]|
|3    |[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]|
|4    |[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]|
+-----+----------------------------------------------------------------------------------------------------------------------------------+

+----------------------------------------------------------------------------------------------------------------------------------+
|doc_cnt                                                                                                                           |
+----------------------------------------------------------------------------------------------------------------------------------+
|[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0, 3.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0]|
+----------------------------------------------------------------------------------------------------------------------------------+

idf特征向量结果: ----------
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|UDF:cnt2idf(doc_cnt, cast(4 as bigint))                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      |
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|[2.6020599913279625, 2.6020599913279625, 2.6020599913279625, 2.6020599913279625, 2.6020599913279625, 2.6020599913279625, 2.6020599913279625, 2.6020599913279625, 2.6020599913279625, 2.6020599913279625, 2.6020599913279625, 2.6020599913279625, 2.6020599913279625, 2.6020599913279625, 2.6020599913279625, 2.6020599913279625, -0.0010843812922198969, 0.12349349573411905, 2.6020599913279625, 0.5977386175453198, 0.5977386175453198, 0.5977386175453198, 0.5977386175453198, 2.6020599913279625, 2.6020599913279625, 2.6020599913279625]|
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31

需要用到的工具类TF_IDF_Util

package cn.doitedu.ml.util

import scala.collection.mutable

object TF_IDF_Util {


  /**
    * 一个udf函数,它可以输入一篇文档,输出一个tf值的特征向量(数组)
    */
  val doc2tf = (doc:String,n:Int)=>{

    //准备一个长度为n,初始值为0的数组 [0,0,0,0,0,0,0,0,0,0,.....]
    val tfArr = Array.fill(n)(0.0)


    // 字符串:"a a a a a a x x y" => 字符串数组:[a,a,a,a,a,a,x,x,y]
    val tmp1: Array[String] = doc.split(" ")

    //数组:[a,a,a,a,a,a,x,x,y] =>  HashMap:{(a,[a,a,a,a,a,a]),(x,[x,x]),(y,[y])}
    val tmp2: Map[String, Array[String]] = tmp1.groupBy(e=>e)

    // (a,[a,a,a,a,a,a]) => (a,6)
    // (x,[x,x])         => (x,2)
    // (y,[y])           => (y,1)
    val wc: Map[String, Int] = tmp2.map(tp=>(tp._1,tp._2.size))

    // 遍历wc这个hashmap,将其中的每一个词的词频映射到向量的这个位置: 词.hashcode%n
    for((w,c)<-wc){
      // 用hash映射求得该词所映射的特征位置脚标
      val index = (w.hashCode&Integer.MAX_VALUE)%n
      // 再把特征向量中该脚标上的特征值,替换为这个词w的词频c
      tfArr(index) = c

    }

    tfArr
  }


  /**
    * 一个udf函数,它可以输入一个double数组,然后将数组中非0值替换成1,返回
    */
  val arr2One = (arr:mutable.WrappedArray[Double])=>{
    arr.map(d=>if(d != 0.0) 1.0 else 0.0)
  }


  /**
    * 将一个“词文档数” 数组 ==》 词idf值数组
    */
  val docCntArr2Idf = (doc_cnt:mutable.WrappedArray[Double],doc_total:Long)=>{

    doc_cnt.map(d=> Math.log10(doc_total/(0.01+d)))
  }

}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57

需要用到的工具类ArraySumUDAF

package cn.doitedu.ml.util

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}

object ArraySumUDAF extends UserDefinedAggregateFunction{
  // 函数的输入参数,有几个字段,分别是什么类型
  // 比如:arr_sum(flag,other) 需要两个字段
  override def inputSchema: StructType = {
    new StructType().add("flag",DataTypes.createArrayType(DataTypes.DoubleType))
  }

  // buffer 是在聚合函数运算过程中,用于存储局部聚合结果的缓存
  override def bufferSchema: StructType = new StructType().add("buffer",DataTypes.createArrayType(DataTypes.DoubleType))

  // 最后返回结果的数据类型,在本需求中,还是一个Double数组
  override def dataType: DataType = DataTypes.createArrayType(DataTypes.DoubleType)

  // 我们的聚合运算逻辑,是否总是能返回确定结果!
  override def deterministic: Boolean = true

  // 对buffer进行初始化,在本需求逻辑中,可以先给一个长度为0的空double数组
  override def initialize(buffer: MutableAggregationBuffer): Unit = buffer.update(0,Array.emptyDoubleArray)

  // 此方法,就是局部聚合的逻辑所在地,大的逻辑就是,根据输入的一行数据input,来更新局部缓存buffer中的数据
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {

    // 从输入的行中,取出flag数组字段
    val inputArr = input.getSeq[Double](0)

    var bufferArr = buffer.getSeq[Double](0)
    // 如果是第一次对buffer做更新操作,那么buffer中的缓存数组应该长度为0,则给他换成跟输入数组长度一致的数组
    if(bufferArr.size<1) bufferArr = Array.fill(inputArr.size)(0.0)

    // 然后,将输入的数组中各个元素按对应位置累加到buffer的数组中
    bufferArr = inputArr.zip(bufferArr).map(tp=>tp._1 + tp._2)

    // 将局部聚合结果,更新到buffer中
    buffer.update(0,bufferArr)

  }

  // 全局聚合逻辑所在地:它是将各个partition的局部聚合结果,一条一条往buffer上累加
  // buffer2代表的是每一个局部聚合结果;  buffer1代表的是本次聚合的存储所在地
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

    var accuArr = buffer1.getSeq[Double](0)
    val inputArr = buffer2.getSeq[Double](0)

    // 如果是第一次对buffer1做更新操作,那么buffer1中的缓存数组应该长度为0,则给他换成跟输入数组长度一致的数组
    if(accuArr.size<1) accuArr = Array.fill(inputArr.size)(0.0)

    // 然后,将输入的数组中各个元素按对应位置累加到buffer的数组中
    accuArr = inputArr.zip(accuArr).map(tp=>tp._1 + tp._2)

    // 将聚合结果,更新到buffer1中
    buffer1.update(0,accuArr)
  }

  // 最后向外部返回结果的方法,这个方法中的buffer缓存,就是merge方法中的buffer1缓存
  override def evaluate(buffer: Row): Any = {
    buffer.getSeq[Double](0)
  }
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65

以上是为了方便理解的手撕TF-IDF,下面是调用人家写好的tfidf算法

package TF_IDF

import cn.doitedu.commons.util.SparkUtil
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.feature.{HashingTF, IDF}
object TFIDF_Mllib {

  def main(args: Array[String]): Unit = {

    Logger.getLogger("org").setLevel(Level.WARN)

    val spark = SparkUtil.getSparkSession(this.getClass.getSimpleName)
    import spark.implicits._


    /**
      * docid,doc
      * 1,    a a a a a a x x y  ->[0,0,0,0,6,0,2,0,0,1,0,0]
      * 2,    b b b x y          ->
      * 3,    c c x y            ->
      * 4,    d x                ->
      */
    // 加载原始数据
    val df = spark.read.option("header","true").csv("userprofile/data/demo/tfidf/docs.txt")

    val wordsDF = df.selectExpr("docid","split(doc,' ') as words")


    // 将分词后的字段,变成hash映射TF值向量
    val tfUtil = new HashingTF()
      .setInputCol("words")
      .setOutputCol("tf_vec")
      .setNumFeatures(26)
    val tfDF = tfUtil.transform(wordsDF)
    tfDF.show(10,false)


    // 利用tf集合,算出idf向量,然后再用idf去加工原来的tf,得到tfidf
    val idfUtil = new IDF()
      .setInputCol("tf_vec")
      .setOutputCol("tfidf_vec")
    val model = idfUtil.fit(tfDF)
    val tfidf = model.transform(tfDF)
    tfidf.show(10,false)

    spark.close()

  }

}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50

算法实现:

+-----+---------------------------+----------------------------+
|docid|words                      |tf_vec                      |
+-----+---------------------------+----------------------------+
|1    |[a, a, a, a, a, a, x, x, y]|(26,[3,4,18],[1.0,6.0,2.0]) |
|2    |[b, b, b, x, y]            |(26,[3,18,25],[1.0,1.0,3.0])|
|3    |[c, c, x, y]               |(26,[3,12,18],[1.0,2.0,1.0])|
|4    |[d, x]                     |(26,[2,18],[1.0,1.0])       |
+-----+---------------------------+----------------------------+

+-----+---------------------------+----------------------------+-----------------------------------------------------------+
|docid|words                      |tf_vec                      |tfidf_vec                                                  |
+-----+---------------------------+----------------------------+-----------------------------------------------------------+
|1    |[a, a, a, a, a, a, x, x, y]|(26,[3,4,18],[1.0,6.0,2.0]) |(26,[3,4,18],[0.22314355131420976,5.497744391244931,0.0])  |
|2    |[b, b, b, x, y]            |(26,[3,18,25],[1.0,1.0,3.0])|(26,[3,18,25],[0.22314355131420976,0.0,2.7488721956224653])|
|3    |[c, c, x, y]               |(26,[3,12,18],[1.0,2.0,1.0])|(26,[3,12,18],[0.22314355131420976,1.8325814637483102,0.0])|
|4    |[d, x]                     |(26,[2,18],[1.0,1.0])       |(26,[2,18],[0.9162907318741551,0.0])                       |
+-----+---------------------------+----------------------------+-----------------------------------------------------------+
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/499221
推荐阅读
相关标签
  

闽ICP备14008679号