赞
踩
您或许知道,作者后续分享网络安全的文章会越来越少。但如果您想学习人工智能和安全结合的应用,您就有福利了,作者将重新打造一个《当人工智能遇上安全》系列博客,详细介绍人工智能与安全相关的论文、实践,并分享各种案例,涉及恶意代码检测、恶意请求识别、入侵检测、对抗样本等等。只想更好地帮助初学者,更加成体系的分享新知识。该系列文章会更加聚焦,更加学术,更加深入,也是作者的慢慢成长史。换专业确实挺难的,系统安全也是块硬骨头,但我也试试,看看自己未来四年究竟能将它学到什么程度,漫漫长征路,偏向虎山行。享受过程,一起加油~
前文讲解如何实现威胁情报实体识别,利用BiLSTM-CRF算法实现对ATT&CK相关的技战术实体进行提取,是安全知识图谱构建的重要支撑。这篇文章将以中文语料为主,介绍中文命名实体识别研究,并构建BiGRU-CRF模型实现。基础性文章,希望对您有帮助,如果存在错误或不足之处,还请海涵。且看且珍惜!
由于上一篇文章详细讲解ATT&CK威胁情报采集、预处理、BiLSTM-CRF实体识别内容,这篇文章不再详细介绍,本文将在上一篇文章基础上补充:
版本信息:
常见框架如下图所示:
作者作为网络安全的小白,分享一些自学基础教程给大家,主要是在线笔记,希望您们喜欢。同时,更希望您能与我一起操作和进步,后续将深入学习AI安全和系统安全知识并分享相关实验。总之,希望该系列文章对博友有所帮助,写文不易,大神们不喜勿喷,谢谢!如果文章对您有帮助,将是我创作的最大动力,点赞、评论、私聊均可,一起加油喔!
前文推荐:
作者的github资源:
了解威胁情报的同学,应该都熟悉Mitre的ATT&CK网站,前文已介绍如何采集该网站APT组织的攻击技战术数据。网址如下:
第一步,通过ATT&CK网站源码分析定位APT组织名称,并进行系统采集。
安装BeautifulSoup扩展包,该部分代码如下所示:
01-get-aptentity.py
#encoding:utf-8 #By:Eastmount CSDN import re import requests from lxml import etree from bs4 import BeautifulSoup import urllib.request #------------------------------------------------------------------------------------------- #获取APT组织名称及链接 #设置浏览器代理,它是一个字典 headers = { 'User-Agent':'Mozilla/5.0 (Windows NT 10.0; Win64; x64) \ AppleWebKit/537.36 (KHTML, like Gecko) Chrome/80.0.3987.149 Safari/537.36' } url = 'https://attack.mitre.org/groups/' #向服务器发出请求 r = requests.get(url = url, headers = headers).text #解析DOM树结构 html_etree = etree.HTML(r) names = html_etree.xpath('//*[@class="table table-bordered table-alternate mt-2"]/tbody/tr/td[2]/a/text()') print (names) print(len(names),names[0]) filename = [] for name in names: filename.append(name.strip()) print(filename) #链接 urls = html_etree.xpath('//*[@class="table table-bordered table-alternate mt-2"]/tbody/tr/td[2]/a/@href') print(urls) print(len(urls), urls[0]) print("\n")
此时输出结果如下图所示,包括APT组织名称及对应的URL网址。
第二步,访问APT组织对应的URL,采集详细信息(正文描述)。
第三步,采集对应的技战术TTPs信息,其源码定位如下图所示。
第四步,编写代码完成威胁情报数据采集。01-spider-mitre.py 完整代码如下:
#encoding:utf-8 #By:Eastmount CSDN import re import requests from lxml import etree from bs4 import BeautifulSoup import urllib.request #------------------------------------------------------------------------------------------- #获取APT组织名称及链接 #设置浏览器代理,它是一个字典 headers = { 'User-Agent':'Mozilla/5.0 (Windows NT 10.0; Win64; x64) \ AppleWebKit/537.36 (KHTML, like Gecko) Chrome/80.0.3987.149 Safari/537.36' } url = 'https://attack.mitre.org/groups/' #向服务器发出请求 r = requests.get(url = url, headers = headers).text #解析DOM树结构 html_etree = etree.HTML(r) names = html_etree.xpath('//*[@class="table table-bordered table-alternate mt-2"]/tbody/tr/td[2]/a/text()') print (names) print(len(names),names[0]) #链接 urls = html_etree.xpath('//*[@class="table table-bordered table-alternate mt-2"]/tbody/tr/td[2]/a/@href') print(urls) print(len(urls), urls[0]) print("\n") #------------------------------------------------------------------------------------------- #获取详细信息 k = 0 while k<len(names): filename = str(names[k]).strip() + ".txt" url = "https://attack.mitre.org" + urls[k] print(url) #获取正文信息 page = urllib.request.Request(url, headers=headers) page = urllib.request.urlopen(page) contents = page.read() soup = BeautifulSoup(contents, "html.parser") #获取正文摘要信息 content = "" for tag in soup.find_all(attrs={"class":"description-body"}): #contents = tag.find("p").get_text() contents = tag.find_all("p") for con in contents: content += con.get_text().strip() + "###\n" #标记句子结束(第二部分分句用) #print(content) #获取表格中的技术信息 for tag in soup.find_all(attrs={"class":"table techniques-used table-bordered mt-2"}): contents = tag.find("tbody").find_all("tr") for con in contents: value = con.find("p").get_text() #存在4列或5列 故获取p值 #print(value) content += value.strip() + "###\n" #标记句子结束(第二部分分句用) #删除内容中的参考文献括号 [n] result = re.sub(u"\\[.*?]", "", content) print(result) #文件写入 filename = "Mitre//" + filename print(filename) f = open(filename, "w", encoding="utf-8") f.write(result) f.close() k += 1
输出结果如下图所示,共整理100个组织信息。
每个文件显示内容如下图所示:
数据标注采用暴力的方式进行,即定义不同类型的实体名称并利用BIO的方式进行标注。通过ATT&CK技战术方式进行标注,后续可以结合人工校正,同时可以定义更多类型的实体。
实体名称 | 实体数量 | 示例 |
---|---|---|
APT攻击组织 | 128 | APT32、Lazarus Group |
攻击漏洞 | 56 | CVE-2009-0927 |
区域位置 | 72 | America、Europe |
攻击行业 | 34 | companies、finance |
攻击手法 | 65 | C&C、RAT、DDoS |
利用软件 | 48 | 7-Zip、Microsoft |
操作系统 | 10 | Linux、Windows |
更多标注和预处理请查看上一篇文章。
常见的数据标注工具:
温馨提示:
由于网站的布局会不断变化和优化,因此读者需要掌握数据采集及语法树定位的基本方法,以不变应万变。此外,读者可以尝试采集所有锻炼甚至是URL跳转链接内容,请读者自行尝试和拓展!
假设存在已经采集和标注好的中文数据集,通常采用按字(Char)分隔,如下图所示,古籍为数据集,当然中文威胁情报也类似。
数据集划分为训练集和测试集。
接下来,我们需要读取CSV数据集,并构建汉字词典。关键函数:
完整代码如下:
#encoding:utf-8 # By: Eastmount WuShuai 2024-02-05 import re import os import csv import sys train_data_path = "data/train.csv" test_data_path = "data/test.csv" char_vocab_path = "char_vocabs.txt" #字典文件 special_words = ['<PAD>', '<UNK>'] #特殊词表示 final_words = [] #统计词典(不重复出现) final_labels = [] #统计标记(不重复出现) #语料文件读取函数 def read_csv(filename): words = [] labels = [] with open(filename,encoding='utf-8') as csvfile: reader = csv.reader(csvfile) for row in reader: if len(row)>0: #存在空行报错越界 word,label = row[0],row[1] words.append(word) labels.append(label) return words,labels #统计不重复词典 def count_vocab(words,labels): fp = open(char_vocab_path, 'a') #注意a为叠加(文件只能运行一次) k = 0 while k<len(words): word = words[k] label = labels[k] if word not in final_words: final_words.append(word) fp.writelines(word + "\n") if label not in final_labels: final_labels.append(label) k += 1 fp.close() #读取数据并构造原文字典(第一列) def build_vocab(): words,labels = read_csv(train_data_path) print(len(words),len(labels),words[:8],labels[:8]) count_vocab(words,labels) print(len(final_words),len(final_labels)) #测试集 words,labels = read_csv(test_data_path) print(len(words),len(labels)) count_vocab(words,labels) print(len(final_words),len(final_labels)) print(final_labels) #labels生成字典 label_dict = {} k = 0 for value in final_labels: label_dict[value] = k k += 1 print(label_dict) return label_dict if __name__ == '__main__': build_vocab()
输出结果如下,包括训练集数量,并输出前8行文字及标注,以及不重复的汉字个数,以及实体类别14个。
['晉', '樂', '王', '鮒', '曰', ':', '', '小']
['S-LOC', 'B-PER', 'I-PER', 'E-PER', 'O', 'O', '', 'O']
xxx 14
输出类别如下。
['S-LOC', 'B-PER', 'I-PER', 'E-PER', 'O', '', 'B-LOC',
'E-LOC', 'S-PER', 'S-TIM', 'B-TIM', 'E-TIM', 'I-TIM', 'I-LOC']
接着实体类别进行编码处理,输出结果如下:
{'S-LOC': 0, 'B-PER': 1, 'I-PER': 2, 'E-PER': 3, 'O': 4, '': 5, 'B-LOC': 6,
'E-LOC': 7, 'S-PER': 8, 'S-TIM': 9, 'B-TIM': 10, 'E-TIM': 11, 'I-TIM': 12, 'I-LOC': 13}
需要注意:在实体识别中,我们可以通过调用该函数获取识别的实体类别,关键代码如下。然而,由于真实分析中“O”通常建议编码为0,因此建议重新定义字典编码,更方便我们撰写代码,尤其是中文本遇到换句处理时,上述编码会乱序。
#原计划
from get_data import build_vocab #调取第一阶段函数
label2idx = build_vocab()
#实际情况
label2idx = {'O': 0,
'S-LOC': 1, 'B-LOC': 2, 'I-LOC': 3, 'E-LOC': 4,
'S-PER': 5, 'B-PER': 6, 'I-PER': 7, 'E-PER': 8,
'S-TIM': 9, 'B-TIM': 10, 'E-TIM': 11, 'I-TIM': 12
}
....
sent_ids = [vocab2idx[char] if char in vocab2idx else vocab2idx['<UNK>'] for char in sent_]
tag_ids = [label2idx[label] if label in label2idx else 0 for label in tag_]
最终生成词典char_vocabs.txt。
CRF模型作者安装的是 keras-contrib
。
第一步,如果读者直接使用“pip install keras-contrib”可能会报错,远程下载也报错。
甚至会报错 ModuleNotFoundError: No module named ‘keras_contrib’。
第二步,作者从github中下载该资源,并在本地安装。
git clone https://www.github.com/keras-team/keras-contrib.git
cd keras-contrib
python setup.py install
安装成功如下图所示:
读者可以从我的资源中下载代码和扩展包。
同样需要安装keras和TensorFlow扩展包。
如果TensorFlow下载太慢,可以设置清华大学镜像,实际安装2.2版本。
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
pip install tensorflow==2.2
第一步,数据预处理,包括BIO标记及词典转换。
#encoding:utf-8 # By: Eastmount WuShuai 2024-02-05 # 参考:https://github.com/huanghao128/zh-nlp-demo import re import os import csv import sys from get_data import build_vocab #调取第一阶段函数 #------------------------------------------------------------------------ #第一步 数据预处理 #------------------------------------------------------------------------ train_data_path = "data/train.csv" test_data_path = "data/test.csv" val_data_path = "data/val.csv" char_vocab_path = "char_vocabs.txt" #字典文件(防止多次写入仅读首次生成文件) special_words = ['<PAD>', '<UNK>'] #特殊词表示 final_words = [] #统计词典(不重复出现) final_labels = [] #统计标记(不重复出现) #BIO标记的标签 字母O初始标记为0 #label2idx = build_vocab() label2idx = {'O': 0, 'S-LOC': 1, 'B-LOC': 2, 'I-LOC': 3, 'E-LOC': 4, 'S-PER': 5, 'B-PER': 6, 'I-PER': 7, 'E-PER': 8, 'S-TIM': 9, 'B-TIM': 10, 'E-TIM': 11, 'I-TIM': 12 } print(label2idx) #索引和BIO标签对应 idx2label = {idx: label for label, idx in label2idx.items()} print(idx2label) #读取字符词典文件 with open(char_vocab_path, "r") as fo: char_vocabs = [line.strip() for line in fo] char_vocabs = special_words + char_vocabs print(char_vocabs) #字符和索引编号对应 idx2vocab = {idx: char for idx, char in enumerate(char_vocabs)} vocab2idx = {char: idx for idx, char in idx2vocab.items()} print(idx2vocab) print(vocab2idx)
输出结果如下所示:
{'O': 0, 'S-LOC': 1, 'B-LOC': 2, 'I-LOC': 3, 'E-LOC': 4, 'S-PER': 5, 'B-PER': 6,
'I-PER': 7, 'E-PER': 8, 'S-TIM': 9, 'B-TIM': 10, 'E-TIM': 11, 'I-TIM': 12}
{0: 'O', 1: 'S-LOC', 2: 'B-LOC', 3: 'I-LOC', 4: 'E-LOC', 5: 'S-PER', 6: 'B-PER',
7: 'I-PER', 8: 'E-PER', 9: 'S-TIM', 10: 'B-TIM', 11: 'E-TIM', 12: 'I-TIM'}
['<PAD>', '<UNK>', '晉', '樂', '王', '鮒', '曰', ':', '', '小', '旻', ...]
{0: '<PAD>', 1: '<UNK>', 2: '晉', 3: '樂', 4: '王', 5: '鮒', 6: '曰', 7: ':', 8: '', 9: '小', 10: '旻', ... ]
{'<PAD>': 0, '<UNK>': 1, '晉': 2, '樂': 3, '王': 4, '鮒': 5, '曰': 6, ':': 7, '': 8, '小': 9, '旻': 10, ... ]
第二步,读取CSV数据,并获取汉字、标记对应的下标,以下标存储。
#------------------------------------------------------------------------ #第二步 数据读取 #------------------------------------------------------------------------ def read_corpus(corpus_path, vocab2idx, label2idx): datas, labels = [], [] with open(corpus_path, encoding='utf-8') as csvfile: reader = csv.reader(csvfile) sent_, tag_ = [], [] for row in reader: word,label = row[0],row[1] if word!="" and label!="": #断句 sent_.append(word) tag_.append(label) """ print(sent_) #['晉', '樂', '王', '鮒', '曰', ':'] print(tag_) #['S-LOC', 'B-PER', 'I-PER', 'E-PER', 'O', 'O'] """ else: #vocab2idx[0] => <PAD> sent_ids = [vocab2idx[char] if char in vocab2idx else vocab2idx['<UNK>'] for char in sent_] tag_ids = [label2idx[label] if label in label2idx else 0 for label in tag_] """ print(sent_ids,tag_ids) for idx,idy in zip(sent_ids,tag_ids): print(idx2vocab[idx],idx2label[idy]) #[2, 3, 4, 5, 6, 7] [1, 6, 7, 8, 0, 0] #晉 S-LOC 樂 B-PER 王 I-PER 鮒 E-PER 曰 O : O """ datas.append(sent_ids) #按句插入列表 labels.append(tag_ids) sent_, tag_ = [], [] return datas, labels #原始数据 train_datas_, train_labels_ = read_corpus(train_data_path, vocab2idx, label2idx) test_datas_, test_labels_ = read_corpus(test_data_path, vocab2idx, label2idx) #输出测试结果 (第五句语料) print(len(train_datas_),len(train_labels_),len(test_datas_),len(test_labels_)) print(train_datas_[5]) print([idx2vocab[idx] for idx in train_datas_[5]]) print(train_labels_[5]) print([idx2label[idx] for idx in train_labels_[5]])
输出结果如下,获取汉字和BIO标记的下标。
[2, 3, 4, 5, 6, 7] [1, 6, 7, 8, 0, 0]
晉 S-LOC 樂 B-PER 王 I-PER 鮒 E-PER 曰 O : O
其中,第5行数据示例如下:
[46, 47, 48, 47, 49, 50, 51, 52, 53, 54, 55, 56]
['齊', '、', '衛', '、', '陳', '大', '夫', '其', '不', '免', '乎', '!']
[1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0]
['S-LOC', 'O', 'S-LOC', 'O', 'S-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
对应语料如下:
第三步,数据填充和one-hot编码。
#------------------------------------------------------------------------ #第三步 数据填充 one-hot编码 #------------------------------------------------------------------------ import keras from keras.preprocessing import sequence MAX_LEN = 100 VOCAB_SIZE = len(vocab2idx) CLASS_NUMS = len(label2idx) #padding data print('padding sequences') train_datas = sequence.pad_sequences(train_datas_, maxlen=MAX_LEN) train_labels = sequence.pad_sequences(train_labels_, maxlen=MAX_LEN) test_datas = sequence.pad_sequences(test_datas_, maxlen=MAX_LEN) test_labels = sequence.pad_sequences(test_labels_, maxlen=MAX_LEN) print('x_train shape:', train_datas.shape) print('x_test shape:', test_datas.shape) #encoder one-hot train_labels = keras.utils.to_categorical(train_labels, CLASS_NUMS) test_labels = keras.utils.to_categorical(test_labels, CLASS_NUMS) print('trainlabels shape:', train_labels.shape) print('testlabels shape:', test_labels.shape)
输出结果如下所示:
padding sequences
x_train shape: (xxx, 100)
x_test shape: (xxx, 100)
trainlabels shape: (xxx, 100, 13)
testlabels shape: (xxx, 100, 13)
编码示例如下:
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 2163 410 294
980 18]
第四步,构建BiLSTM+CRF模型。
#------------------------------------------------------------------------ #第四步 构建BiLSTM+CRF模型 # pip install git+https://www.github.com/keras-team/keras-contrib.git # 安装过程详见文件夹截图 # ModuleNotFoundError: No module named ‘keras_contrib’ #------------------------------------------------------------------------ import numpy as np from keras.models import Sequential from keras.models import Model from keras.layers import Masking, Embedding, Bidirectional, LSTM, \ Dense, Input, TimeDistributed, Activation from keras_contrib.layers import CRF from keras_contrib.losses import crf_loss from keras_contrib.metrics import crf_viterbi_accuracy from keras import backend as K from keras.models import load_model from sklearn import metrics EPOCHS = 2 EMBED_DIM = 128 HIDDEN_SIZE = 64 MAX_LEN = 100 VOCAB_SIZE = len(vocab2idx) CLASS_NUMS = len(label2idx) K.clear_session() print(VOCAB_SIZE, CLASS_NUMS) #3319 13 #模型构建 BiLSTM-CRF inputs = Input(shape=(MAX_LEN,), dtype='int32') x = Masking(mask_value=0)(inputs) x = Embedding(VOCAB_SIZE, EMBED_DIM, mask_zero=False)(x) #修改掩码False x = Bidirectional(LSTM(HIDDEN_SIZE, return_sequences=True))(x) x = TimeDistributed(Dense(CLASS_NUMS))(x) outputs = CRF(CLASS_NUMS)(x) model = Model(inputs=inputs, outputs=outputs) model.summary()
输出结果如下图所示,显示该模型的结构。
第五步,模型训练和测试。flag标记变量分别设置为“train”和“test”。
flag = "train" if flag=="train": #模型训练 model.compile(loss=crf_loss, optimizer='adam', metrics=[crf_viterbi_accuracy]) model.fit(train_datas, train_labels, epochs=EPOCHS, verbose=1, validation_split=0.1) score = model.evaluate(test_datas, test_labels, batch_size=256) print(model.metrics_names) print(score) model.save("bilstm_ner_model.h5") elif flag=="test": #训练模型 char_vocab_path = "char_vocabs_.txt" #字典文件 model_path = "bilstm_ner_model.h5" #模型文件 ner_labels = label2idx special_words = ['<PAD>', '<UNK>'] MAX_LEN = 100 #预测结果 model = load_model(model_path, custom_objects={'CRF': CRF}, compile=False) y_pred = model.predict(test_datas) y_labels = np.argmax(y_pred, axis=2) #取最大值 z_labels = np.argmax(test_labels, axis=2) #真实值 word_labels = test_datas #真实值 k = 0 final_y = [] #预测结果对应的标签 final_z = [] #真实结果对应的标签 final_word = [] #对应的特征单词 while k<len(y_labels): y = y_labels[k] for idx in y: final_y.append(idx2label[idx]) #print("预测结果:", [idx2label[idx] for idx in y]) z = z_labels[k] for idx in z: final_z.append(idx2label[idx]) #print("真实结果:", [idx2label[idx] for idx in z]) word = word_labels[k] for idx in word: final_word.append(idx2vocab[idx]) k += 1 print("最终结果大小:", len(final_y),len(final_z)) n = 0 numError = 0 numRight = 0 while n<len(final_y): if final_y[n]!=final_z[n] and final_z[n]!='O': numError += 1 if final_y[n]==final_z[n] and final_z[n]!='O': numRight += 1 n += 1 print("预测错误数量:", numError) print("预测正确数量:", numRight) print("Acc:", numRight*1.0/(numError+numRight)) print("预测单词:", [idx2vocab[idx] for idx in test_datas_[5]]) print("真实结果:", [idx2label[idx] for idx in test_labels_[5]]) print("预测结果:", [idx2label[idx] for idx in y_labels[5]][-len(test_datas_[5]):])
训练结果如下所示:
Epoch 1/2
32/8439 [..............................] - ETA: 6:51 - loss: 2.5549 - crf_viterbi_accuracy: 3.1250e-04
64/8439 [..............................] - ETA: 3:45 - loss: 2.5242 - crf_viterbi_accuracy: 0.1142
8439/8439 [==============================] - 118s 14ms/step - loss: 0.1833 - crf_viterbi_accuracy: 0.9591 - val_loss: 0.0688 - val_crf_viterbi_accuracy: 0.9820
Epoch 2/10
32/8439 [..............................] - ETA: 19s - loss: 0.0644 - crf_viterbi_accuracy: 0.9825
64/8439 [..............................] - ETA: 42s - loss: 0.0592 - crf_viterbi_accuracy: 0.9845
...
['loss', 'crf_viterbi_accuracy']
[0.043232945389307574, 0.9868513941764832]
最终测试结果如下所示,由于作者数据集仅放了少量数据,且未进行调参比较,真实数据更多且效果会更好。
预测错误数量: 2183
预测正确数量: 2209
Acc: 0.5029599271402551
预测单词: ['冬', ',', '楚', '公', '子', '罷', '如', '晉', '聘', ',', '且', '涖', '盟', '。']
真实结果: ['O', 'O', 'B-PER', 'I-PER', 'I-PER', 'E-PER', 'O', 'S-LOC', 'O', 'O', 'O', 'O', 'O', 'O']
预测结果: ['O', 'O', 'B-PER', 'E-PER', 'E-PER', 'E-PER', 'O', 'S-LOC', 'O', 'O', 'O', 'O', 'O', 'O']
接下来构建BiGRU-CRF代码,以完整代码为例,并将预测结果存储在CSV文件上。
#encoding:utf-8 # By: Eastmount WuShuai 2024-02-05 import re import os import csv import sys from get_data import build_vocab #调取第一阶段函数 #------------------------------------------------------------------------ #第一步 数据预处理 #------------------------------------------------------------------------ train_data_path = "data/train.csv" test_data_path = "data/test.csv" val_data_path = "data/val.csv" char_vocab_path = "char_vocabs.txt" #字典文件(防止多次写入仅读首次生成文件) special_words = ['<PAD>', '<UNK>'] #特殊词表示 final_words = [] #统计词典(不重复出现) final_labels = [] #统计标记(不重复出现) #BIO标记的标签 字母O初始标记为0 #label2idx = build_vocab() label2idx = {'O': 0, 'S-LOC': 1, 'B-LOC': 2, 'I-LOC': 3, 'E-LOC': 4, 'S-PER': 5, 'B-PER': 6, 'I-PER': 7, 'E-PER': 8, 'S-TIM': 9, 'B-TIM': 10, 'E-TIM': 11, 'I-TIM': 12 } #索引和BIO标签对应 idx2label = {idx: label for label, idx in label2idx.items()} #读取字符词典文件 with open(char_vocab_path, "r") as fo: char_vocabs = [line.strip() for line in fo] char_vocabs = special_words + char_vocabs #字符和索引编号对应 idx2vocab = {idx: char for idx, char in enumerate(char_vocabs)} vocab2idx = {char: idx for idx, char in idx2vocab.items()} #------------------------------------------------------------------------ #第二步 数据读取 #------------------------------------------------------------------------ def read_corpus(corpus_path, vocab2idx, label2idx): datas, labels = [], [] with open(corpus_path, encoding='utf-8') as csvfile: reader = csv.reader(csvfile) sent_, tag_ = [], [] for row in reader: word,label = row[0],row[1] if word!="" and label!="": #断句 sent_.append(word) tag_.append(label) else: #vocab2idx[0] => <PAD> sent_ids = [vocab2idx[char] if char in vocab2idx else vocab2idx['<UNK>'] for char in sent_] tag_ids = [label2idx[label] if label in label2idx else 0 for label in tag_] datas.append(sent_ids) #按句插入列表 labels.append(tag_ids) sent_, tag_ = [], [] return datas, labels #原始数据 train_datas_, train_labels_ = read_corpus(train_data_path, vocab2idx, label2idx) test_datas_, test_labels_ = read_corpus(test_data_path, vocab2idx, label2idx) #------------------------------------------------------------------------ #第三步 数据填充 one-hot编码 #------------------------------------------------------------------------ import keras from keras.preprocessing import sequence MAX_LEN = 100 VOCAB_SIZE = len(vocab2idx) CLASS_NUMS = len(label2idx) #padding data print('padding sequences') train_datas = sequence.pad_sequences(train_datas_, maxlen=MAX_LEN) train_labels = sequence.pad_sequences(train_labels_, maxlen=MAX_LEN) test_datas = sequence.pad_sequences(test_datas_, maxlen=MAX_LEN) test_labels = sequence.pad_sequences(test_labels_, maxlen=MAX_LEN) #encoder one-hot train_labels = keras.utils.to_categorical(train_labels, CLASS_NUMS) test_labels = keras.utils.to_categorical(test_labels, CLASS_NUMS) #------------------------------------------------------------------------ #第四步 构建BiGRU+CRF模型 #------------------------------------------------------------------------ import numpy as np from keras.models import Sequential from keras.models import Model from keras.layers import Masking, Embedding, Bidirectional, LSTM, GRU, \ Dense, Input, TimeDistributed, Activation from keras_contrib.layers import CRF from keras_contrib.losses import crf_loss from keras_contrib.metrics import crf_viterbi_accuracy from keras import backend as K from keras.models import load_model from sklearn import metrics EPOCHS = 2 EMBED_DIM = 128 HIDDEN_SIZE = 64 MAX_LEN = 100 VOCAB_SIZE = len(vocab2idx) CLASS_NUMS = len(label2idx) K.clear_session() print(VOCAB_SIZE, CLASS_NUMS) #模型构建 BiGRU-CRF inputs = Input(shape=(MAX_LEN,), dtype='int32') x = Masking(mask_value=0)(inputs) x = Embedding(VOCAB_SIZE, EMBED_DIM, mask_zero=False)(x) #修改掩码False x = Bidirectional(GRU(HIDDEN_SIZE, return_sequences=True))(x) x = TimeDistributed(Dense(CLASS_NUMS))(x) outputs = CRF(CLASS_NUMS)(x) model = Model(inputs=inputs, outputs=outputs) model.summary() flag = "test" if flag=="train": #模型训练 model.compile(loss=crf_loss, optimizer='adam', metrics=[crf_viterbi_accuracy]) model.fit(train_datas, train_labels, epochs=EPOCHS, verbose=1, validation_split=0.1) score = model.evaluate(test_datas, test_labels, batch_size=256) print(model.metrics_names) print(score) model.save("bigru_ner_model.h5") elif flag=="test": #训练模型 char_vocab_path = "char_vocabs_.txt" #字典文件 model_path = "bigru_ner_model.h5" #模型文件 ner_labels = label2idx special_words = ['<PAD>', '<UNK>'] MAX_LEN = 100 #预测结果 model = load_model(model_path, custom_objects={'CRF': CRF}, compile=False) y_pred = model.predict(test_datas) y_labels = np.argmax(y_pred, axis=2) #取最大值 z_labels = np.argmax(test_labels, axis=2) #真实值 word_labels = test_datas #真实值 k = 0 final_y = [] #预测结果对应的标签 final_z = [] #真实结果对应的标签 final_word = [] #对应的特征单词 while k<len(y_labels): y = y_labels[k] for idx in y: final_y.append(idx2label[idx]) z = z_labels[k] for idx in z: final_z.append(idx2label[idx]) word = word_labels[k] for idx in word: final_word.append(idx2vocab[idx]) k += 1 n = 0 numError = 0 numRight = 0 while n<len(final_y): if final_y[n]!=final_z[n] and final_z[n]!='O': numError += 1 if final_y[n]==final_z[n] and final_z[n]!='O': numRight += 1 n += 1 print("预测错误数量:", numError) print("预测正确数量:", numRight) print("Acc:", numRight*1.0/(numError+numRight)) print("预测单词:", [idx2vocab[idx] for idx in test_datas_[5]]) print("真实结果:", [idx2label[idx] for idx in test_labels_[5]]) print("预测结果:", [idx2label[idx] for idx in y_labels[5]][-len(test_datas_[5]):]) #文件存储 fw = open("Final_BiGRU_CRF_Result.csv", "w", encoding="utf8", newline='') fwrite = csv.writer(fw) fwrite.writerow(['pre_label','real_label', 'word']) n = 0 while n<len(final_y): fwrite.writerow([final_y[n],final_z[n],final_word[n]]) n += 1 fw.close()
输出结果如下所示:
['loss', 'crf_viterbi_accuracy']
[0.03543611364953834, 0.9894005656242371]
生成文件如下图所示:
写到这里这篇文章就结束,希望对您有所帮助,后续将结合经典的Bert进行分享。忙碌的2024,真的很忙,项目本子论文毕业工作,等忙完后好好写几篇安全博客,感谢支持和陪伴,尤其是家人的鼓励和支持, 继续加油!
人生路是一个个十字路口,一次次博弈,一次次纠结和得失组成。得失得失,有得有失,不同的选择,不一样的精彩。虽然累和忙,但看到小珞珞还是挺满足的,感谢家人的陪伴。望小珞能开心健康成长,爱你们喔,继续干活,加油!
(By:Eastmount 2024-02-07 夜于贵阳 http://blog.csdn.net/eastmount/ )
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。