本文实现的模型来自于论文:《Dynamic Memory Networks for Visual and Textual Question Answering》


实验数据比较多,图片用的是COCO的,文本标注是VQA 1.0的,另外还用到了vgg16,所以需要准备的东西也非常多。




Neural network architectures with memory and attention mechanisms exhibit certain reasoning capabilities required for question answering. One such architecture, the dynamic memory network (DMN), obtained high accuracy on a variety of language tasks. However, it was not shown whether the architecture achieves strong results for question answering when supporting facts are not marked during training or whether it could be applied to other modalities such as images. Based on an analysis of the DMN, we propose several improvements to its memory and input modules. Together with these changes we introduce a novel input module for images in order to be able to answer visual questions. Our new DMN+ model improves the state of the art on both the Visual Question Answering dataset and the bAbI-10k text question-answering dataset without supporting fact supervision.


记忆机制能够一定程度的处理推理关系,注意力机制也成功用在了机器翻译和看图说话的模型中。而动态记忆网络DMN(Dynamic Memory Network)就是一个同时使用了这两种机制的神经网络模型。本文分了了DMN的组成,主要是输入模块和记忆模块,以改善问答的效果。作者的主要创新在于:

We propose a new input module which uses a two level encoder with a sentence reader and input fusion layer to allow for information flow between sentences. (提出了一个输入模块。该模块使用了由句子阅读器和输入融合层构成的二级编码器,以实现信息在句子间的流动)

For the memory, we propose a modification to gated recurrent units (GRU). The new GRU formulation incorporates attention gates that are computed using global knowledge over the facts.(对于记忆模块,使用了一个修改的GRU单元。新的GRU公式包含了注意力门,该注意力门由信息中的全局知识计算得到)

In addition, we introduce a new input module to represent images.(另外还提出了一个输入模块来表示图像)



1. 动态记忆网络DMN(Dynamic Memory Networks)

DMN是一个用于问答的模型,它由多个不同的模块组成,比如输入表示模块,记忆模块等。部分模块如前面的Figure 1所表示的那样:

输入模块(Input Module):该模块用于处理输入的数据,主要是将提问的问题转化为一组向量,并称为facts,记为F = [f_{1}, ······ , f_{N}],N表示所有facts的数量,这组向量可以用于后续的模块中。该输入模块在文本问答中是由许多GRU单元组成,若用x_{i}表示第i步的输入,前一步的隐含状态(hidden state)为h_{i-1},那么当前步的隐含状态h_{i}=GRU(x_{i}, h_{i-1})可以由下式计算出:

问题模块(Question Module):该模块用于计算问题q的向量表示,其中q是问题中的单词在GRU中的最终隐含状态。

情景记忆模块(Episodic Memory Module):情景记忆旨在从输入的facts中检索能够回答问题q的信息。为了改进对输入问题的理解,特别是当问题需要涉及推理,情景记忆模块会多次传递输入,每次传递后都会更新情景记忆。情景记忆模块由两个部分组成:注意力机制和记忆更新机制。注意力机制用于产生上下文向量c^{t}和通过问题q和前一次的情景记忆m^{t-1}推断的相关性,其中c^{t}是对传入t的相关输入的总结。记忆更新机制是基于上下文向量c^{t}和前一次的情景记忆m^{t-1},生成新的情景记忆m^{t}

回答模块(Answer Module):回答模块是根据qm^{t}来生成模型的预测答案。对于一个简单的答案,比如一个单词,使用一个softmax激活的线性层即可。对于复杂答案,比如一个句子,可以用RNN来解码问题q和情景记忆m^{t}链接a=[q;m^{t}]至单词的有序排列。训练过程可以采用交叉熵。

2. 改进的动态记忆网络DMN+(Improved Dynamic Memory Networks: DMN+)

作者提出的模型主要修改了两个地方:一个是输入表示,另一个是注意力机制和输入更新。文本问答(Text QA)和视觉问答(VQA)不同之处仅在于输入表示。

(1)文本问答的输入模块(Input Module for Text QA)

在DMN中,GRU被用于处理文本中的单词,提取句子的表示。但是,该方法在有支持信息(supporting facts)的bAbI-1k数据集上表现较好,在没有支持信息(supporting facts)的bAbI-10k数据集上表现很差。作者推测有两个原因:首先,GRU只允许句子前面的句子有上下文,但不允许后面的有。这组织了信息在后面句子中的传播。其次,支持语句可能在词级相距过于遥远,使得这些遥远的句子无法通过词级GRU相互作用。

输入融合层(Input Fusion Layer):在DMN+中,我们用两个部分替代了单个GRU。第一个是句子阅读器(sentence reader),用于将单词编码为句子嵌入;第二部分是输入融合层(input fusion layer),能够进行句子间的交互,它类似于分层神经自编码器结构。输入融合层采用双向GRU(bi-directional GRU),它能实现信息在句子前后的传递。因为梯度不需要借助句子间的单词进行传播,融合层也能实现远距离的支持语句(supporting sentences)进行更多的直接交互。

下图的Fig 2表示了一个输入模块,句子阅读器(sentence reader)使用了位置编码器(positional encoder),输入融合曾使用了双向GRU,每一个句子编码f_{i}是使用词特征[w_{1}^{i}, ......, w_{M_{i}}^{i}]的编码输出,其中M_{i}是句子的长度:

句子阅读器(sentence reader)能够使用多种编码机制。这里作者使用了位置编码。由于GRU和LSTM需要更多的计算资源而且在大量任务情况下(比如重构原始语句)容易过拟合,因此没有使用。

对于位置编码机制,句子表示可以利用公式:fi=j=1Mljwij生成,其中\circ表示矩阵点乘,l_{j}是一个列向量,且l_{jd}=\left ( 1-j/M \right )-(d/D)(1-2j/M),d是嵌入的索引,D是嵌入的维度。


(2)VQA的输入模块(Input Module for VQA)


局部区域特征提取(Local region feature extraction):基于VGG19来提取图像特征。首先将输入图像resize成448*448的,然后拿出最后一个池化层,此时维度为d=512*14*14,池化层将图像分成了14*14的小格网,因此原图对应196个局部区域,每个区域的维度是512。

视觉特征嵌入(Visual feature embedding):VQA涉及到图像和文本,作者添加了一个tanh激活的线性层,映射局部区域向量到问题向量q使用的文本特征空间。

输入融合层(Input fusion layer):前面做的局部特征提取,并没有提取全局特征,缺乏全局特征的表示能力是有限的,为了解决这个问题,作者加入了一个输入融合层,首先,生成一个输入的facts F,用蛇形遍历图像;然后在这些输入的facts F上用双向GRU来生成一个全局感知的输入facts,双向GRU能够使得信息在图像相邻的网格中传递信息。

(3)情景记忆模块(The Episodic Memory Module)


将注意力集中在这些facts的一个子集上以从输入的facts中检索信息。通过结合单个标量注意力门g_{i}^{t}(attention gate)和传递t中的每一个fact来实现该注意力。它的计算是通过fact和问题表示与情景记忆状态之间的交互实现:



软注意力(Soft attention):软注意力生成上下文向量c^{t},是通过向量\underset{F}{\rightarrow }排序列表相应的注意力门g_{i}^{t}加权求和得到的,即:ct=Ni=1gtifi。这种方法有两个优点:第一,易于计算;第二,如果softmax激活是尖峰的,它可以通过仅为上下文向量选择一个fact来近似一个硬注意函数,且它仍然是可微的。

GRU注意力(Attention based GRU):对于复杂的查询,我们希望注意力对输入fact的位置和次序更敏感,在这种情况下RNN是更好的选择。我们提出了一个修改的GRU结构,它从注意力机制中嵌入信息。公式(1)中的更新门u_{i}决定了隐含状态的每一个维度的多少来保持,和多少来更新输入的x_{i}。因为u_{i}的计算只涉及当前的输入和前一步的隐含状态,它缺乏来自前一个情景记忆或者问题的知识。




情景记忆更新(Episode Memory Updates)


有文献建议,每次传递使用不同的权值更新场景记忆会更好一些。当模型只包含一组权值时,对于所有经过输入的情景,它被称为tied model,如Table 1中的“Mem Weights”行所示:



3. 数据集


bAbI-10k:该合成数据集包含20种任务,每一个样本包含facts,问题,和答案,还有提示答案的supporting facts。

DAQUAR-ALL visual dataset:The DAtaset for QUestion Answering on Real-world images (DAQUAR)包含795张训练图像和654张测试图像,6795个训练问题和5673个测试问题。

VQA:VQA 1.0数据集

4. 实验


这里作者进行了DMN及其模型变种的分析,ODMN为原始的DMN模型;DMN2是用输入融合层取代输入模块;DMN3在DMN2的基础上,用GRU注意力取代了软注意力;DMN+是在DMN3的基础上,在每次传递的过程中,使用线性层ReLU和一组权值来更新记忆。所有的实验结果如Table 1中所示。


对于该数据集,优化器采用Adam,学习率为0.001,batch size为128,epoch设置为256,如果验证的loss在20个epoch内不改变,则提前结束模型。词嵌入初始化采用随机均匀分布,范围为[-\sqrt{3}, \sqrt{3}],其余的权值初始化都采用Xavier。hidden size都设置为d=80,所有权值采用l2范数约束。对于dropout来说,保留输入的90%。场景记忆模块中,一共传递3次。不同task及其结果如下图所示:



优化器用Adam,学习率0.003,batch size为100,一共256个epochs,如果验证loss在10个epochs内没有改变,则提前结束训练。权值初始化采用随机均匀分布,范围[-0.08 0.08]。hidden size设置为d=512,dropout=0.5。结果主要展示在Fig 6中。




  1. |------ base_model.py
  2. |------ config.py
  3. |------ dataset.py
  4. |------ episodic_memory.py
  5. |------ main.py
  6. |------ model.py
  7. |------ vgg16_no_fc.npy # 需要额外下载
  8. |------ utils
  9. |------ vqa
  10. |------ vqa.py
  11. |------ vqaEval.py
  12. |------ vocabulary.py
  13. |------ misc.py
  14. |------ nn.py
  15. |------ train # 需要额外下载
  16. |------ images
  17. |------ image01.jpg
  18. |------ image01.jpg
  19. |------ ......
  20. |------ mscoco_train2014_annotations.json
  21. |------ OpenEnded_mscoco_train2014_questions.json
  22. |------ val # 需要额外下载
  23. |------ images
  24. |------ image01.jpg
  25. |------ image01.jpg
  26. |------ ......
  27. |------ mscoco_val2014_annotations.json
  28. |------ OpenEnded_mscoco_val2014_questions.json
  29. |------ test
  30. |------ images

1. 实验环境


python 3.7

GPU: GTX 1050TI 4G

tensorflow 1.14

numpy 1.16.2

opencv 3.4.1

Natural Language Toolkit (NLTK) 3.4

Pandas 0.24.2

Matplotlib 3.0.3

tqdm 4.31.1

2. 数据集准备










作者给出的预训练的VGG16模型的下载链接点不开,所以通过其他渠道找到了该文件,我把它上传到自己的网盘上了:链接:https://pan.baidu.com/s/1jPzXKZIXbNnknT7Nubh3yw  提取码:nms0 


