当前位置:   article > 正文

Pytorch实现word2vec训练_pytorch模型训练word2vec

pytorch模型训练word2vec

Pytorch实现word2vec

主要内容

Word2Vec的原理网上有很多很多资料,这里就不再复述了。本人使用pytorch来尽可能复现Distributed Representations of Words and Phrases and their Compositionality
论文中训练词向量的方法。论文中有很多模型实现的细节,这些细节对于词向量的好坏至关重要。我们虽然无法完全复现论文中的实验结果,主要是由于计算资源等各种细节原因,但是还是可以大致展示如何训练词向量。
以下是一些未实现的细节。

  • subsampling:参考论文section 2.3

训练数据为text8,所有相关代码及数据下载地址Word2Vec地址,提取密码:p46t。
在项目目录下运行:bash run_word2vec.sh。

数据预处理

  • 从文本文件中读取所有的文字,通过这些文本创建一个vocabulary
  • 由于单词数量可能太大,我们只选取最常见的MAX_VOCAB_SIZE个单词
  • 我们添加一个UNK单词表示所有不常见的单词
  • 我们需要记录单词到index的mapping,以及index到单词的mapping,单词的count,单词的(normalized) frequency,以及单词总数。
    with open(args.data_dir,'r') as fin:
        text = fin.read()
    text = [w for w in text.lower().split()]
    vocab = dict(Counter(text).most_common(args.max_vocab_size-1))
    vocab["<unk>"] = len(text) - np.sum(list(vocab.values()))
    idx_to_word = [word for word in vocab.keys()]
    word_to_idx = {word:idx for idx,word in enumerate(idx_to_word)}

    # negsample的采样概率分布
    word_counts = np.asarray([value for value in vocab.values()],dtype=np.float32)
    word_freqs = word_counts / np.sum(word_counts)
    word_freqs = word_freqs ** (3./4.)
    word_freqs = word_freqs / np.sum(word_freqs)
    vocab_size = len(idx_to_word)

    return text,idx_to_word,word_to_idx,word_freqs,vocab_size
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

实现Dataloader

一个dataloader需要以下内容:

  • 把所有text编码成数字,然后用subsampling预处理这些文字。
  • 保存vocabulary,单词count,normalized word frequency
  • 每个iteration sample一个中心词
  • 根据中心词sample一些negative单词
  • 返回单词的counts
    直接使用pytorch的dataloader,使用方法参照这里Pytorch dataloader。我们需要定义一下两个function
  • _ len_ 返回整个数据集中的item总数
  • _ get_ 根据给定的index返回指定的item
class wordEmbeddingDataset(Dataset):
    def __init__(self,text, word_to_idx, idx_to_word, word_freqs,C,K):
        '''
        :param text: 语料
        :param word_to_idx:
        :param idx_to_word:
        :param word_freqs: 词频的3/4,negatiesample
        :param C: skip_gram的周围词个数
        :param K: negative sample的个数
        '''
        super(wordEmbeddingDataset, self).__init__()
        self.vocab_size = len(word_to_idx)
        self.text_encoded = [word_to_idx.get(t,self.vocab_size-1) for t in text]
        self.text_encoded = torch.tensor(self.text_encoded).long()
        self.word_to_idx = word_to_idx
        self.idx_to_word = idx_to_word
        self.word_freqs = torch.tensor(word_freqs)
        self.C = C
        self.K = K

    def __len__(self):
        return len(self.text_encoded)

    def __getitem__(self, idx):
        center_word = self.text_encoded[idx]
        pos_indices = list(range(idx-self.C,idx)) + list(range(idx+1,idx+self.C+1))
        # 前后超范围从后前取
        pos_indices = [i%len(self.text_encoded) for i in pos_indices]
        pos_words = self.text_encoded[pos_indices]
        #neg_words的采样
        neg_words = torch.multinomial(self.word_freqs,self.K*pos_words.shape[0],replacement=True)

        return center_word,pos_words,neg_words
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33

定义Word2vec模型

