当前位置:   article > 正文

TfidfVectorizer计算复现和细节探究

tfidfvectorizer

简介

tf-idf算法,我想很多人都知道它的由来和公式,更进一步,会在纸上用笔计算,但是在sklearn的实际实现中,却鲜有人去复现背后的计算细节和逻辑,去对比验算。本文将提出并解决以下细节问题:
1.TfidfVectorizer和TfidfTransformer是什么关系?
2.tf-idf中tf和idf在代码中分别是怎么实现计算的?
3.idf中的文档是怎么定义的?
4.为什么我用笔计算的和sklearn中计算出来的向量不一样?

问题一

直接上答案,TfidfVectorizer是由CountVectorizer和TfidfTransformer组成的,因为TfidfTransformer默认接受sparse matrix即稀疏矩阵作为输入,所以先要用CountVectorizer进行转换,变成矩阵后再输入进TfidfTransformer。

问题二三四

在这里插入图片描述
先上一张用烂的图,根据官方公式上,不难知道tf其实就是统计词频,后面的idf是(总文档数目/包含这个词的文档数目)再取了个对数。
在进行计算前呢,我们先说明sklearn中tf-idf转换的默认参数:

class sklearn.feature_extraction.text.TfidfTransformer(*, 
norm='l2', use_idf=True, smooth_idf=True, sublinear_tf=False)
  • 1
  • 2

Norm代表的是标准化,l2表示向量元素的平方和为1,l1表示向量元素的绝对值总和为1。
use_idf若为false,就不会进行idf的计算,得到的只是tf。
smooth_idf若为true,公式里log的分子分母各加一。
sub_linear_tf若为true,则把tf变成1+log(tf)。
所以为了保证计算的一致性,应该将它设置为如下:

from sklearn.feature_extraction.text import TfidfVectorizer
tfidf = TfidfVectorizer(norm=None,use_idf=True, smooth_idf=False,sublinear_tf=False)
  • 1
  • 2

把标准化和平滑idf关掉,这样idf的计算就为1+log(N/df)。

接下来上等同于TfidfVectorizer计算的代码,在这里我们使用pyts提供的gunpoint的数据集,也就是将时间序列转化为语言表征的方法(这里不多说了):

import numpy as np
from pyts.datasets import load_gunpoint
from pyts.bag_of_words import BagOfWords
from sklearn.preprocessing import LabelEncoder
X_train, X_test, y_train, y_test = load_gunpoint(return_X_y=True)
bow = BagOfWords(window_size=7, word_size=4) #生成四个字母的词袋模型
LA = bow.transform(X_train)
le = LabelEncoder()
y_ind = le.fit_transform(y_train) #标签标准化
X_class = [' '.join(LA[y_ind == classe]) for classe in range(le.classes_.size)] 
#根据两种标签分别循环
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

以上是装载数据集,总之我们得到了X_class,它大概长下面这样:
在这里插入图片描述
两行,分别是两类,每一类包含大量的四个字母组成的字符串。那么我们先来计算tf:

import scipy.sparse as sp
from sklearn.feature_extraction.text import CountVectorizer
vectorizer = CountVectorizer()
X_class = vectorizer.fit_transform(X_class)
  • 1
  • 2
  • 3
  • 4

调用CountVectorizer或者np.unique函数(查看上一篇文章)都可以得到频数矩阵,得到如下:
在这里插入图片描述
tf表示在所有的不重复元素(也就是字典中),该元素在该文档中出现的次数,这里说明一下文档的概念,根据sklearn的源代码注释:

X : sparse matrix of shape n_samples, n_features)
            A matrix of term/token counts.
  • 1
  • 2

n_samples,即行,为文档的数目,在这里为2行,也就是N=2,列是n_features,也就是字典包含的词的数目。

tf有了,N也有了,接下来就需要计算df,也就是包含该词的文档数:

