当前位置:   article > 正文

Python之TensorFlow实现seq2seq自动文摘

Python之TensorFlow实现seq2seq自动文摘

简介
这篇文章中我们将基于Tensorflow的Seq2Seq+Attention模型,介绍如何训练一个中文的自动生成新闻标题的模型。自动总结(Automatic Summarization)类型的模型一直是研究热点。 直接抽出重要的句子的抽取式方法较为简单,有如textrank之类的算法,而生成式(重新生成新句子)较为复杂,效果也不尽如人意。目前比较流行的Seq2Seq模型,由 Sutskever等人提出,基于一个Encoder-Decoder的结构将source句子先Encode成一个固定维度d的向量,然后通过Decoder部分一个字符一个字符生成Target句子。添加入了Attention注意力分配机制后,使得Decoder在生成新的Target Sequence时,能得到之前Encoder编码阶段每个字符的隐藏层的信息向量Hidden State,使得生成新序列的准确度提高。

数据准备和预处理
我们选择公开的“搜狐新闻数据(SogouCS)”的语料,包含2012年6月—7月期间的新闻数据,超过1M的语料数据,包含新闻标题和正文的信息。数据集可以从搜狗lab下载。 http://www.sogou.com/labs/resource/cs.php

数据的预处理阶段极为重要,因为在Encoder编码阶段处理那些信息,直接影响到整个模型的效果。

我们主要对下列信息进行替换和处理:
特殊字符:去除特殊字符,如:“「,」,¥,…”;
括号内的内容:如表情符,【嘻嘻】,【哈哈】
日期:替换日期标签为TAG_DATE,如:***年*月*日,****年*月,等等
超链接URL:替换为标签TAG_URL;
删除全角的英文:替换为标签TAG_NAME_EN;
替换数字:TAG_NUMBER;
在对文本进行了预处理后,准备训练语料: 我们的Source序列,是新闻的正文,待预测的Target序列是新闻的标题。 我们截取正文的分词个数到MAX_LENGTH_ENC=120个词,是为了训练的效果正文部分不宜过长。标题部分截取到MIN_LENGTH_ENC = 30,即生成标题不超过30个词。

在data_util.py类中,生成训练数据时做了下列事情:

