当前位置:   article > 正文

NLP_知识图谱_图谱问答实战

NLP_知识图谱_图谱问答实战


图谱问答

图谱问答有很多种情况,例如根据实体和关系查询尾实体,或者根据实体查询关系,甚至还会出现多跳的情况,不同的情况采用的方法略有不同,我们先来看最简单的情况,根据头实体和关系查询尾实体。

1、找到实体与关系,可以采用BIO的形式做NER,也可以直接使用分类的方法
2、实体链接,如果遇到相同名字的实体,需要做一个消歧

NER

目前的NER的方式很多,基本的结构都是encoder+crf层

ac自动机

1、构建前缀树
2、给前缀树加上fail指针
节点i的fail指针,如果在第一层,则指向root节点,其它情况指向其父节点的fail指针指向的节点的相同节点
在这里插入图片描述

有如下的几个模式串:she he say shr her
匹配串:yasherhs
在这里插入图片描述

实体链接

实体链接包括两个步骤:
Candidate Entity Generation、Entity Disambiguation

找到候选实体后,下一步就是实体消歧

实体消歧

实体消歧,这里我们使用的是匹配的方法:
1、使用孪生网络,计算相似度
2、对问题和候选集做embedding,计算余弦相似度

多跳问答

在这里插入图片描述

neo4j_graph执行流程

1、先执行import_data.py脚本,把company_data下面的数据导入到neo4j

2、执行gnn/saint.py脚本进行节点分类

3、company.csv文件是每个节点的属性

结构图在这里插入图片描述

company_data在这里插入图片描述

截图举几个例子(ps:数据为虚假,作为学习使用):
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

代码与数据

先启动neo4j图数据库

操作流程:WIN+R,cmd,neo4j.bat concole
在这里插入图片描述
在这里插入图片描述

import_data

import os
from py2neo import Node, Subgraph, Graph, Relationship, NodeMatcher
from tqdm import tqdm
import pandas as pd
import numpy as np

#graph = Graph("http://127.0.0.1:7474", auth=("neo4j", "qwer"))

#graph = Graph("http://127.0.0.1:7474", auth=("neo4j", "qwer"))

#uri = 'bolt://localhost:7687'
#graph = Graph(uri, auth=("neo4j", "password"), port= 7687, secure=True)

#uri = uri = 'http://localhost:7687'
#graph = Graph(uri, auth=("neo4j", "qwer"), port= 7687, secure=True, name= "StellarGraph")

import py2neo
default_host = os.environ.get("STELLARGRAPH_NEO4J_HOST")

# Create the Neo4j Graph database object; the arguments can be edited to specify location and authentication

graph = py2neo.Graph(host=default_host, port=7687, user='neo4j', password='qwer')

def import_company():
    df = pd.read_csv('company_data/公司.csv')
    eid = df['eid'].values
    name = df['companyname'].values

    nodes = []
    data = list(zip(eid, name))
    for eid, name in tqdm(data):
        profit = np.random.randint(100000, 100000000, 1)[0]
        node = Node('company', name=name, profit=int(profit), eid=eid)
        nodes.append(node)

    graph.create(Subgraph(nodes))


def import_person():
    df = pd.read_csv('company_data/人物.csv')
    pid = df['personcode'].values
    name = df['personname'].values

    nodes = []
    data = list(zip(pid, name))
    for eid, name in tqdm(data):
        age = np.random.randint(20, 70, 1)[0]
        node = Node('person', name=name, age=int(age), pid=str(eid))
        nodes.append(node)

    graph.create(Subgraph(nodes))


def import_industry():
    df = pd.read_csv('company_data/行业.csv')
    names = df['orgtype'].values

    nodes = []
    for name in tqdm(names):
        node = Node('industry', name=name)
        nodes.append(node)

    graph.create(Subgraph(nodes))


def import_assign():
    df = pd.read_csv('company_data/分红.csv')
    names = df['schemetype'].values

    nodes = []
    for name in tqdm(names):
        node = Node('assign', name=name)
        nodes.append(node)

    graph.create(Subgraph(nodes))