df = np.bincount(X_class.indices, minlength=X_class.shape[1])
print(df)
[2 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 1 2 2 1 2 1 1
 2 2 1 2 2 1 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 1 1 2 2 1 1 2 2 2
 2 2 2 2 2 2 2 2 1 2 2 2 2 2 2 1 1 2 2 2 2 2 1 2 2 1 1 2 2 2 1 2 2 2 2 2 2
 2 2 2 2 2 2 2 1 1 2 2 2 1 2 2 2 2 2 2 2 1 2 2 2 1 2 1 2]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

indice取词的索引,bincount统计索引出现次数,可以发现在两份文档中,大多数词汇都在两边都有,只有少部分只在一份里有,另一份没有。接下来计算idf(log(N/df)+1)

idf = np.log(2 / df)+1
print(idf)
[1.         1.         1.         1.         1.         1.
 1.         1.         1.69314718 1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.69314718 1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.69314718 1.         1.         1.69314718 1.         1.69314718
 1.69314718 1.         1.         1.69314718 1.         1.
 1.69314718 1.         1.         1.         1.69314718 1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.69314718
 1.69314718 1.         1.         1.69314718 1.69314718 1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.69314718 1.
 1.         1.         1.         1.         1.         1.69314718
 1.69314718 1.         1.         1.         1.         1.
 1.69314718 1.         1.         1.69314718 1.69314718 1.
 1.         1.         1.69314718 1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.69314718 1.69314718
 1.         1.         1.         1.69314718 1.         1.
 1.         1.         1.         1.         1.         1.69314718
 1.         1.         1.         1.69314718 1.         1.69314718
 1.        ]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

细心的人会发现,出现一次的idf都为1.69314718,这和我们电脑计算器按的log(2/1)=0.301不一样啊!
实际上,这里调用的np.log的log的底数为e,而我们电脑计算器的log底数为10,而且在我的印象中从小这个默认不写底数的log就是log10…所以这就是sklearn有点坑的地方,看了底层才知道原来是用loge实现的,搜了一下只有一篇博客提到,难以想象这么多介绍的文章里不把这个重点写出来,我表示震惊,疑惑,不解,好吧。

