当前位置:   article > 正文

NLP——基于transformer 的翻译系统_使用transformer实现一个翻译系统

使用transformer实现一个翻译系统

基于transformer 的翻译系统

论文:https://arxiv.org/abs/1706.03762
项目地址:https://github.com/audier/my_deep_project/tree/master/NLP/4.transformer

本文实现了一个基于自注意力机制的翻译系统。注意力机制是机制是这两年比较火的方向,其中去年提出的自注意力机制更是各位大神的宠儿,网上可读性较高的代码有一点点不完美的地方就是mask没有发挥作用,最近也在做翻译系统,于是整理本文分享思路。
本文代码参考网上可读性较好的项目:https://github.com/Kyubyong/transformer
但是作者在key_mask和queries_mask中有一定的失误,本文修改了对应的模型和multihead层,使该功能正常。

转载请注明出处:https://blog.csdn.net/chinatelecom08

1. 数据处理

本文使用数据:https://github.com/audier/my_deep_project/tree/master/NLP/4.transformer

  • 读取数据
  • 分别保存为inputs,outputs
with open('cmn.txt', 'r', encoding='utf8') as f:
    data = f.readlines()
  • 1
  • 2
  • 1
  • 2
  • 3
from tqdm import tqdm
  • 1

inputs = []
outputs = []
for line in tqdm(data[:10000]):
[en, ch] = line.strip(’\n’).split(’\t’)
inputs.append(en.replace(’,’,’ ,’)[:-1].lower())
outputs.append(ch[:-1])

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
100%|██████████| 10000/10000 [00:00<00:00, 473991.57it/s]

 
 
 
 
  • 1
  • 1
  • 查看数据格式
print(inputs[:10])

 
 
 
 
  • 1
  • 1
['hi', 'hi', 'run', 'wait', 'hello', 'i try', 'i won', 'oh no', 'cheers', 'he ran']

 
 
 
 
  • 1
  • 1
print(outputs[:10])

 
 
 
 
  • 1
  • 1
['嗨', '你好', '你用跑的', '等等', '你好', '让我来', '我赢了', '不会吧', '乾杯', '他跑了']

 
 
 
 
  • 1
  • 1

1.1 英文分词

我们将英文用空格隔开即可,但是需要稍微修改一下,将大写字母全部用小写字母代替。在上文中使用.lower进行了替代。

for line in tqdm(data):
    [en, ch] = line.strip('\n').split('\t')
    inputs.append(en[:-1].lower())
    outputs.append(ch[:-1])
  • 1
  • 2
  • 3
  • 4
  • 1
  • 2
  • 3
  • 4
  • 5

此处我们只需要将英文用空格分开即可。

inputs = [en.split(' ') for en in inputs]

 
 
 
 
  • 1
  • 1
print(inputs[:10])

 
 
 
 
  • 1
  • 1
[['hi'], ['hi'], ['run'], ['wait'], ['hello'], ['i', 'try'], ['i', 'won'], ['oh', 'no'], ['cheers'], ['he', 'ran']]

 
 
 
 
  • 1
  • 1

1.2 中文分词

  • 中文分词选择结巴分词工具。
import jieba
outputs = [[char for char in jieba.cut(line) if char != ' '] for line in outputs]

 
 
 
 
  • 1
  • 2
  • 1
  • 2
  • 也可以用hanlp。
from pyhanlp import *
outputs = [[term.word for term in HanLP.segment(line) if term.word != ' '] for line in outputs]

 
 
 
 
  • 1
  • 2
  • 1
  • 2
  • 或者按字分词?

  • 最终我选择了结巴分词

import jieba
jieba_outputs = [[char for char in jieba.cut(line) if char != ' '] for line in outputs[-10:]]
print(jieba_outputs)

 
 
 
 
  • 1
  • 2
  • 3
  • 1
  • 2
  • 3
[['你', '不應', '該', '去', '那裡', '的'], ['你', '以前', '吸煙', ',', '不是', '嗎'], ['你現', '在', '最好', '回家'], ['你', '今天', '最好', '不要', '出門'], ['你', '滑雪', '比', '我', '好'], ['你', '正在', '把', '我', '杯子', '里', '的', '东西', '喝掉'], ['你', '并', '不', '满意', ',', '对', '吧'], ['你', '病', '了', ',', '该', '休息', '了'], ['你', '很', '勇敢', ',', '不是', '嗎'], ['你', '的', '意志力', '很強']]

 
 
 
 
  • 1
  • 1
outputs = [[char for char in jieba.cut(line) if char != ' '] for line in tqdm(outputs)]

 
 
 
 
  • 1
  • 1
100%|██████████| 10000/10000 [00:00<00:00, 11981.68it/s]

 
 
 
 
  • 1
  • 1

1.3 生成字典

将英文和中文映射为id

def get_vocab(data, init=['<PAD>']):
    vocab = init
    for line in tqdm(data):
        for word in line:
            if word not in vocab:
                vocab.append(word)
    return vocab
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

SOURCE_CODES = [’<PAD>’]
TARGET_CODES = [’<PAD>’, ‘<GO>’, ‘<EOS>’]
encoder_vocab = get_vocab(inputs, init=SOURCE_CODES)
decoder_vocab = get_vocab(outputs, init=TARGET_CODES)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
100%|██████████| 10000/10000 [00:00<00:00, 20585.73it/s]
100%|██████████| 10000/10000 [00:01<00:00, 7808.17it/s]

 
 
 
 
  • 1
  • 2
  • 1
  • 2
print(encoder_vocab[:10])
print(decoder_vocab[:10])

 
 
 
 
  • 1
  • 2
  • 1
  • 2
['<PAD>', 'hi', 'run', 'wait', 'hello', 'i', 'try', 'won', 'oh', 'no']
['<PAD>', '<GO>', '<EOS>', '嗨', '你好', '你', '用', '跑', '的', '等等']

 
 
 
 
  • 1
  • 2
  • 1
  • 2

1.4 数据生成器

翻译系统训练所需要的数据形式,跟谷歌gnmt输入致,gnmt的原理可以参考:https://github.com/tensorflow/nmt
大概是:

  • 编码器输入:I am a student
  • 解码器输入:(go) Je suis étudiant
  • 解码器输出:Je suis étudiant (end)

即解码器输入起始部分有个开始符号,输出句尾有个结束符号。

