赞
踩
import json from keras_bert import load_trained_model_from_checkpoint, Tokenizer import codecs from keras.layers import * from keras.models import Model import keras.backend as K from keras.optimizers import Adam from keras.callbacks import Callback from tqdm import tqdm import jieba import editdistance import re import numpy as np import tensorflow as tf import keras import pandas as pd print(tf.__version__) print(keras.__version__)
1.13.1
2.2.4
''' { "table_id": "a1b2c3d4", # 相应表格的id "question": "世茂茂悦府新盘容积率大于1,请问它的套均面积是多少?", # 自然语言问句 "sql":{ # 真实SQL "sel": [7], # SQL选择的列 "agg": [0], # 选择的列相应的聚合函数, '0'代表无 "cond_conn_op": 0, # 条件之间的关系 "conds": [ [1, 2, "世茂茂悦府"], # 条件列, 条件类型, 条件值,col_1 == "世茂茂悦府" [6, 0, "1"] ] } } # 其中条件运算符、聚合符、连接符分别如下 op_sql_dict = {0:">", 1:"<", 2:"==", 3:"!="} agg_sql_dict = {0:"", 1:"AVG", 2:"MAX", 3:"MIN", 4:"COUNT", 5:"SUM"} conn_sql_dict = {0:"", 1:"and", 2:"or"} ''' maxlen = 160 num_agg = 7 # agg_sql_dict = {0:"", 1:"AVG", 2:"MAX", 3:"MIN", 4:"COUNT", 5:"SUM", 6:"不被select"} num_op = 5 # {0:">", 1:"<", 2:"==", 3:"!=", 4:"不被select"} num_cond_conn_op = 3 # conn_sql_dict = {0:"", 1:"and", 2:"or"} learning_rate = 5e-5 min_learning_rate = 1e-5 config_path = 'E:\\zym_test\\test\\nlp\\chinese_wwm_ext_L-12_H-768_A-12\\bert_config.json' checkpoint_path = 'E:\\zym_test\\test\\nlp\\chinese_wwm_ext_L-12_H-768_A-12\\bert_model.ckpt' dict_path = 'E:\\zym_test\\test\\nlp\\chinese_wwm_ext_L-12_H-768_A-12\\vocab.txt'
def read_data(data_file, table_file): data, tables = [], {} with open(data_file,encoding='UTF-8') as f: for l in f: data.append(json.loads(l)) with open(table_file,encoding='UTF-8') as f: for l in f: l = json.loads(l) # 观察f后发现,rows、name、title、header、common、ids、types # rows是一个表格,里面有具体的值,name:是该表的名称,title未知, # header是表的列名,common未知,ids是name的id,types是具体值的类型:是text还是real等 # 创建新的字典 # 原来header变为现在的headers # 将headers添加索引记录到header2id中,(索引,名字) # content为空 # all_values创建一个空set() # rows,将列表保存为数组 d = {} d['headers'] = l['header'] d['header2id'] = {j: i for i, j in enumerate(d['headers'])} d['content'] = {} d['all_values'] = set() rows = np.array(l['rows']) # 填充content字典:{列名:该列的值},并且去除了重复的值 for i, h in enumerate(d['headers']): d['content'][h] = set(rows[:, i]) # 记录所有的值(去除重复):set.update() -> 更新原有set(),并去重 d['all_values'].update(d['content'][h]) # hasattr() 函数用于判断对象是否包含对应的属性 # 去除空位置 d['all_values'] = set([i for i in d['all_values'] if hasattr(i, '__len__')]) # {id:d} tables[l['id']] = d return data, tables
train_data, train_tables = read_data('E:/zym_test/test/nlp/data/train/train.json','E:/zym_test/test/nlp/data/train/train.tables.json')
valid_data, valid_tables = read_data('E:/zym_test/test/nlp/data/val/val.json','E:/zym_test/test/nlp/data/val/val.tables.json')
test_data, test_tables = read_data('E:/zym_test/test/nlp/data/test/test.json','E:/zym_test/test/nlp/data/test/test.tables.json')
train_data[0:4]
[{'table_id': '4d29d0513aaa11e9b911f40f24344a08', 'question': '二零一九年第四周大黄蜂和密室逃生这两部影片的票房总占比是多少呀', 'sql': {'agg': [5], 'cond_conn_op': 2, 'sel': [2], 'conds': [[0, 2, '大黄蜂'], [0, 2, '密室逃生']]}}, {'table_id': '4d29d0513aaa11e9b911f40f24344a08', 'question': '你好,你知道今年第四周密室逃生,还有那部大黄蜂它们票房总的占比吗', 'sql': {'agg': [5], 'cond_conn_op': 2, 'sel': [2], 'conds': [[0, 2, '大黄蜂'], [0, 2, '密室逃生']]}}, {'table_id': '4d29d0513aaa11e9b911f40f24344a08', 'question': '我想你帮我查一下第四周大黄蜂,还有密室逃生这两部电影票房的占比加起来会是多少来着', 'sql': {'agg': [5], 'cond_conn_op': 2, 'sel': [2], 'conds': [[0, 2, '大黄蜂'], [0, 2, '密室逃生']]}}, {'table_id': '4d25e6403aaa11e9bdbbf40f24344a08', 'question': '有几家传媒公司16年为了融资收购其他资产而进行定增的呀', 'sql': {'agg': [4], 'cond_conn_op': 1, 'sel': [1], 'conds': [[6, 2, '2016'], [7, 2, '融资收购其他资产']]}}]
train_tables[ '4d29d0513aaa11e9b911f40f24344a08' ]
{'headers': ['影片名称', '周票房(万)', '票房占比(%)', '场均人次'], 'header2id': {'影片名称': 0, '周票房(万)': 1, '票房占比(%)': 2, '场均人次': 3}, 'content': {'影片名称': {'“大”人物', '一条狗的回家路', '大黄蜂', '家和万事惊', '密室逃生', '掠食城市', '死侍2:我爱我家', '海王', '白蛇:缘起', '钢铁飞龙之奥特曼崛起'}, '周票房(万)': {'10503.8', '10637.3', '3322.9', '356.6', '360.0', '500.3', '5841.4', '595.5', '635.2', '6426.6'}, '票房占比(%)': {'0.9', '1.2', '1.4', '1.5', '14.2', '15.6', '25.4', '25.8', '8.1'}, '场均人次': {'25.0', '3.0', '4.0', '5.0', '6.0', '7.0'}}, 'all_values': {'0.9', '1.2', '1.4', '1.5', '10503.8', '10637.3', '14.2', '15.6', '25.0', '25.4', '25.8', '3.0', '3322.9', '356.6', '360.0', '4.0', '5.0', '500.3', '5841.4', '595.5', '6.0', '635.2', '6426.6', '7.0', '8.1', '“大”人物', '一条狗的回家路', '大黄蜂', '家和万事惊', '密室逃生', '掠食城市', '死侍2:我爱我家', '海王', '白蛇:缘起', '钢铁飞龙之奥特曼崛起'}}
train_tables['4d25e6403aaa11e9bdbbf40f24344a08']
{'headers': ['证券代码', '证券简称', '最新收盘价', '定增价除权后至今价格', '增发价格', '倒挂率', '定增年度', '增发目的'], 'header2id': {'证券代码': 0, '证券简称': 1, '最新收盘价': 2, '定增价除权后至今价格': 3, '增发价格': 4, '倒挂率': 5, '定增年度': 6, '增发目的': 7}, 'content': {'证券代码': {'300148.SZ', '300182.SZ', '300269.SZ'}, '证券简称': {'天舟文化', '捷成股份', '联建光电'}, '最新收盘价': {'4.09', '4.69', '5.48'}, '定增价除权后至今价格': {'11.16', '11.29', '12.48', '21.88', '23.07', '9.91'}, '增发价格': {'14.78', '15.09', '16.34', '16.988', '22.09', '23.3004'}, '倒挂率': {'23.75', '25.05', '36.65', '37.58', '41.26', '41.54'}, '定增年度': {'2016.0'}, '增发目的': {'融资收购其他资产', '配套融资'}}, 'all_values': {'11.16', '11.29', '12.48', '14.78', '15.09', '16.34', '16.988', '2016.0', '21.88', '22.09', '23.07', '23.3004', '23.75', '25.05', '300148.SZ', '300182.SZ', '300269.SZ', '36.65', '37.58', '4.09', '4.69', '41.26', '41.54', '5.48', '9.91', '天舟文化', '捷成股份', '联建光电', '融资收购其他资产', '配套融资'}}
# 对每个汉字进行编码
# 读取词表,并给创建每一个字对应的序号的字典
token_dict = {}
with codecs.open(dict_path, 'r', 'utf8') as reader:
for line in reader:
token = line.strip()
token_dict[token] = len(token_dict)
token_dict
{'[PAD]': 0, '[unused1]': 1, '[unused2]': 2, '[unused3]': 3, '[unused4]': 4, '[unused5]': 5, '[unused6]': 6, '[unused7]': 7, '[unused8]': 8, '[unused9]': 9, '[unused10]': 10, '[unused11]': 11, '[unused12]': 12, '[unused13]': 13, '[unused14]': 14, '[unused15]': 15, '[unused16]': 16, '[unused17]': 17, '[unused18]': 18, '[unused19]': 19, '[unused20]': 20, '[unused21]': 21, '[unused22]': 22, '[unused23]': 23, '[unused24]': 24, '[unused25]': 25, '[unused26]': 26, '[unused27]': 27, '[unused28]': 28, '[unused29]': 29, '[unused30]': 30, '[unused31]': 31, '[unused32]': 32, '[unused33]': 33, '[unused34]': 34, '[unused35]': 35, '[unused36]': 36, '[unused37]': 37, '[unused38]': 38, '[unused39]': 39, '[unused40]': 40, '[unused41]': 41, '[unused42]': 42, '[unused43]': 43, '[unused44]': 44, '[unused45]': 45, '[unused46]': 46, '[unused47]': 47, '[unused48]': 48, '[unused49]': 49, '[unused50]': 50, '[unused51]': 51, '[unused52]': 52, '[unused53]': 53, '[unused54]': 54, '[unused55]': 55, '[unused56]': 56, '[unused57]': 57, '[unused58]': 58, '[unused59]': 59, '[unused60]': 60, '[unused61]': 61, '[unused62]': 62, '[unused63]': 63, '[unused64]': 64, '[unused65]': 65, '[unused66]': 66, '[unused67]': 67, '[unused68]': 68, '[unused69]': 69, '[unused70]': 70, '[unused71]': 71, '[unused72]': 72, '[unused73]': 73, '[unused74]': 74, '[unused75]': 75, '[unused76]': 76, '[unused77]': 77, '[unused78]': 78, '[unused79]': 79, '[unused80]': 80, '[unused81]': 81, '[unused82]': 82, '[unused83]': 83, '[unused84]': 84, '[unused85]': 85, '[unused86]': 86, '[unused87]': 87, '[unused88]': 88, '[unused89]': 89, '[unused90]': 90, '[unused91]': 91, '[unused92]': 92, '[unused93]': 93, '[unused94]': 94, '[unused95]': 95, '[unused96]': 96, '[unused97]': 97, '[unused98]': 98, '[unused99]': 99, '[UNK]': 100, '[CLS]': 101, '[SEP]': 102, '[MASK]': 103, '<S>': 104, '<T>': 105, '!': 106, '"': 107, '#': 108, '$': 109, '%': 110, '&': 111, "'": 112, '(': 113, ')': 114, '*': 115, '+': 116, ',': 117, '-': 118, '.': 119, '/': 120, '0': 121, '1': 122, '2': 123, '3': 124, '4': 125, '5': 126, '6': 127, '7': 128, '8': 129, '9': 130, ':': 131, ';': 132, '<': 133, '=': 134, '>': 135, '?': 136, '@': 137, '[': 138, '\\': 139, ']': 140, '^': 141, '_': 142, 'a': 143, 'b': 144, 'c': 145, 'd': 146, 'e': 147, 'f': 148, 'g': 149, 'h': 150, 'i': 151, 'j': 152, 'k': 153, 'l': 154, 'm': 155, 'n': 156, 'o': 157, 'p': 158, 'q': 159, 'r': 160, 's': 161, 't': 162, 'u': 163, 'v': 164, 'w': 165, 'x': 166, 'y': 167, 'z': 168, '{': 169, '|': 170, '}': 171, '~': 172, '£': 173, '¤': 174, '¥': 175, '§': 176, '©': 177, '«': 178, '®': 179, '°': 180, '±': 181, '²': 182, '³': 183, 'µ': 184, '·': 185, '¹': 186, 'º': 187, '»': 188, '¼': 189, '×': 190, 'ß': 191, 'æ': 192, '÷': 193, 'ø': 194, 'đ': 195, 'ŋ': 196, 'ɔ': 197, 'ə': 198, 'ɡ': 199, 'ʰ': 200, 'ˇ': 201, 'ˈ': 202, 'ˊ': 203, 'ˋ': 204, 'ˍ': 205, 'ː': 206, '˙': 207, '˚': 208, 'ˢ': 209, 'α': 210, 'β': 211, 'γ': 212, 'δ': 213, 'ε': 214, 'η': 215, 'θ': 216, 'ι': 217, 'κ': 218, 'λ': 219, 'μ': 220, 'ν': 221, 'ο': 222, 'π': 223, 'ρ': 224, 'ς': 225, 'σ': 226, 'τ': 227, 'υ': 228, 'φ': 229, 'χ': 230, 'ψ': 231, 'ω': 232, 'а': 233, 'б': 234, 'в': 235, 'г': 236, 'д': 237, 'е': 238, 'ж': 239, 'з': 240, 'и': 241, 'к': 242, 'л': 243, 'м': 244, 'н': 245, 'о': 246, 'п': 247, 'р': 248, 'с': 249, 'т': 250, 'у': 251, 'ф': 252, 'х': 253, 'ц': 254, 'ч': 255, 'ш': 256, 'ы': 257, 'ь': 258, 'я': 259, 'і': 260, 'ا': 261, 'ب': 262, 'ة': 263, 'ت': 264, 'د': 265, 'ر': 266, 'س': 267, 'ع': 268, 'ل': 269, 'م': 270, 'ن': 271, 'ه': 272, 'و': 273, 'ي': 274, '۩': 275, 'ก': 276, 'ง': 277, 'น': 278, 'ม': 279, 'ย': 280, 'ร': 281, 'อ': 282, 'า': 283, 'เ': 284, '๑': 285, '་': 286, 'ღ': 287, 'ᄀ': 288, 'ᄁ': 289, 'ᄂ': 290, 'ᄃ': 291, 'ᄅ': 292, 'ᄆ': 293, 'ᄇ': 294, 'ᄈ': 295, 'ᄉ': 296, 'ᄋ': 297, 'ᄌ': 298, 'ᄎ': 299, 'ᄏ': 300, 'ᄐ': 301, 'ᄑ': 302, 'ᄒ': 303, 'ᅡ': 304, 'ᅢ': 305, 'ᅣ': 306, 'ᅥ': 307, 'ᅦ': 308, 'ᅧ': 309, 'ᅨ': 310, 'ᅩ': 311, 'ᅪ': 312, 'ᅬ': 313, 'ᅭ': 314, 'ᅮ': 315, 'ᅯ': 316, 'ᅲ': 317, 'ᅳ': 318, 'ᅴ': 319, 'ᅵ': 320, 'ᆨ': 321, 'ᆫ': 322, 'ᆯ': 323, 'ᆷ': 324, 'ᆸ': 325, 'ᆺ': 326, 'ᆻ': 327, 'ᆼ': 328, 'ᗜ': 329, 'ᵃ': 330, 'ᵉ': 331, 'ᵍ': 332, 'ᵏ': 333, 'ᵐ': 334, 'ᵒ': 335, 'ᵘ': 336, '‖': 337, '„': 338, '†': 339, '•': 340, '‥': 341, '‧': 342, '': 13503, '‰': 344, '′': 345, '″': 346, '‹': 347, '›': 348, '※': 349, '‿': 350, '⁄': 351, 'ⁱ': 352, '⁺': 353, 'ⁿ': 354, '₁': 355, '₂': 356, '₃': 357, '₄': 358, '€': 359, '℃': 360, '№': 361, '™': 362, 'ⅰ': 363, 'ⅱ': 364, 'ⅲ': 365, 'ⅳ': 366, 'ⅴ': 367, '←': 368, '↑': 369, '→': 370, '↓': 371, '↔': 372, '↗': 373, '↘': 374, '⇒': 375, '∀': 376, '−': 377, '∕': 378, '∙': 379, '√': 380, '∞': 381, '∟': 382, '∠': 383, '∣': 384, '∥': 385, '∩': 386, '∮': 387, '∶': 388, '∼': 389, '∽': 390, '≈': 391, '≒': 392, '≡': 393, '≤': 394, '≥': 395, '≦': 396, '≧': 397, '≪': 398, '≫': 399, '⊙': 400, '⋅': 401, '⋈': 402, '⋯': 403, '⌒': 404, '①': 405, '②': 406, '③': 407, '④': 408, '⑤': 409, '⑥': 410, '⑦': 411, '⑧': 412, '⑨': 413, '⑩': 414, '⑴': 415, '⑵': 416, '⑶': 417, '⑷': 418, '⑸': 419, '⒈': 420, '⒉': 421, '⒊': 422, '⒋': 423, 'ⓒ': 424, 'ⓔ': 425, 'ⓘ': 426, '─': 427, '━': 428, '│': 429, '┃': 430, '┅': 431, '┆': 432, '┊': 433, '┌': 434, '└': 435, '├': 436, '┣': 437, '═': 438, '║': 439, '╚': 440, '╞': 441, '╠': 442, '╭': 443, '╮': 444, '╯': 445, '╰': 446, '╱': 447, '╳': 448, '▂': 449, '▃': 450, '▅': 451, '▇': 452, '█': 453, '▉': 454, '▋': 455, '▌': 456, '▍': 457, '▎': 458, '■': 459, '□': 460, '▪': 461, '▫': 462, '▬': 463, '▲': 464, '△': 465, '▶': 466, '►': 467, '▼': 468, '▽': 469, '◆': 470, '◇': 471, '○': 472, '◎': 473, '●': 474, '◕': 475, '◠': 476, '◢': 477, '◤': 478, '☀': 479, '★': 480, '☆': 481, '☕': 482, '☞': 483, '☺': 484, '☼': 485, '♀': 486, '♂': 487, '♠': 488, '♡': 489, '♣': 490, '♥': 491, '♦': 492, '♪': 493, '♫': 494, '♬': 495, '✈': 496, '✔': 497, '✕': 498, '✖': 499, '✦': 500, '✨': 501, '✪': 502, '✰': 503, '✿': 504, '❀': 505, '❤': 506, '➜': 507, '➤': 508, '⦿': 509, '、': 510, '。': 511, '〃': 512, '々': 513, '〇': 514, '〈': 515, '〉': 516, '《': 517, '》': 518, '「': 519, '」': 520, '『': 521, '』': 522, '【': 523, '】': 524, '〓': 525, '〔': 526, '〕': 527, '〖': 528, '〗': 529, '〜': 530, '〝': 531, '〞': 532, 'ぁ': 533, 'あ': 534, 'ぃ': 535, 'い': 536, 'う': 537, 'ぇ': 538, 'え': 539, 'お': 540, 'か': 541, 'き': 542, 'く': 543, 'け': 544, 'こ': 545, 'さ': 546, 'し': 547, 'す': 548, 'せ': 549, 'そ': 550, 'た': 551, 'ち': 552, 'っ': 553, 'つ': 554, 'て': 555, 'と': 556, 'な': 557, 'に': 558, 'ぬ': 559, 'ね': 560, 'の': 561, 'は': 562, 'ひ': 563, 'ふ': 564, 'へ': 565, 'ほ': 566, 'ま': 567, 'み': 568, 'む': 569, 'め': 570, 'も': 571, 'ゃ': 572, 'や': 573, 'ゅ': 574, 'ゆ': 575, 'ょ': 576, 'よ': 577, 'ら': 578, 'り': 579, 'る': 580, 'れ': 581, 'ろ': 582, 'わ': 583, 'を': 584, 'ん': 585, '゜': 586, 'ゝ': 587, 'ァ': 588, 'ア': 589, 'ィ': 590, 'イ': 591, 'ゥ': 592, 'ウ': 593, 'ェ': 594, 'エ': 595, 'ォ': 596, 'オ': 597, 'カ': 598, 'キ': 599, 'ク': 600, 'ケ': 601, 'コ': 602, 'サ': 603, 'シ': 604, 'ス': 605, 'セ': 606, 'ソ': 607, 'タ': 608, 'チ': 609, 'ッ': 610, 'ツ': 611, 'テ': 612, 'ト': 613, 'ナ': 614, 'ニ': 615, 'ヌ': 616, 'ネ': 617, 'ノ': 618, 'ハ': 619, 'ヒ': 620, 'フ': 621, 'ヘ': 622, 'ホ': 623, 'マ': 624, 'ミ': 625, 'ム': 626, 'メ': 627, 'モ': 628, 'ャ': 629, 'ヤ': 630, 'ュ': 631, 'ユ': 632, 'ョ': 633, 'ヨ': 634, 'ラ': 635, 'リ': 636, 'ル': 637, 'レ': 638, 'ロ': 639, 'ワ': 640, 'ヲ': 641, 'ン': 642, 'ヶ': 643, '・': 644, 'ー': 645, 'ヽ': 646, 'ㄅ': 647, 'ㄆ': 648, 'ㄇ': 649, 'ㄉ': 650, 'ㄋ': 651, 'ㄌ': 652, 'ㄍ': 653, 'ㄎ': 654, 'ㄏ': 655, 'ㄒ': 656, 'ㄚ': 657, 'ㄛ': 658, 'ㄞ': 659, 'ㄟ': 660, 'ㄢ': 661, 'ㄤ': 662, 'ㄥ': 663, 'ㄧ': 664, 'ㄨ': 665, 'ㆍ': 666, '㈦': 667, '㊣': 668, '㎡': 669, '㗎': 670, '一': 671, '丁': 672, '七': 673, '万': 674, '丈': 675, '三': 676, '上': 677, '下': 678, '不': 679, '与': 680, '丐': 681, '丑': 682, '专': 683, '且': 684, '丕': 685, '世': 686, '丘': 687, '丙': 688, '业': 689, '丛': 690, '东': 691, '丝': 692, '丞': 693, '丟': 694, '両': 695, '丢': 696, '两': 697, '严': 698, '並': 699, '丧': 700, '丨': 701, '个': 702, '丫': 703, '中': 704, '丰': 705, '串': 706, '临': 707, '丶': 708, '丸': 709, '丹': 710, '为': 711, '主': 712, '丼': 713, '丽': 714, '举': 715, '丿': 716, '乂': 717, '乃': 718, '久': 719, '么': 720, '义': 721, '之': 722, '乌': 723, '乍': 724, '乎': 725, '乏': 726, '乐': 727, '乒': 728, '乓': 729, '乔': 730, '乖': 731, '乗': 732, '乘': 733, '乙': 734, '乜': 735, '九': 736, '乞': 737, '也': 738, '习': 739, '乡': 740, '书': 741, '乩': 742, '买': 743, '乱': 744, '乳': 745, '乾': 746, '亀': 747, '亂': 748, '了': 749, '予': 750, '争': 751, '事': 752, '二': 753, '于': 754, '亏': 755, '云': 756, '互': 757, '五': 758, '井': 759, '亘': 760, '亙': 761, '亚': 762, '些': 763, '亜': 764, '亞': 765, '亟': 766, '亡': 767, '亢': 768, '交': 769, '亥': 770, '亦': 771, '产': 772, '亨': 773, '亩': 774, '享': 775, '京': 776, '亭': 777, '亮': 778, '亲': 779, '亳': 780, '亵': 781, '人': 782, '亿': 783, '什': 784, '仁': 785, '仃': 786, '仄': 787, '仅': 788, '仆': 789, '仇': 790, '今': 791, '介': 792, '仍': 793, '从': 794, '仏': 795, '仑': 796, '仓': 797, '仔': 798, '仕': 799, '他': 800, '仗': 801, '付': 802, '仙': 803, '仝': 804, '仞': 805, '仟': 806, '代': 807, '令': 808, '以': 809, '仨': 810, '仪': 811, '们': 812, '仮': 813, '仰': 814, '仲': 815, '件': 816, '价': 817, '任': 818, '份': 819, '仿': 820, '企': 821, '伉': 822, '伊': 823, '伍': 824, '伎': 825, '伏': 826, '伐': 827, '休': 828, '伕': 829, '众': 830, '优': 831, '伙': 832, '会': 833, '伝': 834, '伞': 835, '伟': 836, '传': 837, '伢': 838, '伤': 839, '伦': 840, '伪': 841, '伫': 842, '伯': 843, '估': 844, '伴': 845, '伶': 846, '伸': 847, '伺': 848, '似': 849, '伽': 850, '佃': 851, '但': 852, '佇': 853, '佈': 854, '位': 855, '低': 856, '住': 857, '佐': 858, '佑': 859, '体': 860, '佔': 861, '何': 862, '佗': 863, '佘': 864, '余': 865, '佚': 866, '佛': 867, '作': 868, '佝': 869, '佞': 870, '佟': 871, '你': 872, '佢': 873, '佣': 874, '佤': 875, '佥': 876, '佩': 877, '佬': 878, '佯': 879, '佰': 880, '佳': 881, '併': 882, '佶': 883, '佻': 884, '佼': 885, '使': 886, '侃': 887, '侄': 888, '來': 889, '侈': 890, '例': 891, '侍': 892, '侏': 893, '侑': 894, '侖': 895, '侗': 896, '供': 897, '依': 898, '侠': 899, '価': 900, '侣': 901, '侥': 902, '侦': 903, '侧': 904, '侨': 905, '侬': 906, '侮': 907, '侯': 908, '侵': 909, '侶': 910, '侷': 911, '便': 912, '係': 913, '促': 914, '俄': 915, '俊': 916, '俎': 917, '俏': 918, '俐': 919, '俑': 920, '俗': 921, '俘': 922, '俚': 923, '保': 924, '俞': 925, '俟': 926, '俠': 927, '信': 928, '俨': 929, '俩': 930, '俪': 931, '俬': 932, '俭': 933, '修': 934, '俯': 935, '俱': 936, '俳': 937, '俸': 938, '俺': 939, '俾': 940, '倆': 941, '倉': 942, '個': 943, '倌': 944, '倍': 945, '倏': 946, '們': 947, '倒': 948, '倔': 949, '倖': 950, '倘': 951, '候': 952, '倚': 953, '倜': 954, '借': 955, '倡': 956, '値': 957, '倦': 958, '倩': 959, '倪': 960, '倫': 961, '倬': 962, '倭': 963, '倶': 964, '债': 965, '值': 966, '倾': 967, '偃': 968, '假': 969, '偈': 970, '偉': 971, '偌': 972, '偎': 973, '偏': 974, '偕': 975, '做': 976, '停': 977, '健': 978, '側': 979, '偵': 980, '偶': 981, '偷': 982, '偻': 983, '偽': 984, '偿': 985, '傀': 986, '傅': 987, '傍': 988, '傑': 989, '傘': 990, '備': 991, '傚': 992, '傢': 993, '傣': 994, '傥': 995, '储': 996, '傩': 997, '催': 998, '傭': 999, ...}
# 重写Tokenizer(分词用),为了保证text经过tokenizer后与原text长度相同 # Tokenizer 自带的 _tokenize 会自动去掉空格,然后有些字符会粘在一块输出, # 导致 tokenize 之后的列表不等于原来字符串的长度了,这样如果做序列标注的任务会很麻烦 # 继承Tokenizer类 class OurTokenizer(Tokenizer): def _tokenize(self, text): R = [] for c in text: if c in self._token_dict: R.append(c) elif self._is_space(c): R.append('[unused1]') # space类用未经训练的[unused1]表示 else: R.append('[UNK]') # 剩余的字符是[UNK] return R # 输入词表形成分词器 tokenizer = OurTokenizer(token_dict) tokenizer
<__main__.OurTokenizer at 0x2430e637908>
# 句子填充 # 将句子都补充为等长 def seq_padding(X, padding=0, maxlen=None): if maxlen is None: L = [len(x) for x in X]#获取每句话的长度 ML = max(L)#获取最长句子的长度 else: ML = maxlen return np.array([ np.concatenate([x[:ML], [padding] * (ML - len(x))]) if len(x[:ML]) < ML else x for x in X ]) #np.cpncatenate会将几个矩阵进行拼接,如果x的长度小于ML会用0进行填充,如果x长度大于ML则不做处理 # for x in X: # if len(x[:ML]) < ML: # np.concatenate([x[:ML], [padding] * (ML - len(x))]) # else: # x
def most_similar(s, slist):
"""从词表中找最相近的词(当无法全匹配的时候)
"""
if len(slist) == 0:
return s
scores = [editdistance.eval(s, t) for t in slist]#最小编辑距离算法
return slist[np.argmin(scores)]
def most_similar_2(w, s):
"""从句子s中找与w最相近的片段,
借助分词工具和ngram的方式尽量精确地确定边界。
"""
sw = jieba.lcut(s)
sl = list(sw)
sl.extend([''.join(i) for i in zip(sw, sw[1:])])
sl.extend([''.join(i) for i in zip(sw, sw[1:], sw[2:])])
return most_similar(w, sl)
d=train_data[0]
# 对"二零一九年"对照编码词表进行编码,并加上了前后的[cls]和[sep]
# x1是词的编码信息,x2是说明每个数字属于第几句话
x1, x2 = tokenizer.encode('二零一九年',"我是傻子")
print(x1)
print(x2)
print(len(x1)) #bert的输入除了单词的向量外还需要有position vector还需要有segment
print(len(x2))
[101, 753, 7439, 671, 736, 2399, 102, 2769, 3221, 1004, 2094, 102]
[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
12
12
# 产生数据 class data_generator: def __init__(self, data, tables, batch_size=32): # 设置batch_size和steps self.data = data self.tables = tables self.batch_size = batch_size self.steps = len(self.data) // self.batch_size if len(self.data) % self.batch_size != 0: self.steps += 1 def __len__(self): return self.steps def __iter__(self): while True: idxs = list(range(len(self.data))) np.random.shuffle(idxs) X1, X2, XM, H, HM, SEL, CONN, CSEL, COP = [], [], [], [], [], [], [], [], [] # 遍历每一个输入数据 -> 字典:包括句子及其id和sql的'agg','cond_conn_op','sel','conds' for i in idxs: d = self.data[i] # 去table中查找该id对应的列名 t = self.tables[d['table_id']]['headers'] # 将输入数据的question编码 x1, x2 = tokenizer.encode(d['question']) # 设置一个与输入数据编码后长度相同的列表[0,1,1,...,1,0] xm = [0] + [1] * len(d['question']) + [0] h = [] for j in t: # 对列名进行编码 _x1, _x2 = tokenizer.encode(j) # len(h)记录了有多少个列名 h.append(len(x1)) # extend() 函数用于在列表末尾一次性追加另一个序列中的多个值 # 将问题编码与列名编码合并 x1.extend(_x1) x2.extend(_x2) # 列名个1 hm = [1] * len(h) sel = [] for j in range(len(h)): # index() 方法检测字符串中是否包含子字符串 str ,并返回索引值 # 如果j是sel中的,则获得这个sel的索引值并赋给j if j in d['sql']['sel']: j = d['sql']['sel'].index(j) sel.append(d['sql']['agg'][j]) else: sel.append(num_agg - 1) # 不被select则被标记为num_agg-1 # 获得and 或 or conn = [d['sql']['cond_conn_op']] csel = np.zeros(len(d['question']) + 2, dtype='int32') # 这里的0既表示padding,又表示第一列,padding部分训练时会被mask cop = np.zeros(len(d['question']) + 2, dtype='int32') + num_op - 1 # 不被select则被标记为num_op-1 for j in d['sql']['conds']: if j[2] not in d['question']: j[2] = most_similar_2(j[2], d['question']) if j[2] not in d['question']: continue k = d['question'].index(j[2]) csel[k + 1: k + 1 + len(j[2])] = j[0] cop[k + 1: k + 1 + len(j[2])] = j[1] if len(x1) > maxlen: continue X1.append(x1) # bert的输入 X2.append(x2) # bert的输入 XM.append(xm) # 输入序列的mask H.append(h) # 列名所在位置 HM.append(hm) # 列名mask SEL.append(sel) # 被select的列 CONN.append(conn) # 连接类型 CSEL.append(csel) # 条件中的列 COP.append(cop) # 条件中的运算符(同时也是值的标记) if len(X1) == self.batch_size: X1 = seq_padding(X1) X2 = seq_padding(X2) XM = seq_padding(XM, maxlen=X1.shape[1]) H = seq_padding(H) HM = seq_padding(HM) SEL = seq_padding(SEL) CONN = seq_padding(CONN) CSEL = seq_padding(CSEL, maxlen=X1.shape[1]) COP = seq_padding(COP, maxlen=X1.shape[1]) yield [X1, X2, XM, H, HM, SEL, CONN, CSEL, COP], None X1, X2, XM, H, HM, SEL, CONN, CSEL, COP = [], [], [], [], [], [], [], [], []
def seq_gather(x):
"""seq是[None, seq_len, s_size]的格式,
idxs是[None, n]的格式,在seq的第i个序列中选出第idxs[i]个向量,
最终输出[None, n, s_size]的向量。
"""
seq, idxs = x
idxs = K.cast(idxs, 'int32')
return K.tf.batch_gather(seq, idxs)
bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=None)
WARNING:tensorflow:From E:\Anaconda\anaconda\envs\tensorflow1\lib\site-packages\tensorflow\python\framework\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From E:\Anaconda\anaconda\envs\tensorflow1\lib\site-packages\keras\backend\tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
for l in bert_model.layers:
l.trainable = True
x1_in = Input(shape=(None,), dtype='int32')
x2_in = Input(shape=(None,))
xm_in = Input(shape=(None,))
h_in = Input(shape=(None,), dtype='int32')
hm_in = Input(shape=(None,))
sel_in = Input(shape=(None,), dtype='int32')
conn_in = Input(shape=(1,), dtype='int32')
csel_in = Input(shape=(None,), dtype='int32')
cop_in = Input(shape=(None,), dtype='int32')
x1, x2, xm, h, hm, sel, conn, csel, cop = (
x1_in, x2_in, xm_in, h_in, hm_in, sel_in, conn_in, csel_in, cop_in
)
hm = Lambda(lambda x: K.expand_dims(x, 1))(hm) # header的mask.shape=(None, 1, h_len) x = bert_model([x1_in, x2_in]) x4conn = Lambda(lambda x: x[:, 0])(x) pconn = Dense(num_cond_conn_op, activation='softmax')(x4conn) x4h = Lambda(seq_gather)([x, h]) psel = Dense(num_agg, activation='softmax')(x4h) pcop = Dense(num_op, activation='softmax')(x) x = Lambda(lambda x: K.expand_dims(x, 2))(x) x4h = Lambda(lambda x: K.expand_dims(x, 1))(x4h) pcsel_1 = Dense(256)(x) pcsel_2 = Dense(256)(x4h) pcsel = Lambda(lambda x: x[0] + x[1])([pcsel_1, pcsel_2]) pcsel = Activation('tanh')(pcsel) pcsel = Dense(1)(pcsel) pcsel = Lambda(lambda x: x[0][..., 0] - (1 - x[1]) * 1e10)([pcsel, hm]) pcsel = Activation('softmax')(pcsel)
model = Model(
[x1_in, x2_in, h_in, hm_in],
[psel, pconn, pcop, pcsel]
)
train_model = Model(
[x1_in, x2_in, xm_in, h_in, hm_in, sel_in, conn_in, csel_in, cop_in],
[psel, pconn, pcop, pcsel]
)
xm = xm # question的mask.shape=(None, x_len) hm = hm[:, 0] # header的mask.shape=(None, h_len) cm = K.cast(K.not_equal(cop, num_op - 1), 'float32') # conds的mask.shape=(None, x_len) psel_loss = K.sparse_categorical_crossentropy(sel_in, psel) psel_loss = K.sum(psel_loss * hm) / K.sum(hm) pconn_loss = K.sparse_categorical_crossentropy(conn_in, pconn) pconn_loss = K.mean(pconn_loss) pcop_loss = K.sparse_categorical_crossentropy(cop_in, pcop) pcop_loss = K.sum(pcop_loss * xm) / K.sum(xm) pcsel_loss = K.sparse_categorical_crossentropy(csel_in, pcsel) pcsel_loss = K.sum(pcsel_loss * xm * cm) / K.sum(xm * cm) loss = psel_loss + pconn_loss + pcop_loss + pcsel_loss train_model.add_loss(loss) train_model.compile(optimizer=Adam(learning_rate)) train_model.summary()
__________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) (None, None) 0 __________________________________________________________________________________________________ input_2 (InputLayer) (None, None) 0 __________________________________________________________________________________________________ model_2 (Model) (None, None, 768) 101677056 input_1[0][0] input_2[0][0] __________________________________________________________________________________________________ input_4 (InputLayer) (None, None) 0 __________________________________________________________________________________________________ lambda_3 (Lambda) (None, None, 768) 0 model_2[1][0] input_4[0][0] __________________________________________________________________________________________________ lambda_4 (Lambda) (None, None, 1, 768) 0 model_2[1][0] __________________________________________________________________________________________________ lambda_5 (Lambda) (None, 1, None, 768) 0 lambda_3[0][0] __________________________________________________________________________________________________ dense_4 (Dense) (None, None, 1, 256) 196864 lambda_4[0][0] __________________________________________________________________________________________________ dense_5 (Dense) (None, 1, None, 256) 196864 lambda_5[0][0] __________________________________________________________________________________________________ lambda_6 (Lambda) (None, None, None, 2 0 dense_4[0][0] dense_5[0][0] __________________________________________________________________________________________________ activation_1 (Activation) (None, None, None, 2 0 lambda_6[0][0] __________________________________________________________________________________________________ input_5 (InputLayer) (None, None) 0 __________________________________________________________________________________________________ dense_6 (Dense) (None, None, None, 1 257 activation_1[0][0] __________________________________________________________________________________________________ lambda_1 (Lambda) (None, 1, None) 0 input_5[0][0] __________________________________________________________________________________________________ lambda_2 (Lambda) (None, 768) 0 model_2[1][0] __________________________________________________________________________________________________ lambda_7 (Lambda) (None, None, None) 0 dense_6[0][0] lambda_1[0][0] __________________________________________________________________________________________________ dense_2 (Dense) (None, None, 7) 5383 lambda_3[0][0] __________________________________________________________________________________________________ dense_1 (Dense) (None, 3) 2307 lambda_2[0][0] __________________________________________________________________________________________________ dense_3 (Dense) (None, None, 5) 3845 model_2[1][0] __________________________________________________________________________________________________ activation_2 (Activation) (None, None, None) 0 lambda_7[0][0] ================================================================================================== Total params: 102,082,576 Trainable params: 102,082,576 Non-trainable params: 0 __________________________________________________________________________________________________
def nl2sql(question, table): """输入question和headers,转SQL """ x1, x2 = tokenizer.encode(question) h = [] for i in table['headers']: _x1, _x2 = tokenizer.encode(i) h.append(len(x1)) x1.extend(_x1) x2.extend(_x2) hm = [1] * len(h) psel, pconn, pcop, pcsel = model.predict([ np.array([x1]), np.array([x2]), np.array([h]), np.array([hm]) ]) R = {'agg': [], 'sel': []} for i, j in enumerate(psel[0].argmax(1)): if j != num_agg - 1: # num_agg-1类是不被select的意思 R['sel'].append(i) R['agg'].append(j) conds = [] v_op = -1 for i, j in enumerate(pcop[0, :len(question)+1].argmax(1)): # 这里结合标注和分类来预测条件 if j != num_op - 1: if v_op != j: if v_op != -1: v_end = v_start + len(v_str) csel = pcsel[0][v_start: v_end].mean(0).argmax() conds.append((csel, v_op, v_str)) v_start = i v_op = j v_str = question[i - 1] else: v_str += question[i - 1] elif v_op != -1: v_end = v_start + len(v_str) csel = pcsel[0][v_start: v_end].mean(0).argmax() conds.append((csel, v_op, v_str)) v_op = -1 R['conds'] = set() for i, j, k in conds: if re.findall('[^\d\.]', k): j = 2 # 非数字只能用等号 if j == 2: if k not in table['all_values']: # 等号的值必须在table出现过,否则找一个最相近的 k = most_similar(k, list(table['all_values'])) h = table['headers'][i] # 然后检查值对应的列是否正确,如果不正确,直接修正列名 if k not in table['content'][h]: for r, v in table['content'].items(): if k in v: i = table['header2id'][r] break R['conds'].add((i, j, k)) R['conds'] = list(R['conds']) if len(R['conds']) <= 1: # 条件数少于等于1时,条件连接符直接为0 R['cond_conn_op'] = 0 else: R['cond_conn_op'] = 1 + pconn[0, 1:].argmax() # 不能是0 return R
def is_equal(R1, R2):
"""判断两个SQL字典是否全匹配
"""
return (R1['cond_conn_op'] == R2['cond_conn_op']) &\
(set(zip(R1['sel'], R1['agg'])) == set(zip(R2['sel'], R2['agg']))) &\
(set([tuple(i) for i in R1['conds']]) == set([tuple(i) for i in R2['conds']]))
def evaluate(data, tables): right = 0. pbar = tqdm() F = open('evaluate_pred.json', 'w') for i, d in enumerate(data): question = d['question'] table = tables[d['table_id']] R = nl2sql(question, table) right += float(is_equal(R, d['sql'])) pbar.update(1) pbar.set_description('< acc: %.5f >' % (right / (i + 1))) d['sql_pred'] = R s = json.dumps(d, ensure_ascii=False, indent=4) F.write(s.encode('utf-8') + '\n') F.close() pbar.close() return right / len(data)
def test(data, tables, outfile='result.json'):
pbar = tqdm()
F = open(outfile, 'w')
for i, d in enumerate(data):
question = d['question']
table = tables[d['table_id']]
R = nl2sql(question, table)
pbar.update(1)
s = json.dumps(R, ensure_ascii=False)
F.write(s.encode('utf-8') + '\n')
F.close()
pbar.close()
class Evaluate(Callback): def __init__(self): self.accs = [] self.best = 0. self.passed = 0 self.stage = 0 def on_batch_begin(self, batch, logs=None): """第一个epoch用来warmup,第二个epoch把学习率降到最低 """ if self.passed < self.params['steps']: lr = (self.passed + 1.) / self.params['steps'] * learning_rate K.set_value(self.model.optimizer.lr, lr) self.passed += 1 elif self.params['steps'] <= self.passed < self.params['steps'] * 2: lr = (2 - (self.passed + 1.) / self.params['steps']) * (learning_rate - min_learning_rate) lr += min_learning_rate K.set_value(self.model.optimizer.lr, lr) self.passed += 1 def on_epoch_end(self, epoch, logs=None): acc = self.evaluate() self.accs.append(acc) if acc > self.best: self.best = acc train_model.save_weights('best_model.weights') print ('acc: %.5f, best acc: %.5f\n' % (acc, self.best)) def evaluate(self): return evaluate(valid_data, valid_tables)
train_D = data_generator(train_data, train_tables)
evaluator = Evaluate()
if __name__ == '__main__':
train_model.fit_generator(
train_D.__iter__(),
steps_per_epoch=len(train_D),
epochs=15,
callbacks=[evaluator]
)
else:
train_model.load_weights('best_model.weights')
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。