def import_violations():
    df = pd.read_csv('company_data/违规类型.csv')
    names = df['gooltype'].values

    nodes = []
    for name in tqdm(names):
        node = Node('violations', name=name)
        nodes.append(node)

    graph.create(Subgraph(nodes))


def import_bond():
    df = pd.read_csv('company_data/债券类型.csv')
    names = df['securitytype'].values

    nodes = []
    for name in tqdm(names):
        node = Node('bond', name=name)
        nodes.append(node)

    graph.create(Subgraph(nodes))


# def import_dishonesty():
#     node = Node('dishonesty', name='失信')
#     graph.create(node)


def import_relation():
    df = pd.read_csv('company_data/公司-人物.csv')
    matcher = NodeMatcher(graph)
    eid = df['eid'].values
    pid = df['pid'].values
    post = df['post'].values
    relations = []
    data = list(zip(eid, pid, post))
    for e, p, po in tqdm(data):
        company = matcher.match('company', eid=e).first()
        person = matcher.match('person', pid=str(p)).first()
        if company is not None and person is not None:
            relations.append(Relationship(company, po, person))

    graph.create(Subgraph(relationships=relations))
    print('import company-person relation succeeded')

    df = pd.read_csv('company_data/公司-行业.csv')
    matcher = NodeMatcher(graph)
    eid = df['eid'].values
    name = df['industry'].values
    relations = []
    data = list(zip(eid, name))
    for e, n in tqdm(data):
        company = matcher.match('company', eid=e).first()
        industry = matcher.match('industry', name=str(n)).first()
        if company is not None and industry is not None:
            relations.append(Relationship(company, '行业类型', industry))

    graph.create(Subgraph(relationships=relations))
    print('import company-industry relation succeeded')

    df = pd.read_csv('company_data/公司-分红.csv')
    matcher = NodeMatcher(graph)
    eid = df['eid'].values
    name = df['assign'].values
    relations = []
    data = list(zip(eid, name))
    for e, n in tqdm(data):
        company = matcher.match('company', eid=e).first()
        assign = matcher.match('assign', name=str(n)).first()
        if company is not None and assign is not None:
            relations.append(Relationship(company, '分红方式', assign))

    graph.create(Subgraph(relationships=relations))
    print('import company-assign relation succeeded')

    df = pd.read_csv('company_data/公司-违规.csv')
    matcher = NodeMatcher(graph)
    eid = df['eid'].values
    name = df['violations'].values
    relations = []
    data = list(zip(eid, name))
    for e, n in tqdm(data):
        company = matcher.match('company', eid=e).first()
        violations = matcher.match('violations', name=str(n)).first()
        if company is not None and violations is not None:
            relations.append(Relationship(company, '违规类型', violations))

    graph.create(Subgraph(relationships=relations))
    print('import company-violations relation succeeded')

    df = pd.read_csv('company_data/公司-债券.csv')
    matcher = NodeMatcher(graph)
    eid = df['eid'].values
    name = df['bond'].values
    relations = []
    data = list(zip(eid, name))
    for e, n in tqdm(data):
        company = matcher.match('company', eid=e).first()
        bond = matcher.match('bond', name=str(n)).first()
        if company is not None and bond is not None:
            relations.append(Relationship(company, '债券类型', bond))

    graph.create(Subgraph(relationships=relations))
    print('import company-bond relation succeeded')

    # df = pd.read_csv('company_data/公司-失信.csv')
    # matcher = NodeMatcher(graph)
    # eid = df['eid'].values
    # rel = df['dishonesty'].values
    # relations = []
    # data = list(zip(eid, rel))
    # for e, r in tqdm(data):
    #     company = matcher.match('company', eid=e).first()
    #     dishonesty = matcher.match('dishonesty', name='失信').first()
    #     if company is not None and dishonesty is not None:
    #         if pd.notna(r):
    #             if int(r) == 0:
    #                 relations.append(Relationship(company, '无', dishonesty))
    #             elif int(r) == 1:
    #                 relations.append(Relationship(company, '有', dishonesty))
    #
    # graph.create(Subgraph(relationships=relations))
    # print('import company-dishonesty relation succeeded')