encoder_inputs = [[encoder_vocab.index(word) for word in line] for line in inputs]
decoder_inputs = [[decoder_vocab.index('<GO>')] + [decoder_vocab.index(word) for word in line] for line in outputs]
decoder_targets = [[decoder_vocab.index(word) for word in line] + [decoder_vocab.index('<EOS>')] for line in outputs]

 
 
 
 
  • 1
  • 2
  • 3
  • 1
  • 2
  • 3
print(decoder_inputs[:4])
print(decoder_targets[:4])

 
 
 
 
  • 1
  • 2
  • 1
  • 2
[[1, 3], [1, 4], [1, 5, 6, 7, 8], [1, 9]]
[[3, 2], [4, 2], [5, 6, 7, 8, 2], [9, 2]]

 
 
 
 
  • 1
  • 2
  • 1
  • 2
import numpy as np
  • 1

def get_batch(encoder_inputs, decoder_inputs, decoder_targets, batch_size=4):
batch_num = len(encoder_inputs) // batch_size
for k in range(batch_num):
begin = k batch_size
end = begin + batch_size
en_input_batch = encoder_inputs[begin:end]
de_input_batch = decoder_inputs[begin:end]
de_target_batch = decoder_targets[begin:end]
max_en_len = max([len(line) for line in en_input_batch])
max_de_len = max([len(line) for line in de_input_batch])
en_input_batch = np.array([line + [0] (max_en_len-len(line)) for line in en_input_batch])
de_input_batch = np.array([line + [0] (max_de_len-len(line)) for line in de_input_batch])
de_target_batch = np.array([line + [0] (max_de_len-len(line)) for line in de_target_batch])
yield en_input_batch, de_input_batch, de_target_batch

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
batch = get_batch(encoder_inputs, decoder_inputs, decoder_targets, batch_size=4)
next(batch)

 
 
 
 
  • 1
  • 2
  • 1
  • 2
(array([[1],
        [1],
        [2],
        [3]]), array([[1, 3, 0, 0, 0],
        [1, 4, 0, 0, 0],
        [1, 5, 6, 7, 8],
        [1, 9, 0, 0, 0]]), array([[3, 2, 0, 0, 0],
        [4, 2, 0, 0, 0],
        [5, 6, 7, 8, 2],
        [9, 2, 0, 0, 0]]))

 
 
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

2. 构建模型

模型结构如下:

在这里插入图片描述
其中主要建模组件下面都会给出。

论文:https://arxiv.org/abs/1706.03762
关于论文讲解:百度即可,对着原论文代码一起看。
我个人觉得结合代码就会很好理解。

import tensorflow as tf

 
 
 
 
  • 1
  • 1

2.1 构造建模组件

下面代码实现了图片结构中的各个功能组件。

layer norm层

在框框的位置。
在这里插入图片描述

def normalize(inputs, 
              epsilon = 1e-8,
              scope="ln",
              reuse=None):
    '''Applies layer normalization.
Args:
  inputs: A tensor with 2 or more dimensions, where the first dimension has
    `batch_size`.
  epsilon: A floating number. A very small number for preventing ZeroDivision Error.
  scope: Optional scope for `variable_scope`.
  reuse: Boolean, whether to reuse the weights of a previous layer
    by the same name.

Returns:
  A tensor with the same shape and data dtype as `inputs`.
'''</span>
<span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span>scope<span class="token punctuation">,</span> reuse<span class="token operator">=</span>reuse<span class="token punctuation">)</span><span class="token punctuation">:</span>
    inputs_shape <span class="token operator">=</span> inputs<span class="token punctuation">.</span>get_shape<span class="token punctuation">(</span><span class="token punctuation">)</span>
    params_shape <span class="token operator">=</span> inputs_shape<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">:</span><span class="token punctuation">]</span>

    mean<span class="token punctuation">,</span> variance <span class="token operator">=</span> tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>moments<span class="token punctuation">(</span>inputs<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> keep_dims<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span>
    beta<span class="token operator">=</span> tf<span class="token punctuation">.</span>Variable<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span>params_shape<span class="token punctuation">)</span><span class="token punctuation">)</span>
    gamma <span class="token operator">=</span> tf<span class="token punctuation">.</span>Variable<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>ones<span class="token punctuation">(</span>params_shape<span class="token punctuation">)</span><span class="token punctuation">)</span>
    normalized <span class="token operator">=</span> <span class="token punctuation">(</span>inputs <span class="token operator">-</span> mean<span class="token punctuation">)</span> <span class="token operator">/</span> <span class="token punctuation">(</span> <span class="token punctuation">(</span>variance <span class="token operator">+</span> epsilon<span class="token punctuation">)</span> <span class="token operator">**</span> <span class="token punctuation">(</span><span class="token number">.5</span><span class="token punctuation">)</span> <span class="token punctuation">)</span>
    outputs <span class="token operator">=</span> gamma <span class="token operator">*</span> normalized <span class="token operator">+</span> beta

<span class="token keyword">return</span> outputs
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 1
  • 2
  • 3
  • 4
  • 5
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
embedding层

这里值得一提的是本文的position encoding也是用embedding层表示,原论文中说用公式或者embedding层自己训练都可以。
在这里插入图片描述

