当前位置:   article > 正文

n-gram的理解与实现_n-gram模型代码

n-gram模型代码

N − g r a m N-gram Ngram简单介绍

语言可以看作一个马尔科夫(MDP)过程,假设目前有一句话s,将s按照token的级别切分后得到token list: [w1,w2,w3,w4~wm], 那么这一句话的概率建模就可以描述为P(s)=p(w1)* p(w2|w1) * p(w3|w2,w1) * ~ p(wm| wm-1 , wn-2~ w1).

然而当m趋近于无穷大时,上述建模是难以实现的,所以引入n-gram

n-gram的思想是只看n个token而不管剩下m-n+1个token。实现思想如下,选定大小为n的滑窗,在整个句子上依次滑动,假设n为3,那么我们可以获得[w1,w2,w3]这个token list,
所以p(w3| w1, w2) = 统计所有[w1,w2,w3]数目/统计所有[w1, w2]数目。

经过上述转化后, P(s)= p(w3 |w1, w2) × \times × p(w4 | w2, w3) × \times × ~ × \times × p(wn | wn-1, wn-2), 从而大大节约了计算的时间。

但是根据大数定理,只有当文本容量十分大时,该方法才有意义。

from tqdm import tqdm
import os
import json
class n_gram_metric:
    def __init__(self,corpus,n,gram_path=None,save_gram=False):
        if gram_path!=None:
            self.gram_score=json.load(open(gram_path,'r'))
        else:
            self.gram_score=self.build_gram(corpus,n)
        print('gram score:', self.gram_score)
        self.num_grams=n
        if save_gram:
            if gram_path==None:
                gram_path='gram_score.json'
            with open(gram_path,'r') as f:
                json.dump(f,self.gram_score)
        
    def build_gram(self,corpus,n,return_raw_data=False):
        output={}

        cleaned_corpus=self.clean_method(corpus,10000)
        print('gram table building```')
        for i in tqdm(range(len(corpus)-1)):
            cur_window=cleaned_corpus[i:i+n]
            cur_gram=' '.join(cur_window)
            if cur_gram not in output:
                output[cur_gram]=1
            else:
                output[cur_gram]+=1

        token_unit_list={}
        print('n-1 gram table building```')
        for i in tqdm(range(len(corpus)-1)):
            cur_window=cleaned_corpus[i:i+n-1]
            cur_gram_last=' '.join(cur_window)
            if cur_gram_last not in token_unit_list:
                token_unit_list[cur_gram_last]=1
            else:
                token_unit_list[cur_gram_last]+=1

        print("gram_table:",output)
        print("n_minus1_table:",token_unit_list)

        for key in output:
            last_key=' '.join(key.split(' ')[:n-1])
            output[key]=output[key]/token_unit_list[' '.join(key.split(' ')[:n-1])]

        print('gram_rate:',output)
        if return_raw_data:

            return output,token_list
        else:
            return output
    def clean_method(self,corpus,cut_num=None):
        if cut_num!=None:
            return corpus.split()[:cut_num]
        return corpus.split()
    def encode2gram(self,seq):
        
        seq=self.clean_method(seq)
        
        score=1
        print('gram score computing````')
        for i in tqdm(range(len(seq)-self.num_grams+1)):
            cur_window=' '.join(seq[i:self.num_grams+i])
            score*=self.gram_score[cur_window]
        print('prob of what you input is orgainzed by human```:',score)
        return score
if __name__=='__main__':
    gram_compute=n_gram_metric(corpus='We evaluate the trained models on seven Semanti',n=2)
    print(gram_compute.encode2gram('We evaluate the trained models'))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/397488
推荐阅读
相关标签
  

闽ICP备14008679号