def import_company_relation():
    df = pd.read_csv('company_data/公司-供应商.csv')
    matcher = NodeMatcher(graph)
    eid1 = df['eid1'].values
    eid2 = df['eid2'].values
    relations = []
    data = list(zip(eid1, eid2))
    for e1, e2 in tqdm(data):
        if pd.notna(e1) and pd.notna(e2) and e1 != e2:
            company1 = matcher.match('company', eid=e1).first()
            company2 = matcher.match('company', eid=e2).first()

            if company1 is not None and company2 is not None:
                relations.append(Relationship(company1, '供应商', company2))

    graph.create(Subgraph(relationships=relations))
    print('import company-supplier relation succeeded')

    df = pd.read_csv('company_data/公司-担保.csv')
    matcher = NodeMatcher(graph)
    eid1 = df['eid1'].values
    eid2 = df['eid2'].values
    relations = []
    data = list(zip(eid1, eid2))
    for e1, e2 in tqdm(data):
        if pd.notna(e1) and pd.notna(e2) and e1 != e2:
            company1 = matcher.match('company', eid=e1).first()
            company2 = matcher.match('company', eid=e2).first()

            if company1 is not None and company2 is not None:
                relations.append(Relationship(company1, '担保', company2))

    graph.create(Subgraph(relationships=relations))
    print('import company-guarantee relation succeeded')

    df = pd.read_csv('company_data/公司-客户.csv')
    matcher = NodeMatcher(graph)
    eid1 = df['eid1'].values
    eid2 = df['eid2'].values
    relations = []
    data = list(zip(eid1, eid2))
    for e1, e2 in tqdm(data):
        if pd.notna(e1) and pd.notna(e2):
            company1 = matcher.match('company', eid=e1).first()
            company2 = matcher.match('company', eid=e2).first()

            if company1 is not None and company2 is not None:
                relations.append(Relationship(company1, '客户', company2))

    graph.create(Subgraph(relationships=relations))
    print('import company-customer relation succeeded')


def delete_relation():
    cypher = 'match ()-[r]-() delete r'
    graph.run(cypher)


def delete_node():
    cypher = 'match (n) delete n'
    graph.run(cypher)


def import_data():
    import_company()
    import_company_relation()

    import_person()
    import_industry()
    import_assign()
    import_violations()
    import_bond()
    # import_dishonesty()

    import_relation()


def delete_data():
    delete_relation()
    delete_node()
    print('delete data succeeded')


if __name__ == '__main__':
    profit = np.random.randint(100000, 100000000, 10).tolist()

    delete_data()
    import_data()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291

在这里插入图片描述

create_question_data

from py2neo import Graph
import numpy as np
import pandas as pd

graph = Graph("http://localhost:7474", auth=("neo4j", "qwer"))

# import os
# import py2neo
# default_host = os.environ.get("STELLARGRAPH_NEO4J_HOST")
# graph = py2neo.Graph(host=default_host, port=7687, user='neo4j', password='qwer')

def create_attribute_question():
    company = graph.run('MATCH (n:company) RETURN n.name as name').to_ndarray()
    person = graph.run('MATCH (n:person) RETURN n.name as name').to_ndarray()

    questions = []

    for c in company:
        c = c[0].strip()
        question = f"{c}的收益"
        questions.append(question)
        question = f"{c}的收入"
        questions.append(question)

    for p in person:
        p = p[0].strip()
        question = f"{p}的年龄是几岁"
        questions.append(question)
        question = f"{p}多大"
        questions.append(question)
        question = f"{p}几岁"
        questions.append(question)

    return questions