def embedding(inputs, 
              vocab_size, 
              num_units, 
              zero_pad=True, 
              scale=True,
              scope="embedding", 
              reuse=None):
    '''Embeds a given tensor.
    Args:
      inputs: A `Tensor` with type `int32` or `int64` containing the ids
         to be looked up in `lookup table`.
      vocab_size: An int. Vocabulary size.
      num_units: An int. Number of embedding hidden units.
      zero_pad: A boolean. If True, all the values of the fist row (id 0)
        should be constant zeros.
      scale: A boolean. If True. the outputs is multiplied by sqrt num_units.
      scope: Optional scope for `variable_scope`.
      reuse: Boolean, whether to reuse the weights of a previous layer
        by the same name.
    Returns:
      A `Tensor` with one more rank than inputs's. The last dimensionality
        should be `num_units`.
For example,

```
import tensorflow as tf

inputs = tf.to_int32(tf.reshape(tf.range(2*3), (2, 3)))
outputs = embedding(inputs, 6, 2, zero_pad=True)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(outputs)
&gt;&gt;
[[[ 0.          0.        ]
  [ 0.09754146  0.67385566]
  [ 0.37864095 -0.35689294]]
 [[-1.01329422 -1.09939694]
  [ 0.7521342   0.38203377]
  [-0.04973143 -0.06210355]]]
```

```
import tensorflow as tf

inputs = tf.to_int32(tf.reshape(tf.range(2*3), (2, 3)))
outputs = embedding(inputs, 6, 2, zero_pad=False)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(outputs)
&gt;&gt;
[[[-0.19172323 -0.39159766]
  [-0.43212751 -0.66207761]
  [ 1.03452027 -0.26704335]]
 [[-0.11634696 -0.35983452]
  [ 0.50208133  0.53509563]
  [ 1.22204471 -0.96587461]]]
```
'''</span>
<span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span>scope<span class="token punctuation">,</span> reuse<span class="token operator">=</span>reuse<span class="token punctuation">)</span><span class="token punctuation">:</span>
    lookup_table <span class="token operator">=</span> tf<span class="token punctuation">.</span>get_variable<span class="token punctuation">(</span><span class="token string">'lookup_table'</span><span class="token punctuation">,</span>
                                   dtype<span class="token operator">=</span>tf<span class="token punctuation">.</span>float32<span class="token punctuation">,</span>
                                   shape<span class="token operator">=</span><span class="token punctuation">[</span>vocab_size<span class="token punctuation">,</span> num_units<span class="token punctuation">]</span><span class="token punctuation">,</span>
                                   initializer<span class="token operator">=</span>tf<span class="token punctuation">.</span>contrib<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>xavier_initializer<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    <span class="token keyword">if</span> zero_pad<span class="token punctuation">:</span>
        lookup_table <span class="token operator">=</span> tf<span class="token punctuation">.</span>concat<span class="token punctuation">(</span><span class="token punctuation">(</span>tf<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span>shape<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> num_units<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                                  lookup_table<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span>
    outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>embedding_lookup<span class="token punctuation">(</span>lookup_table<span class="token punctuation">,</span> inputs<span class="token punctuation">)</span>

    <span class="token keyword">if</span> scale<span class="token punctuation">:</span>
        outputs <span class="token operator">=</span> outputs <span class="token operator">*</span> <span class="token punctuation">(</span>num_units <span class="token operator">**</span> <span class="token number">0.5</span><span class="token punctuation">)</span> 

<span class="token keyword">return</span> outputs
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
multihead层

是self-attention的核心思想,务必把原理搞清楚。
在这里插入图片描述
意思是自己跟自己做注意力机制,但是在这之前通过线性变换,将原来的输入映射到8个不同的空间去计算,最后再接到一起。
在这里插入图片描述
该层实现了下面功能,给谷歌鼓掌:
在这里插入图片描述

def multihead_attention(key_emb,
                        que_emb,
                        queries, 
                        keys, 
                        num_units=None, 
                        num_heads=8, 
                        dropout_rate=0,
                        is_training=True,
                        causality=False,
                        scope="multihead_attention", 
                        reuse=None):
    '''Applies multihead attention.
Args:
  queries: A 3d tensor with shape of [N, T_q, C_q].
  keys: A 3d tensor with shape of [N, T_k, C_k].
  num_units: A scalar. Attention size.
  dropout_rate: A floating point number.
  is_training: Boolean. Controller of mechanism for dropout.
  causality: Boolean. If true, units that reference the future are masked. 
  num_heads: An int. Number of heads.
  scope: Optional scope for `variable_scope`.
  reuse: Boolean, whether to reuse the weights of a previous layer
    by the same name.
    
Returns
  A 3d tensor with shape of (N, T_q, C)  
'''</span>
<span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span>scope<span class="token punctuation">,</span> reuse<span class="token operator">=</span>reuse<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token comment"># Set the fall back option for num_units</span>
    <span class="token keyword">if</span> num_units <span class="token keyword">is</span> <span class="token boolean">None</span><span class="token punctuation">:</span>
        num_units <span class="token operator">=</span> queries<span class="token punctuation">.</span>get_shape<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>as_list<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span>
    
    <span class="token comment"># Linear projections</span>
    Q <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>dense<span class="token punctuation">(</span>queries<span class="token punctuation">,</span> num_units<span class="token punctuation">,</span> activation<span class="token operator">=</span>tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>relu<span class="token punctuation">)</span> <span class="token comment"># (N, T_q, C)</span>
    K <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>dense<span class="token punctuation">(</span>keys<span class="token punctuation">,</span> num_units<span class="token punctuation">,</span> activation<span class="token operator">=</span>tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>relu<span class="token punctuation">)</span> <span class="token comment"># (N, T_k, C)</span>
    V <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>dense<span class="token punctuation">(</span>keys<span class="token punctuation">,</span> num_units<span class="token punctuation">,</span> activation<span class="token operator">=</span>tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>relu<span class="token punctuation">)</span> <span class="token comment"># (N, T_k, C)</span>
    
    <span class="token comment"># Split and concat</span>
    Q_ <span class="token operator">=</span> tf<span class="token punctuation">.</span>concat<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>split<span class="token punctuation">(</span>Q<span class="token punctuation">,</span> num_heads<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q, C/h) </span>
    K_ <span class="token operator">=</span> tf<span class="token punctuation">.</span>concat<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>split<span class="token punctuation">(</span>K<span class="token punctuation">,</span> num_heads<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_k, C/h) </span>
    V_ <span class="token operator">=</span> tf<span class="token punctuation">.</span>concat<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>split<span class="token punctuation">(</span>V<span class="token punctuation">,</span> num_heads<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_k, C/h) </span>

    <span class="token comment"># Multiplication</span>
    outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>matmul<span class="token punctuation">(</span>Q_<span class="token punctuation">,</span> tf<span class="token punctuation">.</span>transpose<span class="token punctuation">(</span>K_<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q, T_k)</span>
    
    <span class="token comment"># Scale</span>
    outputs <span class="token operator">=</span> outputs <span class="token operator">/</span> <span class="token punctuation">(</span>K_<span class="token punctuation">.</span>get_shape<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>as_list<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">**</span> <span class="token number">0.5</span><span class="token punctuation">)</span>
    
    <span class="token comment"># Key Masking</span>
    key_masks <span class="token operator">=</span> tf<span class="token punctuation">.</span>sign<span class="token punctuation">(</span>tf<span class="token punctuation">.</span><span class="token builtin">abs</span><span class="token punctuation">(</span>tf<span class="token punctuation">.</span>reduce_sum<span class="token punctuation">(</span>key_emb<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token comment"># (N, T_k)</span>
    key_masks <span class="token operator">=</span> tf<span class="token punctuation">.</span>tile<span class="token punctuation">(</span>key_masks<span class="token punctuation">,</span> <span class="token punctuation">[</span>num_heads<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_k)</span>
    key_masks <span class="token operator">=</span> tf<span class="token punctuation">.</span>tile<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>expand_dims<span class="token punctuation">(</span>key_masks<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> tf<span class="token punctuation">.</span>shape<span class="token punctuation">(</span>queries<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q, T_k)</span>
    
    paddings <span class="token operator">=</span> tf<span class="token punctuation">.</span>ones_like<span class="token punctuation">(</span>outputs<span class="token punctuation">)</span><span class="token operator">*</span><span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">2</span><span class="token operator">**</span><span class="token number">32</span><span class="token operator">+</span><span class="token number">1</span><span class="token punctuation">)</span>
    outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>where<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>equal<span class="token punctuation">(</span>key_masks<span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> paddings<span class="token punctuation">,</span> outputs<span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q, T_k)</span>

    <span class="token comment"># Causality = Future blinding</span>
    <span class="token keyword">if</span> causality<span class="token punctuation">:</span>
        diag_vals <span class="token operator">=</span> tf<span class="token punctuation">.</span>ones_like<span class="token punctuation">(</span>outputs<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># (T_q, T_k)</span>
        tril <span class="token operator">=</span> tf<span class="token punctuation">.</span>linalg<span class="token punctuation">.</span>LinearOperatorLowerTriangular<span class="token punctuation">(</span>diag_vals<span class="token punctuation">)</span><span class="token punctuation">.</span>to_dense<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># (T_q, T_k)</span>
        masks <span class="token operator">=</span> tf<span class="token punctuation">.</span>tile<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>expand_dims<span class="token punctuation">(</span>tril<span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">[</span>tf<span class="token punctuation">.</span>shape<span class="token punctuation">(</span>outputs<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q, T_k)</span>

        paddings <span class="token operator">=</span> tf<span class="token punctuation">.</span>ones_like<span class="token punctuation">(</span>masks<span class="token punctuation">)</span><span class="token operator">*</span><span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">2</span><span class="token operator">**</span><span class="token number">32</span><span class="token operator">+</span><span class="token number">1</span><span class="token punctuation">)</span>
        outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>where<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>equal<span class="token punctuation">(</span>masks<span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> paddings<span class="token punctuation">,</span> outputs<span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q, T_k)</span>

    <span class="token comment"># Activation</span>
    outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>softmax<span class="token punctuation">(</span>outputs<span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q, T_k)</span>
     
    <span class="token comment"># Query Masking</span>
    query_masks <span class="token operator">=</span> tf<span class="token punctuation">.</span>sign<span class="token punctuation">(</span>tf<span class="token punctuation">.</span><span class="token builtin">abs</span><span class="token punctuation">(</span>tf<span class="token punctuation">.</span>reduce_sum<span class="token punctuation">(</span>que_emb<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token comment"># (N, T_q)</span>
    query_masks <span class="token operator">=</span> tf<span class="token punctuation">.</span>tile<span class="token punctuation">(</span>query_masks<span class="token punctuation">,</span> <span class="token punctuation">[</span>num_heads<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q)</span>
    query_masks <span class="token operator">=</span> tf<span class="token punctuation">.</span>tile<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>expand_dims<span class="token punctuation">(</span>query_masks<span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> tf<span class="token punctuation">.</span>shape<span class="token punctuation">(</span>keys<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># (h*N, T_q, T_k)</span>
    outputs <span class="token operator">*=</span> query_masks <span class="token comment"># broadcasting. (N, T_q, C)</span>
      
    <span class="token comment"># Dropouts</span>
    outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>dropout<span class="token punctuation">(</span>outputs<span class="token punctuation">,</span> rate<span class="token operator">=</span>dropout_rate<span class="token punctuation">,</span> training<span class="token operator">=</span>tf<span class="token punctuation">.</span>convert_to_tensor<span class="token punctuation">(</span>is_training<span class="token punctuation">)</span><span class="token punctuation">)</span>
           
    <span class="token comment"># Weighted sum</span>
    outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>matmul<span class="token punctuation">(</span>outputs<span class="token punctuation">,</span> V_<span class="token punctuation">)</span> <span class="token comment"># ( h*N, T_q, C/h)</span>
    
    <span class="token comment"># Restore shape</span>
    outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>concat<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>split<span class="token punctuation">(</span>outputs<span class="token punctuation">,</span> num_heads<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">2</span> <span class="token punctuation">)</span> <span class="token comment"># (N, T_q, C)</span>
          
    <span class="token comment"># Residual connection</span>
    outputs <span class="token operator">+=</span> queries
          
    <span class="token comment"># Normalize</span>
    outputs <span class="token operator">=</span> normalize<span class="token punctuation">(</span>outputs<span class="token punctuation">)</span> <span class="token comment"># (N, T_q, C)</span>

<span class="token keyword">return</span> outputs
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
feedforward

两层全连接,用卷积模拟加速运算,也可以使用dense层。你会发现这个框架所需组件全部凑齐了,可以召唤神龙了。
在这里插入图片描述

def feedforward(inputs, 
                num_units=[2048, 512],
                scope="multihead_attention", 
                reuse=None):
    '''Point-wise feed forward net.
Args:
  inputs: A 3d tensor with shape of [N, T, C].
  num_units: A list of two integers.
  scope: Optional scope for `variable_scope`.
  reuse: Boolean, whether to reuse the weights of a previous layer
    by the same name.
    
Returns:
  A 3d tensor with the same shape and dtype as inputs
'''</span>
<span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span>scope<span class="token punctuation">,</span> reuse<span class="token operator">=</span>reuse<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token comment"># Inner layer</span>
    params <span class="token operator">=</span> <span class="token punctuation">{</span><span class="token string">"inputs"</span><span class="token punctuation">:</span> inputs<span class="token punctuation">,</span> <span class="token string">"filters"</span><span class="token punctuation">:</span> num_units<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token string">"kernel_size"</span><span class="token punctuation">:</span> <span class="token number">1</span><span class="token punctuation">,</span>
              <span class="token string">"activation"</span><span class="token punctuation">:</span> tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>relu<span class="token punctuation">,</span> <span class="token string">"use_bias"</span><span class="token punctuation">:</span> <span class="token boolean">True</span><span class="token punctuation">}</span>
    outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>conv1d<span class="token punctuation">(</span><span class="token operator">**</span>params<span class="token punctuation">)</span>
    
    <span class="token comment"># Readout layer</span>
    params <span class="token operator">=</span> <span class="token punctuation">{</span><span class="token string">"inputs"</span><span class="token punctuation">:</span> outputs<span class="token punctuation">,</span> <span class="token string">"filters"</span><span class="token punctuation">:</span> num_units<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token string">"kernel_size"</span><span class="token punctuation">:</span> <span class="token number">1</span><span class="token punctuation">,</span>
              <span class="token string">"activation"</span><span class="token punctuation">:</span> <span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token string">"use_bias"</span><span class="token punctuation">:</span> <span class="token boolean">True</span><span class="token punctuation">}</span>
    outputs <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>conv1d<span class="token punctuation">(</span><span class="token operator">**</span>params<span class="token punctuation">)</span>
    
    <span class="token comment"># Residual connection</span>
    outputs <span class="token operator">+=</span> inputs
    
    <span class="token comment"># Normalize</span>
    outputs <span class="token operator">=</span> normalize<span class="token punctuation">(</span>outputs<span class="token punctuation">)</span>

<span class="token keyword">return</span> outputs
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 1
  • 2
  • 3
  • 4
  • 5
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
label_smoothing.

对于训练有好处,将0变为接近零的小数,1变为接近1的数,下面注释很清楚。

def label_smoothing(inputs, epsilon=0.1):
    '''Applies label smoothing. See https://arxiv.org/abs/1512.00567.
Args:
  inputs: A 3d tensor with shape of [N, T, V], where V is the number of vocabulary.
  epsilon: Smoothing rate.

For example,

```
import tensorflow as tf
inputs = tf.convert_to_tensor([[[0, 0, 1], 
   [0, 1, 0],
   [1, 0, 0]],
  [[1, 0, 0],
   [1, 0, 0],
   [0, 1, 0]]], tf.float32)
   
outputs = label_smoothing(inputs)

with tf.Session() as sess:
    print(sess.run([outputs]))

&gt;&gt;
[array([[[ 0.03333334,  0.03333334,  0.93333334],
    [ 0.03333334,  0.93333334,  0.03333334],
    [ 0.93333334,  0.03333334,  0.03333334]],
   [[ 0.93333334,  0.03333334,  0.03333334],
    [ 0.93333334,  0.03333334,  0.03333334],
    [ 0.03333334,  0.93333334,  0.03333334]]], dtype=float32)]   
```
'''</span>
K <span class="token operator">=</span> inputs<span class="token punctuation">.</span>get_shape<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>as_list<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token comment"># number of channels</span>
<span class="token keyword">return</span> <span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token operator">-</span>epsilon<span class="token punctuation">)</span> <span class="token operator">*</span> inputs<span class="token punctuation">)</span> <span class="token operator">+</span> <span class="token punctuation">(</span>epsilon <span class="token operator">/</span> K<span class="token punctuation">)</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 1
  • 2
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

2.2 搭建模型

再看一次模型,我们发现里面的组件我们都已经构建好了。
按照这个结构搭建模型就可以啦!
在这里插入图片描述
代码如下:

class Graph():
    def __init__(self, is_training=True):
        tf.reset_default_graph()
        self.is_training = arg.is_training
        self.hidden_units = arg.hidden_units
        self.input_vocab_size = arg.input_vocab_size
        self.label_vocab_size = arg.label_vocab_size
        self.num_heads = arg.num_heads
        self.num_blocks = arg.num_blocks
        self.max_length = arg.max_length
        self.lr = arg.lr
        self.dropout_rate = arg.dropout_rate
    <span class="token comment"># input placeholder</span>
    self<span class="token punctuation">.</span>x <span class="token operator">=</span> tf<span class="token punctuation">.</span>placeholder<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>int32<span class="token punctuation">,</span> shape<span class="token operator">=</span><span class="token punctuation">(</span><span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>y <span class="token operator">=</span> tf<span class="token punctuation">.</span>placeholder<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>int32<span class="token punctuation">,</span> shape<span class="token operator">=</span><span class="token punctuation">(</span><span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>de_inp <span class="token operator">=</span> tf<span class="token punctuation">.</span>placeholder<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>int32<span class="token punctuation">,</span> shape<span class="token operator">=</span><span class="token punctuation">(</span><span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    
    <span class="token comment"># Encoder</span>
    <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span><span class="token string">"encoder"</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
        <span class="token comment"># embedding</span>
        self<span class="token punctuation">.</span>en_emb <span class="token operator">=</span> embedding<span class="token punctuation">(</span>self<span class="token punctuation">.</span>x<span class="token punctuation">,</span> vocab_size<span class="token operator">=</span>self<span class="token punctuation">.</span>input_vocab_size<span class="token punctuation">,</span> num_units<span class="token operator">=</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> scale<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">"enc_embed"</span><span class="token punctuation">)</span>
        self<span class="token punctuation">.</span>enc <span class="token operator">=</span> self<span class="token punctuation">.</span>en_emb <span class="token operator">+</span> embedding<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>tile<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>expand_dims<span class="token punctuation">(</span>tf<span class="token punctuation">.</span><span class="token builtin">range</span><span class="token punctuation">(</span>tf<span class="token punctuation">.</span>shape<span class="token punctuation">(</span>self<span class="token punctuation">.</span>x<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">[</span>tf<span class="token punctuation">.</span>shape<span class="token punctuation">(</span>self<span class="token punctuation">.</span>x<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                                      vocab_size<span class="token operator">=</span>self<span class="token punctuation">.</span>max_length<span class="token punctuation">,</span>num_units<span class="token operator">=</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> zero_pad<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span> scale<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span>scope<span class="token operator">=</span><span class="token string">"enc_pe"</span><span class="token punctuation">)</span>
        <span class="token comment">## Dropout</span>
        self<span class="token punctuation">.</span>enc <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>dropout<span class="token punctuation">(</span>self<span class="token punctuation">.</span>enc<span class="token punctuation">,</span> 
                                    rate<span class="token operator">=</span>self<span class="token punctuation">.</span>dropout_rate<span class="token punctuation">,</span> 
                                    training<span class="token operator">=</span>tf<span class="token punctuation">.</span>convert_to_tensor<span class="token punctuation">(</span>self<span class="token punctuation">.</span>is_training<span class="token punctuation">)</span><span class="token punctuation">)</span>

        <span class="token comment">## Blocks</span>
        <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>num_blocks<span class="token punctuation">)</span><span class="token punctuation">:</span>
            <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span><span class="token string">"num_blocks_{}"</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>i<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
                <span class="token comment">### Multihead Attention</span>
                self<span class="token punctuation">.</span>enc <span class="token operator">=</span> multihead_attention<span class="token punctuation">(</span>key_emb <span class="token operator">=</span> self<span class="token punctuation">.</span>en_emb<span class="token punctuation">,</span>
                                               que_emb <span class="token operator">=</span> self<span class="token punctuation">.</span>en_emb<span class="token punctuation">,</span>
                                               queries<span class="token operator">=</span>self<span class="token punctuation">.</span>enc<span class="token punctuation">,</span> 
                                                keys<span class="token operator">=</span>self<span class="token punctuation">.</span>enc<span class="token punctuation">,</span> 
                                                num_units<span class="token operator">=</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> 
                                                num_heads<span class="token operator">=</span>self<span class="token punctuation">.</span>num_heads<span class="token punctuation">,</span> 
                                                dropout_rate<span class="token operator">=</span>self<span class="token punctuation">.</span>dropout_rate<span class="token punctuation">,</span>
                                                is_training<span class="token operator">=</span>self<span class="token punctuation">.</span>is_training<span class="token punctuation">,</span>
                                                causality<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>

        <span class="token comment">### Feed Forward</span>
        self<span class="token punctuation">.</span>enc <span class="token operator">=</span> feedforward<span class="token punctuation">(</span>self<span class="token punctuation">.</span>enc<span class="token punctuation">,</span> num_units<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">4</span><span class="token operator">*</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">]</span><span class="token punctuation">)</span>
    
    <span class="token comment"># Decoder</span>
    <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span><span class="token string">"decoder"</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
        <span class="token comment"># embedding</span>
        self<span class="token punctuation">.</span>de_emb <span class="token operator">=</span> embedding<span class="token punctuation">(</span>self<span class="token punctuation">.</span>de_inp<span class="token punctuation">,</span> vocab_size<span class="token operator">=</span>self<span class="token punctuation">.</span>label_vocab_size<span class="token punctuation">,</span> num_units<span class="token operator">=</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> scale<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> scope<span class="token operator">=</span><span class="token string">"dec_embed"</span><span class="token punctuation">)</span>
        self<span class="token punctuation">.</span>dec <span class="token operator">=</span> self<span class="token punctuation">.</span>de_emb <span class="token operator">+</span> embedding<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>tile<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>expand_dims<span class="token punctuation">(</span>tf<span class="token punctuation">.</span><span class="token builtin">range</span><span class="token punctuation">(</span>tf<span class="token punctuation">.</span>shape<span class="token punctuation">(</span>self<span class="token punctuation">.</span>de_inp<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">[</span>tf<span class="token punctuation">.</span>shape<span class="token punctuation">(</span>self<span class="token punctuation">.</span>de_inp<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                                      vocab_size<span class="token operator">=</span>self<span class="token punctuation">.</span>max_length<span class="token punctuation">,</span>num_units<span class="token operator">=</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> zero_pad<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span> scale<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span>scope<span class="token operator">=</span><span class="token string">"dec_pe"</span><span class="token punctuation">)</span>
        <span class="token comment">## Dropout</span>
        self<span class="token punctuation">.</span>dec <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>dropout<span class="token punctuation">(</span>self<span class="token punctuation">.</span>dec<span class="token punctuation">,</span> 
                                    rate<span class="token operator">=</span>self<span class="token punctuation">.</span>dropout_rate<span class="token punctuation">,</span> 
                                    training<span class="token operator">=</span>tf<span class="token punctuation">.</span>convert_to_tensor<span class="token punctuation">(</span>self<span class="token punctuation">.</span>is_training<span class="token punctuation">)</span><span class="token punctuation">)</span>        

        <span class="token comment">## Multihead Attention ( self-attention)</span>
        <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>num_blocks<span class="token punctuation">)</span><span class="token punctuation">:</span>
            <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span><span class="token string">"num_blocks_{}"</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>i<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
                <span class="token comment">### Multihead Attention</span>
                self<span class="token punctuation">.</span>dec <span class="token operator">=</span> multihead_attention<span class="token punctuation">(</span>key_emb <span class="token operator">=</span> self<span class="token punctuation">.</span>de_emb<span class="token punctuation">,</span>
                                               que_emb <span class="token operator">=</span> self<span class="token punctuation">.</span>de_emb<span class="token punctuation">,</span>
                                               queries<span class="token operator">=</span>self<span class="token punctuation">.</span>dec<span class="token punctuation">,</span> 
                                                keys<span class="token operator">=</span>self<span class="token punctuation">.</span>dec<span class="token punctuation">,</span> 
                                                num_units<span class="token operator">=</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> 
                                                num_heads<span class="token operator">=</span>self<span class="token punctuation">.</span>num_heads<span class="token punctuation">,</span> 
                                                dropout_rate<span class="token operator">=</span>self<span class="token punctuation">.</span>dropout_rate<span class="token punctuation">,</span>
                                                is_training<span class="token operator">=</span>self<span class="token punctuation">.</span>is_training<span class="token punctuation">,</span>
                                                causality<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span>
                                                scope<span class="token operator">=</span><span class="token string">'self_attention'</span><span class="token punctuation">)</span>

        <span class="token comment">## Multihead Attention ( vanilla attention)</span>
        <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>num_blocks<span class="token punctuation">)</span><span class="token punctuation">:</span>
            <span class="token keyword">with</span> tf<span class="token punctuation">.</span>variable_scope<span class="token punctuation">(</span><span class="token string">"num_blocks_{}"</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>i<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
                <span class="token comment">### Multihead Attention</span>
                self<span class="token punctuation">.</span>dec <span class="token operator">=</span> multihead_attention<span class="token punctuation">(</span>key_emb <span class="token operator">=</span> self<span class="token punctuation">.</span>en_emb<span class="token punctuation">,</span>
                                               que_emb <span class="token operator">=</span> self<span class="token punctuation">.</span>de_emb<span class="token punctuation">,</span>
                                               queries<span class="token operator">=</span>self<span class="token punctuation">.</span>dec<span class="token punctuation">,</span> 
                                                keys<span class="token operator">=</span>self<span class="token punctuation">.</span>enc<span class="token punctuation">,</span> 
                                                num_units<span class="token operator">=</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> 
                                                num_heads<span class="token operator">=</span>self<span class="token punctuation">.</span>num_heads<span class="token punctuation">,</span> 
                                                dropout_rate<span class="token operator">=</span>self<span class="token punctuation">.</span>dropout_rate<span class="token punctuation">,</span>
                                                is_training<span class="token operator">=</span>self<span class="token punctuation">.</span>is_training<span class="token punctuation">,</span>
                                                causality<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span>
                                                scope<span class="token operator">=</span><span class="token string">'vanilla_attention'</span><span class="token punctuation">)</span> 

        <span class="token comment">### Feed Forward</span>
        self<span class="token punctuation">.</span>outputs <span class="token operator">=</span> feedforward<span class="token punctuation">(</span>self<span class="token punctuation">.</span>dec<span class="token punctuation">,</span> num_units<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">4</span><span class="token operator">*</span>self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">,</span> self<span class="token punctuation">.</span>hidden_units<span class="token punctuation">]</span><span class="token punctuation">)</span>
            
    <span class="token comment"># Final linear projection</span>
    self<span class="token punctuation">.</span>logits <span class="token operator">=</span> tf<span class="token punctuation">.</span>layers<span class="token punctuation">.</span>dense<span class="token punctuation">(</span>self<span class="token punctuation">.</span>outputs<span class="token punctuation">,</span> self<span class="token punctuation">.</span>label_vocab_size<span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>preds <span class="token operator">=</span> tf<span class="token punctuation">.</span>to_int32<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>self<span class="token punctuation">.</span>logits<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>istarget <span class="token operator">=</span> tf<span class="token punctuation">.</span>to_float<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>not_equal<span class="token punctuation">(</span>self<span class="token punctuation">.</span>y<span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>acc <span class="token operator">=</span> tf<span class="token punctuation">.</span>reduce_sum<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>to_float<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>equal<span class="token punctuation">(</span>self<span class="token punctuation">.</span>preds<span class="token punctuation">,</span> self<span class="token punctuation">.</span>y<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token operator">*</span>self<span class="token punctuation">.</span>istarget<span class="token punctuation">)</span><span class="token operator">/</span> <span class="token punctuation">(</span>tf<span class="token punctuation">.</span>reduce_sum<span class="token punctuation">(</span>self<span class="token punctuation">.</span>istarget<span class="token punctuation">)</span><span class="token punctuation">)</span>
    tf<span class="token punctuation">.</span>summary<span class="token punctuation">.</span>scalar<span class="token punctuation">(</span><span class="token string">'acc'</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>acc<span class="token punctuation">)</span>
            
    <span class="token keyword">if</span> is_training<span class="token punctuation">:</span>  
        <span class="token comment"># Loss</span>
        self<span class="token punctuation">.</span>y_smoothed <span class="token operator">=</span> label_smoothing<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>one_hot<span class="token punctuation">(</span>self<span class="token punctuation">.</span>y<span class="token punctuation">,</span> depth<span class="token operator">=</span>self<span class="token punctuation">.</span>label_vocab_size<span class="token punctuation">)</span><span class="token punctuation">)</span>
        self<span class="token punctuation">.</span>loss <span class="token operator">=</span> tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>softmax_cross_entropy_with_logits_v2<span class="token punctuation">(</span>logits<span class="token operator">=</span>self<span class="token punctuation">.</span>logits<span class="token punctuation">,</span> labels<span class="token operator">=</span>self<span class="token punctuation">.</span>y_smoothed<span class="token punctuation">)</span>
        self<span class="token punctuation">.</span>mean_loss <span class="token operator">=</span> tf<span class="token punctuation">.</span>reduce_sum<span class="token punctuation">(</span>self<span class="token punctuation">.</span>loss<span class="token operator">*</span>self<span class="token punctuation">.</span>istarget<span class="token punctuation">)</span> <span class="token operator">/</span> <span class="token punctuation">(</span>tf<span class="token punctuation">.</span>reduce_sum<span class="token punctuation">(</span>self<span class="token punctuation">.</span>istarget<span class="token punctuation">)</span><span class="token punctuation">)</span>
           
        <span class="token comment"># Training Scheme</span>
        self<span class="token punctuation">.</span>global_step <span class="token operator">=</span> tf<span class="token punctuation">.</span>Variable<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> name<span class="token operator">=</span><span class="token string">'global_step'</span><span class="token punctuation">,</span> trainable<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
        self<span class="token punctuation">.</span>optimizer <span class="token operator">=</span> tf<span class="token punctuation">.</span>train<span class="token punctuation">.</span>AdamOptimizer<span class="token punctuation">(</span>learning_rate<span class="token operator">=</span>self<span class="token punctuation">.</span>lr<span class="token punctuation">,</span> beta1<span class="token operator">=</span><span class="token number">0.9</span><span class="token punctuation">,</span> beta2<span class="token operator">=</span><span class="token number">0.98</span><span class="token punctuation">,</span> epsilon<span class="token operator">=</span><span class="token number">1e</span><span class="token operator">-</span><span class="token number">8</span><span class="token punctuation">)</span>
        self<span class="token punctuation">.</span>train_op <span class="token operator">=</span> self<span class="token punctuation">.</span>optimizer<span class="token punctuation">.</span>minimize<span class="token punctuation">(</span>self<span class="token punctuation">.</span>mean_loss<span class="token punctuation">,</span> global_step<span class="token operator">=</span>self<span class="token punctuation">.</span>global_step<span class="token punctuation">)</span>
               
        <span class="token comment"># Summary </span>
        tf<span class="token punctuation">.</span>summary<span class="token punctuation">.</span>scalar<span class="token punctuation">(</span><span class="token string">'mean_loss'</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>mean_loss<span class="token punctuation">)</span>
        self<span class="token punctuation">.</span>merged <span class="token operator">=</span> tf<span class="token punctuation">.</span>summary<span class="token punctuation">.</span>merge_all<span class="token punctuation">(</span><span class="token punctuation">)</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111

3. 训练模型

用我们搭建好的模型,和准备好的数据进行训练!

3.1 参数设定

def create_hparams():
    params = tf.contrib.training.HParams(
        num_heads = 8,
        num_blocks = 6,
        # vocab
        input_vocab_size = 50,
        label_vocab_size = 50,
        # embedding size
        max_length = 100,
        hidden_units = 512,
        dropout_rate = 0.2,
        lr = 0.0003,
        is_training = True)
    return params
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

arg = create_hparams()
arg.input_vocab_size = len(encoder_vocab)
arg.label_vocab_size = len(decoder_vocab)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

3.2 模型训练

import os
  • 1

epochs = 25
batch_size = 64

g = Graph(arg)

saver =tf.train.Saver()
with tf.Session() as sess:
merged = tf.summary.merge_all()
sess.run(tf.global_variables_initializer())
if os.path.exists(‘logs/model.meta’):
saver.restore(sess, ‘logs/model’)
writer = tf.summary.FileWriter(‘tensorboard/lm’, tf.get_default_graph())
for k in range(epochs):
total_loss = 0
batch_num = len(encoder_inputs) // batch_size
batch = get_batch(encoder_inputs, decoder_inputs, decoder_targets, batch_size)
for i in tqdm(range(batch_num)):
encoder_input, decoder_input, decoder_target = next(batch)
feed = {g.x: encoder_input, g.y: decoder_target, g.de_inp:decoder_input}
cost,_ = sess.run([g.mean_loss,g.train_op], feed_dict=feed)
total_loss += cost
if (k batch_num + i) % 10 == 0:
rs=sess.run(merged, feed_dict=feed)
writer.add_summary(rs, k batch_num + i)
if (k+1) % 5 == 0:
print(‘epochs’, k+1, ': average loss = ', total_loss/batch_num)
saver.save(sess, ‘logs/model’)
writer.close()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
100%|██████████| 156/156 [00:31<00:00,  6.19it/s]
100%|██████████| 156/156 [00:24<00:00,  5.83it/s]
100%|██████████| 156/156 [00:24<00:00,  6.23it/s]
100%|██████████| 156/156 [00:24<00:00,  6.11it/s]
100%|██████████| 156/156 [00:24<00:00,  6.14it/s]
  • 1
  • 2
  • 3
  • 4
  • 5

epochs 5 : average loss = 3.3463863134384155

100%|██████████| 156/156 [00:23<00:00, 6.27it/s]
100%|██████████| 156/156 [00:23<00:00, 5.86it/s]
100%|██████████| 156/156 [00:23<00:00, 6.33it/s]
100%|██████████| 156/156 [00:24<00:00, 6.08it/s]
100%|██████████| 156/156 [00:23<00:00, 6.29it/s]

epochs 10 : average loss = 2.0142565186207113

100%|██████████| 156/156 [00:24<00:00, 6.18it/s]
100%|██████████| 156/156 [00:24<00:00, 5.84it/s]
100%|██████████| 156/156 [00:24<00:00, 6.10it/s]
100%|██████████| 156/156 [00:24<00:00, 6.10it/s]
100%|██████████| 156/156 [00:23<00:00, 6.38it/s]

epochs 15 : average loss = 1.5278632457439716

100%|██████████| 156/156 [00:24<00:00, 6.15it/s]
100%|██████████| 156/156 [00:24<00:00, 5.86it/s]
100%|██████████| 156/156 [00:24<00:00, 6.23it/s]
100%|██████████| 156/156 [00:23<00:00, 6.13it/s]
100%|██████████| 156/156 [00:23<00:00, 6.32it/s]

epochs 20 : average loss = 1.4216684783116365

100%|██████████| 156/156 [00:23<00:00, 6.26it/s]
100%|██████████| 156/156 [00:23<00:00, 5.89it/s]
100%|██████████| 156/156 [00:24<00:00, 6.26it/s]
100%|██████████| 156/156 [00:24<00:00, 6.10it/s]
100%|██████████| 156/156 [00:23<00:00, 6.35it/s]

epochs 25 : average loss = 1.3833287457625072

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44

3.3 模型推断

输入几条拼音测试一下效果如何:

arg.is_training = False
  • 1

g = Graph(arg)

saver =tf.train.Saver()

with tf.Session() as sess:
saver.restore(sess, ‘logs/model’)
while True:
line = input(‘输入测试拼音: ‘)
if line ‘exit’: break
line = line.lower().replace(’,’, ’ ,’).strip(’\n’).split(’ ‘)
x = np.array([encoder_vocab.index(pny) for pny in line])
x = x.reshape(1, -1)
de_inp = [[decoder_vocab.index(’<GO>’)]]
while True:
y = np.array(de_inp)
preds = sess.run(g.preds, {g.x: x, g.de_inp: y})
if preds[0][-1] decoder_vocab.index(’<EOS>’):
break
de_inp[0].append(preds[0][-1])
got = ‘’.join(decoder_vocab[idx] for idx in de_inp[0][1:])
print(got)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
INFO:tensorflow:Restoring parameters from logs/model
输入测试拼音: You could be right, I suppose
我猜想你可能是对的
输入测试拼音: You don't believe Tom, do you
你不信任汤姆,对吗
输入测试拼音: Tom has lived here since 2003
汤姆自从2003年就住在这里
输入测试拼音: Tom asked if I'd found my key
湯姆問我找到我的鑰匙了吗
输入测试拼音: They have a very nice veranda
他们有一个非常漂亮的暖房
输入测试拼音: She was married to a rich man
她嫁給了一個有錢的男人
输入测试拼音: My parents sent me a postcard
我父母給我寄了一張明信片
输入测试拼音: Just put yourself in my shoes
你站在我的立場上考慮看看
输入测试拼音: It was a very stupid decision
这是一个十分愚蠢的决定
输入测试拼音: I'm really sorry to hear that
听到这样的消息我真的很难过
输入测试拼音: His wife is one of my friends
他的妻子是我的一個朋友
输入测试拼音: He thought of a good solution
他想到了一個解決的好辦法
输入测试拼音: exit

 
 
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

结果果然不错,训练速度也是比基于rnn的encoder decoder结构快很多,不得不说谷歌真棒啊。

转载请注明出处:https://blog.csdn.net/chinatelecom08

同学们喜欢的话给我项目点个星吧!
https://github.com/audier

                                </div><div data-report-view="{&quot;mod&quot;:&quot;1585297308_001&quot;,&quot;dest&quot;:&quot;https://blog.csdn.net/chinatelecom08/article/details/85068059&quot;,&quot;extend1&quot;:&quot;pc&quot;,&quot;ab&quot;:&quot;new&quot;}"><div></div></div>
            <link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-60ecaf1f42.css" rel="stylesheet">
                            </div>
        </article>
  • 1
  • 2
  • 3
  • 4
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/秋刀鱼在做梦/article/detail/845292
推荐阅读
相关标签
  

闽ICP备14008679号