3. 代码




  1. class Config(object):
  2. """ Wrapper class for various (hyper)parameters. """
  3. def __init__(self):
  4. # 模型结构
  5. self.cnn = 'vgg16' # 'vgg16' 或 'resnet50'
  6. self.max_question_length = 30
  7. self.dim_embedding = 512
  8. self.num_gru_units = 512
  9. self.memory_step = 3
  10. self.memory_update = 'relu' # 'gru' 或 'relu'
  11. self.attention = 'gru' # 'gru' 或 'soft',消融实验的时候可以设置为soft
  12. self.tie_memory_weight = False
  13. self.question_encoding = 'gru' # 'gru' 或 'positional'
  14. self.embed_fact = False
  15. # 权值初始化和正则化
  16. self.fc_kernel_initializer_scale = 0.08
  17. self.fc_kernel_regularizer_scale = 1e-6
  18. self.fc_activity_regularizer_scale = 0.0
  19. self.conv_kernel_regularizer_scale = 1e-6
  20. self.conv_activity_regularizer_scale = 0.0
  21. self.fc_drop_rate = 0.5
  22. self.gru_drop_rate = 0.3
  23. # 优化
  24. self.num_epochs = 100
  25. self.batch_size = 4
  26. self.optimizer = 'Adam' # 'Adam', 'RMSProp', 'Momentum' or 'SGD'
  27. self.initial_learning_rate = 0.0001
  28. self.learning_rate_decay_factor = 1.0
  29. self.num_steps_per_decay = 10000
  30. self.clip_gradients = 10.0
  31. self.momentum = 0.0
  32. self.use_nesterov = True
  33. self.decay = 0.9
  34. self.centered = True
  35. self.beta1 = 0.9
  36. self.beta2 = 0.999
  37. self.epsilon = 1e-5
  38. # 存储器
  39. self.save_period = 1000
  40. self.save_dir = './models/'
  41. self.summary_dir = './summary/'
  42. # 词汇表
  43. self.vocabulary_file = './vocabulary.csv'
  44. # 训练
  45. self.train_image_dir = './train/images/'
  46. self.train_question_file = './train/OpenEnded_mscoco_train2014_questions.json'
  47. self.train_answer_file = './train/mscoco_train2014_annotations.json'
  48. self.temp_train_annotation_file = './train/anns.csv'
  49. self.temp_train_data_file = './train/data.npy'
  50. # 评价
  51. self.eval_image_dir = './val/images/'
  52. self.eval_question_file = './val/OpenEnded_mscoco_val2014_questions.json'
  53. self.eval_answer_file = './val/mscoco_val2014_annotations.json'
  54. self.temp_eval_annotation_file = './val/anns.csv'
  55. self.temp_eval_data_file = './val/data.npy'
  56. self.eval_result_dir = './val/results/'
  57. self.eval_result_file = './val/results.json'
  58. self.save_eval_result_as_image = False
  59. # 测试
  60. self.test_image_dir = './test/images/'
  61. self.test_question_file = './test/questions.csv'
  62. self.temp_test_info_file = './test/info.csv'
  63. self.test_result_dir = './test/results/'
  64. self.test_result_file = './test/results.csv'




  1. import numpy as np
  2. import cv2
  3. class ImageLoader(object):
  4. def __init__(self, mean_file):
  5. self.bgr = True
  6. self.scale_shape = np.array([224, 224], np.int32)
  7. self.crop_shape = np.array([224, 224], np.int32)
  8. self.mean = np.load(mean_file).mean(1).mean(1)
  9. def load_image(self, image_file):
  10. """ Load and preprocess an image. """
  11. image = cv2.imread(image_file)
  12. if self.bgr:
  13. temp = image.swapaxes(0, 2)
  14. temp = temp[::-1] # 变成rgb
  15. image = temp.swapaxes(0, 2)
  16. image = cv2.resize(image, (self.scale_shape[0], self.scale_shape[1]))
  17. offset = (self.scale_shape - self.crop_shape) / 2
  18. offset = offset.astype(np.int32)
  19. image = image[offset[0]:offset[0]+self.crop_shape[0],
  20. offset[1]:offset[1]+self.crop_shape[1]]
  21. image = image - self.mean
  22. return image
  23. def load_images(self, image_files):
  24. """ Load and preprocess a list of images. """
  25. images = []
  26. for image_file in image_files:
  27. images.append(self.load_image(image_file))
  28. images = np.array(images, np.float32)
  29. return images


  1. import tensorflow as tf
  2. import tensorflow.contrib.layers as layers
  3. class NN(object):
  4. def __init__(self, config):
  5. self.config = config
  6. self.is_train = True if config.phase == 'train' else False
  7. self.train_cnn = self.is_train and config.train_cnn
  8. self.prepare()
  9. def prepare(self):
  10. """ Setup the weight initalizers and regularizers. """
  11. config = self.config
  12. self.conv_kernel_initializer = layers.xavier_initializer()
  13. if self.train_cnn and config.conv_kernel_regularizer_scale > 0:
  14. self.conv_kernel_regularizer = layers.l2_regularizer(
  15. scale = config.conv_kernel_regularizer_scale)
  16. else:
  17. self.conv_kernel_regularizer = None
  18. if self.train_cnn and config.conv_activity_regularizer_scale > 0:
  19. self.conv_activity_regularizer = layers.l1_regularizer(
  20. scale = config.conv_activity_regularizer_scale)
  21. else:
  22. self.conv_activity_regularizer = None
  23. self.fc_kernel_initializer = tf.random_uniform_initializer(
  24. minval = -config.fc_kernel_initializer_scale,
  25. maxval = config.fc_kernel_initializer_scale)
  26. if self.is_train and config.fc_kernel_regularizer_scale > 0:
  27. self.fc_kernel_regularizer = layers.l2_regularizer(
  28. scale = config.fc_kernel_regularizer_scale)
  29. else:
  30. self.fc_kernel_regularizer = None
  31. if self.is_train and config.fc_activity_regularizer_scale > 0:
  32. self.fc_activity_regularizer = layers.l1_regularizer(
  33. scale = config.fc_activity_regularizer_scale)
  34. else:
  35. self.fc_activity_regularizer = None
  36. def conv2d(self,
  37. inputs,
  38. filters,
  39. kernel_size = (3, 3),
  40. strides = (1, 1),
  41. activation = tf.nn.relu,
  42. use_bias = True,
  43. name = None):
  44. """ 2D Convolution layer. """
  45. if activation is not None:
  46. activity_regularizer = self.conv_activity_regularizer
  47. else:
  48. activity_regularizer = None
  49. return tf.layers.conv2d(
  50. inputs = inputs,
  51. filters = filters,
  52. kernel_size = kernel_size,
  53. strides = strides,
  54. padding='same',
  55. activation = activation,
  56. use_bias = use_bias,
  57. trainable = self.train_cnn,
  58. kernel_initializer = self.conv_kernel_initializer,
  59. kernel_regularizer = self.conv_kernel_regularizer,
  60. activity_regularizer = activity_regularizer,
  61. name = name)
  62. def max_pool2d(self,
  63. inputs,
  64. pool_size = (2, 2),
  65. strides = (2, 2),
  66. name = None):
  67. """ 2D Pooling layer. """
  68. return tf.layers.max_pooling2d(
  69. inputs = inputs,
  70. pool_size = pool_size,
  71. strides = strides,
  72. padding='same',
  73. name = name)
  74. def dense(self,
  75. inputs,
  76. units,
  77. activation = tf.tanh,
  78. use_bias = True,
  79. name = None):
  80. """ Fully-connected layer. """
  81. if activation is not None:
  82. activity_regularizer = self.fc_activity_regularizer
  83. else:
  84. activity_regularizer = None
  85. return tf.layers.dense(
  86. inputs = inputs,
  87. units = units,
  88. activation = activation,
  89. use_bias = use_bias,
  90. trainable = self.is_train,
  91. kernel_initializer = self.fc_kernel_initializer,
  92. kernel_regularizer = self.fc_kernel_regularizer,
  93. activity_regularizer = activity_regularizer,
  94. name = name)
  95. def dropout(self,
  96. inputs,
  97. name = None):
  98. """ Dropout layer. """
  99. return tf.layers.dropout(
  100. inputs = inputs,
  101. rate = self.config.fc_drop_rate,
  102. training = self.is_train)
  103. def batch_norm(self,
  104. inputs,
  105. name = None):
  106. """ Batch normalization layer. """
  107. return tf.layers.batch_normalization(
  108. inputs = inputs,
  109. training = self.train_cnn,
  110. trainable = self.train_cnn,
  111. name = name
  112. )
  113. def gru(self):
  114. """ GRU layer. """
  115. gru = tf.nn.rnn_cell.GRUCell(
  116. num_units = self.config.num_gru_units,
  117. kernel_initializer = self.fc_kernel_initializer)
  118. if self.is_train:
  119. gru = tf.nn.rnn_cell.DropoutWrapper(
  120. gru,
  121. input_keep_prob = 1.0 - self.config.gru_drop_rate,
  122. output_keep_prob = 1.0 - self.config.gru_drop_rate,
  123. state_keep_prob = 1.0 - self.config.gru_drop_rate)
  124. return gru


  1. import os
  2. import numpy as np
  3. import pandas as pd
  4. from nltk.tokenize import word_tokenize
  5. class Vocabulary(object):
  6. def __init__(self, save_file = None):
  7. self.words = []
  8. self.word2idx = {}
  9. self.size = 0
  10. self.word_counts = {}
  11. self.word_frequencies = []
  12. if save_file is not None:
  13. self.load(save_file)
  14. else:
  15. self.add_words(["<unknown>"])
  16. def add_words(self, words):
  17. """ Add new words to the vocabulary. """
  18. for w in words:
  19. if w not in self.word2idx.keys():
  20. self.words.append(w)
  21. self.word2idx[w] = self.size
  22. self.size += 1
  23. self.word_counts[w] = self.word_counts.get(w, 0) + 1
  24. def compute_frequency(self):
  25. """ Compute the frequency of each word. """
  26. self.word_frequencies = []
  27. for w in self.words:
  28. self.word_frequencies.append(self.word_counts[w])
  29. self.word_frequencies = np.array(self.word_frequencies, np.float32)
  30. self.word_frequencies /= np.sum(self.word_frequencies)
  31. self.word_frequencies = np.log(self.word_frequencies)
  32. self.word_frequencies -= np.max(self.word_frequencies)
  33. def word_to_idx(self, word):
  34. """ Translate a word into its index. """
  35. return self.word2idx[word] if word in self.word2idx.keys() else 0
  36. def process_sentence(self, sentence):
  37. """ Tokenize a sentence, and translate each token into its index
  38. in the vocabulary. """
  39. words = word_tokenize(sentence.lower())
  40. word_idxs = [self.word_to_idx(w) for w in words]
  41. return word_idxs
  42. def save(self, save_file):
  43. """ Save the vocabulary to a file. """
  44. data = pd.DataFrame({'word': self.words,
  45. 'index': list(range(self.size)),
  46. 'frequency': self.word_frequencies})
  47. data.to_csv(save_file)
  48. def load(self, save_file):
  49. """ Load the vocabulary from a file. """
  50. assert os.path.exists(save_file)
  51. data = pd.read_csv(save_file)
  52. self.words = data['word'].values
  53. self.size = len(self.words)
  54. self.word2idx = {self.words[i]:i for i in range(self.size)}
  55. self.word_frequencies = data['frequency'].values



  1. __author__ = 'aagrawal'
  2. __version__ = '0.9'
  3. # Interface for accessing the VQA dataset.
  4. # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
  5. # (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
  6. # The following functions are defined:
  7. # VQA - VQA class that loads VQA annotation file and prepares data structures.
  8. # getQuesIds - Get question ids that satisfy given filter conditions.
  9. # getImgIds - Get image ids that satisfy given filter conditions.
  10. # loadQA - Load questions and answers with the specified question ids.
  11. # showQA - Display the specified questions and answers.
  12. # loadRes - Load result file and create result object.
  13. # Help on each function can be accessed by: "help(COCO.function)"
  14. import json
  15. import datetime
  16. import copy
  17. from tqdm import tqdm
  18. from nltk.tokenize import word_tokenize
  19. class VQA:
  20. def __init__(self, annotation_file=None, question_file=None):
  21. """
  22. Constructor of VQA helper class for reading and visualizing questions and answers.
  23. :param annotation_file (str): location of VQA annotation file
  24. :return:
  25. """
  26. # load dataset
  27. self.dataset = {}
  28. self.questions = {}
  29. self.qa = {}
  30. self.qqa = {}
  31. self.imgToQA = {}
  32. self.max_ques_len = 0
  33. if not annotation_file == None and not question_file == None:
  34. print('loading VQA annotations and questions into memory...')
  35. time_t = datetime.datetime.utcnow()
  36. dataset = json.load(open(annotation_file, 'r'))
  37. questions = json.load(open(question_file, 'r'))
  38. print(datetime.datetime.utcnow() - time_t)
  39. self.dataset = dataset
  40. self.questions = questions
  41. self.process_dataset()
  42. self.createIndex()
  43. def createIndex(self):
  44. # create index
  45. print('creating index...')
  46. imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']}
  47. qa = {ann['question_id']: [] for ann in self.dataset['annotations']}
  48. qqa = {ann['question_id']: [] for ann in self.dataset['annotations']}
  49. max_ques_len = 0
  50. for ann in self.dataset['annotations']:
  51. imgToQA[ann['image_id']] += [ann]
  52. qa[ann['question_id']] = ann
  53. # print(qa)
  54. for ques in self.questions['questions']:
  55. qqa[ques['question_id']] = ques
  56. max_ques_len = max(max_ques_len, len(word_tokenize(ques['question'])))
  57. # create class members
  58. self.qa = qa
  59. self.qqa = qqa
  60. self.imgToQA = imgToQA
  61. self.max_ques_len = max_ques_len
  62. # print("11111111111")
  63. def info(self):
  64. """
  65. Print information about the VQA annotation file.
  66. :return:
  67. """
  68. for key, value in list(self.datset['info'].items()):
  69. print('%s: %s'%(key, value))
  70. def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
  71. """
  72. Get question ids that satisfy given filter conditions. default skips that filter
  73. :param imgIds (int array) : get question ids for given imgs
  74. quesTypes (str array) : get question ids for given question types
  75. ansTypes (str array) : get question ids for given answer types
  76. :return: ids (int array) : integer array of question ids
  77. """
  78. imgIds = imgIds if type(imgIds) == list else [imgIds]
  79. quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
  80. ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
  81. if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
  82. anns = self.dataset['annotations']
  83. else:
  84. if not len(imgIds) == 0:
  85. anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],[])
  86. else:
  87. anns = self.dataset['annotations']
  88. anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
  89. anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
  90. ids = [ann['question_id'] for ann in anns]
  91. return ids
  92. def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
  93. """
  94. Get image ids that satisfy given filter conditions. default skips that filter
  95. :param quesIds (int array) : get image ids for given question ids
  96. quesTypes (str array) : get image ids for given question types
  97. ansTypes (str array) : get image ids for given answer types
  98. :return: ids (int array) : integer array of image ids
  99. """
  100. quesIds = quesIds if type(quesIds) == list else [quesIds]
  101. quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
  102. ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
  103. if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
  104. anns = self.dataset['annotations']
  105. else:
  106. if not len(quesIds) == 0:
  107. anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa],[])
  108. else:
  109. anns = self.dataset['annotations']
  110. anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
  111. anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
  112. ids = [ann['image_id'] for ann in anns]
  113. return ids
  114. def loadQA(self, ids=[]):
  115. """
  116. Load questions and answers with the specified question ids.
  117. :param ids (int array) : integer ids specifying question ids
  118. :return: qa (object array) : loaded qa objects
  119. """
  120. if type(ids) == list:
  121. return [self.qa[id] for id in ids]
  122. elif type(ids) == int:
  123. return [self.qa[ids]]
  124. def showQA(self, anns):
  125. """
  126. Display the specified annotations.
  127. :param anns (array of object): annotations to display
  128. :return: None
  129. """
  130. if len(anns) == 0:
  131. return 0
  132. for ann in anns:
  133. quesId = ann['question_id']
  134. print("Question: %s" %(self.qqa[quesId]['question']))
  135. for ans in ann['answers']:
  136. print("Answer %d: %s" %(ans['answer_id'], ans['answer']))
  137. def loadRes(self, resFile, quesFile):
  138. """
  139. Load result file and return a result object.
  140. :param resFile (str) : file name of result file
  141. :return: res (obj) : result api object
  142. """
  143. res = VQA()
  144. res.questions = json.load(open(quesFile))
  145. res.dataset['info'] = copy.deepcopy(self.questions['info'])
  146. res.dataset['task_type'] = copy.deepcopy(self.questions['task_type'])
  147. res.dataset['data_type'] = copy.deepcopy(self.questions['data_type'])
  148. res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype'])
  149. res.dataset['license'] = copy.deepcopy(self.questions['license'])
  150. print('Loading and preparing results... ')
  151. time_t = datetime.datetime.utcnow()
  152. anns = json.load(open(resFile))
  153. assert type(anns) == list, 'results is not an array of objects'
  154. annsQuesIds = [ann['question_id'] for ann in anns]
  155. assert set(annsQuesIds) == set(self.getQuesIds()), \
  156. 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
  157. for ann in anns:
  158. quesId = ann['question_id']
  159. if res.dataset['task_type'] == 'Multiple Choice':
  160. assert ann['answer'] in self.qqa[quesId]['multiple_choices'], 'predicted answer is not one of the multiple choices'
  161. qaAnn = self.qa[quesId]
  162. ann['image_id'] = qaAnn['image_id']
  163. ann['question_type'] = qaAnn['question_type']
  164. ann['answer_type'] = qaAnn['answer_type']
  165. print('DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds()))
  166. res.dataset['annotations'] = anns
  167. res.createIndex()
  168. return res
  169. def process_dataset(self):
  170. for ann in self.dataset['annotations']:
  171. count = {}
  172. for ans in ann['answers']:
  173. ans['answer'] = ans['answer'].lower() # 将每个答案中的大写变成小写
  174. count[ans['answer']] = count.get(ans['answer'], 0) + 1 # 统计每个答案出现的数量
  175. sorted_ans = sorted(list(count.items()),
  176. key=lambda x: x[1],
  177. reverse=True) # 按答案出现次数排序
  178. best_ans, best_ans_count = sorted_ans[0] # 记录出现次数最多的答案和次数
  179. ann['best_answer'] = best_ans
  180. ann['best_answer_count'] = best_ans_count
  181. for ques in self.questions['questions']:
  182. q = ques['question']
  183. q = q.replace('?', '') # 用空格将问号取代掉
  184. q = q.lower()
  185. ques['question'] = q
  186. def filter_by_ques_len(self, max_ques_len):
  187. print("Filtering the questions by length...")
  188. keep_ques = {}
  189. for ques in tqdm(self.questions['questions']):
  190. if len(word_tokenize(ques['question'])) <= max_ques_len:
  191. keep_ques[ques['question_id']] = \
  192. keep_ques.get(ques['question_id'], 0) + 1
  193. self.dataset['annotations'] = \
  194. [ann for ann in self.dataset['annotations'] \
  195. if keep_ques.get(ann['question_id'],0)>0]
  196. self.questions['questions'] = \
  197. [ques for ques in self.questions['questions'] \
  198. if keep_ques.get(ques['question_id'],0)>0]
  199. self.createIndex()
  200. def filter_by_ans_len(self, max_ans_len, min_freq=5):
  201. print("Filtering the answers by length...")
  202. keep_ques = {}
  203. for ann in tqdm(self.dataset['annotations']):
  204. if len(word_tokenize(ann['best_answer'])) <= max_ans_len \
  205. and ann['best_answer_count']>=min_freq:
  206. keep_ques[ann['question_id']] = \
  207. keep_ques.get(ann['question_id'], 0) + 1
  208. self.dataset['annotations'] = \
  209. [ann for ann in self.dataset['annotations'] \
  210. if keep_ques.get(ann['question_id'],0)>0]
  211. self.questions['questions'] = \
  212. [ques for ques in self.questions['questions'] \
  213. if keep_ques.get(ques['question_id'],0)>0]
  214. self.createIndex()


  1. import sys
  2. import re
  3. from tqdm import tqdm
  4. class VQAEval:
  5. def __init__(self, vqa, vqaRes, n=2):
  6. self.n = n
  7. self.accuracy = {}
  8. self.evalQA = {}
  9. self.evalQuesType = {}
  10. self.evalAnsType = {}
  11. self.vqa = vqa
  12. self.vqaRes = vqaRes
  13. self.params = {'question_id': vqa.getQuesIds()}
  14. self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", \
  15. "couldn'tve": "couldn’t’ve", "couldnt’ve": "couldn’t’ve", "didnt": "didn’t", "doesnt": "doesn’t", "dont": "don’t", "hadnt": "hadn’t", \
  16. "hadnt’ve": "hadn’t’ve", "hadn'tve": "hadn’t’ve", "hasnt": "hasn’t", "havent": "haven’t", "hed": "he’d", "hed’ve": "he’d’ve", \
  17. "he’dve": "he’d’ve", "hes": "he’s", "howd": "how’d", "howll": "how’ll", "hows": "how’s", "Id’ve": "I’d’ve", "I’dve": "I’d’ve", \
  18. "Im": "I’m", "Ive": "I’ve", "isnt": "isn’t", "itd": "it’d", "itd’ve": "it’d’ve", "it’dve": "it’d’ve", "itll": "it’ll", "let’s": "let’s", \
  19. "maam": "ma’am", "mightnt": "mightn’t", "mightnt’ve": "mightn’t’ve", "mightn’tve": "mightn’t’ve", "mightve": "might’ve", \
  20. "mustnt": "mustn’t", "mustve": "must’ve", "neednt": "needn’t", "notve": "not’ve", "oclock": "o’clock", "oughtnt": "oughtn’t", \
  21. "ow’s’at": "’ow’s’at", "’ows’at": "’ow’s’at", "’ow’sat": "’ow’s’at", "shant": "shan’t", "shed’ve": "she’d’ve", "she’dve": "she’d’ve", \
  22. "she’s": "she’s", "shouldve": "should’ve", "shouldnt": "shouldn’t", "shouldnt’ve": "shouldn’t’ve", "shouldn’tve": "shouldn’t’ve", \
  23. "somebody’d": "somebodyd", "somebodyd’ve": "somebody’d’ve", "somebody’dve": "somebody’d’ve", "somebodyll": "somebody’ll", \
  24. "somebodys": "somebody’s", "someoned": "someone’d", "someoned’ve": "someone’d’ve", "someone’dve": "someone’d’ve", \
  25. "someonell": "someone’ll", "someones": "someone’s", "somethingd": "something’d", "somethingd’ve": "something’d’ve", \
  26. "something’dve": "something’d’ve", "somethingll": "something’ll", "thats": "that’s", "thered": "there’d", "thered’ve": "there’d’ve", \
  27. "there’dve": "there’d’ve", "therere": "there’re", "theres": "there’s", "theyd": "they’d", "theyd’ve": "they’d’ve", \
  28. "they’dve": "they’d’ve", "theyll": "they’ll", "theyre": "they’re", "theyve": "they’ve", "twas": "’twas", "wasnt": "wasn’t", \
  29. "wed’ve": "we’d’ve", "we’dve": "we’d’ve", "weve": "we've", "werent": "weren’t", "whatll": "what’ll", "whatre": "what’re", \
  30. "whats": "what’s", "whatve": "what’ve", "whens": "when’s", "whered": "where’d", "wheres": "where's", "whereve": "where’ve", \
  31. "whod": "who’d", "whod’ve": "who’d’ve", "who’dve": "who’d’ve", "wholl": "who’ll", "whos": "who’s", "whove": "who've", "whyll": "why’ll", \
  32. "whyre": "why’re", "whys": "why’s", "wont": "won’t", "wouldve": "would’ve", "wouldnt": "wouldn’t", "wouldnt’ve": "wouldn’t’ve", \
  33. "wouldn’tve": "wouldn’t’ve", "yall": "y’all", "yall’ll": "y’all’ll", "y’allll": "y’all’ll", "yall’d’ve": "y’all’d’ve", \
  34. "y’alld’ve": "y’all’d’ve", "y’all’dve": "y’all’d’ve", "youd": "you’d", "youd’ve": "you’d’ve", "you’dve": "you’d’ve", \
  35. "youll": "you’ll", "youre": "you’re", "youve": "you’ve"}
  36. self.manualMap = {'none': '0',
  37. 'zero': '0',
  38. 'one': '1',
  39. 'two': '2',
  40. 'three': '3',
  41. 'four': '4',
  42. 'five': '5',
  43. 'six': '6',
  44. 'seven': '7',
  45. 'eight': '8',
  46. 'nine': '9',
  47. 'ten': '10'}
  48. self.articles = ['a', 'an', 'the']
  49. self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
  50. self.commaStrip = re.compile("(\d)(\,)(\d)")
  51. self.punct = [';', r"/", '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-', '>', '<', '@', '`', ',', '?', '!']
  52. def evaluate(self, quesIds=None):
  53. if quesIds == None:
  54. quesIds = [quesId for quesId in self.params['question_id']]
  55. gts = {}
  56. res = {}
  57. for quesId in quesIds:
  58. gts[quesId] = self.vqa.qa[quesId]
  59. res[quesId] = self.vqaRes.qa[quesId]
  60. # =================================================
  61. # Compute accuracy
  62. # =================================================
  63. accQA = []
  64. accQuesType = {}
  65. accAnsType = {}
  66. print("computing accuracy")
  67. step = 0
  68. for quesId in tqdm(quesIds):
  69. resAns = res[quesId]['answer']
  70. resAns = resAns.replace('\n', ' ')
  71. resAns = resAns.replace('\t', ' ')
  72. resAns = resAns.strip()
  73. resAns = self.processPunctuation(resAns)
  74. resAns = self.processDigitArticle(resAns)
  75. gtAcc = []
  76. gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']]
  77. if len(set(gtAnswers)) > 1:
  78. for ansDic in gts[quesId]['answers']:
  79. ansDic['answer'] = self.processPunctuation(ansDic['answer'])
  80. for gtAnsDatum in gts[quesId]['answers']:
  81. otherGTAns = [item for item in gts[quesId]['answers'] if item!=gtAnsDatum]
  82. matchingAns = [item for item in otherGTAns if item['answer']==resAns]
  83. acc = min(1, float(len(matchingAns))/3)
  84. gtAcc.append(acc)
  85. quesType = gts[quesId]['question_type']
  86. ansType = gts[quesId]['answer_type']
  87. avgGTAcc = float(sum(gtAcc))/len(gtAcc)
  88. accQA.append(avgGTAcc)
  89. if quesType not in accQuesType:
  90. accQuesType[quesType] = []
  91. accQuesType[quesType].append(avgGTAcc)
  92. if ansType not in accAnsType:
  93. accAnsType[ansType] = []
  94. accAnsType[ansType].append(avgGTAcc)
  95. self.setEvalQA(quesId, avgGTAcc)
  96. self.setEvalQuesType(quesId, quesType, avgGTAcc)
  97. self.setEvalAnsType(quesId, ansType, avgGTAcc)
  98. step = step + 1
  99. self.setAccuracy(accQA, accQuesType, accAnsType)
  100. print("Done computing accuracy")
  101. self.showAccuracy(accQA, accQuesType, accAnsType)
  102. def processPunctuation(self, inText):
  103. outText = inText
  104. for p in self.punct:
  105. if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None):
  106. outText = outText.replace(p, '')
  107. else:
  108. outText = outText.replace(p, ' ')
  109. outText = self.periodStrip.sub("", outText, re.UNICODE)
  110. return outText
  111. def processDigitArticle(self, inText):
  112. outText = []
  113. tempText = inText.lower().split()
  114. for word in tempText:
  115. word = self.manualMap.setdefault(word, word)
  116. if word not in self.articles:
  117. outText.append(word)
  118. else:
  119. pass
  120. for wordId, word in enumerate(outText):
  121. if word in self.contractions:
  122. outText[wordId] = self.contractions[word]
  123. outText = ' '.join(outText)
  124. return outText
  125. def setAccuracy(self, accQA, accQuesType, accAnsType):
  126. self.accuracy['overall'] = round(100*float(sum(accQA))/len(accQA), self.n)
  127. self.accuracy['perQuestionType'] = {quesType: round(100*float(sum(accQuesType[quesType]))/len(accQuesType[quesType]), self.n) for quesType in accQuesType}
  128. self.accuracy['perAnswerType'] = {ansType: round(100*float(sum(accAnsType[ansType]))/len(accAnsType[ansType]), self.n) for ansType in accAnsType}
  129. def showAccuracy(self, accQA, accQuesType, accAnsType):
  130. print("Overall accurarcy = %f" %(self.accuracy['overall']))
  131. print("Accuracy per question type:")
  132. for quesType in accQuesType:
  133. print("quesType: %s accuracy = %f" %(quesType, self.accuracy['perQuestionType'][quesType]))
  134. print("Accuracy per answer type:")
  135. for ansType in accAnsType:
  136. print("ansType: %s accuracy = %f" %(ansType, self.accuracy['perAnswerType'][ansType]))
  137. def setEvalQA(self, quesId, acc):
  138. self.evalQA[quesId] = round(100*acc, self.n)
  139. def setEvalQuesType(self, quesId, quesType, acc):
  140. if quesType not in self.evalQuesType:
  141. self.evalQuesType[quesType] = {}
  142. self.evalQuesType[quesType][quesId] = round(100*acc, self.n)
  143. def setEvalAnsType(self, quesId, ansType, acc):
  144. if ansType not in self.evalAnsType:
  145. self.evalAnsType[ansType] = {}
  146. self.evalAnsType[ansType][quesId] = round(100*acc, self.n)
  147. def updateProgress(self, progress):
  148. barLength = 20
  149. status = ""
  150. if isinstance(progress, int):
  151. progress = float(progress)
  152. if not isinstance(progress, float):
  153. progress = 0
  154. status = "error: progress var must be float\r\n"
  155. if progress < 0:
  156. progress = 0
  157. status = "Halt...\r\n"
  158. if progress >= 1:
  159. progress = 1
  160. status = "Done...\r\n"
  161. block = int(round(barLength*progress))
  162. text = "\rFinshed Percent: [{0}] {1}% {2}".format( "#"*block + "-"*(barLength-block), int(progress*100), status)
  163. print(text)



  1. import os
  2. import numpy as np
  3. import pandas as pd
  4. from tqdm import tqdm
  5. from nltk.tokenize import word_tokenize # 如果word_tokenize报错了就重新安一下
  6. from utils.vocabulary import Vocabulary
  7. from utils.vqa.vqa import VQA
  8. # 建立数据类
  9. class DataSet(object):
  10. def __init__(self,
  11. image_files,
  12. question_word_idxs,
  13. question_lens,
  14. question_ids,
  15. batch_size,
  16. answer_idxs = None,
  17. is_train = False,
  18. shuffle = False):
  19. self.image_files = np.array(image_files)
  20. self.question_word_idxs = np.array(question_word_idxs)
  21. self.question_lens = np.array(question_lens)
  22. self.question_ids = np.array(question_ids)
  23. self.answer_idxs = np.array(answer_idxs)
  24. self.batch_size = batch_size
  25. self.is_train = is_train
  26. self.shuffle = shuffle
  27. self.setup() # 初始化传完参数建立数据集
  28. def setup(self):
  29. """ Setup the dataset. """
  30. self.count = len(self.question_ids) # 问题计数
  31. self.num_batches = int(np.ceil(self.count * 1.0 / self.batch_size))
  32. self.fake_count = self.num_batches * self.batch_size - self.count
  33. self.idxs = list(range(self.count))
  34. self.reset() # 重建
  35. def reset(self):
  36. """ Reset the dataset. """
  37. self.current_idx = 0
  38. if self.shuffle:
  39. np.random.shuffle(self.idxs)
  40. def next_batch(self):
  41. """ Fetch the next batch. """
  42. assert self.has_next_batch()
  43. if self.has_full_next_batch():
  44. start, end = self.current_idx, self.current_idx + self.batch_size
  45. current_idxs = self.idxs[start:end]
  46. else:
  47. start, end = self.current_idx, self.count
  48. current_idxs = self.idxs[start:end]
  49. current_idxs += list(np.random.choice(self.count, self.fake_count))
  50. image_files = self.image_files[current_idxs]
  51. question_word_idxs = self.question_word_idxs[current_idxs]
  52. question_lens = self.question_lens[current_idxs]
  53. if self.is_train:
  54. answer_idxs = self.answer_idxs[current_idxs]
  55. self.current_idx += self.batch_size
  56. return image_files, question_word_idxs, question_lens, answer_idxs
  57. else:
  58. self.current_idx += self.batch_size
  59. return image_files, question_word_idxs, question_lens
  60. def has_next_batch(self):
  61. """ Determine whether there is a batch left. """
  62. return self.current_idx < self.count
  63. def has_full_next_batch(self):
  64. """ Determine whether there is a full batch left. """
  65. return self.current_idx + self.batch_size <= self.count
  66. # 准备训练数据
  67. def prepare_train_data(config):
  68. """ Prepare the data for training the model. """
  69. vqa = VQA(config.train_answer_file, config.train_question_file)
  70. vqa.filter_by_ques_len(config.max_question_length)
  71. vqa.filter_by_ans_len(1)
  72. print("Reading the questions and answers...")
  73. annotations = process_vqa(vqa,
  74. 'COCO_train2014',
  75. config.train_image_dir,
  76. config.temp_train_annotation_file)
  77. image_files = annotations['image_file'].values
  78. questions = annotations['question'].values
  79. question_ids = annotations['question_id'].values
  80. answers = annotations['answer'].values
  81. print("Questions and answers read.")
  82. print("Number of questions = %d" %(len(question_ids)))
  83. print("Building the vocabulary...")
  84. vocabulary = Vocabulary()
  85. if not os.path.exists(config.vocabulary_file):
  86. for question in tqdm(questions):
  87. vocabulary.add_words(word_tokenize(question))
  88. for answer in tqdm(answers):
  89. vocabulary.add_words(word_tokenize(answer))
  90. vocabulary.compute_frequency()
  91. vocabulary.save(config.vocabulary_file)
  92. else:
  93. vocabulary.load(config.vocabulary_file)
  94. print("Vocabulary built.")
  95. print("Number of words = %d" %(vocabulary.size))
  96. config.vocabulary_size = vocabulary.size
  97. print("Processing the questions and answers...")
  98. if not os.path.exists(config.temp_train_data_file):
  99. question_word_idxs, question_lens = process_questions(questions,
  100. vocabulary,
  101. config)
  102. answer_idxs = process_answers(answers, vocabulary)
  103. data = {'question_word_idxs': question_word_idxs,
  104. 'question_lens': question_lens,
  105. 'answer_idxs': answer_idxs}
  106. np.save(config.temp_train_data_file, data)
  107. else:
  108. data = np.load(config.temp_train_data_file).item()
  109. question_word_idxs = data['question_word_idxs']
  110. question_lens = data['question_lens']
  111. answer_idxs = data['answer_idxs']
  112. print("Questions and answers processed.")
  113. print("Building the dataset...")
  114. dataset = DataSet(image_files,
  115. question_word_idxs,
  116. question_lens,
  117. question_ids,
  118. config.batch_size,
  119. answer_idxs,
  120. True,
  121. True)
  122. print("Dataset built.")
  123. return dataset, config
  124. # 准备评价数据
  125. def prepare_eval_data(config):
  126. """ Prepare the data for evaluating the model. """
  127. vqa = VQA(config.eval_answer_file, config.eval_question_file)
  128. vqa.filter_by_ques_len(config.max_question_length)
  129. vqa.filter_by_ans_len(1)
  130. print("Reading the questions...")
  131. annotations = process_vqa(vqa,
  132. 'COCO_val2014',
  133. config.eval_image_dir,
  134. config.temp_eval_annotation_file)
  135. image_files = annotations['image_file'].values
  136. questions = annotations['question'].values
  137. question_ids = annotations['question_id'].values
  138. print("Questions read.")
  139. print("Number of questions = %d" %(len(question_ids)))
  140. print("Building the vocabulary...")
  141. if os.path.exists(config.vocabulary_file):
  142. vocabulary = Vocabulary(config.vocabulary_file)
  143. else:
  144. vocabulary = build_vocabulary(config)
  145. print("Vocabulary built.")
  146. print("Number of words = %d" %(vocabulary.size))
  147. config.vocabulary_size = vocabulary.size
  148. print("Processing the questions...")
  149. if not os.path.exists(config.temp_eval_data_file):
  150. question_word_idxs, question_lens = process_questions(questions,
  151. vocabulary,
  152. config)
  153. data = {'question_word_idxs': question_word_idxs,
  154. 'question_lens': question_lens}
  155. np.save(config.temp_eval_data_file, data)
  156. else:
  157. data = np.load(config.temp_eval_data_file).item()
  158. question_word_idxs = data['question_word_idxs']
  159. question_lens = data['question_lens']
  160. print("Questions processed.")
  161. print("Building the dataset...")
  162. dataset = DataSet(image_files,
  163. question_word_idxs,
  164. question_lens,
  165. question_ids,
  166. config.batch_size)
  167. print("Dataset built.")
  168. return vqa, dataset, vocabulary, config
  169. # 准备测试数据
  170. def prepare_test_data(config):
  171. """ Prepare the data for testing the model. """
  172. print("Reading the questions...")
  173. annotations = pd.read_csv(config.test_question_file)
  174. images = annotations['image'].unique()
  175. image_files = [os.path.join(config.test_image_dir, f) for f in images]
  176. temp = pd.DataFrame({'image': images, 'image_file': image_files})
  177. annotations = pd.merge(annotations, temp)
  178. annotations.to_csv(config.temp_test_info_file)
  179. image_files = annotations['image_file'].values
  180. questions = annotations['question'].values
  181. question_ids = annotations['question_id'].values
  182. print("Questions read.")
  183. print("Number of questions = %d" %(len(question_ids)))
  184. print("Building the vocabulary...")
  185. if os.path.exists(config.vocabulary_file):
  186. vocabulary = Vocabulary(config.vocabulary_file)
  187. else:
  188. vocabulary = build_vocabulary(config)
  189. print("Vocabulary built.")
  190. print("Number of words = %d" %(vocabulary.size))
  191. config.vocabulary_size = vocabulary.size
  192. print("Processing the questions...")
  193. question_word_idxs, question_lens = process_questions(questions,
  194. vocabulary,
  195. config)
  196. print("Questions processed.")
  197. print("Building the dataset...")
  198. dataset = DataSet(image_files,
  199. question_word_idxs,
  200. question_lens,
  201. question_ids,
  202. config.batch_size)
  203. print("Dataset built.")
  204. return dataset, vocabulary, config
  205. # 处理vqa
  206. def process_vqa(vqa, label, image_dir, annotation_file):
  207. """ Build a temporary annotation file for training or evaluation. """
  208. question_ids = list(vqa.qa.keys())
  209. image_ids = [vqa.qa[k]['image_id'] for k in question_ids]
  210. image_files = [os.path.join(image_dir, label+"_000000"+("%06d" %k)+".jpg")
  211. for k in image_ids]
  212. questions = [vqa.qqa[k]['question'] for k in question_ids]
  213. answers = [vqa.qa[k]['best_answer'] for k in question_ids]
  214. annotations = pd.DataFrame({'question_id': question_ids,
  215. 'image_id': image_ids,
  216. 'image_file': image_files,
  217. 'question': questions,
  218. 'answer': answers})
  219. annotations.to_csv(annotation_file)
  220. return annotations
  221. # 处理问题
  222. def process_questions(questions, vocabulary, config):
  223. """ Tokenize the questions and translate each token into its index \
  224. in the vocabulary, and get the number of tokens. """
  225. question_word_idxs = []
  226. question_lens = []
  227. for q in tqdm(questions):
  228. word_idxs = vocabulary.process_sentence(q)
  229. current_length = len(word_idxs)
  230. current_word_idxs = np.zeros((config.max_question_length), np.int32)
  231. current_word_idxs[:current_length] = np.array(word_idxs)
  232. question_word_idxs.append(current_word_idxs)
  233. question_lens.append(current_length)
  234. return np.array(question_word_idxs), np.array(question_lens)
  235. # 处理答案
  236. def process_answers(answers, vocabulary):
  237. """ Translate the answers into their indicies in the vocabulary. """
  238. answer_idxs = []
  239. for answer in tqdm(answers):
  240. answer_idxs.append(vocabulary.word_to_idx(word_tokenize(answer)[0]))
  241. return np.array(answer_idxs)
  242. # 建立词汇表
  243. def build_vocabulary(config):
  244. """ Build the vocabulary from the training data and save it to a file. """
  245. vqa = VQA(config.train_answer_file, config.train_question_file)
  246. vqa.filter_by_ques_len(config.max_question_length)
  247. vqa.filter_by_ans_len(1)
  248. question_ids = list(vqa.qa.keys())
  249. questions = [vqa.qqa[k]['question'] for k in question_ids]
  250. answers = [vqa.qa[k]['best_answer'] for k in question_ids]
  251. vocabulary = Vocabulary()
  252. for question in tqdm(questions):
  253. vocabulary.add_words(word_tokenize(question))
  254. for answer in tqdm(answers):
  255. vocabulary.add_words(word_tokenize(answer))
  256. vocabulary.compute_frequency()
  257. vocabulary.save(config.vocabulary_file)
  258. return vocabulary



  1. import os
  2. import numpy as np
  3. import pandas as pd
  4. import tensorflow as tf
  5. import cv2
  6. import matplotlib.pyplot as plt
  7. import pickle as pickle
  8. from tqdm import tqdm
  9. import json
  10. import copy
  11. import string
  12. from utils.nn import NN
  13. from utils.misc import ImageLoader
  14. from utils.vqa.vqa import VQA
  15. from utils.vqa.vqaEval import VQAEval
  16. class BaseModel(object):
  17. def __init__(self, config):
  18. self.config = config
  19. self.is_train = True if config.phase == 'train' else False
  20. self.train_cnn = self.is_train and config.train_cnn
  21. self.image_loader = ImageLoader('./utils/ilsvrc_2012_mean.npy')
  22. self.image_shape = [224, 224, 3]
  23. self.global_step = tf.Variable(0,
  24. name = 'global_step',
  25. trainable = False)
  26. self.nn = NN(config)
  27. self.build()
  28. def build(self):
  29. raise NotImplementedError()
  30. def get_feed_dict(self, batch):
  31. raise NotImplementedError()
  32. def train(self, sess, train_data):
  33. """ Train the model using the VQA training data. """
  34. print("Training the model...")
  35. config = self.config
  36. if not os.path.exists(config.summary_dir):
  37. os.mkdir(config.summary_dir)
  38. train_writer = tf.summary.FileWriter(config.summary_dir, sess.graph)
  39. for epoch_no in tqdm(list(range(config.num_epochs)), desc='epoch'):
  40. for idx in tqdm(list(range(train_data.num_batches)), desc='batch'):
  41. batch = train_data.next_batch()
  42. feed_dict = self.get_feed_dict(batch)
  43. _, summary, global_step = sess.run([self.opt_op,
  44. self.summary,
  45. self.global_step],
  46. feed_dict = feed_dict)
  47. if (global_step + 1) % config.save_period == 0:
  48. self.save()
  49. train_writer.add_summary(summary, global_step)
  50. train_data.reset()
  51. print("Training complete.")
  52. def eval(self, sess, eval_gt_vqa, eval_data, vocabulary):
  53. """ Evaluate the model using the VQA validation data. """
  54. print("Evaluating the model...")
  55. config = self.config
  56. if not os.path.exists(config.eval_result_dir):
  57. os.mkdir(config.eval_result_dir)
  58. question_ids = eval_data.question_ids
  59. answers = []
  60. # Compute the answers to the questions
  61. idx = 0
  62. for k in tqdm(list(range(eval_data.num_batches))):
  63. batch = eval_data.next_batch()
  64. image_files, question_word_idxs, question_lens = batch
  65. feed_dict = self.get_feed_dict(batch)
  66. result = sess.run(self.prediction, feed_dict = feed_dict)
  67. fake_cnt = 0 if k<eval_data.num_batches-1 \
  68. else eval_data.fake_count
  69. for l in range(eval_data.batch_size-fake_cnt):
  70. answer = vocabulary.words[result[l]]
  71. answers.append(answer)
  72. # Save the result in an image file
  73. if config.save_eval_result_as_image:
  74. image_file = image_files[l]
  75. image_name = image_file.split(os.sep)[-1]
  76. image_name = os.path.splitext(image_name)[0]
  77. q_word_idxs = question_word_idxs[l]
  78. q_len = question_lens[l]
  79. q_words = [vocabulary.words[q_word_idxs[i]] \
  80. for i in range(q_len)]
  81. if q_words[-1] != '?':
  82. q_words.append('?')
  83. Q = 'Q: ' + ''.join([' '+w if not w.startswith("'") \
  84. and w not in string.punctuation \
  85. else w for w in q_words]).strip()
  86. A = 'A: ' + answer
  87. image = plt.imread(image_file)
  88. plt.imshow(image)
  89. plt.axis('off')
  90. plt.title(Q+'\n'+A)
  91. plt.savefig(image_name + '_' + str(question_ids[idx]) + '_result.jpg')
  92. idx += 1
  93. results = [{'question_id': question_id, 'answer': answer} \
  94. for question_id, answer in zip(question_ids, answers)]
  95. fp = open(config.eval_result_file, 'wb')
  96. json.dump(results, fp)
  97. fp.close()
  98. # Evaluate these answers
  99. eval_res_vqa = eval_gt_vqa.loadRes(config.eval_result_file,
  100. config.eval_question_file)
  101. scorer = VQAEval(eval_gt_vqa, eval_res_vqa)
  102. scorer.evaluate()
  103. print("Evaluation complete.")
  104. def test(self, sess, test_data, vocabulary):
  105. """ Test the model using any given images and questions. """
  106. print("Testing the model...")
  107. config = self.config
  108. if not os.path.exists(config.test_result_dir):
  109. os.mkdir(config.test_result_dir)
  110. question_ids = test_data.question_ids
  111. answers = []
  112. # Compute the answers to the questions
  113. idx = 0
  114. for k in tqdm(list(range(test_data.num_batches))):
  115. batch = test_data.next_batch()
  116. image_files, question_word_idxs, question_lens = batch
  117. feed_dict = self.get_feed_dict(batch)
  118. result = sess.run(self.prediction, feed_dict = feed_dict)
  119. fake_cnt = 0 if k < test_data.num_batches-1 \
  120. else test_data.fake_count
  121. for l in range(test_data.batch_size-fake_cnt):
  122. answer = vocabulary.words[result[l]]
  123. answers.append(answer)
  124. # Save the result in an image file
  125. image_file = image_files[l]
  126. image_name = image_file.split(os.sep)[-1]
  127. image_name = os.path.splitext(image_name)[0]
  128. q_word_idxs = question_word_idxs[l]
  129. q_len = question_lens[l]
  130. q_words = [vocabulary.words[q_word_idxs[i]] \
  131. for i in range(q_len)]
  132. if q_words[-1] != '?':
  133. q_words.append('?')
  134. Q = 'Q: ' + ''.join([' '+w if not w.startswith("'") \
  135. and w not in string.punctuation \
  136. else w for w in q_words]).strip()
  137. A = 'A: ' + answer
  138. image = plt.imread(image_file)
  139. plt.imshow(image)
  140. plt.axis('off')
  141. plt.title(Q+'\n'+A)
  142. plt.savefig(os.path.join(config.test_result_dir, \
  143. image_name + '_' + str(question_ids[idx]) \
  144. + '_result.jpg'))
  145. idx += 1
  146. # Save the answers to a file
  147. test_info = pd.read_csv(config.temp_test_info_file)
  148. results = pd.DataFrame({'question_id': question_ids,
  149. 'answer': answers})
  150. results = pd.merge(test_info, results)
  151. results.to_csv(config.test_result_file)
  152. print("Testing complete.")
  153. def save(self):
  154. """ Save the model. """
  155. config = self.config
  156. data = {v.name: v.eval() for v in tf.global_variables()}
  157. save_path = os.path.join(config.save_dir, str(self.global_step.eval()))
  158. print((" Saving the model to %s..." % (save_path+".npy")))
  159. np.save(save_path, data)
  160. info_file = open(os.path.join(config.save_dir, "config.pickle"), "wb")
  161. config_ = copy.copy(config)
  162. config_.global_step = self.global_step.eval()
  163. pickle.dump(config_, info_file)
  164. info_file.close()
  165. print("Model saved.")
  166. def load(self, sess, model_file=None):
  167. """ Load the model. """
  168. config = self.config
  169. if model_file is not None:
  170. save_path = model_file
  171. else:
  172. info_path = os.path.join(config.save_dir, "config.pickle")
  173. info_file = open(info_path, "rb")
  174. config = pickle.load(info_file)
  175. global_step = config.global_step
  176. info_file.close()
  177. save_path = os.path.join(config.save_dir,
  178. str(global_step)+".npy")
  179. print("Loading the model from %s..." %save_path)
  180. data_dict = np.load(save_path).item()
  181. count = 0
  182. for v in tqdm(tf.global_variables()):
  183. if v.name in data_dict.keys():
  184. sess.run(v.assign(data_dict[v.name]))
  185. count += 1
  186. print("%d tensors loaded." %count)
  187. def load_cnn(self, session, data_path, ignore_missing=True):
  188. """ Load a pretrained CNN model. """
  189. print("Loading the CNN from %s..." %data_path)
  190. data_dict = np.load(data_path).item()
  191. count = 0
  192. for op_name in tqdm(data_dict):
  193. with tf.variable_scope(op_name, reuse=True):
  194. for param_name, data in data_dict[op_name].iteritems():
  195. try:
  196. var = tf.get_variable(param_name)
  197. session.run(var.assign(data))
  198. count += 1
  199. except ValueError:
  200. pass
  201. print("%d tensors loaded." %count)



  1. import tensorflow as tf
  2. from utils.nn import NN
  3. class AttnGRU(object):
  4. """ Attention-based GRU (used by the Episodic Memory Module). """
  5. def __init__(self, config):
  6. self.nn = NN(config)
  7. self.num_units = config.num_gru_units
  8. def __call__(self, inputs, state, attention):
  9. with tf.variable_scope('attn_gru'):
  10. r_input = tf.concat([inputs, state], axis = 1)
  11. r_input = self.nn.dropout(r_input)
  12. r = self.nn.dense(r_input,
  13. units = self.num_units,
  14. activation = None,
  15. use_bias = False,
  16. name = 'fc1')
  17. b = tf.get_variable('fc1/bias',
  18. shape = [self.num_units],
  19. initializer = tf.constant_initializer(1.0))
  20. r = tf.nn.bias_add(r, b)
  21. r = tf.sigmoid(r)
  22. c_input = tf.concat([inputs, r*state], axis = 1)
  23. c_input = self.nn.dropout(c_input)
  24. c = self.nn.dense(c_input,
  25. units = self.num_units,
  26. activation = tf.tanh,
  27. name = 'fc2')
  28. new_state = attention * c + (1 - attention) * state
  29. return new_state
  30. class EpisodicMemory(object):
  31. """ Episodic Memory Module. """
  32. def __init__(self, config, num_facts, question, facts):
  33. self.nn = NN(config)
  34. self.num_units = config.num_gru_units
  35. self.num_facts = num_facts
  36. self.question = question
  37. self.facts = facts
  38. self.attention = config.attention
  39. if self.attention == 'gru':
  40. self.attn_gru = AttnGRU(config)
  41. def new_fact(self, memory):
  42. """ Get the context vector by using either soft attention or
  43. attention-based GRU. """
  44. fact_list = tf.unstack(self.facts, axis = 1)
  45. mixed_fact = tf.zeros_like(fact_list[0])
  46. with tf.variable_scope('attend'):
  47. attentions = self.attend(memory)
  48. if self.attention == 'gru':
  49. with tf.variable_scope('attn_gate') as scope:
  50. attentions = tf.unstack(attentions, axis = 1)
  51. for ctx, att in zip(fact_list, attentions):
  52. mixed_fact = self.attn_gru(ctx,
  53. mixed_fact,
  54. tf.expand_dims(att, 1))
  55. scope.reuse_variables()
  56. else:
  57. mixed_fact = tf.reduce_sum(self.facts*tf.expand_dims(attentions, 2),
  58. axis = 1)
  59. return mixed_fact
  60. def attend(self, memory):
  61. """ Get the attention weights. """
  62. c = self.facts
  63. q = tf.tile(tf.expand_dims(self.question, 1), [1, self.num_facts, 1])
  64. m = tf.tile(tf.expand_dims(memory, 1), [1, self.num_facts, 1])
  65. z = tf.concat([c*q, c*m, tf.abs(c-q), tf.abs(c-m)], 2)
  66. z = tf.reshape(z, [-1, 4*self.num_units])
  67. z = self.nn.dropout(z)
  68. z1 = self.nn.dense(z,
  69. units = self.num_units,
  70. activation = tf.tanh,
  71. name = 'fc1')
  72. z1 = self.nn.dropout(z1)
  73. z2 = self.nn.dense(z1,
  74. units = 1,
  75. activation = None,
  76. use_bias = False,
  77. name = 'fc2')
  78. z2 = tf.reshape(z2, [-1, self.num_facts])
  79. attentions = tf.nn.softmax(z2)
  80. return attentions



  1. import tensorflow as tf
  2. import numpy as np
  3. from base_model import BaseModel
  4. from episodic_memory import EpisodicMemory
  5. class QuestionAnswerer(BaseModel):
  6. def build(self):
  7. """ Build the model. """
  8. self.build_cnn()
  9. self.build_rnn()
  10. if self.is_train:
  11. self.build_optimizer()
  12. self.build_summary()
  13. def build_cnn(self):
  14. """ Build the CNN. """
  15. print("Building the CNN...")
  16. if self.config.cnn =='vgg16':
  17. self.build_vgg16()
  18. else:
  19. self.build_resnet50()
  20. print("CNN built.")
  21. def build_vgg16(self):
  22. """ Build the VGG16 net. """
  23. config = self.config
  24. images = tf.placeholder(
  25. dtype = tf.float32,
  26. shape = [config.batch_size] + self.image_shape)
  27. conv1_1_feats = self.nn.conv2d(images, 64, name = 'conv1_1')
  28. conv1_2_feats = self.nn.conv2d(conv1_1_feats, 64, name = 'conv1_2')
  29. pool1_feats = self.nn.max_pool2d(conv1_2_feats, name = 'pool1')
  30. conv2_1_feats = self.nn.conv2d(pool1_feats, 128, name = 'conv2_1')
  31. conv2_2_feats = self.nn.conv2d(conv2_1_feats, 128, name = 'conv2_2')
  32. pool2_feats = self.nn.max_pool2d(conv2_2_feats, name = 'pool2')
  33. conv3_1_feats = self.nn.conv2d(pool2_feats, 256, name = 'conv3_1')
  34. conv3_2_feats = self.nn.conv2d(conv3_1_feats, 256, name = 'conv3_2')
  35. conv3_3_feats = self.nn.conv2d(conv3_2_feats, 256, name = 'conv3_3')
  36. pool3_feats = self.nn.max_pool2d(conv3_3_feats, name = 'pool3')
  37. conv4_1_feats = self.nn.conv2d(pool3_feats, 512, name = 'conv4_1')
  38. conv4_2_feats = self.nn.conv2d(conv4_1_feats, 512, name = 'conv4_2')
  39. conv4_3_feats = self.nn.conv2d(conv4_2_feats, 512, name = 'conv4_3')
  40. pool4_feats = self.nn.max_pool2d(conv4_3_feats, name = 'pool4')
  41. conv5_1_feats = self.nn.conv2d(pool4_feats, 512, name = 'conv5_1')
  42. conv5_2_feats = self.nn.conv2d(conv5_1_feats, 512, name = 'conv5_2')
  43. conv5_3_feats = self.nn.conv2d(conv5_2_feats, 512, name = 'conv5_3')
  44. self.permutation = self.get_permutation(14, 14)
  45. conv5_3_feats_flat = self.flatten_feats(conv5_3_feats, 512)
  46. self.conv_feats = conv5_3_feats_flat
  47. self.conv_feat_shape = [196, 512]
  48. self.images = images
  49. def build_resnet50(self):
  50. """ Build the ResNet50. """
  51. config = self.config
  52. images = tf.placeholder(
  53. dtype = tf.float32,
  54. shape = [config.batch_size] + self.image_shape)
  55. conv1_feats = self.nn.conv2d(images,
  56. filters = 64,
  57. kernel_size = (7, 7),
  58. strides = (2, 2),
  59. activation = None,
  60. name = 'conv1')
  61. conv1_feats = self.nn.batch_norm(conv1_feats, 'bn_conv1')
  62. conv1_feats = tf.nn.relu(conv1_feats)
  63. pool1_feats = self.nn.max_pool2d(conv1_feats,
  64. pool_size = (3, 3),
  65. strides = (2, 2),
  66. name = 'pool1')
  67. res2a_feats = self.resnet_block(pool1_feats, 'res2a', 'bn2a', 64, 1)
  68. res2b_feats = self.resnet_block2(res2a_feats, 'res2b', 'bn2b', 64)
  69. res2c_feats = self.resnet_block2(res2b_feats, 'res2c', 'bn2c', 64)
  70. res3a_feats = self.resnet_block(res2c_feats, 'res3a', 'bn3a', 128)
  71. res3b_feats = self.resnet_block2(res3a_feats, 'res3b', 'bn3b', 128)
  72. res3c_feats = self.resnet_block2(res3b_feats, 'res3c', 'bn3c', 128)
  73. res3d_feats = self.resnet_block2(res3c_feats, 'res3d', 'bn3d', 128)
  74. res4a_feats = self.resnet_block(res3d_feats, 'res4a', 'bn4a', 256)
  75. res4b_feats = self.resnet_block2(res4a_feats, 'res4b', 'bn4b', 256)
  76. res4c_feats = self.resnet_block2(res4b_feats, 'res4c', 'bn4c', 256)
  77. res4d_feats = self.resnet_block2(res4c_feats, 'res4d', 'bn4d', 256)
  78. res4e_feats = self.resnet_block2(res4d_feats, 'res4e', 'bn4e', 256)
  79. res4f_feats = self.resnet_block2(res4e_feats, 'res4f', 'bn4f', 256)
  80. res5a_feats = self.resnet_block(res4f_feats, 'res5a', 'bn5a', 512)
  81. res5b_feats = self.resnet_block2(res5a_feats, 'res5b', 'bn5b', 512)
  82. res5c_feats = self.resnet_block2(res5b_feats, 'res5c', 'bn5c', 512)
  83. self.permutation = self.get_permutation(7, 7)
  84. res5c_feats_flat = self.flatten_feats(res5c_feats, 2048)
  85. self.conv_feats = res5c_feats_flat
  86. self.conv_feat_shape = [49, 2048]
  87. self.images = images
  88. def resnet_block(self, inputs, name1, name2, c, s=2):
  89. """ A basic block of ResNet. """
  90. branch1_feats = self.nn.conv2d(inputs,
  91. filters = 4*c,
  92. kernel_size = (1, 1),
  93. strides = (s, s),
  94. activation = None,
  95. use_bias = False,
  96. name = name1+'_branch1')
  97. branch1_feats = self.nn.batch_norm(branch1_feats, name2+'_branch1')
  98. branch2a_feats = self.nn.conv2d(inputs,
  99. filters = c,
  100. kernel_size = (1, 1),
  101. strides = (s, s),
  102. activation = None,
  103. use_bias = False,
  104. name = name1+'_branch2a')
  105. branch2a_feats = self.nn.batch_norm(branch2a_feats, name2+'_branch2a')
  106. branch2a_feats = tf.nn.relu(branch2a_feats)
  107. branch2b_feats = self.nn.conv2d(branch2a_feats,
  108. filters = c,
  109. kernel_size = (3, 3),
  110. strides = (1, 1),
  111. activation = None,
  112. use_bias = False,
  113. name = name1+'_branch2b')
  114. branch2b_feats = self.nn.batch_norm(branch2b_feats, name2+'_branch2b')
  115. branch2b_feats = tf.nn.relu(branch2b_feats)
  116. branch2c_feats = self.nn.conv2d(branch2b_feats,
  117. filters = 4*c,
  118. kernel_size = (1, 1),
  119. strides = (1, 1),
  120. activation = None,
  121. use_bias = False,
  122. name = name1+'_branch2c')
  123. branch2c_feats = self.nn.batch_norm(branch2c_feats, name2+'_branch2c')
  124. outputs = branch1_feats + branch2c_feats
  125. outputs = tf.nn.relu(outputs)
  126. return outputs
  127. def resnet_block2(self, inputs, name1, name2, c):
  128. """ Another basic block of ResNet. """
  129. branch2a_feats = self.nn.conv2d(inputs,
  130. filters = c,
  131. kernel_size = (1, 1),
  132. strides = (1, 1),
  133. activation = None,
  134. use_bias = False,
  135. name = name1+'_branch2a')
  136. branch2a_feats = self.nn.batch_norm(branch2a_feats, name2+'_branch2a',)
  137. branch2a_feats = tf.nn.relu(branch2a_feats)
  138. branch2b_feats = self.nn.conv2d(branch2a_feats,
  139. filters = c,
  140. kernel_size = (3, 3),
  141. strides = (1, 1),
  142. activation = None,
  143. use_bias = False,
  144. name = name1+'_branch2b')
  145. branch2b_feats = self.nn.batch_norm(branch2b_feats, name2+'_branch2b')
  146. branch2b_feats = tf.nn.relu(branch2b_feats)
  147. branch2c_feats = self.nn.conv2d(branch2b_feats,
  148. filters = 4*c,
  149. kernel_size = (1, 1),
  150. strides = (1, 1),
  151. activation = None,
  152. use_bias = False,
  153. name = name1+'_branch2c')
  154. branch2c_feats = self.nn.batch_norm(branch2c_feats, name2+'_branch2c')
  155. outputs = inputs + branch2c_feats
  156. outputs = tf.nn.relu(outputs)
  157. return outputs
  158. def get_permutation(self, height, width):
  159. """ Get the permutation corresponding to the snake-like walk decribed \
  160. in the paper. Used to flatten the convolutional feats. """
  161. permutation = np.zeros(height*width, np.int32)
  162. for i in range(height):
  163. for j in range(width):
  164. permutation[i*width+j] = i*width+j if i%2==0 \
  165. else (i+1)*width-j-1
  166. return permutation
  167. def flatten_feats(self, feats, channels):
  168. """ Flatten the feats. """
  169. temp1 = tf.reshape(feats, [self.config.batch_size, -1, channels])
  170. temp1 = tf.transpose(temp1, [1, 0, 2])
  171. temp2 = tf.gather(temp1, self.permutation)
  172. temp2 = tf.transpose(temp2, [1, 0, 2])
  173. return temp2
  174. def build_rnn(self):
  175. """ Build the RNN. """
  176. print("Building the RNN...")
  177. config = self.config
  178. facts = self.conv_feats
  179. num_facts, dim_fact = self.conv_feat_shape
  180. # Setup the placeholders
  181. question_word_idxs = tf.placeholder(
  182. dtype = tf.int32,
  183. shape = [config.batch_size, config.max_question_length])
  184. question_lens = tf.placeholder(
  185. dtype = tf.int32,
  186. shape = [config.batch_size])
  187. if self.is_train:
  188. answer_idxs = tf.placeholder(
  189. dtype = tf.int32,
  190. shape = [config.batch_size])
  191. if config.question_encoding == 'positional':
  192. position_weights = tf.placeholder(
  193. dtype = tf.float32,
  194. shape = [config.batch_size, \
  195. config.max_question_length, \
  196. config.dim_embedding])
  197. # Setup the word embedding
  198. with tf.variable_scope("word_embedding"):
  199. embedding_matrix = tf.get_variable(
  200. name = 'weights',
  201. shape = [config.vocabulary_size, config.dim_embedding],
  202. initializer = self.nn.fc_kernel_initializer,
  203. regularizer = self.nn.fc_kernel_regularizer,
  204. trainable = self.is_train)
  205. # Encode the questions
  206. with tf.variable_scope('question_encoding'):
  207. question_embeddings = tf.nn.embedding_lookup(
  208. embedding_matrix,
  209. question_word_idxs)
  210. if config.question_encoding == 'positional':
  211. # use positional encoding
  212. self.build_position_weights()
  213. question_encodings = question_embeddings * position_weights
  214. question_encodings = tf.reduce_sum(question_encodings,
  215. axis = 1)
  216. else:
  217. # use GRU encoding
  218. outputs, _ = tf.nn.dynamic_rnn(
  219. self.nn.gru(),
  220. inputs = question_embeddings,
  221. dtype = tf.float32)
  222. question_encodings = []
  223. for k in range(config.batch_size):
  224. question_encoding = tf.slice(outputs,
  225. [k, question_lens[k]-1, 0],
  226. [1, 1, config.num_gru_units])
  227. question_encodings.append(tf.squeeze(question_encoding))
  228. question_encodings = tf.stack(question_encodings, axis = 0)
  229. # Encode the facts
  230. with tf.variable_scope('input_fusion'):
  231. if config.embed_fact:
  232. facts = tf.reshape(facts, [-1, dim_fact])
  233. facts = self.nn.dropout(facts)
  234. facts = self.nn.dense(
  235. facts,
  236. units = config.dim_embedding,
  237. activation = tf.tanh,
  238. name = 'fc')
  239. facts = tf.reshape(facts, [-1, num_facts, config.dim_embedding])
  240. outputs, _ = tf.nn.bidirectional_dynamic_rnn(
  241. self.nn.gru(),
  242. self.nn.gru(),
  243. inputs = facts,
  244. dtype = tf.float32)
  245. outputs_fw, outputs_bw = outputs
  246. fact_encodings = outputs_fw + outputs_bw
  247. # Episodic Memory Update
  248. with tf.variable_scope('episodic_memory'):
  249. episode = EpisodicMemory(config,
  250. num_facts,
  251. question_encodings,
  252. fact_encodings)
  253. memory = tf.identity(question_encodings)
  254. if config.tie_memory_weight:
  255. scope_list = ['layer'] * config.memory_step
  256. else:
  257. scope_list = ['layer'+str(t) for t in range(config.memory_step)]
  258. for t in range(config.memory_step):
  259. with tf.variable_scope(scope_list[t], reuse = tf.AUTO_REUSE):
  260. fact = episode.new_fact(memory)
  261. if config.memory_update == 'gru':
  262. gru = self.nn.gru()
  263. memory = gru(fact, memory)[0]
  264. else:
  265. expanded_memory = tf.concat(
  266. [memory, fact, question_encodings],
  267. axis = 1)
  268. expanded_memory = self.nn.dropout(expanded_memory)
  269. memory = self.nn.dense(
  270. expanded_memory,
  271. units = config.num_gru_units,
  272. activation = tf.nn.relu,
  273. name = 'fc')
  274. # Compute the result
  275. with tf.variable_scope('result'):
  276. expanded_memory = tf.concat([memory, question_encodings],
  277. axis = 1)
  278. expanded_memory = self.nn.dropout(expanded_memory)
  279. logits = self.nn.dense(expanded_memory,
  280. units = config.vocabulary_size,
  281. activation = None,
  282. name = 'logits')
  283. prediction = tf.argmax(logits, axis = 1)
  284. # Compute the loss and accuracy if necessary
  285. if self.is_train:
  286. cross_entropy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
  287. labels = answer_idxs,
  288. logits = logits)
  289. cross_entropy_loss = tf.reduce_mean(cross_entropy_loss)
  290. reg_loss = tf.losses.get_regularization_loss()
  291. total_loss = cross_entropy_loss + reg_loss
  292. ground_truth = tf.cast(answer_idxs, tf.int64)
  293. prediction_correct = tf.where(
  294. tf.equal(prediction, ground_truth),
  295. tf.cast(tf.ones_like(prediction), tf.float32),
  296. tf.cast(tf.zeros_like(prediction), tf.float32))
  297. accuracy = tf.reduce_mean(prediction_correct)
  298. self.question_word_idxs = question_word_idxs
  299. self.question_lens = question_lens
  300. self.prediction = prediction
  301. if self.is_train:
  302. self.answer_idxs = answer_idxs
  303. if config.question_encoding == 'positional':
  304. self.position_weights = position_weights
  305. self.total_loss = total_loss
  306. self.cross_entropy_loss = cross_entropy_loss
  307. self.reg_loss = reg_loss
  308. self.accuracy = accuracy
  309. print("RNN built.")
  310. def build_position_weights(self):
  311. """ Setup the weights for the positional encoding of questions. """
  312. config = self.config
  313. D = config.dim_embedding
  314. pos_weights = []
  315. for M in range(config.max_question_length):
  316. cur_pos_weights = []
  317. for j in range(config.max_question_length):
  318. if j <= M:
  319. temp = [1.0-(j+1.0)/(M+1.0) \
  320. -((d+1.0)/D)*(1-2.0*(j+1.0)/(M+1.0)) \
  321. for d in range(D)]
  322. else:
  323. temp = [0.0] * D
  324. cur_pos_weights.append(temp)
  325. pos_weights.append(cur_pos_weights)
  326. self.pos_weights = np.array(pos_weights, np.float32)
  327. def build_optimizer(self):
  328. """ Setup the training operation. """
  329. config = self.config
  330. learning_rate = tf.constant(config.initial_learning_rate)
  331. if config.learning_rate_decay_factor < 1.0:
  332. def _learning_rate_decay_fn(learning_rate, global_step):
  333. return tf.train.exponential_decay(
  334. learning_rate,
  335. global_step,
  336. decay_steps = config.num_steps_per_decay,
  337. decay_rate = config.learning_rate_decay_factor,
  338. staircase = True)
  339. learning_rate_decay_fn = _learning_rate_decay_fn
  340. else:
  341. learning_rate_decay_fn = None
  342. with tf.variable_scope('optimizer', reuse = tf.AUTO_REUSE):
  343. if config.optimizer == 'Adam':
  344. optimizer = tf.train.AdamOptimizer(
  345. learning_rate = config.initial_learning_rate,
  346. beta1 = config.beta1,
  347. beta2 = config.beta2,
  348. epsilon = config.epsilon
  349. )
  350. elif config.optimizer == 'RMSProp':
  351. optimizer = tf.train.RMSPropOptimizer(
  352. learning_rate = config.initial_learning_rate,
  353. decay = config.decay,
  354. momentum = config.momentum,
  355. centered = config.centered,
  356. epsilon = config.epsilon
  357. )
  358. elif config.optimizer == 'Momentum':
  359. optimizer = tf.train.MomentumOptimizer(
  360. learning_rate = config.initial_learning_rate,
  361. momentum = config.momentum,
  362. use_nesterov = config.use_nesterov
  363. )
  364. else:
  365. optimizer = tf.train.GradientDescentOptimizer(
  366. learning_rate = config.initial_learning_rate
  367. )
  368. opt_op = tf.contrib.layers.optimize_loss(
  369. loss = self.total_loss,
  370. global_step = self.global_step,
  371. learning_rate = learning_rate,
  372. optimizer = optimizer,
  373. clip_gradients = config.clip_gradients,
  374. learning_rate_decay_fn = learning_rate_decay_fn)
  375. self.opt_op = opt_op
  376. def build_summary(self):
  377. """ Build the summary (for TensorBoard visualization). """
  378. with tf.name_scope("variables"):
  379. for var in tf.trainable_variables():
  380. with tf.name_scope(var.name[:var.name.find(":")]):
  381. self.variable_summary(var)
  382. with tf.name_scope("metrics"):
  383. tf.summary.scalar("cross_entropy_loss", self.cross_entropy_loss)
  384. tf.summary.scalar("reg_loss", self.reg_loss)
  385. tf.summary.scalar("total_loss", self.total_loss)
  386. tf.summary.scalar("accuracy", self.accuracy)
  387. self.summary = tf.summary.merge_all()
  388. def variable_summary(self, var):
  389. """ Build the summary for a variable. """
  390. mean = tf.reduce_mean(var)
  391. tf.summary.scalar('mean', mean)
  392. stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
  393. tf.summary.scalar('stddev', stddev)
  394. tf.summary.scalar('max', tf.reduce_max(var))
  395. tf.summary.scalar('min', tf.reduce_min(var))
  396. tf.summary.histogram('histogram', var)
  397. def get_feed_dict(self, batch):
  398. """ Get the feed dictionary for the current batch. """
  399. config = self.config
  400. if self.is_train:
  401. # training phase
  402. image_files, question_word_idxs, question_lens, answer_idxs = batch
  403. images = self.image_loader.load_images(image_files)
  404. if config.question_encoding == 'positional':
  405. position_weights = [self.pos_weights[question_lens[i]-1, :, :]
  406. for i in range(config.batch_size)]
  407. position_weights = np.array(position_weights, np.float32)
  408. return {self.images: images,
  409. self.question_word_idxs: question_word_idxs,
  410. self.question_lens: question_lens,
  411. self.answer_idxs: answer_idxs,
  412. self.position_weights: position_weights}
  413. else:
  414. return {self.images: images,
  415. self.question_word_idxs: question_word_idxs,
  416. self.question_lens: question_lens,
  417. self.answer_idxs: answer_idxs}
  418. else:
  419. # evaluation or testing phase
  420. image_files, question_word_idxs, question_lens = batch
  421. images = self.image_loader.load_images(image_files)
  422. return {self.images: images,
  423. self.question_word_idxs: question_word_idxs,
  424. self.question_lens: question_lens}



  1. #!/usr/bin/python
  2. import tensorflow as tf
  3. from config import Config
  4. from model import QuestionAnswerer
  5. from dataset import prepare_train_data, prepare_eval_data, prepare_test_data
  6. FLAGS = tf.app.flags.FLAGS
  7. tf.flags.DEFINE_string('phase', 'train',
  8. 'The phase can be train, eval or test')
  9. tf.flags.DEFINE_boolean('load', False,
  10. 'Turn on to load a pretrained model from either \
  11. the latest checkpoint or a specified file')
  12. tf.flags.DEFINE_string('model_file', None,
  13. 'If sepcified, load a pretrained model from this file')
  14. tf.flags.DEFINE_boolean('load_cnn', False,
  15. 'Turn on to load a pretrained CNN model')
  16. tf.flags.DEFINE_string('cnn_model_file', './vgg16_no_fc.npy',
  17. 'File containing a pretrained CNN model')
  18. tf.flags.DEFINE_boolean('train_cnn', False,
  19. 'Turn on to train both CNN and RNN. \
  20. Otherwise, only RNN is trained')
  21. def main(argv):
  22. config = Config()
  23. config.phase = FLAGS.phase
  24. config.train_cnn = FLAGS.train_cnn
  25. with tf.Session() as sess:
  26. if FLAGS.phase == 'train':
  27. # training phase
  28. data, config = prepare_train_data(config)
  29. model = QuestionAnswerer(config)
  30. sess.run(tf.global_variables_initializer())
  31. if FLAGS.load:
  32. model.load(sess, FLAGS.model_file)
  33. if FLAGS.load_cnn:
  34. model.load_cnn(sess, FLAGS.cnn_model_file)
  35. tf.get_default_graph().finalize()
  36. model.train(sess, data)
  37. elif FLAGS.phase == 'eval':
  38. # evaluation phase
  39. vqa, data, vocabulary, config = prepare_eval_data(config)
  40. model = QuestionAnswerer(config)
  41. model.load(sess, FLAGS.model_file)
  42. tf.get_default_graph().finalize()
  43. model.eval(sess, vqa, data, vocabulary)
  44. else:
  45. # testing phase
  46. data, vocabulary, config = prepare_test_data(config)
  47. model = QuestionAnswerer(config)
  48. model.load(sess, FLAGS.model_file)
  49. tf.get_default_graph().finalize()
  50. model.test(sess, data, vocabulary)
  51. if __name__ == '__main__':
  52. tf.app.run()

4. 实验训练



由于我的GPU(1050TI)比较小,直接运行会内存溢出,于是batch size改成了4,数据量比较大,训练时间很长,大概训练了6天吧(还没训练完,6天只训练了一个半epoch),下面这张图是训练了一天的景象:



如果想要监测训练过程,可以在cmd中运行:tensorboard --logdir='./summary/'



需要设置两个参数:--phase=eval --model_file='./models/xxxxxx.npy'



可以用训练模型来回答JPEG图像的任何问题,将图像放在路径‘test/images’下,然后创建一个CSV(内容包含图像,问题,问题id),放在路径‘test’下,最后按照下面的参数设置运行:--phase=test --model_file='./models/xxxxxx.npy'


5. 实验结果










1. VQA还挺难的,不管是模型编写还是数据读取,都要非常细心。