print(X_class.toarray() * idf)
array([[ 18.        ,  31.        ,   7.        ,   3.        ,
        127.        ,   1.        ,  13.        , 197.        ,
          0.        ,   8.        ,  10.        ,   4.        ,
         52.        ,  69.        ,   7.        ,  50.        ,
        105.        ,   1.        ,   6.        ,   8.        ,
          0.        ,   2.        ,  18.        ,  28.        ,
          1.        ,   6.        ,   3.        ,   3.        ,
          1.        ,  46.        ,   0.        ,   8.        ,
         25.        ,   0.        ,   1.        ,   0.        ,
          0.        ,  24.        ,   1.        ,   0.        ,
          3.        ,  28.        ,   0.        ,   2.        ,
          7.        ,   3.        ,   0.        ,   2.        ,
          2.        ,   2.        ,   4.        ,   1.        ,
          3.        ,   3.        ,   7.        ,   4.        ,
          1.        ,   2.        ,   4.        ,   2.        ,
          2.        ,   2.        ,   2.        ,  19.        ,
          4.        ,   0.        ,   0.        ,  16.        ,
          3.        ,   0.        ,   0.        ,   1.        ,
         16.        ,   1.        ,   6.        ,  24.        ,
          3.        ,   4.        ,   1.        ,   2.        ,
          3.        ,   1.        ,   0.        ,   6.        ,
          7.        ,   3.        ,   8.        ,   3.        ,
          2.        ,   0.        ,   0.        ,   4.        ,
          1.        ,   4.        ,  11.        ,   1.        ,
          0.        ,  27.        ,   7.        ,   0.        ,
          0.        ,  20.        ,   7.        ,   2.        ,
          1.69314718,  30.        ,   5.        ,  41.        ,
          1.        ,   4.        ,   4.        ,   7.        ,
          1.        ,  18.        ,  34.        ,   3.        ,
          3.        ,  10.        ,   1.69314718,   0.        ,
         78.        ,  62.        ,   5.        ,   0.        ,
         43.        ,  31.        ,   1.        ,   8.        ,
          3.        , 152.        ,  11.        ,   0.        ,
        112.        ,   5.        ,   4.        ,   0.        ,
         25.        ,   1.69314718,  13.        ],
       [ 20.        ,  28.        ,   5.        ,   3.        ,
        137.        ,   1.        ,  14.        , 204.        ,
          1.69314718,  17.        ,  11.        ,   8.        ,
         43.        ,  73.        ,   5.        ,  59.        ,
        123.        ,   2.        ,  10.        ,  12.        ,
          3.38629436,   2.        ,  25.        ,  24.        ,
          2.        ,   6.        ,   3.        ,   2.        ,
          3.        ,  49.        ,   3.38629436,  10.        ,
         23.        ,   3.38629436,   1.        ,   1.69314718,
         13.54517744,  47.        ,   2.        ,  13.54517744,
         12.        ,  41.        ,   1.69314718,   4.        ,
          7.        ,   4.        ,   1.69314718,   6.        ,
          1.        ,   1.        ,   7.        ,   8.        ,
          5.        ,   5.        ,  13.        ,  10.        ,
          1.        ,   1.        ,  10.        ,   2.        ,
          8.        ,   9.        ,   3.        ,  15.        ,
          9.        ,   3.38629436,   1.69314718,  11.        ,
          5.        ,   3.38629436,   3.38629436,   6.        ,
         15.        ,   1.        ,  11.        ,  16.        ,
          6.        ,  11.        ,   4.        ,   1.        ,
          8.        ,   2.        ,   1.69314718,   7.        ,
         12.        ,   2.        ,   6.        ,   9.        ,
          2.        ,   3.38629436,   3.38629436,   3.        ,
          4.        ,   6.        ,  11.        ,   8.        ,
          1.69314718,  36.        ,  12.        ,  11.85203026,
          3.38629436,  28.        ,   4.        ,   3.        ,
          0.        ,  34.        ,   5.        ,  49.        ,
          1.        ,   5.        ,   2.        ,   4.        ,
          2.        ,  22.        ,  27.        ,   5.        ,
         12.        ,   5.        ,   0.        ,   1.69314718,
         78.        ,  68.        ,   9.        ,   1.69314718,
         59.        ,  45.        ,   3.        ,  20.        ,
         11.        , 192.        ,  20.        ,   1.69314718,
        163.        ,   7.        ,  10.        ,   1.69314718,
         40.        ,   0.        ,  17.        ]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71

最后把两个相乘就可以得到最终结果,这和以下操作所得的向量是等价的:

from sklearn.feature_extraction.text import TfidfVectorizer
tfidf = TfidfVectorizer(norm=None,use_idf=True, smooth_idf=False,sublinear_tf=False)
print(tfidf.fit_transform(X_class))
  • 1
  • 2
  • 3

短短三行代码,里面包含的细节如此多,网上的多数文章都在不断重复的介绍官网都有写的API,和手算的内容,却没有几个人去验证,对比,可能大家都是调包侠吧。在这里依旧表示一下震惊,疑惑,不解。

总结

以后只发两种文章,一种深度的比如这篇文章,另一种就是小白都能懂和用的介绍型文章,当然我发现看的最多的竟然是故障解决型的文章,这种文章除非网上没有类似的解决方法,不然我是不会发的。我对自己的第一要求就是尽量不写重复的,包括介绍型,我会介绍一些比较新的库。包含深度的文章看心情发,反正也没什么人看,有的时候介绍类写着写着也会变成挖深度的。深度的东西写的很长,有的时候也会拆开来发。
尽管,可能很多人对博客的定义是一个个人总结的东西,我还是把这个博客作为知识分享的途径,尽量分享新的,有深度的,不重复的东西,毕竟独乐乐不如众乐乐。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/350072
推荐阅读
相关标签
  

闽ICP备14008679号