word2vec模型很简单其实就是一个in_embed,一个out_imbed,采用neg_sample,objective函数为:
在这里插入图片描述
代码如下:

class word2VecModel(nn.Module):
    def __init__(self,vocab_size,emb_size):
        super(word2VecModel,self).__init__()

        self.vocab_size = vocab_size
        self.emb_size = emb_size

        initrange = 0.5/self.emb_size
        self.in_embed = nn.Embedding(vocab_size,emb_size)
        self.in_embed.weight.data.uniform_(-initrange,initrange)

        self.out_embed = nn.Embedding(vocab_size,emb_size)
        self.out_embed.weight.data.uniform_(-initrange,initrange)

    def forward(self,center_words,pos_words,neg_words):
		'''
		center_words: 中心词, [batch_size]
        pos_words: 中心词周围 context window 出现过的单词 [batch_size * (window_size * 2)]
        neg_words: 中心词周围没有出现过的单词,从 negative sampling 得到 [batch_size, (window_size * 2 * K)]
         
		'''
        batch_size = center_words.size(0)
        input_embedding = self.in_embed(center_words) #[batch,emb]
        pos_embedding = self.out_embed(pos_words) #[batch,2c,emb]
        neg_embedding = self.out_embed(neg_words)  #[batch,2c*k,emb]

        log_pos = torch.matmul(pos_embedding,input_embedding.unsqueeze(2)).squeeze() # [batch,2c]
        log_nes = torch.matmul(neg_embedding,-input_embedding.unsqueeze(2)).squeeze() #[batch,2c*k]

        log_pos_los = F.logsigmoid(log_pos).sum(1)
        log_neg_los = F.logsigmoid(log_nes).sum(1)
        loss = log_neg_los+log_pos_los

        return -loss.mean()

    def input_embeddings(self):
        return self.in_embed.weight.data.cpu().numpy()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37

模型的训练与评估

  • 模型一般需要训练若干个epoch
  • 每个epoch我们都把所有的数据分成若干个batch
  • 把每个batch的输入和输出都包装成cuda tensor
  • forward pass
  • 清空模型当前gradient
  • backward pass,更新模型参数
  • 每隔一定的iteration输出模型在当前iteration的loss,以及在验证数据集上做模型的评估
  • 模型保存
def train(args,model,dataloader,word_to_idx,idx_to_word):
    LOG_FILE = "word-embedding.log"
    tb_writer = SummaryWriter('./runs')
    model.train()
    t_total = args.num_epoch * len(dataloader)
    optimizer = AdamW(model.parameters(),lr=args.learnning_rate,eps=1e-8)
    scheduler = get_linear_schedule_with_warmup(optimizer=optimizer,num_warmup_steps=args.warmup_steps,num_training_steps=t_total)
    train_iterator = trange(args.num_epoch,desc="epoch")
    tr_loss = 0.
    logg_loss = 0.
    global_step = 0
    for k in train_iterator:
        print("the {} epoch beginning!".format(k))
        epoch_iteration = tqdm(dataloader,desc="iteration")
        for step,batch in enumerate(epoch_iteration):
            batch = tuple(t.to(args.device) for t in batch)
            input = {"center_words":batch[0],"pos_words":batch[1],"neg_words":batch[2]}
            loss = model(**input)
            model.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            global_step +=1

            tr_loss += loss.item()
            if (step+1) % 100 == 0:
                loss_scalar = (tr_loss - logg_loss) / 100
                logg_loss = tr_loss
                with open(LOG_FILE, "a") as fout:
                    fout.write("epoch: {}, iter: {}, loss: {},learn_rate: {}\n".format(k, step, loss_scalar,scheduler.get_lr()[0]))
                    print("epoch: {}, iter: {}, loss: {}, learning_rate: {}".format(k, step, loss_scalar,scheduler.get_lr()[0]))
                    tb_writer.add_scalar("learning_rate",scheduler.get_lr()[0],global_step)
                    tb_writer.add_scalar("loss",loss_scalar,global_step)

            if (step+1) % 2000 == 0:
                embedding_weights = model.input_embeddings()
                sim_simlex = evaluate("./worddata/simlex-999.txt", embedding_weights,word_to_idx)
                sim_men = evaluate("./worddata/men.txt", embedding_weights,word_to_idx)
                sim_353 = evaluate("./worddata/wordsim353.csv", embedding_weights,word_to_idx)
                with open(LOG_FILE, "a") as fout:
                    print("epoch: {}, iteration: {}, simlex-999: {}, men:{}, sim353:{}, nearest to monster: {}\n".format(
                                k, step, sim_simlex,sim_men,sim_353, find_nearest("monster",embedding_weights,word_to_idx,idx_to_word)))
                    fout.write("epoch: {}, iteration: {}, simlex-999: {}, men: {}, sim353: {}, nearest to monster: {}\n".format(
                                k, step, sim_simlex, sim_men, sim_353, find_nearest("monster",embedding_weights,word_to_idx,idx_to_word)))

    embedding_weights = model.input_embeddings()
    np.save("embedding-{}".format(args.embed_size), embedding_weights)
    torch.save(model.state_dict(), "embedding-{}.th".format(args.embed_size))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48