create_vocabulary()方法创建词典;
data_to_token_ids()方法把训练数据(content-train.txt)转化为对应的词ID的表示;

  1. # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Library for creating sequence-to-sequence models in TensorFlow.
  16. Sequence-to-sequence recurrent neural networks can learn complex functions
  17. that map input sequences to output sequences. These models yield very good
  18. results on a number of tasks, such as speech recognition, parsing, machine
  19. translation, or even constructing automated replies to emails.
  20. Before using this module, it is recommended to read the TensorFlow tutorial
  21. on sequence-to-sequence models. It explains the basic concepts of this module
  22. and shows an end-to-end example of how to build a translation model.
  23. https://www.tensorflow.org/versions/master/tutorials/seq2seq/index.html
  24. Here is an overview of functions available in this module. They all use
  25. a very similar interface, so after reading the above tutorial and using
  26. one of them, others should be easy to substitute.
  27. * Full sequence-to-sequence models.
  28. - basic_rnn_seq2seq: The most basic RNN-RNN model.
  29. - tied_rnn_seq2seq: The basic model with tied encoder and decoder weights.
  30. - embedding_rnn_seq2seq: The basic model with input embedding.
  31. - embedding_tied_rnn_seq2seq: The tied model with input embedding.
  32. - embedding_attention_seq2seq: Advanced model with input embedding and
  33. the neural attention mechanism; recommended for complex tasks.
  34. * Multi-task sequence-to-sequence models.
  35. - one2many_rnn_seq2seq: The embedding model with multiple decoders.
  36. * Decoders (when you write your own encoder, you can use these to decode;
  37. e.g., if you want to write a model that generates captions for images).
  38. - rnn_decoder: The basic decoder based on a pure RNN.
  39. - attention_decoder: A decoder that uses the attention mechanism.
  40. * Losses.
  41. - sequence_loss: Loss for a sequence model returning average log-perplexity.
  42. - sequence_loss_by_example: As above, but not averaging over all examples.
  43. * model_with_buckets: A convenience function to create models with bucketing
  44. (see the tutorial above for an explanation of why and how to use it).
  45. """
  46. from __future__ import absolute_import
  47. from __future__ import division
  48. from __future__ import print_function
  49. # We disable pylint because we need python3 compatibility.
  50. from six.moves import xrange # pylint: disable=redefined-builtin
  51. from six.moves import zip # pylint: disable=redefined-builtin
  52. from tensorflow.contrib.rnn.python.ops import core_rnn
  53. from tensorflow.contrib.rnn.python.ops import core_rnn_cell
  54. from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
  55. from tensorflow.python.framework import dtypes
  56. from tensorflow.python.framework import ops
  57. from tensorflow.python.ops import array_ops
  58. from tensorflow.python.ops import control_flow_ops
  59. from tensorflow.python.ops import embedding_ops
  60. from tensorflow.python.ops import math_ops
  61. from tensorflow.python.ops import nn_ops
  62. from tensorflow.python.ops import variable_scope
  63. from tensorflow.python.util import nest
  64. # TODO(ebrevdo): Remove once _linear is fully deprecated.
  65. linear = core_rnn_cell_impl._linear # pylint: disable=protected-access
  66. def _extract_argmax_and_embed(embedding,
  67. output_projection=None,
  68. update_embedding=True):
  69. """Get a loop_function that extracts the previous symbol and embeds it.
  70. Args:
  71. embedding: embedding tensor for symbols.
  72. output_projection: None or a pair (W, B). If provided, each fed previous
  73. output will first be multiplied by W and added B.
  74. update_embedding: Boolean; if False, the gradients will not propagate
  75. through the embeddings.
  76. Returns:
  77. A loop function.
  78. """
  79. def loop_function(prev, _):
  80. if output_projection is not None:
  81. prev = nn_ops.xw_plus_b(prev, output_projection[0], output_projection[1])
  82. prev_symbol = math_ops.argmax(prev, 1)
  83. # Note that gradients will not propagate through the second parameter of
  84. # embedding_lookup.
  85. emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol)
  86. if not update_embedding:
  87. emb_prev = array_ops.stop_gradient(emb_prev)
  88. return emb_prev
  89. return loop_function
  90. def rnn_decoder(decoder_inputs,
  91. initial_state,
  92. cell,
  93. loop_function=None,
  94. scope=None):
  95. """RNN decoder for the sequence-to-sequence model.
  96. Args:
  97. decoder_inputs: A list of 2D Tensors [batch_size x input_size].
  98. initial_state: 2D Tensor with shape [batch_size x cell.state_size].
  99. cell: core_rnn_cell.RNNCell defining the cell function and size.
  100. loop_function: If not None, this function will be applied to the i-th output
  101. in order to generate the i+1-st input, and decoder_inputs will be ignored,
  102. except for the first element ("GO" symbol). This can be used for decoding,
  103. but also for training to emulate http://arxiv.org/abs/1506.03099.
  104. Signature -- loop_function(prev, i) = next
  105. * prev is a 2D Tensor of shape [batch_size x output_size],
  106. * i is an integer, the step number (when advanced control is needed),
  107. * next is a 2D Tensor of shape [batch_size x input_size].
  108. scope: VariableScope for the created subgraph; defaults to "rnn_decoder".
  109. Returns:
  110. A tuple of the form (outputs, state), where:
  111. outputs: A list of the same length as decoder_inputs of 2D Tensors with
  112. shape [batch_size x output_size] containing generated outputs.
  113. state: The state of each cell at the final time-step.
  114. It is a 2D Tensor of shape [batch_size x cell.state_size].
  115. (Note that in some cases, like basic RNN cell or GRU cell, outputs and
  116. states can be the same. They are different for LSTM cells though.)
  117. """
  118. with variable_scope.variable_scope(scope or "rnn_decoder"):
  119. state = initial_state
  120. outputs = []
  121. prev = None
  122. for i, inp in enumerate(decoder_inputs):
  123. if loop_function is not None and prev is not None:
  124. with variable_scope.variable_scope("loop_function", reuse=True):
  125. inp = loop_function(prev, i)
  126. if i > 0:
  127. variable_scope.get_variable_scope().reuse_variables()
  128. output, state = cell(inp, state)
  129. outputs.append(output)
  130. if loop_function is not None:
  131. prev = output
  132. return outputs, state
  133. def basic_rnn_seq2seq(encoder_inputs,
  134. decoder_inputs,
  135. cell,
  136. dtype=dtypes.float32,
  137. scope=None):
  138. """Basic RNN sequence-to-sequence model.
  139. This model first runs an RNN to encode encoder_inputs into a state vector,
  140. then runs decoder, initialized with the last encoder state, on decoder_inputs.
  141. Encoder and decoder use the same RNN cell type, but don't share parameters.
  142. Args:
  143. encoder_inputs: A list of 2D Tensors [batch_size x input_size].
  144. decoder_inputs: A list of 2D Tensors [batch_size x input_size].
  145. cell: core_rnn_cell.RNNCell defining the cell function and size.
  146. dtype: The dtype of the initial state of the RNN cell (default: tf.float32).
  147. scope: VariableScope for the created subgraph; default: "basic_rnn_seq2seq".
  148. Returns:
  149. A tuple of the form (outputs, state), where:
  150. outputs: A list of the same length as decoder_inputs of 2D Tensors with
  151. shape [batch_size x output_size] containing the generated outputs.
  152. state: The state of each decoder cell in the final time-step.
  153. It is a 2D Tensor of shape [batch_size x cell.state_size].
  154. """
  155. with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"):
  156. _, enc_state = core_rnn.static_rnn(cell, encoder_inputs, dtype=dtype)
  157. return rnn_decoder(decoder_inputs, enc_state, cell)
  158. def tied_rnn_seq2seq(encoder_inputs,
  159. decoder_inputs,
  160. cell,
  161. loop_function=None,
  162. dtype=dtypes.float32,
  163. scope=None):
  164. """RNN sequence-to-sequence model with tied encoder and decoder parameters.
  165. This model first runs an RNN to encode encoder_inputs into a state vector, and
  166. then runs decoder, initialized with the last encoder state, on decoder_inputs.
  167. Encoder and decoder use the same RNN cell and share parameters.
  168. Args:
  169. encoder_inputs: A list of 2D Tensors [batch_size x input_size].
  170. decoder_inputs: A list of 2D Tensors [batch_size x input_size].
  171. cell: core_rnn_cell.RNNCell defining the cell function and size.
  172. loop_function: If not None, this function will be applied to i-th output
  173. in order to generate i+1-th input, and decoder_inputs will be ignored,
  174. except for the first element ("GO" symbol), see rnn_decoder for details.
  175. dtype: The dtype of the initial state of the rnn cell (default: tf.float32).
  176. scope: VariableScope for the created subgraph; default: "tied_rnn_seq2seq".
  177. Returns:
  178. A tuple of the form (outputs, state), where:
  179. outputs: A list of the same length as decoder_inputs of 2D Tensors with
  180. shape [batch_size x output_size] containing the generated outputs.
  181. state: The state of each decoder cell in each time-step. This is a list
  182. with length len(decoder_inputs) -- one item for each time-step.
  183. It is a 2D Tensor of shape [batch_size x cell.state_size].
  184. """
  185. with variable_scope.variable_scope("combined_tied_rnn_seq2seq"):
  186. scope = scope or "tied_rnn_seq2seq"
  187. _, enc_state = core_rnn.static_rnn(
  188. cell, encoder_inputs, dtype=dtype, scope=scope)
  189. variable_scope.get_variable_scope().reuse_variables()
  190. return rnn_decoder(
  191. decoder_inputs,
  192. enc_state,
  193. cell,
  194. loop_function=loop_function,
  195. scope=scope)
  196. def embedding_rnn_decoder(decoder_inputs,
  197. initial_state,
  198. cell,
  199. num_symbols,
  200. embedding_size,
  201. output_projection=None,
  202. feed_previous=False,
  203. update_embedding_for_previous=True,
  204. scope=None):
  205. """RNN decoder with embedding and a pure-decoding option.
  206. Args:
  207. decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs).
  208. initial_state: 2D Tensor [batch_size x cell.state_size].
  209. cell: core_rnn_cell.RNNCell defining the cell function.
  210. num_symbols: Integer, how many symbols come into the embedding.
  211. embedding_size: Integer, the length of the embedding vector for each symbol.
  212. output_projection: None or a pair (W, B) of output projection weights and
  213. biases; W has shape [output_size x num_symbols] and B has
  214. shape [num_symbols]; if provided and feed_previous=True, each fed
  215. previous output will first be multiplied by W and added B.
  216. feed_previous: Boolean; if True, only the first of decoder_inputs will be
  217. used (the "GO" symbol), and all other decoder inputs will be generated by:
  218. next = embedding_lookup(embedding, argmax(previous_output)),
  219. In effect, this implements a greedy decoder. It can also be used
  220. during training to emulate http://arxiv.org/abs/1506.03099.
  221. If False, decoder_inputs are used as given (the standard decoder case).
  222. update_embedding_for_previous: Boolean; if False and feed_previous=True,
  223. only the embedding for the first symbol of decoder_inputs (the "GO"
  224. symbol) will be updated by back propagation. Embeddings for the symbols
  225. generated from the decoder itself remain unchanged. This parameter has
  226. no effect if feed_previous=False.
  227. scope: VariableScope for the created subgraph; defaults to
  228. "embedding_rnn_decoder".
  229. Returns:
  230. A tuple of the form (outputs, state), where:
  231. outputs: A list of the same length as decoder_inputs of 2D Tensors. The
  232. output is of shape [batch_size x cell.output_size] when
  233. output_projection is not None (and represents the dense representation
  234. of predicted tokens). It is of shape [batch_size x num_decoder_symbols]
  235. when output_projection is None.
  236. state: The state of each decoder cell in each time-step. This is a list
  237. with length len(decoder_inputs) -- one item for each time-step.
  238. It is a 2D Tensor of shape [batch_size x cell.state_size].
  239. Raises:
  240. ValueError: When output_projection has the wrong shape.
  241. """
  242. with variable_scope.variable_scope(scope or "embedding_rnn_decoder") as scope:
  243. if output_projection is not None:
  244. dtype = scope.dtype
  245. proj_weights = ops.convert_to_tensor(output_projection[0], dtype=dtype)
  246. proj_weights.get_shape().assert_is_compatible_with([None, num_symbols])
  247. proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype)
  248. proj_biases.get_shape().assert_is_compatible_with([num_symbols])
  249. embedding = variable_scope.get_variable("embedding",
  250. [num_symbols, embedding_size])
  251. loop_function = _extract_argmax_and_embed(
  252. embedding, output_projection,
  253. update_embedding_for_previous) if feed_previous else None
  254. emb_inp = (embedding_ops.embedding_lookup(embedding, i)
  255. for i in decoder_inputs)
  256. return rnn_decoder(
  257. emb_inp, initial_state, cell, loop_function=loop_function)
  258. def embedding_rnn_seq2seq(encoder_inputs,
  259. decoder_inputs,
  260. cell,
  261. num_encoder_symbols,
  262. num_decoder_symbols,
  263. embedding_size,
  264. output_projection=None,
  265. feed_previous=False,
  266. dtype=None,
  267. scope=None):
  268. """Embedding RNN sequence-to-sequence model.
  269. This model first embeds encoder_inputs by a newly created embedding (of shape
  270. [num_encoder_symbols x input_size]). Then it runs an RNN to encode
  271. embedded encoder_inputs into a state vector. Next, it embeds decoder_inputs
  272. by another newly created embedding (of shape [num_decoder_symbols x
  273. input_size]). Then it runs RNN decoder, initialized with the last
  274. encoder state, on embedded decoder_inputs.
  275. Args:
  276. encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
  277. decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
  278. cell: core_rnn_cell.RNNCell defining the cell function and size.
  279. num_encoder_symbols: Integer; number of symbols on the encoder side.
  280. num_decoder_symbols: Integer; number of symbols on the decoder side.
  281. embedding_size: Integer, the length of the embedding vector for each symbol.
  282. output_projection: None or a pair (W, B) of output projection weights and
  283. biases; W has shape [output_size x num_decoder_symbols] and B has
  284. shape [num_decoder_symbols]; if provided and feed_previous=True, each
  285. fed previous output will first be multiplied by W and added B.
  286. feed_previous: Boolean or scalar Boolean Tensor; if True, only the first
  287. of decoder_inputs will be used (the "GO" symbol), and all other decoder
  288. inputs will be taken from previous outputs (as in embedding_rnn_decoder).
  289. If False, decoder_inputs are used as given (the standard decoder case).
  290. dtype: The dtype of the initial state for both the encoder and encoder
  291. rnn cells (default: tf.float32).
  292. scope: VariableScope for the created subgraph; defaults to
  293. "embedding_rnn_seq2seq"
  294. Returns:
  295. A tuple of the form (outputs, state), where:
  296. outputs: A list of the same length as decoder_inputs of 2D Tensors. The
  297. output is of shape [batch_size x cell.output_size] when
  298. output_projection is not None (and represents the dense representation
  299. of predicted tokens). It is of shape [batch_size x num_decoder_symbols]
  300. when output_projection is None.
  301. state: The state of each decoder cell in each time-step. This is a list
  302. with length len(decoder_inputs) -- one item for each time-step.
  303. It is a 2D Tensor of shape [batch_size x cell.state_size].
  304. """
  305. with variable_scope.variable_scope(scope or "embedding_rnn_seq2seq") as scope:
  306. if dtype is not None:
  307. scope.set_dtype(dtype)
  308. else:
  309. dtype = scope.dtype
  310. # Encoder.
  311. encoder_cell = core_rnn_cell.EmbeddingWrapper(
  312. cell,
  313. embedding_classes=num_encoder_symbols,
  314. embedding_size=embedding_size)
  315. _, encoder_state = core_rnn.static_rnn(
  316. encoder_cell, encoder_inputs, dtype=dtype)
  317. # Decoder.
  318. if output_projection is None:
  319. cell = core_rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols)
  320. if isinstance(feed_previous, bool):
  321. return embedding_rnn_decoder(
  322. decoder_inputs,
  323. encoder_state,
  324. cell,
  325. num_decoder_symbols,
  326. embedding_size,
  327. output_projection=output_projection,
  328. feed_previous=feed_previous)
  329. # If feed_previous is a Tensor, we construct 2 graphs and use cond.
  330. def decoder(feed_previous_bool):
  331. reuse = None if feed_previous_bool else True
  332. with variable_scope.variable_scope(
  333. variable_scope.get_variable_scope(), reuse=reuse) as scope:
  334. outputs, state = embedding_rnn_decoder(
  335. decoder_inputs,
  336. encoder_state,
  337. cell,
  338. num_decoder_symbols,
  339. embedding_size,
  340. output_projection=output_projection,
  341. feed_previous=feed_previous_bool,
  342. update_embedding_for_previous=False)
  343. state_list = [state]
  344. if nest.is_sequence(state):
  345. state_list = nest.flatten(state)
  346. return outputs + state_list
  347. outputs_and_state = control_flow_ops.cond(feed_previous,
  348. lambda: decoder(True),
  349. lambda: decoder(False))
  350. outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs.
  351. state_list = outputs_and_state[outputs_len:]
  352. state = state_list[0]
  353. if nest.is_sequence(encoder_state):
  354. state = nest.pack_sequence_as(
  355. structure=encoder_state, flat_sequence=state_list)
  356. return outputs_and_state[:outputs_len], state
  357. def embedding_tied_rnn_seq2seq(encoder_inputs,
  358. decoder_inputs,
  359. cell,
  360. num_symbols,
  361. embedding_size,
  362. num_decoder_symbols=None,
  363. output_projection=None,
  364. feed_previous=False,
  365. dtype=None,
  366. scope=None):
  367. """Embedding RNN sequence-to-sequence model with tied (shared) parameters.
  368. This model first embeds encoder_inputs by a newly created embedding (of shape
  369. [num_symbols x input_size]). Then it runs an RNN to encode embedded
  370. encoder_inputs into a state vector. Next, it embeds decoder_inputs using
  371. the same embedding. Then it runs RNN decoder, initialized with the last
  372. encoder state, on embedded decoder_inputs. The decoder output is over symbols
  373. from 0 to num_decoder_symbols - 1 if num_decoder_symbols is none; otherwise it
  374. is over 0 to num_symbols - 1.
  375. Args:
  376. encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
  377. decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
  378. cell: core_rnn_cell.RNNCell defining the cell function and size.
  379. num_symbols: Integer; number of symbols for both encoder and decoder.
  380. embedding_size: Integer, the length of the embedding vector for each symbol.
  381. num_decoder_symbols: Integer; number of output symbols for decoder. If
  382. provided, the decoder output is over symbols 0 to num_decoder_symbols - 1.
  383. Otherwise, decoder output is over symbols 0 to num_symbols - 1. Note that
  384. this assumes that the vocabulary is set up such that the first
  385. num_decoder_symbols of num_symbols are part of decoding.
  386. output_projection: None or a pair (W, B) of output projection weights and
  387. biases; W has shape [output_size x num_symbols] and B has
  388. shape [num_symbols]; if provided and feed_previous=True, each
  389. fed previous output will first be multiplied by W and added B.
  390. feed_previous: Boolean or scalar Boolean Tensor; if True, only the first
  391. of decoder_inputs will be used (the "GO" symbol), and all other decoder
  392. inputs will be taken from previous outputs (as in embedding_rnn_decoder).
  393. If False, decoder_inputs are used as given (the standard decoder case).
  394. dtype: The dtype to use for the initial RNN states (default: tf.float32).
  395. scope: VariableScope for the created subgraph; defaults to
  396. "embedding_tied_rnn_seq2seq".
  397. Returns:
  398. A tuple of the form (outputs, state), where:
  399. outputs: A list of the same length as decoder_inputs of 2D Tensors with
  400. shape [batch_size x output_symbols] containing the generated
  401. outputs where output_symbols = num_decoder_symbols if
  402. num_decoder_symbols is not None otherwise output_symbols = num_symbols.
  403. state: The state of each decoder cell at the final time-step.
  404. It is a 2D Tensor of shape [batch_size x cell.state_size].
  405. Raises:
  406. ValueError: When output_projection has the wrong shape.
  407. """
  408. with variable_scope.variable_scope(
  409. scope or "embedding_tied_rnn_seq2seq", dtype=dtype) as scope:
  410. dtype = scope.dtype
  411. if output_projection is not None:
  412. proj_weights = ops.convert_to_tensor(output_projection[0], dtype=dtype)
  413. proj_weights.get_shape().assert_is_compatible_with([None, num_symbols])
  414. proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype)
  415. proj_biases.get_shape().assert_is_compatible_with([num_symbols])
  416. embedding = variable_scope.get_variable(
  417. "embedding", [num_symbols, embedding_size], dtype=dtype)
  418. emb_encoder_inputs = [
  419. embedding_ops.embedding_lookup(embedding, x) for x in encoder_inputs
  420. ]
  421. emb_decoder_inputs = [
  422. embedding_ops.embedding_lookup(embedding, x) for x in decoder_inputs
  423. ]
  424. output_symbols = num_symbols
  425. if num_decoder_symbols is not None:
  426. output_symbols = num_decoder_symbols
  427. if output_projection is None:
  428. cell = core_rnn_cell.OutputProjectionWrapper(cell, output_symbols)
  429. if isinstance(feed_previous, bool):
  430. loop_function = _extract_argmax_and_embed(embedding, output_projection,
  431. True) if feed_previous else None
  432. return tied_rnn_seq2seq(
  433. emb_encoder_inputs,
  434. emb_decoder_inputs,
  435. cell,
  436. loop_function=loop_function,
  437. dtype=dtype)
  438. # If feed_previous is a Tensor, we construct 2 graphs and use cond.
  439. def decoder(feed_previous_bool):
  440. loop_function = _extract_argmax_and_embed(
  441. embedding, output_projection, False) if feed_previous_bool else None
  442. reuse = None if feed_previous_bool else True
  443. with variable_scope.variable_scope(
  444. variable_scope.get_variable_scope(), reuse=reuse):
  445. outputs, state = tied_rnn_seq2seq(
  446. emb_encoder_inputs,
  447. emb_decoder_inputs,
  448. cell,
  449. loop_function=loop_function,
  450. dtype=dtype)
  451. state_list = [state]
  452. if nest.is_sequence(state):
  453. state_list = nest.flatten(state)
  454. return outputs + state_list
  455. outputs_and_state = control_flow_ops.cond(feed_previous,
  456. lambda: decoder(True),
  457. lambda: decoder(False))
  458. outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs.
  459. state_list = outputs_and_state[outputs_len:]
  460. state = state_list[0]
  461. # Calculate zero-state to know it's structure.
  462. static_batch_size = encoder_inputs[0].get_shape()[0]
  463. for inp in encoder_inputs[1:]:
  464. static_batch_size.merge_with(inp.get_shape()[0])
  465. batch_size = static_batch_size.value
  466. if batch_size is None:
  467. batch_size = array_ops.shape(encoder_inputs[0])[0]
  468. zero_state = cell.zero_state(batch_size, dtype)
  469. if nest.is_sequence(zero_state):
  470. state = nest.pack_sequence_as(
  471. structure=zero_state, flat_sequence=state_list)
  472. return outputs_and_state[:outputs_len], state
  473. def attention_decoder(decoder_inputs,
  474. initial_state,
  475. attention_states,
  476. cell,
  477. output_size=None,
  478. num_heads=1,
  479. loop_function=None,
  480. dtype=None,
  481. scope=None,
  482. initial_state_attention=False):
  483. """RNN decoder with attention for the sequence-to-sequence model.
  484. In this context "attention" means that, during decoding, the RNN can look up
  485. information in the additional tensor attention_states, and it does this by
  486. focusing on a few entries from the tensor. This model has proven to yield
  487. especially good results in a number of sequence-to-sequence tasks. This
  488. implementation is based on http://arxiv.org/abs/1412.7449 (see below for
  489. details). It is recommended for complex sequence-to-sequence tasks.
  490. Args:
  491. decoder_inputs: A list of 2D Tensors [batch_size x input_size].
  492. initial_state: 2D Tensor [batch_size x cell.state_size].
  493. attention_states: 3D Tensor [batch_size x attn_length x attn_size].
  494. cell: core_rnn_cell.RNNCell defining the cell function and size.
  495. output_size: Size of the output vectors; if None, we use cell.output_size.
  496. num_heads: Number of attention heads that read from attention_states.
  497. loop_function: If not None, this function will be applied to i-th output
  498. in order to generate i+1-th input, and decoder_inputs will be ignored,
  499. except for the first element ("GO" symbol). This can be used for decoding,
  500. but also for training to emulate http://arxiv.org/abs/1506.03099.
  501. Signature -- loop_function(prev, i) = next
  502. * prev is a 2D Tensor of shape [batch_size x output_size],
  503. * i is an integer, the step number (when advanced control is needed),
  504. * next is a 2D Tensor of shape [batch_size x input_size].
  505. dtype: The dtype to use for the RNN initial state (default: tf.float32).
  506. scope: VariableScope for the created subgraph; default: "attention_decoder".
  507. initial_state_attention: If False (default), initial attentions are zero.
  508. If True, initialize the attentions from the initial state and attention
  509. states -- useful when we wish to resume decoding from a previously
  510. stored decoder state and attention states.
  511. Returns:
  512. A tuple of the form (outputs, state), where:
  513. outputs: A list of the same length as decoder_inputs of 2D Tensors of
  514. shape [batch_size x output_size]. These represent the generated outputs.
  515. Output i is computed from input i (which is either the i-th element
  516. of decoder_inputs or loop_function(output {i-1}, i)) as follows.
  517. First, we run the cell on a combination of the input and previous
  518. attention masks:
  519. cell_output, new_state = cell(linear(input, prev_attn), prev_state).
  520. Then, we calculate new attention masks:
  521. new_attn = softmax(V^T * tanh(W * attention_states + U * new_state))
  522. and then we calculate the output:
  523. output = linear(cell_output, new_attn).
  524. state: The state of each decoder cell the final time-step.
  525. It is a 2D Tensor of shape [batch_size x cell.state_size].
  526. Raises:
  527. ValueError: when num_heads is not positive, there are no inputs, shapes
  528. of attention_states are not set, or input size cannot be inferred
  529. from the input.
  530. """
  531. if not decoder_inputs:
  532. raise ValueError("Must provide at least 1 input to attention decoder.")
  533. if num_heads < 1:
  534. raise ValueError("With less than 1 heads, use a non-attention decoder.")
  535. if attention_states.get_shape()[2].value is None:
  536. raise ValueError("Shape[2] of attention_states must be known: %s" %
  537. attention_states.get_shape())
  538. if output_size is None:
  539. output_size = cell.output_size
  540. with variable_scope.variable_scope(
  541. scope or "attention_decoder", dtype=dtype) as scope:
  542. dtype = scope.dtype
  543. batch_size = array_ops.shape(decoder_inputs[0])[0] # Needed for reshaping.
  544. attn_length = attention_states.get_shape()[1].value
  545. if attn_length is None:
  546. attn_length = array_ops.shape(attention_states)[1]
  547. attn_size = attention_states.get_shape()[2].value
  548. # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before.
  549. hidden = array_ops.reshape(attention_states,
  550. [-1, attn_length, 1, attn_size])
  551. hidden_features = []
  552. v = []
  553. attention_vec_size = attn_size # Size of query vectors for attention.
  554. for a in xrange(num_heads):
  555. k = variable_scope.get_variable("AttnW_%d" % a,
  556. [1, 1, attn_size, attention_vec_size])
  557. hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME"))
  558. v.append(
  559. variable_scope.get_variable("AttnV_%d" % a, [attention_vec_size]))
  560. state = initial_state
  561. def attention(query):
  562. """Put attention masks on hidden using hidden_features and query."""
  563. ds = [] # Results of attention reads will be stored here.
  564. if nest.is_sequence(query): # If the query is a tuple, flatten it.
  565. query_list = nest.flatten(query)
  566. for q in query_list: # Check that ndims == 2 if specified.
  567. ndims = q.get_shape().ndims
  568. if ndims:
  569. assert ndims == 2
  570. query = array_ops.concat_v2(query_list, 1)
  571. for a in xrange(num_heads):
  572. with variable_scope.variable_scope("Attention_%d" % a):
  573. y = linear(query, attention_vec_size, True)
  574. y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size])
  575. # Attention mask is a softmax of v^T * tanh(...).
  576. s = math_ops.reduce_sum(v[a] * math_ops.tanh(hidden_features[a] + y),
  577. [2, 3])
  578. a = nn_ops.softmax(s)
  579. # Now calculate the attention-weighted vector d.
  580. d = math_ops.reduce_sum(
  581. array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, [1, 2])
  582. ds.append(array_ops.reshape(d, [-1, attn_size]))
  583. return ds
  584. outputs = []
  585. prev = None
  586. batch_attn_size = array_ops.stack([batch_size, attn_size])
  587. attns = [
  588. array_ops.zeros(
  589. batch_attn_size, dtype=dtype) for _ in xrange(num_heads)
  590. ]
  591. for a in attns: # Ensure the second shape of attention vectors is set.
  592. a.set_shape([None, attn_size])
  593. if initial_state_attention:
  594. attns = attention(initial_state)
  595. for i, inp in enumerate(decoder_inputs):
  596. if i > 0:
  597. variable_scope.get_variable_scope().reuse_variables()
  598. # If loop_function is set, we use it instead of decoder_inputs.
  599. if loop_function is not None and prev is not None:
  600. with variable_scope.variable_scope("loop_function", reuse=True):
  601. inp = loop_function(prev, i)
  602. # Merge input and previous attentions into one vector of the right size.
  603. input_size = inp.get_shape().with_rank(2)[1]
  604. if input_size.value is None:
  605. raise ValueError("Could not infer input size from input: %s" % inp.name)
  606. x = linear([inp] + attns, input_size, True)
  607. # Run the RNN.
  608. cell_output, state = cell(x, state)
  609. # Run the attention mechanism.
  610. if i == 0 and initial_state_attention:
  611. with variable_scope.variable_scope(
  612. variable_scope.get_variable_scope(), reuse=True):
  613. attns = attention(state)
  614. else:
  615. attns = attention(state)
  616. with variable_scope.variable_scope("AttnOutputProjection"):
  617. output = linear([cell_output] + attns, output_size, True)
  618. if loop_function is not None:
  619. prev = output
  620. outputs.append(output)
  621. return outputs, state
  622. def embedding_attention_decoder(decoder_inputs,
  623. initial_state,
  624. attention_states,
  625. cell,
  626. num_symbols,
  627. embedding_size,
  628. num_heads=1,
  629. output_size=None,
  630. output_projection=None,
  631. feed_previous=False,
  632. update_embedding_for_previous=True,
  633. dtype=None,
  634. scope=None,
  635. initial_state_attention=False):
  636. """RNN decoder with embedding and attention and a pure-decoding option.
  637. Args:
  638. decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs).
  639. initial_state: 2D Tensor [batch_size x cell.state_size].
  640. attention_states: 3D Tensor [batch_size x attn_length x attn_size].
  641. cell: core_rnn_cell.RNNCell defining the cell function.
  642. num_symbols: Integer, how many symbols come into the embedding.
  643. embedding_size: Integer, the length of the embedding vector for each symbol.
  644. num_heads: Number of attention heads that read from attention_states.
  645. output_size: Size of the output vectors; if None, use output_size.
  646. output_projection: None or a pair (W, B) of output projection weights and
  647. biases; W has shape [output_size x num_symbols] and B has shape
  648. [num_symbols]; if provided and feed_previous=True, each fed previous
  649. output will first be multiplied by W and added B.
  650. feed_previous: Boolean; if True, only the first of decoder_inputs will be
  651. used (the "GO" symbol), and all other decoder inputs will be generated by:
  652. next = embedding_lookup(embedding, argmax(previous_output)),
  653. In effect, this implements a greedy decoder. It can also be used
  654. during training to emulate http://arxiv.org/abs/1506.03099.
  655. If False, decoder_inputs are used as given (the standard decoder case).
  656. update_embedding_for_previous: Boolean; if False and feed_previous=True,
  657. only the embedding for the first symbol of decoder_inputs (the "GO"
  658. symbol) will be updated by back propagation. Embeddings for the symbols
  659. generated from the decoder itself remain unchanged. This parameter has
  660. no effect if feed_previous=False.
  661. dtype: The dtype to use for the RNN initial states (default: tf.float32).
  662. scope: VariableScope for the created subgraph; defaults to
  663. "embedding_attention_decoder".
  664. initial_state_attention: If False (default), initial attentions are zero.
  665. If True, initialize the attentions from the initial state and attention
  666. states -- useful when we wish to resume decoding from a previously
  667. stored decoder state and attention states.
  668. Returns:
  669. A tuple of the form (outputs, state), where:
  670. outputs: A list of the same length as decoder_inputs of 2D Tensors with
  671. shape [batch_size x output_size] containing the generated outputs.
  672. state: The state of each decoder cell at the final time-step.
  673. It is a 2D Tensor of shape [batch_size x cell.state_size].
  674. Raises:
  675. ValueError: When output_projection has the wrong shape.
  676. """
  677. if output_size is None:
  678. output_size = cell.output_size
  679. if output_projection is not None:
  680. proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype)
  681. proj_biases.get_shape().assert_is_compatible_with([num_symbols])
  682. with variable_scope.variable_scope(
  683. scope or "embedding_attention_decoder", dtype=dtype) as scope:
  684. embedding = variable_scope.get_variable("embedding",
  685. [num_symbols, embedding_size])
  686. loop_function = _extract_argmax_and_embed(
  687. embedding, output_projection,
  688. update_embedding_for_previous) if feed_previous else None
  689. emb_inp = [
  690. embedding_ops.embedding_lookup(embedding, i) for i in decoder_inputs
  691. ]
  692. return attention_decoder(
  693. emb_inp,
  694. initial_state,
  695. attention_states,
  696. cell,
  697. output_size=output_size,
  698. num_heads=num_heads,
  699. loop_function=loop_function,
  700. initial_state_attention=initial_state_attention)
  701. def embedding_attention_seq2seq(encoder_inputs,
  702. decoder_inputs,
  703. cell,
  704. num_encoder_symbols,
  705. num_decoder_symbols,
  706. embedding_size,
  707. num_heads=1,
  708. output_projection=None,
  709. feed_previous=False,
  710. dtype=None,
  711. scope=None,
  712. initial_state_attention=False):
  713. """Embedding sequence-to-sequence model with attention.
  714. This model first embeds encoder_inputs by a newly created embedding (of shape
  715. [num_encoder_symbols x input_size]). Then it runs an RNN to encode
  716. embedded encoder_inputs into a state vector. It keeps the outputs of this
  717. RNN at every step to use for attention later. Next, it embeds decoder_inputs
  718. by another newly created embedding (of shape [num_decoder_symbols x
  719. input_size]). Then it runs attention decoder, initialized with the last
  720. encoder state, on embedded decoder_inputs and attending to encoder outputs.
  721. Warning: when output_projection is None, the size of the attention vectors
  722. and variables will be made proportional to num_decoder_symbols, can be large.
  723. Args:
  724. encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
  725. decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
  726. cell: core_rnn_cell.RNNCell defining the cell function and size.
  727. num_encoder_symbols: Integer; number of symbols on the encoder side.
  728. num_decoder_symbols: Integer; number of symbols on the decoder side.
  729. embedding_size: Integer, the length of the embedding vector for each symbol.
  730. num_heads: Number of attention heads that read from attention_states.
  731. output_projection: None or a pair (W, B) of output projection weights and
  732. biases; W has shape [output_size x num_decoder_symbols] and B has
  733. shape [num_decoder_symbols]; if provided and feed_previous=True, each
  734. fed previous output will first be multiplied by W and added B.
  735. feed_previous: Boolean or scalar Boolean Tensor; if True, only the first
  736. of decoder_inputs will be used (the "GO" symbol), and all other decoder
  737. inputs will be taken from previous outputs (as in embedding_rnn_decoder).
  738. If False, decoder_inputs are used as given (the standard decoder case).
  739. dtype: The dtype of the initial RNN state (default: tf.float32).
  740. scope: VariableScope for the created subgraph; defaults to
  741. "embedding_attention_seq2seq".
  742. initial_state_attention: If False (default), initial attentions are zero.
  743. If True, initialize the attentions from the initial state and attention
  744. states.
  745. Returns:
  746. A tuple of the form (outputs, state), where:
  747. outputs: A list of the same length as decoder_inputs of 2D Tensors with
  748. shape [batch_size x num_decoder_symbols] containing the generated
  749. outputs.
  750. state: The state of each decoder cell at the final time-step.
  751. It is a 2D Tensor of shape [batch_size x cell.state_size].
  752. """
  753. with variable_scope.variable_scope(
  754. scope or "embedding_attention_seq2seq", dtype=dtype) as scope:
  755. dtype = scope.dtype
  756. # Encoder.
  757. encoder_cell = core_rnn_cell.EmbeddingWrapper(
  758. cell,
  759. embedding_classes=num_encoder_symbols,
  760. embedding_size=embedding_size)
  761. encoder_outputs, encoder_state = core_rnn.static_rnn(
  762. encoder_cell, encoder_inputs, dtype=dtype)
  763. # First calculate a concatenation of encoder outputs to put attention on.
  764. top_states = [
  765. array_ops.reshape(e, [-1, 1, cell.output_size]) for e in encoder_outputs
  766. ]
  767. attention_states = array_ops.concat_v2(top_states, 1)
  768. # Decoder.
  769. output_size = None
  770. if output_projection is None:
  771. cell = core_rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols)
  772. output_size = num_decoder_symbols
  773. if isinstance(feed_previous, bool):
  774. return embedding_attention_decoder(
  775. decoder_inputs,
  776. encoder_state,
  777. attention_states,
  778. cell,
  779. num_decoder_symbols,
  780. embedding_size,
  781. num_heads=num_heads,
  782. output_size=output_size,
  783. output_projection=output_projection,
  784. feed_previous=feed_previous,
  785. initial_state_attention=initial_state_attention)
  786. # If feed_previous is a Tensor, we construct 2 graphs and use cond.
  787. def decoder(feed_previous_bool):
  788. reuse = None if feed_previous_bool else True
  789. with variable_scope.variable_scope(
  790. variable_scope.get_variable_scope(), reuse=reuse) as scope:
  791. outputs, state = embedding_attention_decoder(
  792. decoder_inputs,
  793. encoder_state,
  794. attention_states,
  795. cell,
  796. num_decoder_symbols,
  797. embedding_size,
  798. num_heads=num_heads,
  799. output_size=output_size,
  800. output_projection=output_projection,
  801. feed_previous=feed_previous_bool,
  802. update_embedding_for_previous=False,
  803. initial_state_attention=initial_state_attention)
  804. state_list = [state]
  805. if nest.is_sequence(state):
  806. state_list = nest.flatten(state)
  807. return outputs + state_list
  808. outputs_and_state = control_flow_ops.cond(feed_previous,
  809. lambda: decoder(True),
  810. lambda: decoder(False))
  811. outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs.
  812. state_list = outputs_and_state[outputs_len:]
  813. state = state_list[0]
  814. if nest.is_sequence(encoder_state):
  815. state = nest.pack_sequence_as(
  816. structure=encoder_state, flat_sequence=state_list)
  817. return outputs_and_state[:outputs_len], state
  818. def one2many_rnn_seq2seq(encoder_inputs,
  819. decoder_inputs_dict,
  820. cell,
  821. num_encoder_symbols,
  822. num_decoder_symbols_dict,
  823. embedding_size,
  824. feed_previous=False,
  825. dtype=None,
  826. scope=None):
  827. """One-to-many RNN sequence-to-sequence model (multi-task).
  828. This is a multi-task sequence-to-sequence model with one encoder and multiple
  829. decoders. Reference to multi-task sequence-to-sequence learning can be found
  830. here: http://arxiv.org/abs/1511.06114
  831. Args:
  832. encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
  833. decoder_inputs_dict: A dictionany mapping decoder name (string) to
  834. the corresponding decoder_inputs; each decoder_inputs is a list of 1D
  835. Tensors of shape [batch_size]; num_decoders is defined as
  836. len(decoder_inputs_dict).
  837. cell: core_rnn_cell.RNNCell defining the cell function and size.
  838. num_encoder_symbols: Integer; number of symbols on the encoder side.
  839. num_decoder_symbols_dict: A dictionary mapping decoder name (string) to an
  840. integer specifying number of symbols for the corresponding decoder;
  841. len(num_decoder_symbols_dict) must be equal to num_decoders.
  842. embedding_size: Integer, the length of the embedding vector for each symbol.
  843. feed_previous: Boolean or scalar Boolean Tensor; if True, only the first of
  844. decoder_inputs will be used (the "GO" symbol), and all other decoder
  845. inputs will be taken from previous outputs (as in embedding_rnn_decoder).
  846. If False, decoder_inputs are used as given (the standard decoder case).
  847. dtype: The dtype of the initial state for both the encoder and encoder
  848. rnn cells (default: tf.float32).
  849. scope: VariableScope for the created subgraph; defaults to
  850. "one2many_rnn_seq2seq"
  851. Returns:
  852. A tuple of the form (outputs_dict, state_dict), where:
  853. outputs_dict: A mapping from decoder name (string) to a list of the same
  854. length as decoder_inputs_dict[name]; each element in the list is a 2D
  855. Tensors with shape [batch_size x num_decoder_symbol_list[name]]
  856. containing the generated outputs.
  857. state_dict: A mapping from decoder name (string) to the final state of the
  858. corresponding decoder RNN; it is a 2D Tensor of shape
  859. [batch_size x cell.state_size].
  860. """
  861. outputs_dict = {}
  862. state_dict = {}
  863. with variable_scope.variable_scope(
  864. scope or "one2many_rnn_seq2seq", dtype=dtype) as scope:
  865. dtype = scope.dtype
  866. # Encoder.
  867. encoder_cell = core_rnn_cell.EmbeddingWrapper(
  868. cell,
  869. embedding_classes=num_encoder_symbols,
  870. embedding_size=embedding_size)
  871. _, encoder_state = core_rnn.static_rnn(
  872. encoder_cell, encoder_inputs, dtype=dtype)
  873. # Decoder.
  874. for name, decoder_inputs in decoder_inputs_dict.items():
  875. num_decoder_symbols = num_decoder_symbols_dict[name]
  876. with variable_scope.variable_scope("one2many_decoder_" + str(
  877. name)) as scope:
  878. decoder_cell = core_rnn_cell.OutputProjectionWrapper(
  879. cell, num_decoder_symbols)
  880. if isinstance(feed_previous, bool):
  881. outputs, state = embedding_rnn_decoder(
  882. decoder_inputs,
  883. encoder_state,
  884. decoder_cell,
  885. num_decoder_symbols,
  886. embedding_size,
  887. feed_previous=feed_previous)
  888. else:
  889. # If feed_previous is a Tensor, we construct 2 graphs and use cond.
  890. def filled_embedding_rnn_decoder(feed_previous):
  891. """The current decoder with a fixed feed_previous parameter."""
  892. # pylint: disable=cell-var-from-loop
  893. reuse = None if feed_previous else True
  894. vs = variable_scope.get_variable_scope()
  895. with variable_scope.variable_scope(vs, reuse=reuse):
  896. outputs, state = embedding_rnn_decoder(
  897. decoder_inputs,
  898. encoder_state,
  899. decoder_cell,
  900. num_decoder_symbols,
  901. embedding_size,
  902. feed_previous=feed_previous)
  903. # pylint: enable=cell-var-from-loop
  904. state_list = [state]
  905. if nest.is_sequence(state):
  906. state_list = nest.flatten(state)
  907. return outputs + state_list
  908. outputs_and_state = control_flow_ops.cond(
  909. feed_previous, lambda: filled_embedding_rnn_decoder(True),
  910. lambda: filled_embedding_rnn_decoder(False))
  911. # Outputs length is the same as for decoder inputs.
  912. outputs_len = len(decoder_inputs)
  913. outputs = outputs_and_state[:outputs_len]
  914. state_list = outputs_and_state[outputs_len:]
  915. state = state_list[0]
  916. if nest.is_sequence(encoder_state):
  917. state = nest.pack_sequence_as(
  918. structure=encoder_state, flat_sequence=state_list)
  919. outputs_dict[name] = outputs
  920. state_dict[name] = state
  921. return outputs_dict, state_dict
  922. def sequence_loss_by_example(logits,
  923. targets,
  924. weights,
  925. average_across_timesteps=True,
  926. softmax_loss_function=None,
  927. name=None):
  928. """Weighted cross-entropy loss for a sequence of logits (per example).
  929. Args:
  930. logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols].
  931. targets: List of 1D batch-sized int32 Tensors of the same length as logits.
  932. weights: List of 1D batch-sized float-Tensors of the same length as logits.
  933. average_across_timesteps: If set, divide the returned cost by the total
  934. label weight.
  935. softmax_loss_function: Function (labels-batch, inputs-batch) -> loss-batch
  936. to be used instead of the standard softmax (the default if this is None).
  937. name: Optional name for this operation, default: "sequence_loss_by_example".
  938. Returns:
  939. 1D batch-sized float Tensor: The log-perplexity for each sequence.
  940. Raises:
  941. ValueError: If len(logits) is different from len(targets) or len(weights).
  942. """
  943. if len(targets) != len(logits) or len(weights) != len(logits):
  944. raise ValueError("Lengths of logits, weights, and targets must be the same "
  945. "%d, %d, %d." % (len(logits), len(weights), len(targets)))
  946. with ops.name_scope(name, "sequence_loss_by_example",
  947. logits + targets + weights):
  948. log_perp_list = []
  949. for logit, target, weight in zip(logits, targets, weights):
  950. if softmax_loss_function is None:
  951. # TODO(irving,ebrevdo): This reshape is needed because
  952. # sequence_loss_by_example is called with scalars sometimes, which
  953. # violates our general scalar strictness policy.
  954. target = array_ops.reshape(target, [-1])
  955. crossent = nn_ops.sparse_softmax_cross_entropy_with_logits(
  956. logits=logit, labels=target)
  957. else:
  958. crossent = softmax_loss_function(target, logit)
  959. log_perp_list.append(crossent * weight)
  960. log_perps = math_ops.add_n(log_perp_list)
  961. if average_across_timesteps:
  962. total_size = math_ops.add_n(weights)
  963. total_size += 1e-12 # Just to avoid division by 0 for all-0 weights.
  964. log_perps /= total_size
  965. return log_perps
  966. def sequence_loss(logits,
  967. targets,
  968. weights,
  969. average_across_timesteps=True,
  970. average_across_batch=True,
  971. softmax_loss_function=None,
  972. name=None):
  973. """Weighted cross-entropy loss for a sequence of logits, batch-collapsed.
  974. Args:
  975. logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols].
  976. targets: List of 1D batch-sized int32 Tensors of the same length as logits.
  977. weights: List of 1D batch-sized float-Tensors of the same length as logits.
  978. average_across_timesteps: If set, divide the returned cost by the total
  979. label weight.
  980. average_across_batch: If set, divide the returned cost by the batch size.
  981. softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch
  982. to be used instead of the standard softmax (the default if this is None).
  983. name: Optional name for this operation, defaults to "sequence_loss".
  984. Returns:
  985. A scalar float Tensor: The average log-perplexity per symbol (weighted).
  986. Raises:
  987. ValueError: If len(logits) is different from len(targets) or len(weights).
  988. """
  989. with ops.name_scope(name, "sequence_loss", logits + targets + weights):
  990. cost = math_ops.reduce_sum(
  991. sequence_loss_by_example(
  992. logits,
  993. targets,
  994. weights,
  995. average_across_timesteps=average_across_timesteps,
  996. softmax_loss_function=softmax_loss_function))
  997. if average_across_batch:
  998. batch_size = array_ops.shape(targets[0])[0]
  999. return cost / math_ops.cast(batch_size, cost.dtype)
  1000. else:
  1001. return cost
  1002. def model_with_buckets(encoder_inputs,
  1003. decoder_inputs,
  1004. targets,
  1005. weights,
  1006. buckets,
  1007. seq2seq,
  1008. softmax_loss_function=None,
  1009. per_example_loss=False,
  1010. name=None):
  1011. """Create a sequence-to-sequence model with support for bucketing.
  1012. The seq2seq argument is a function that defines a sequence-to-sequence model,
  1013. e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(
  1014. x, y, core_rnn_cell.GRUCell(24))
  1015. Args:
  1016. encoder_inputs: A list of Tensors to feed the encoder; first seq2seq input.
  1017. decoder_inputs: A list of Tensors to feed the decoder; second seq2seq input.
  1018. targets: A list of 1D batch-sized int32 Tensors (desired output sequence).
  1019. weights: List of 1D batch-sized float-Tensors to weight the targets.
  1020. buckets: A list of pairs of (input size, output size) for each bucket.
  1021. seq2seq: A sequence-to-sequence model function; it takes 2 input that
  1022. agree with encoder_inputs and decoder_inputs, and returns a pair
  1023. consisting of outputs and states (as, e.g., basic_rnn_seq2seq).
  1024. softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch
  1025. to be used instead of the standard softmax (the default if this is None).
  1026. per_example_loss: Boolean. If set, the returned loss will be a batch-sized
  1027. tensor of losses for each sequence in the batch. If unset, it will be
  1028. a scalar with the averaged loss from all examples.
  1029. name: Optional name for this operation, defaults to "model_with_buckets".
  1030. Returns:
  1031. A tuple of the form (outputs, losses), where:
  1032. outputs: The outputs for each bucket. Its j'th element consists of a list
  1033. of 2D Tensors. The shape of output tensors can be either
  1034. [batch_size x output_size] or [batch_size x num_decoder_symbols]
  1035. depending on the seq2seq model used.
  1036. losses: List of scalar Tensors, representing losses for each bucket, or,
  1037. if per_example_loss is set, a list of 1D batch-sized float Tensors.
  1038. Raises:
  1039. ValueError: If length of encoder_inputsut, targets, or weights is smaller
  1040. than the largest (last) bucket.
  1041. """
  1042. if len(encoder_inputs) < buckets[-1][0]:
  1043. raise ValueError("Length of encoder_inputs (%d) must be at least that of la"
  1044. "st bucket (%d)." % (len(encoder_inputs), buckets[-1][0]))
  1045. if len(targets) < buckets[-1][1]:
  1046. raise ValueError("Length of targets (%d) must be at least that of last"
  1047. "bucket (%d)." % (len(targets), buckets[-1][1]))
  1048. if len(weights) < buckets[-1][1]:
  1049. raise ValueError("Length of weights (%d) must be at least that of last"
  1050. "bucket (%d)." % (len(weights), buckets[-1][1]))
  1051. all_inputs = encoder_inputs + decoder_inputs + targets + weights
  1052. losses = []
  1053. outputs = []
  1054. with ops.name_scope(name, "model_with_buckets", all_inputs):
  1055. for j, bucket in enumerate(buckets):
  1056. with variable_scope.variable_scope(
  1057. variable_scope.get_variable_scope(), reuse=True if j > 0 else None):
  1058. bucket_outputs, _ = seq2seq(encoder_inputs[:bucket[0]],
  1059. decoder_inputs[:bucket[1]])
  1060. outputs.append(bucket_outputs)
  1061. if per_example_loss:
  1062. losses.append(
  1063. sequence_loss_by_example(
  1064. outputs[-1],
  1065. targets[:bucket[1]],
  1066. weights[:bucket[1]],
  1067. softmax_loss_function=softmax_loss_function))
  1068. else:
  1069. losses.append(
  1070. sequence_loss(
  1071. outputs[-1],
  1072. targets[:bucket[1]],
  1073. weights[:bucket[1]],
  1074. softmax_loss_function=softmax_loss_function))
  1075. return outputs, losses

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/281722
推荐阅读
相关标签
  

闽ICP备14008679号