def create_entity_question():
    questions = []

    for _ in range(250):
        for op in ['大于', '等于', '小于', '是', '有']:
            profit = np.random.randint(10000, 10000000, 1)[0]
            question = f"收益{op}{profit}的公司有哪些"
            questions.append(question)
            profit = np.random.randint(10000, 10000000, 1)[0]
            question = f"哪些公司收益{op}{profit}"
            questions.append(question)

    for _ in range(250):
        for op in ['大于', '等于', '小于', '是', '有']:
            profit = np.random.randint(20, 60, 1)[0]
            question = f"年龄{op}{profit}的人有哪些"
            questions.append(question)
            profit = np.random.randint(20, 60, 1)[0]
            question = f"哪些人年龄{op}{profit}"
            questions.append(question)

    return questions


def create_relation_question():
    relation = graph.run('MATCH (n)-[r]->(m) RETURN n.name as name, type(r) as r').to_ndarray()

    questions = []

    for r in relation:
        if str(r[1]) in ['董事', '监事']:
            question = f"{r[0]}{r[1]}是谁"
            questions.append(question)
        else:
            question = f"{r[0]}{r[1]}"
            questions.append(question)
            question = f"{r[0]}{r[1]}是啥"
            questions.append(question)
            question = f"{r[0]}{r[1]}什么"
            questions.append(question)

    return questions


q1 = create_entity_question()
q2 = create_attribute_question()
q3 = create_relation_question()

df = pd.DataFrame()
df['question'] = q1 + q2 + q3
df['label'] = [0] * len(q1) + [1] * len(q2) + [2] * len(q3)

df.to_csv('question_classification.csv', encoding='utf_8_sig', index=False)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89

data_process

import pandas as pd
import jieba
from collections import defaultdict
import numpy as np
import os

__file__ = 'kbqa'
path = os.path.dirname(__file__)


def tokenize(text, use_jieba=True):
    if use_jieba:
        res = list(jieba.cut(text, cut_all=False))
    else:
        res = list(text)
    return res


# 构建词典
def build_vocab(del_word_frequency=0):
    data = pd.read_csv('question_classification.csv')
    segment = data['question'].apply(tokenize)

    word_frequency = defaultdict(int)
    for row in segment:
        for i in row:
            word_frequency[i] += 1

    word_sort = sorted(word_frequency.items(), key=lambda x: x[1], reverse=True)  # 根据词频降序排序

    f = open('vocab.txt', 'w', encoding='utf-8')
    f.write('[PAD]' + "\n" + '[UNK]' + "\n")
    for d in word_sort:
        if d[1] > del_word_frequency:
            f.write(d[0] + "\n")
    f.close()


# 划分训练集和测试集
def split_data(df, split=0.7):
    df = df.sample(frac=1)
    length = len(df)
    train_data = df[0:length - 2000]
    eval_data = df[length - 2000:]

    return train_data, eval_data


vocab = {}
if os.path.exists(path + '/vocab.txt'):
    with open(path + '/vocab.txt', encoding='utf-8')as file:
        for line in file.readlines():
            vocab[line.strip()] = len(vocab)


# 把数据转换成index
def seq2index(seq):
    seg = tokenize(seq)
    seg_index = []
    for s in seg:
        seg_index.append(vocab.get(s, 1))
    return seg_index


# 统一长度
def padding_seq(X, max_len=10):
    return np.array([
        np.concatenate([x, [0] * (max_len - len(x))]) if len(x) < max_len else x[:max_len] for x in X
    ])


if __name__ == '__main__':
    build_vocab(5)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73

在这里插入图片描述

ac_automaton

import ahocorasick
from py2neo import Graph

graph = Graph("http://localhost:7474", auth=("neo4j", "qwer"))
company = graph.run('MATCH (n:company) RETURN n.name as name').to_ndarray()
relation = graph.run('MATCH ()-[r]-() RETURN distinct type(r)').to_ndarray()

ac_company = ahocorasick.Automaton()
ac_relation = ahocorasick.Automaton()

for key in enumerate(company):
    ac_company.add_word(key[1][0], key[1][0])
for key in enumerate(relation):
    ac_relation.add_word(key[1][0], key[1][0])

ac_company.make_automaton()
ac_relation.make_automaton()

