当前位置:   article > 正文

fasttext 训练词向量 并 使用余弦相似度得出短文本的相似度_fasttext实现词相似度

fasttext实现词相似度
# -*- coding: utf-8 -*-
import os

import fasttext
import jieba
import numpy as np
import tqdm
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

base_path = os.path.dirname(os.path.abspath(__file__))
# 我这里使用的 sqlite 连接的某个数据库
# 下面用的表是聊天数据的表,其中我只会使用到 sentence 也就是聊天内容的字段
database_path = os.path.dirname(base_path)
database_dir = os.path.join(database_path, "datas", "data.db")

# 加载 jieba 分词词典
jieba.load_userdict(os.path.join(base_path, "lcut.txt"))
jieba.load_userdict(os.path.join(base_path, "500000-dict.txt"))

def get_data():
	”“”只需保证通过get_data,得到一个所有内容分词后的 txt 即可,词与词之间空格间隔“”“
    engine = create_engine('sqlite:///{}'.format(database_dir))

    Session = sessionmaker(bind=engine)

    session = Session() # type: sqlalchemy.orm.session.Session

    r = session.execute("select sentence from chat where role=0 limit 1000000")

    sentence_objs = r.fetchall()

    sentence_objs_len = len(sentence_objs)

    pass_words = ["message", "http", "系统提示", "戳"]

    with open("finance_news_cut.txt", "w", encoding='utf-8') as f:
        for sentence_obj in tqdm.tqdm(sentence_objs):
            sentence = sentence_obj["sentence"] # type: str

            if any([i in sentence for i in pass_words]):
                continue

            if sentence.isalnum():
                continue
                
            seg_sentence = jieba.cut(sentence.replace("\t", " ").replace("\n", " "))
            
            outline = " ".join(seg_sentence)
            outline = outline + " "

            f.write(outline)
            f.flush()


def train_model():
	”“”训练词向量模型并保存“”“
    model = fasttext.train_unsupervised('finance_news_cut.txt', )
    model.save_model("news_fasttext.model.bin")
    

def get_word_vector(word):
	”“”获取某词词向量“”“
    model = fasttext.load_model('news_fasttext.model.bin')

    word_vector = model.get_word_vector(word)

    return word_vector

def get_sentence_vector(sentence):
	”“”获取某句句向量“”“
    cut_words = jieba.lcut(sentence)

    sentence_vector = None
    
    for word in cut_words:
        word_vector = get_word_vector(word)

        if sentence_vector is not None:
            sentence_vector += word_vector
        else:
            sentence_vector = word_vector

    sentence_vector = sentence_vector / len(cut_words)

    return sentence_vector


def cos_sim(vector_a, vector_b):
    """
    计算两个向量之间的余弦相似度
    :param vector_a: 向量 a
    :param vector_b: 向量 b
    :return: sim
    """
    vector_a = np.mat(vector_a)
    vector_b = np.mat(vector_b)
    num = float(vector_a * vector_b.T)
    denom = np.linalg.norm(vector_a) * np.linalg.norm(vector_b)
    cos = num / denom
    sim = 0.5 + 0.5 * cos
    return sim

if __name__ == "__main__":
    a = get_sentence_vector("可以包邮吗")
    b = get_sentence_vector("能不能包邮")
    print(cos_sim(a, b))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/288929
推荐阅读
相关标签
  

闽ICP备14008679号