当前位置:   article > 正文

tensorflow的一些代码分析(六) tensorflow实现word2vec_normed_embedding

normed_embedding

核心代码

核心代码主要就是描述模型,计算loss,根据loss优化参数等步骤。这里计算loss直接使用了tf封装好的tf.nn.nce_loss方法,比较方便。优化方法这里也是选的最简单的梯度下降法。具体的描述就放在代码里说好了

  1. self.graph = tf.Graph()
  2. self.graph = tf.Graph()
  3. with self.graph.as_default():
  4. # 首先定义两个用作输入的占位符,分别输入输入集(train_inputs)和标签集(train_labels)
  5. self.train_inputs = tf.placeholder(tf.int32, shape=[self.batch_size])
  6. self.train_labels = tf.placeholder(tf.int32, shape=[self.batch_size, 1])
  7. # 词向量矩阵,初始时为均匀随机正态分布
  8. self.embedding_dict = tf.Variable(
  9. tf.random_uniform([self.vocab_size,self.embedding_size],-1.0,1.0)
  10. )
  11. # 模型内部参数矩阵,初始为截断正太分布
  12. self.nce_weight = tf.Variable(tf.truncated_normal([self.vocab_size, self.embedding_size],
  13. stddev=1.0/math.sqrt(self.embedding_size)))
  14. self.nce_biases = tf.Variable(tf.zeros([self.vocab_size]))
  15. # 将输入序列向量化,具体可见我的【常用函数说明】那一篇
  16. embed = tf.nn.embedding_lookup(self.embedding_dict, self.train_inputs) # batch_size
  17. # 得到NCE损失(负采样得到的损失)
  18. self.loss = tf.reduce_mean(
  19. tf.nn.nce_loss(
  20. weights = self.nce_weight, # 权重
  21. biases = self.nce_biases, # 偏差
  22. labels = self.train_labels, # 输入的标签
  23. inputs = embed, # 输入向量
  24. num_sampled = self.num_sampled, # 负采样的个数
  25. num_classes = self.vocab_size # 类别数目
  26. )
  27. )
  28. # tensorboard 相关
  29. tf.scalar_summary('loss',self.loss) # 让tensorflow记录参数
  30. # 根据 nce loss 来更新梯度和embedding,使用梯度下降法(gradient descent)来实现
  31. self.train_op = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(self.loss) # 训练操作
  32. # 计算与指定若干单词的相似度
  33. self.test_word_id = tf.placeholder(tf.int32,shape=[None])
  34. vec_l2_model = tf.sqrt( # 求各词向量的L2模
  35. tf.reduce_sum(tf.square(self.embedding_dict),1,keep_dims=True)
  36. )
  37. avg_l2_model = tf.reduce_mean(vec_l2_model)
  38. tf.scalar_summary('avg_vec_model',avg_l2_model)
  39. self.normed_embedding = self.embedding_dict / vec_l2_model
  40. # self.embedding_dict = norm_vec # 对embedding向量正则化
  41. test_embed = tf.nn.embedding_lookup(self.normed_embedding, self.test_word_id)
  42. self.similarity = tf.matmul(test_embed, self.normed_embedding, transpose_b=True)
  43. # 变量初始化操作
  44. self.init = tf.global_variables_initializer()
  45. # 汇总所有的变量记录
  46. self.merged_summary_op = tf.merge_all_summaries()
  47. # 保存模型的操作
  48. self.saver = tf.train.Saver()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58

外围代码

外围代码其实有很多,例如训练过程中变量的记录,模型的保存与读取等等,不过这与训练本身没什么关系,这里还是贴如何将句子转化成输入集和标签集的代码。对其他方面感兴趣的看官可以到github上看完整的代码。

  1. def train_by_sentence(self, input_sentence=[]):
  2. # input_sentence: [sub_sent1, sub_sent2, ...]
  3. # 每个sub_sent是一个单词序列,例如['这次','大选','让']
  4. sent_num = input_sentence.__len__()
  5. batch_inputs = []
  6. batch_labels = []
  7. for sent in input_sentence: # 输入有可能是多个句子,这里每个循环处理一个句子
  8. for i in range(sent.__len__()): # 处理单个句子中的每个单词
  9. start = max(0,i-self.win_len) # 窗口为 [-win_len,+win_len],总计长2*win_len+1
  10. end = min(sent.__len__(),i+self.win_len+1)
  11. # 将某个单词对应窗口中的其他单词转化为id计入label,该单词本身计入input
  12. for index in range(start,end):
  13. if index == i:
  14. continue
  15. else:
  16. input_id = self.word2id.get(sent[i])
  17. label_id = self.word2id.get(sent[index])
  18. if not (input_id and label_id): # 如果单词不在词典中,则跳过
  19. continue
  20. batch_inputs.append(input_id)
  21. batch_labels.append(label_id)
  22. if len(batch_inputs)==0: # 如果标签集为空,则跳过
  23. return
  24. batch_inputs = np.array(batch_inputs,dtype=np.int32)
  25. batch_labels = np.array(batch_labels,dtype=np.int32)
  26. batch_labels = np.reshape(batch_labels,[batch_labels.__len__(),1])
  27. # 生成供tensorflow训练用的数据
  28. feed_dict = {
  29. self.train_inputs: batch_inputs,
  30. self.train_labels: batch_labels
  31. }
  32. # 这句操控tf进行各项操作。数组中的选项,train_op等,是让tf运行的操作,feed_dict选项用来输入数据
  33. _, loss_val, summary_str = self.sess.run([self.train_op,self.loss,self.merged_summary_op], feed_dict=feed_dict)
  34. # train loss,记录这次训练的loss值
  35. self.train_loss_records.append(loss_val)
  36. # self.train_loss_k10 = sum(self.train_loss_records)/self.train_loss_records.__len__()
  37. self.train_loss_k10 = np.mean(self.train_loss_records) # 求loss均值
  38. if self.train_sents_num % 1000 == 0 :
  39. self.summary_writer.add_summary(summary_str,self.train_sents_num)
  40. print("{a} sentences dealed, loss: {b}"
  41. .format(a=self.train_sents_num,b=self.train_loss_k10))
  42. # train times
  43. self.train_words_num += batch_inputs.__len__()
  44. self.train_sents_num += input_sentence.__len__()
  45. self.train_times_num += 1
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/367125
推荐阅读
相关标签
  

闽ICP备14008679号