# haystack = '浙江东阳东欣房地产开发有限公司的客户的供应商'
haystack = '衡水中南锦衡房地产有限公司的债券类型'
# haystack = '临沂金丰公社农业服务有限公司的分红方式'
print('question:', haystack)

subject = ''
predicate = []

for end_index, original_value in ac_company.iter(haystack):
    start_index = end_index - len(original_value) + 1
    print('公司实体:', (start_index, end_index, original_value))
    assert haystack[start_index:start_index + len(original_value)] == original_value
    subject = original_value

for end_index, original_value in ac_relation.iter(haystack):
    start_index = end_index - len(original_value) + 1
    print('关系:', (start_index, end_index, original_value))
    assert haystack[start_index:start_index + len(original_value)] == original_value
    predicate.append(original_value)

for p in predicate:
    cypher = f'''match (s:company)-[p:`{p}`]-(o) where s.name='{subject}' return o.name'''
    print(cypher)
    res = graph.run(cypher).to_ndarray()
    # print(res)
    subject = res[0][0]
print('answer:', res[0][0])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45

在这里插入图片描述

torch_utils

import torch
import time
import numpy as np
import six


class TrainHandler:

    def __init__(self,
                 train_loader,
                 valid_loader,
                 model,
                 criterion,
                 optimizer,
                 model_path,
                 batch_size=32,
                 epochs=5,
                 scheduler=None,
                 gpu_num=0):
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.model_path = model_path
        self.batch_size = batch_size
        self.epochs = epochs
        self.scheduler = scheduler
        if torch.cuda.is_available():
            self.device = torch.device(f'cuda:{gpu_num}')
            print('Training device is gpu:{gpu_num}')
        else:
            self.device = torch.device('cpu')
            print('Training device is cpu')
        self.model = model.to(self.device)

    def _train_func(self):
        train_loss = 0
        train_correct = 0
        for i, (x, y) in enumerate(self.train_loader):
            self.optimizer.zero_grad()
            x, y = x.to(self.device).long(), y.to(self.device)
            output = self.model(x)
            loss = self.criterion(output, y)
            train_loss += loss.item()
            loss.backward()
            self.optimizer.step()
            train_correct += (output.argmax(1) == y).sum().item()

        if self.scheduler is not None:
            self.scheduler.step()

        return train_loss / len(self.train_loader), train_correct / len(self.train_loader.dataset)

    def _test_func(self):
        valid_loss = 0
        valid_correct = 0
        for x, y in self.valid_loader:
            x, y = x.to(self.device).long(), y.to(self.device)
            with torch.no_grad():
                output = self.model(x)
                loss = self.criterion(output, y)
                valid_loss += loss.item()
                valid_correct += (output.argmax(1) == y).sum().item()

        return valid_loss / len(self.valid_loader), valid_correct / len(self.valid_loader.dataset)

    def train(self):
        min_valid_loss = float('inf')

        for epoch in range(self.epochs):
            start_time = time.time()
            train_loss, train_acc = self._train_func()
            valid_loss, valid_acc = self._test_func()

            if min_valid_loss > valid_loss:
                min_valid_loss = valid_loss
                torch.save(self.model, self.model_path)
                print(f'\tSave model done valid loss: {valid_loss:.4f}')

            secs = int(time.time() - start_time)
            mins = secs / 60
            secs = secs % 60

            print('Epoch: %d' % (epoch + 1), " | time in %d minutes, %d seconds" % (mins, secs))
            print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)')
            print(f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)')


def torch_text_process():
    from torchtext import data

    def tokenizer(text):
        import jieba
        return list(jieba.cut(text))

    TEXT = data.Field(sequential=True, tokenize=tokenizer, lower=True, fix_length=20)
    LABEL = data.Field(sequential=False, use_vocab=False)
    all_dataset = data.TabularDataset.splits(path='',
                                             train='LCQMC.csv',
                                             format='csv',
                                             fields=[('sentence1', TEXT), ('sentence2', TEXT), ('label', LABEL)])[0]
    TEXT.build_vocab(all_dataset)
    train, valid = all_dataset.split(0.1)
    (train_iter, valid_iter) = data.BucketIterator.splits(datasets=(train, valid),
                                                          batch_sizes=(64, 128),
                                                          sort_key=lambda x: len(x.sentence1))
    return train_iter, valid_iter


