赞
踩
当生成标题之后,就需要生成摘要,摘要生成也是通过,GTP-2生成,但师弟说有抽取法,目前还没有比较GTP-2生成法得到的摘要,但是通过人工来看的化GTP-2摘要生成效果并不是很好,目前也没有在垂直领域实验,只是跑通了大佬的模型,也还是一篇学习记录。
依旧是读和修改大佬代码原代码传送门,解决了数据集的问题,和一些模型参数问题。
目前GitHub上大佬没有提供数据集,并且没有现成符合这份代码的,首先我需要将数据集进行加工
数据集我用的是nlpcc_data.json,然后,将数据集进行调整,如果不想写代码调整可以直接把后缀改成.txt然后对照着大佬提供的应输入模板调整。如果想写的化第一步就是json转成txt对应文本如下:
这里强调一下,你写好了还是需要注意的,因为python里面的原因,这里都是用单引号,但是,原代码中有一段如图,是按照json,读取的,虽然也能度.txt但是注意里面的’summarization’单引号要换成双引号,直接在txt里面替换就好了。
import json
path='./data/nlpcc_data.json'
f = open(path,'r',encoding='utf-8')
m = json.load(f) # json.load() 这种方法是解析一个文件中的数据 # json.loads() 需要先将文件,读到一个变量作为字符串, 解析一个字符串中的数
with open("test.txt","w",encoding='utf-8') as f:
for index in range(len(m)):
print(m[index])
f.write("{'summarization': '"+m[index]['summarization']+"', 'article': '"+m[index]['article']+"'}\n")
print(type(m[0]))
print(m[1])
之后我们要对这个数据集进行相对应的筛选因为里面会有一些特殊字符和让json.loads读取出错的地方,因此就替换掉:
with open("../data/train_with_summary.txt", 'r', encoding='utf-8') as file:
i = 0
for line in file.readlines():
with open("../data/mini_set_test.txt", 'a', encoding='utf-8') as f:
lines = str(line)
countleft = lines.count('{', 0,len(lines))
countright = lines.count('}', 0, len(lines))
countdouble = lines.count('"', 0, len(lines))
if countleft == 1 and countright == 1 and countdouble==8:
f.write(str(line).replace("\\",""))
这些步骤之后就简单将原代码里面那一段拿出来,单独写一下,看看数据集读取对不对(这步就好像有些数据集检查一下,如果你确信自己改的数据集ok的就省略):
count = 0
with open('./data/mini_set_test.txt', 'r', encoding="utf-8") as file:
for line in tqdm(file.readlines()):
count +=1
file_line = json.loads(line)
print(str(count) + "ok"+line)
将数据集放到百度云盘中,数据集提取入口:链接:https://pan.baidu.com/s/1VYbTPXNG0qvx09mfC4xPIg
目前班子是3090ti,进行重新训练但只能微调batch_size但效果感觉还是不怎么好,我也不清楚是要领域的数据的问题,还是什么问题。我目前想到的办法是重新跑领域GTP-2 Chinese,让他变成垂直领域对应的中文文本。我训练的权重:
链接:https://pan.baidu.com/s/15LXzJOz22Qm7vb6Uu7NWQQ
提取码:yen2
所需要的中文权重(wiki_common_model)
链接:https://pan.baidu.com/s/1DKJBEo3J54X7nifqHZ-n3Q
def setup_train_args():
"""
设置训练参数
"""
parser = argparse.ArgumentParser()
parser.add_argument('--device', default='0', type=str, required=False, help='设置使用哪些显卡')
parser.add_argument('--no_cuda', default=False, help='使用GPU进行训练')
parser.add_argument('--model_config', default='config/model_config_dialogue_small.json', type=str, required=False,
help='选择模型参数')
parser.add_argument('--vocab_path', default='vocabulary/vocab_small.txt', type=str, required=False, help='选择词库')
parser.add_argument('--train_raw_path', default='data/mini_set_test.txt', type=str, required=False, help='原始训练语料')
parser.add_argument('--train_tokenized_path', default='data/train_tokenized.txt', type=str,
required=False,
help='将原始训练语料tokenize之后的数据的存放位置')
parser.add_argument('--log_path', default='data/training.log', type=str, required=False, help='训练日志存放位置')
parser.add_argument('--raw', default=True, help='是否对原始训练语料做tokenize。若尚未对原始训练语料进行tokenize,则指定该参数')
parser.add_argument('--epochs', default=100, type=int, required=False, help='训练的轮次')
parser.add_argument('--batch_size', default=6, type=int, required=False, help='训练batch size')
parser.add_argument('--lr', default=1.5e-6, type=float, required=False, help='学习率')
parser.add_argument('--warmup_steps', default=2000, type=int, required=False, help='warm up步数')
parser.add_argument('--log_step', default=1, type=int, required=False, help='多少步汇报一次loss')
parser.add_argument('--gradient_accumulation', default=2, type=int, required=False, help='梯度积累')
parser.add_argument('--max_grad_norm', default=1.0, type=float, required=False)
parser.add_argument('--dialogue_model_output_path', default='summary_model/', type=str, required=False,
help='对话模型输出路径')
parser.add_argument('--pretrained_model', default='wiki_common_model/', type=str, required=False, help='预训练的GPT2模型的路径')
parser.add_argument('--writer_dir', default='tensorboard_summary/', type=str, required=False, help='Tensorboard路径')
parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的')
parser.add_argument('--num_workers', type=int, default=1, help="dataloader加载数据时使用的线程数量")
parser.add_argument('--train_mmi', default=False, help="若指定该参数,则训练DialoGPT的MMI模型")
parser.add_argument('--train_mmi_tokenized_path', default='data/train_mmi_tokenized.txt', type=str,
required=False,
help='将原始训练语料的每段对话翻转,然后进行tokenize之后的数据的存放位置,用于训练MMI模型')
parser.add_argument('--mmi_model_output_path', default='mmi_model', type=str, required=False, help='MMI模型保存路径')
# parser.add_argument('--max_len', type=int, default=60, help='每个utterance的最大长度,超过指定长度则进行截断')
# parser.add_argument('--max_history_len', type=int, default=4, help="dialogue history的最大长度")
return parser.parse_args()
效果如图就说不上来的Emmmm
自我学习GitHub连接地址:
https://github.com/zhichen-roger/GTP-2_AbstractGeneration_learn.git
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。