模型的loss下降曲线:
在这里插入图片描述

模型展示

在 MEN 和 Simplex-999、sim-353 数据集上做评估:
simlex999: SpearmanrResult(correlation=0.17249746449326459, pvalue=8.268870735375061e-08),
men: SpearmanrResult(correlation=0.427926614729899, pvalue=1.76732628946326e-115),
sim353: SpearmanrResult(correlation=0.4555634677353853, pvalue=9.452442365338771e-18)
寻找nearest neighbors:
word:good, nearest:[‘good’, ‘bad’, ‘things’, ‘happiness’, ‘everything’, ‘pleasure’, ‘nothing’, ‘something’, ‘think’, ‘whatever’]
word:fresh, nearest:[‘fresh’, ‘salt’, ‘dry’, ‘grain’, ‘vegetables’, ‘eggs’, ‘fruit’, ‘sugar’, ‘milk’, ‘drinking’]
word:monster, nearest:[‘monster’, ‘giant’, ‘loch’, ‘ness’, ‘creature’, ‘beast’, ‘hero’, ‘wolf’, ‘sword’, ‘serpent’]
word:green, nearest:[‘green’, ‘blue’, ‘yellow’, ‘orange’, ‘red’, ‘purple’, ‘white’, ‘colored’, ‘brown’, ‘colors’]
word:like, nearest:[‘like’, ‘similar’, ‘such’, ‘resemble’, ‘teeth’, ‘sometimes’, ‘soft’, ‘unlike’, ‘honey’, ‘etc’]
word:america, nearest:[‘america’, ‘africa’, ‘australia’, ‘europe’, ‘canada’, ‘african’, ‘caribbean’, ‘pacific’, ‘carolina’, ‘americas’]
word:chicago, nearest:[‘chicago’, ‘illinois’, ‘boston’, ‘detroit’, ‘atlanta’, ‘cleveland’, ‘cincinnati’, ‘miami’, ‘houston’, ‘denver’]
word:work, nearest:[‘work’, ‘works’, ‘ideas’, ‘scientific’, ‘writing’, ‘haydn’, ‘seminal’, ‘philosophical’, ‘philosophy’, ‘writings’]
word:computer, nearest:[‘computer’, ‘computers’, ‘hardware’, ‘software’, ‘computing’, ‘digital’, ‘graphics’, ‘machines’, ‘portable’, ‘interface’]
word:language, nearest:[‘language’, ‘languages’, ‘dialects’, ‘dialect’, ‘spoken’, ‘vocabulary’, ‘syntax’, ‘grammar’, ‘alphabet’, ‘speakers’]
单词之间的关系
the nearest to <women-man+king>:
queen、prince、emperor、king、son、daughter、throne、iii、kings、wife、iv、duke、heir、vii、henry、princess、father、empress、anne、brother

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/正经夜光杯/article/detail/755588
推荐阅读
相关标签
  

闽ICP备14008679号