def pad_sequences(sequences, maxlen=None, dtype='int32',
                  padding='post', truncating='pre', value=0.):
    """Pads sequences to the same length.

    This function transforms a list of
    `num_samples` sequences (lists of integers)
    into a 2D Numpy array of shape `(num_samples, num_timesteps)`.
    `num_timesteps` is either the `maxlen` argument if provided,
    or the length of the longest sequence otherwise.

    Sequences that are shorter than `num_timesteps`
    are padded with `value` at the end.

    Sequences longer than `num_timesteps` are truncated
    so that they fit the desired length.
    The position where padding or truncation happens is determined by
    the arguments `padding` and `truncating`, respectively.

    Pre-padding is the default.

    # Arguments
        sequences: List of lists, where each element is a sequence.
        maxlen: Int, maximum length of all sequences.
        dtype: Type of the output sequences.
            To pad sequences with variable length strings, you can use `object`.
        padding: String, 'pre' or 'post':
            pad either before or after each sequence.
        truncating: String, 'pre' or 'post':
            remove values from sequences larger than
            `maxlen`, either at the beginning or at the end of the sequences.
        value: Float or String, padding value.

    # Returns
        x: Numpy array with shape `(len(sequences), maxlen)`

    # Raises
        ValueError: In case of invalid values for `truncating` or `padding`,
            or in case of invalid shape for a `sequences` entry.
    """
    if not hasattr(sequences, '__len__'):
        raise ValueError('`sequences` must be iterable.')
    num_samples = len(sequences)

    lengths = []
    for x in sequences:
        try:
            lengths.append(len(x))
        except TypeError:
            raise ValueError('`sequences` must be a list of iterables. '
                             'Found non-iterable: ' + str(x))

    if maxlen is None:
        maxlen = np.max(lengths)

    # take the sample shape from the first non empty sequence
    # checking for consistency in the main loop below.
    sample_shape = tuple()
    for s in sequences:
        if len(s) > 0:
            sample_shape = np.asarray(s).shape[1:]
            break

    is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.unicode_)
    if isinstance(value, six.string_types) and dtype != object and not is_dtype_str:
        raise ValueError("`dtype` {} is not compatible with `value`'s type: {}\n"
                         "You should set `dtype=object` for variable length strings."
                         .format(dtype, type(value)))

    x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype)
    for idx, s in enumerate(sequences):
        if not len(s):
            continue  # empty list/array was found
        if truncating == 'pre':
            trunc = s[-maxlen:]
        elif truncating == 'post':
            trunc = s[:maxlen]
        else:
            raise ValueError('Truncating type "%s" '
                             'not understood' % truncating)

        # check `trunc` has expected shape
        trunc = np.asarray(trunc, dtype=dtype)
        if trunc.shape[1:] != sample_shape:
            raise ValueError('Shape of sample %s of sequence at position %s '
                             'is different from expected shape %s' %
                             (trunc.shape[1:], idx, sample_shape))

        if padding == 'post':
            x[idx, :len(trunc)] = trunc
        elif padding == 'pre':
            x[idx, -len(trunc):] = trunc
        else:
            raise ValueError('Padding type "%s" not understood' % padding)
    return x


if __name__ == '__main__':
    torch_text_process()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208

text_cnn

import torch
from torch import nn


