当前位置:   article > 正文

[Pytorch]pytorch中的LSTM模型_pytorch lstm的 h c h = torch.zeros((2,1,self.hidden

pytorch lstm的 h c h = torch.zeros((2,1,self.hidden_dim)) c = torch.zeros((2

公式表示

Pytorch中LSTM的公式表示为:

定义

Pytorch中LSTM的定义如下:

class torch.nn.LSTM(*args, **kwargs)


参数列表


输入数据格式: 
input(seq_len, batch, input_size) 
h0(num_layers * num_directions, batch, hidden_size) 
c0(num_layers * num_directions, batch, hidden_size)

输出数据格式: 
output(seq_len, batch, hidden_size * num_directions) 
hn(num_layers * num_directions, batch, hidden_size) 
cn(num_layers * num_directions, batch, hidden_size)

实例:基于LSTM的词性标注模型

  1. import torch
  2. import gensim
  3. torch.manual_seed(2)
  4. datas=[('你 叫 什么 名字 ?','n v n n f'),('今天 天气 怎么样 ?','n n adj f'),]
  5. words=[ data[0].split() for data in datas]
  6. tags=[ data[1].split() for data in datas]
  7. id2word=gensim.corpora.Dictionary(words)
  8. word2id=id2word.token2id
  9. id2tag=gensim.corpora.Dictionary(tags)
  10. tag2id=id2tag.token2id
  11. def sen2id(inputs):
  12. return [word2id[word] for word in inputs]
  13. def tags2id(inputs):
  14. return [tag2id[word] for word in inputs]
  15. # print(sen2id('你 叫 什么 名字'.split()))
  16. def formart_input(inputs):
  17. return torch.autograd.Variable(torch.LongTensor(sen2id(inputs)))
  18. def formart_tag(inputs):
  19. return torch.autograd.Variable(torch.LongTensor(tags2id(inputs)),)
  20. class LSTMTagger(torch.nn.Module):
  21. def __init__(self,embedding_dim,hidden_dim,voacb_size,target_size):
  22. super(LSTMTagger,self).__init__()
  23. self.embedding_dim=embedding_dim
  24. self.hidden_dim=hidden_dim
  25. self.voacb_size=voacb_size
  26. self.target_size=target_size
  27. self.lstm=torch.nn.LSTM(self.embedding_dim,self.hidden_dim)
  28. self.log_softmax=torch.nn.LogSoftmax()
  29. self.embedding=torch.nn.Embedding(self.voacb_size,self.embedding_dim)
  30. self.hidden=(torch.autograd.Variable(torch.zeros(1,1,self.hidden_dim)),torch.autograd.Variable(torch.zeros(1,1,self.hidden_dim)))
  31. self.out2tag=torch.nn.Linear(self.hidden_dim,self.target_size)
  32. def forward(self,inputs):
  33. input=self.embedding((inputs))
  34. out,self.hidden=self.lstm(input.view(-1,1,self.embedding_dim),self.hidden)
  35. tags=self.log_softmax(self.out2tag(out.view(-1,self.hidden_dim)))
  36. return tags
  37. model=LSTMTagger(3,3,len(word2id),len(tag2id))
  38. loss_function=torch.nn.NLLLoss()
  39. optimizer=torch.optim.SGD(model.parameters(),lr=0.1)
  40. for _ in range(100):
  41. model.zero_grad()
  42. input=formart_input('你 叫 什么 名字'.split())
  43. tags=formart_tag('n n adj f'.split())
  44. out=model(input)
  45. loss=loss_function(out,tags)
  46. loss.backward(retain_variables=True)
  47. optimizer.step()
  48. print(loss.data[0])
  49. input=formart_input('你 叫 什么'.split())
  50. out=model(input)
  51. out=torch.max(out,1)[1]
  52. print([id2tag[out.data[i]] for i in range(0,out.size()[0])])

转自:https://blog.csdn.net/android_ruben/article/details/80206792

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/488335
推荐阅读
相关标签
  

闽ICP备14008679号