赞
踩
很早之前,我曾经写过一个古体诗生成器(详情可以戳TensorFlow练手项目二:基于循环神经网络(RNN)的古诗生成器),那个时候用的还是Python 2.7和TensorFlow 1.4。
随着框架的迭代,API 的变更,老项目已经很难无障碍运行起来了。有不少朋友在老项目下提出了各种问题,于是,我就萌生了使用TensorFlow 2.0重写项目的想法。
这不,终于抽空,重写了这个项目。
完整的项目已经放到了GitHub上:
AaronJny/DeepLearningExamples/tf2-rnn-poetry-generator (https://github.com/AaronJny/DeepLearningExamples/tree/master/tf2-rnn-poetry-generator)
先对项目做个简单展示。项目主要包含如下功能:
随机生成一首古体诗:
金鹤有僧心,临天寄旧身。
石松惊枕树,红鸟发禅新。
不到风前远,何人怨夕时。
明期多尔处,闲此不依迟。
水泉临鸟声,北去暮空行。
林阁多开雪,楼庭起洞城。
夜来疏竹外,柳鸟暗苔清。
寂寂重阳里,悠悠一钓矶。
续写一首古体诗(以"床前明月光,"为例):
床前明月光,翠席覆银丝。
岁气分龙阁,无人入鸟稀。
圣明无泛物,云庙逐雕旗。
永夜重江望,南风正送君。
床前明月光,清水入寒云。
远景千山雨,萧花入翠微。
影云虚雪润,花影落云斜。
独去江飞夜,谁能作一花。
随机生成一首藏头诗(以"海阔天空"为例):
海口多无定,
阔庭何所难。
天山秋色上,
空石昼尘连。
海庭愁不定,
阔处到南关。
天阙青秋上,
空城雁渐催。
下面开始讲解项目实现过程。
转载请注明来源:https://blog.csdn.net/aaronjny/article/details/103806954
跟老项目一样,我们仍然使用四万首唐诗的文本作为训练集(已经上传,可以直接从GitHub上下载)。我们打开文本,看一下数据格式:
能够看到,文本中每行是一首诗,且使用冒号分割,前面是标题,后面是正文,且诗的长度不一。
我们对数据的处理流程大致如下:
代码如下:
# -*- coding: utf-8 -*- # @File : dataset.py # @Author : AaronJny # @Time : 2019/12/30 # @Desc : 构建数据集 from collections import Counter import math import numpy as np import tensorflow as tf import settings # 禁用词 disallowed_words = settings.DISALLOWED_WORDS # 句子最大长度 max_len = settings.MAX_LEN # 最小词频 min_word_frequency = settings.MIN_WORD_FREQUENCY # mini batch 大小 batch_size = settings.BATCH_SIZE # 加载数据集 with open(settings.DATASET_PATH, 'r', encoding='utf-8') as f: lines = f.readlines() # 将冒号统一成相同格式 lines = [line.replace(':', ':') for line in lines] # 数据集列表 poetry = [] # 逐行处理读取到的数据 for line in lines: # 有且只能有一个冒号用来分割标题 if line.count(':') != 1: continue # 后半部分不能包含禁止词 __, last_part = line.split(':') ignore_flag = False for dis_word in disallowed_words: if dis_word in last_part: ignore_flag = True break if ignore_flag: continue # 长度不能超过最大长度 if len(last_part) > max_len - 2: continue poetry.append(last_part.replace('\n', '')) # 统计词频 counter = Counter() for line in poetry: counter.update(line) # 过滤掉低频词 _tokens = [(token, count) for token, count in counter.items() if count >= min_word_frequency] # 按词频排序 _tokens = sorted(_tokens, key=lambda x: -x[1]) # 去掉词频,只保留词列表 _tokens = [token for token, count in _tokens] # 将特殊词和数据集中的词拼接起来 _tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + _tokens # 创建词典 token->id映射关系 token_id_dict = dict(zip(_tokens, range(len(_tokens)))) # 使用新词典重新建立分词器 tokenizer = Tokenizer(token_id_dict) # 混洗数据 np.random.shuffle(poetry)
代码很简单,注释也很清晰,就不一行一行说了。有几点需要注意一下:
class Tokenizer:
"""
分词器
"""
def __init__(self, token_dict):
# 词->编号的映射
self.token_dict = token_dict
# 编号->词的映射
self.token_dict_rev =
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。