class TextCNN(nn.Module):
    def __init__(self, vocab_len, embedding_size, n_class):
        super().__init__()

        self.embedding = nn.Embedding(vocab_len, embedding_size)

        self.cnn1 = nn.Conv2d(in_channels=1, out_channels=100, kernel_size=[3, embedding_size])
        self.cnn2 = nn.Conv2d(in_channels=1, out_channels=100, kernel_size=[4, embedding_size])
        self.cnn3 = nn.Conv2d(in_channels=1, out_channels=100, kernel_size=[5, embedding_size])

        self.max_pool1 = nn.MaxPool1d(kernel_size=8)
        self.max_pool2 = nn.MaxPool1d(kernel_size=7)
        self.max_pool3 = nn.MaxPool1d(kernel_size=6)

        self.drop_out = nn.Dropout(0.2)
        self.full_connect = nn.Linear(300, n_class)

    def forward(self, x):
        embedding = self.embedding(x)
        embedding = embedding.unsqueeze(1)

        cnn1_out = self.cnn1(embedding).squeeze(-1)
        cnn2_out = self.cnn2(embedding).squeeze(-1)
        cnn3_out = self.cnn3(embedding).squeeze(-1)

        out1 = self.max_pool1(cnn1_out)
        out2 = self.max_pool2(cnn2_out)
        out3 = self.max_pool3(cnn3_out)

        out = torch.cat([out1, out2, out3], dim=1).squeeze(-1)

        out = self.drop_out(out)
        out = self.full_connect(out)
        # out = torch.softmax(out, dim=-1).squeeze(dim=-1)
        return out

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40

train

import torch
from torch.utils.data import TensorDataset, DataLoader
from kbqa.torch_utils import TrainHandler
from kbqa.data_process import *
from kbqa.text_cnn import TextCNN

# df = pd.read_csv('question_classification.csv')
# print(df['label'].value_counts())


def load_data(batch_size=32):
    df = pd.read_csv('kbqa/question_classification.csv')
    train_df, eval_df = split_data(df)
    train_x = df['question']
    train_y = df['label']
    valid_x = eval_df['question']
    valid_y = eval_df['label']

    train_x = padding_seq(train_x.apply(seq2index))
    train_y = np.array(train_y)
    valid_x = padding_seq(valid_x.apply(seq2index))
    valid_y = np.array(valid_y)

    train_data_set = TensorDataset(torch.from_numpy(train_x),
                                   torch.from_numpy(train_y))
    valid_data_set = TensorDataset(torch.from_numpy(valid_x),
                                   torch.from_numpy(valid_y))
    train_data_loader = DataLoader(dataset=train_data_set, batch_size=batch_size, shuffle=True)
    valid_data_loader = DataLoader(dataset=valid_data_set, batch_size=batch_size, shuffle=True)

    return train_data_loader, valid_data_loader


train_loader, valid_loader = load_data(batch_size=64)

model = TextCNN(1289, 256, 3)# 原model = TextCNN(1141, 256, 3),1289根据vocat.txt行数
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)
model_path = 'text_cnn.p'
handler = TrainHandler(train_loader,valid_loader,model,
                       criterion,
                       optimizer,
                       model_path,
                       batch_size=32,
                       epochs=5,
                       scheduler=None,
                       gpu_num=0)
handler.train()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50

在这里插入图片描述

main

import torch
from kbqa.data_process import *
import ahocorasick
from py2neo import Graph
import re
import traceback


model = torch.load('kbqa/text_cnn.p', map_location=torch.device('cpu'))
model.eval()

graph = Graph("http://localhost:7474", auth=("neo4j", "qwer"))
company = graph.run('MATCH (n:company) RETURN n.name as name').to_ndarray()
person = graph.run('MATCH (n:person) RETURN n.name as name').to_ndarray()
relation = graph.run('MATCH ()-[r]-() RETURN distinct type(r)').to_ndarray()

ac_company = ahocorasick.Automaton()
ac_person = ahocorasick.Automaton()
ac_relation = ahocorasick.Automaton()

for key in enumerate(company):
    ac_company.add_word(key[1][0], key[1][0])
for key in enumerate(person):
    ac_person.add_word(key[1][0], key[1][0])
for key in enumerate(relation):
    ac_relation.add_word(key[1][0], key[1][0])
ac_relation.add_word('年龄', '年龄')
ac_relation.add_word('年纪', '年纪')
ac_relation.add_word('收入', '收入')
ac_relation.add_word('收益', '收益')

