赞
踩
Word2Vec的原理网上有很多很多资料,这里就不再复述了。本人使用pytorch来尽可能复现Distributed Representations of Words and Phrases and their Compositionality
论文中训练词向量的方法。论文中有很多模型实现的细节,这些细节对于词向量的好坏至关重要。我们虽然无法完全复现论文中的实验结果,主要是由于计算资源等各种细节原因,但是还是可以大致展示如何训练词向量。
以下是一些未实现的细节。
训练数据为text8,所有相关代码及数据下载地址Word2Vec地址,提取密码:p46t。
在项目目录下运行:bash run_word2vec.sh。
with open(args.data_dir,'r') as fin: text = fin.read() text = [w for w in text.lower().split()] vocab = dict(Counter(text).most_common(args.max_vocab_size-1)) vocab["<unk>"] = len(text) - np.sum(list(vocab.values())) idx_to_word = [word for word in vocab.keys()] word_to_idx = {word:idx for idx,word in enumerate(idx_to_word)} # negsample的采样概率分布 word_counts = np.asarray([value for value in vocab.values()],dtype=np.float32) word_freqs = word_counts / np.sum(word_counts) word_freqs = word_freqs ** (3./4.) word_freqs = word_freqs / np.sum(word_freqs) vocab_size = len(idx_to_word) return text,idx_to_word,word_to_idx,word_freqs,vocab_size
一个dataloader需要以下内容:
class wordEmbeddingDataset(Dataset): def __init__(self,text, word_to_idx, idx_to_word, word_freqs,C,K): ''' :param text: 语料 :param word_to_idx: :param idx_to_word: :param word_freqs: 词频的3/4,negatiesample :param C: skip_gram的周围词个数 :param K: negative sample的个数 ''' super(wordEmbeddingDataset, self).__init__() self.vocab_size = len(word_to_idx) self.text_encoded = [word_to_idx.get(t,self.vocab_size-1) for t in text] self.text_encoded = torch.tensor(self.text_encoded).long() self.word_to_idx = word_to_idx self.idx_to_word = idx_to_word self.word_freqs = torch.tensor(word_freqs) self.C = C self.K = K def __len__(self): return len(self.text_encoded) def __getitem__(self, idx): center_word = self.text_encoded[idx] pos_indices = list(range(idx-self.C,idx)) + list(range(idx+1,idx+self.C+1)) # 前后超范围从后前取 pos_indices = [i%len(self.text_encoded) for i in pos_indices] pos_words = self.text_encoded[pos_indices] #neg_words的采样 neg_words = torch.multinomial(self.word_freqs,self.K*pos_words.shape[0],replacement=True) return center_word,pos_words,neg_words
word2vec模型很简单其实就是一个in_embed,一个out_imbed,采用neg_sample,objective函数为:
代码如下:
class word2VecModel(nn.Module): def __init__(self,vocab_size,emb_size): super(word2VecModel,self).__init__() self.vocab_size = vocab_size self.emb_size = emb_size initrange = 0.5/self.emb_size self.in_embed = nn.Embedding(vocab_size,emb_size) self.in_embed.weight.data.uniform_(-initrange,initrange) self.out_embed = nn.Embedding(vocab_size,emb_size) self.out_embed.weight.data.uniform_(-initrange,initrange) def forward(self,center_words,pos_words,neg_words): ''' center_words: 中心词, [batch_size] pos_words: 中心词周围 context window 出现过的单词 [batch_size * (window_size * 2)] neg_words: 中心词周围没有出现过的单词,从 negative sampling 得到 [batch_size, (window_size * 2 * K)] ''' batch_size = center_words.size(0) input_embedding = self.in_embed(center_words) #[batch,emb] pos_embedding = self.out_embed(pos_words) #[batch,2c,emb] neg_embedding = self.out_embed(neg_words) #[batch,2c*k,emb] log_pos = torch.matmul(pos_embedding,input_embedding.unsqueeze(2)).squeeze() # [batch,2c] log_nes = torch.matmul(neg_embedding,-input_embedding.unsqueeze(2)).squeeze() #[batch,2c*k] log_pos_los = F.logsigmoid(log_pos).sum(1) log_neg_los = F.logsigmoid(log_nes).sum(1) loss = log_neg_los+log_pos_los return -loss.mean() def input_embeddings(self): return self.in_embed.weight.data.cpu().numpy()
def train(args,model,dataloader,word_to_idx,idx_to_word): LOG_FILE = "word-embedding.log" tb_writer = SummaryWriter('./runs') model.train() t_total = args.num_epoch * len(dataloader) optimizer = AdamW(model.parameters(),lr=args.learnning_rate,eps=1e-8) scheduler = get_linear_schedule_with_warmup(optimizer=optimizer,num_warmup_steps=args.warmup_steps,num_training_steps=t_total) train_iterator = trange(args.num_epoch,desc="epoch") tr_loss = 0. logg_loss = 0. global_step = 0 for k in train_iterator: print("the {} epoch beginning!".format(k)) epoch_iteration = tqdm(dataloader,desc="iteration") for step,batch in enumerate(epoch_iteration): batch = tuple(t.to(args.device) for t in batch) input = {"center_words":batch[0],"pos_words":batch[1],"neg_words":batch[2]} loss = model(**input) model.zero_grad() loss.backward() optimizer.step() scheduler.step() global_step +=1 tr_loss += loss.item() if (step+1) % 100 == 0: loss_scalar = (tr_loss - logg_loss) / 100 logg_loss = tr_loss with open(LOG_FILE, "a") as fout: fout.write("epoch: {}, iter: {}, loss: {},learn_rate: {}\n".format(k, step, loss_scalar,scheduler.get_lr()[0])) print("epoch: {}, iter: {}, loss: {}, learning_rate: {}".format(k, step, loss_scalar,scheduler.get_lr()[0])) tb_writer.add_scalar("learning_rate",scheduler.get_lr()[0],global_step) tb_writer.add_scalar("loss",loss_scalar,global_step) if (step+1) % 2000 == 0: embedding_weights = model.input_embeddings() sim_simlex = evaluate("./worddata/simlex-999.txt", embedding_weights,word_to_idx) sim_men = evaluate("./worddata/men.txt", embedding_weights,word_to_idx) sim_353 = evaluate("./worddata/wordsim353.csv", embedding_weights,word_to_idx) with open(LOG_FILE, "a") as fout: print("epoch: {}, iteration: {}, simlex-999: {}, men:{}, sim353:{}, nearest to monster: {}\n".format( k, step, sim_simlex,sim_men,sim_353, find_nearest("monster",embedding_weights,word_to_idx,idx_to_word))) fout.write("epoch: {}, iteration: {}, simlex-999: {}, men: {}, sim353: {}, nearest to monster: {}\n".format( k, step, sim_simlex, sim_men, sim_353, find_nearest("monster",embedding_weights,word_to_idx,idx_to_word))) embedding_weights = model.input_embeddings() np.save("embedding-{}".format(args.embed_size), embedding_weights) torch.save(model.state_dict(), "embedding-{}.th".format(args.embed_size))
模型的loss下降曲线:
在 MEN 和 Simplex-999、sim-353 数据集上做评估:
simlex999: SpearmanrResult(correlation=0.17249746449326459, pvalue=8.268870735375061e-08),
men: SpearmanrResult(correlation=0.427926614729899, pvalue=1.76732628946326e-115),
sim353: SpearmanrResult(correlation=0.4555634677353853, pvalue=9.452442365338771e-18)
寻找nearest neighbors:
word:good, nearest:[‘good’, ‘bad’, ‘things’, ‘happiness’, ‘everything’, ‘pleasure’, ‘nothing’, ‘something’, ‘think’, ‘whatever’]
word:fresh, nearest:[‘fresh’, ‘salt’, ‘dry’, ‘grain’, ‘vegetables’, ‘eggs’, ‘fruit’, ‘sugar’, ‘milk’, ‘drinking’]
word:monster, nearest:[‘monster’, ‘giant’, ‘loch’, ‘ness’, ‘creature’, ‘beast’, ‘hero’, ‘wolf’, ‘sword’, ‘serpent’]
word:green, nearest:[‘green’, ‘blue’, ‘yellow’, ‘orange’, ‘red’, ‘purple’, ‘white’, ‘colored’, ‘brown’, ‘colors’]
word:like, nearest:[‘like’, ‘similar’, ‘such’, ‘resemble’, ‘teeth’, ‘sometimes’, ‘soft’, ‘unlike’, ‘honey’, ‘etc’]
word:america, nearest:[‘america’, ‘africa’, ‘australia’, ‘europe’, ‘canada’, ‘african’, ‘caribbean’, ‘pacific’, ‘carolina’, ‘americas’]
word:chicago, nearest:[‘chicago’, ‘illinois’, ‘boston’, ‘detroit’, ‘atlanta’, ‘cleveland’, ‘cincinnati’, ‘miami’, ‘houston’, ‘denver’]
word:work, nearest:[‘work’, ‘works’, ‘ideas’, ‘scientific’, ‘writing’, ‘haydn’, ‘seminal’, ‘philosophical’, ‘philosophy’, ‘writings’]
word:computer, nearest:[‘computer’, ‘computers’, ‘hardware’, ‘software’, ‘computing’, ‘digital’, ‘graphics’, ‘machines’, ‘portable’, ‘interface’]
word:language, nearest:[‘language’, ‘languages’, ‘dialects’, ‘dialect’, ‘spoken’, ‘vocabulary’, ‘syntax’, ‘grammar’, ‘alphabet’, ‘speakers’]
单词之间的关系
the nearest to <women-man+king>:
queen、prince、emperor、king、son、daughter、throne、iii、kings、wife、iv、duke、heir、vii、henry、princess、father、empress、anne、brother
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。