ac_company.make_automaton()
ac_person.make_automaton()
ac_relation.make_automaton()


def classification_predict(s):
    s = seq2index(s)
    s = torch.from_numpy(padding_seq([s])).long() #.cuda().long()
    out = model(s)
    out = out.cpu().data.numpy()
    print(out)
    return out.argmax(1)[0]


def entity_link(text):
    subject = []
    subject_type = None
    for end_index, original_value in ac_company.iter(text):
        start_index = end_index - len(original_value) + 1
        print('实体:', (start_index, end_index, original_value))
        assert text[start_index:start_index + len(original_value)] == original_value
        subject.append(original_value)
        subject_type = 'company'
    for end_index, original_value in ac_person.iter(text):
        start_index = end_index - len(original_value) + 1
        print('实体:', (start_index, end_index, original_value))
        assert text[start_index:start_index + len(original_value)] == original_value
        subject.append(original_value)
        subject_type = 'person'

    return subject[0], subject_type


def get_op(text):
    pattern = re.compile(r'\d+')
    num = pattern.findall(text)
    op = None
    if '大于' in text:
        op = '>'
    elif '小于' in text:
        op = '<'
    elif '等于' in text or '是' in text:
        op = '='
    return op, float(num[0])


def kbqa(text):
    print('*' * 100)
    cls = classification_predict(text)
    print('question type:', cls)
    res = ''

    if cls == 0:
        op, num = get_op(text)
        subject_type = ''
        attribute = ''
        for w in ['年龄', '年纪']:
            if w in text:
                subject_type = 'person'
                attribute = 'age'
                break
        for w in ['收入', '收益']:
            if w in text:
                subject_type = 'company'
                attribute = 'profit'
                break
        cypher = f'match (n:{subject_type}) where n.{attribute}{op}{num} return n.name'
        print(cypher)
        res = graph.run(cypher).to_ndarray()
    elif cls == 1:
        # 查询属性
        subject, subject_type = entity_link(text)
        predicate = ''
        for w in ['年龄', '年纪']:
            if w in text and subject_type == 'person':
                predicate = 'age'
                break
        for w in ['收入', '收益']:
            if w in text and subject_type == 'company':
                predicate = 'profit'
                break
        cypher = f'''match (n:{subject_type}) where n.name='{subject}' return n.{predicate}'''
        print(cypher)
        res = graph.run(cypher).to_ndarray()
    elif cls == 2:
        subject = ''
        for end_index, original_value in ac_company.iter(text):
            start_index = end_index - len(original_value) + 1
            print('公司实体:', (start_index, end_index, original_value))
            assert text[start_index:start_index + len(original_value)] == original_value
            subject = original_value
        predicate = []
        for end_index, original_value in ac_relation.iter(text):
            start_index = end_index - len(original_value) + 1
            print('关系:', (start_index, end_index, original_value))
            assert text[start_index:start_index + len(original_value)] == original_value
            predicate.append(original_value)
        for i, p in enumerate(predicate):
            cypher = f'''match (s:company)-[p:`{p}`]->(o) where s.name='{subject}' return o.name'''
            print(cypher)
            res = graph.run(cypher).to_ndarray()
            subject = res[0][0]
            if i == len(predicate) - 1:
                break
            new_index = text.index(p) + len(p)
            new_question = subject + str(text[new_index:])
            print('new question:', new_question)
            res = kbqa(new_question)
            break
    return res


if __name__ == '__main__':

    while 1:
        try:
            text = input('text:')
            res = kbqa(text)
            print(res)
        except:
            print(traceback.format_exc())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152

在这里插入图片描述

图谱问答实战小结

模型的整体结构:ac自动机+找实体+多跳问答

ps:这里实体没有多个,用不到实体消歧,这里我们使用的是匹配的方法


学习的参考资料:
七月在线NLP高级班

代码参考:
https://github.com/terrifyzhao/neo4j_graph

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/426698
推荐阅读
相关标签
  

闽ICP备14008679号