当前位置:   article > 正文

nn.embedding层报错index out of range in self详解_nn.embedding使用indexerror: index out of range in se

nn.embedding使用indexerror: index out of range in self

nn.embedding层报错index out of range in self详解

报错详情

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-383-d67388d2e4cc> in <module>
      1 output_emb = myEmbed(total_words = total_words, embedding_dim = 8)
      2 word_vector = torch.tensor(word_vector, dtype=torch.long).clone().detach()
----> 3 output = output_emb(word_vector)
      4 print(output)
      5 # word_vector

/opt/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

<ipython-input-382-10f2ec94e0ae> in forward(self, sentences_idx)
      4         self.embed = nn.Embedding(total_words,embedding_dim)
      5     def forward(self,sentences_idx):
----> 6         return self.embed(sentences_idx).clone().detach()

/opt/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/opt/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/sparse.py in forward(self, input)
    124         return F.embedding(
    125             input, self.weight, self.padding_idx, self.max_norm,
--> 126             self.norm_type, self.scale_grad_by_freq, self.sparse)
    127 
    128     def extra_repr(self) -> str:

/opt/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   1812         # remove once script supports set_grad_enabled
   1813         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 1814     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
   1815 
   1816 

IndexError: index out of range in self
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43

报错代码

  1. 数据的预处理,统计单词总数并映射成字典;
sentences = ['It is a good day.','how are you?','I want to study the nn.embedding.','I want to elmate my pox.','the experience that I have done today is my favriate experience.']
sentences = [sentence.split() for sentence in sentences]
all_words = []
total_words = 0
for sentence in sentences:
    all_words += [ words for words in sentence ]
no_repeat_words = set(all_words)
total_words = len(no_repeat_words)  
word_to_idx = {word: i+1 for i, word in enumerate(no_repeat_words)}
word_to_idx['<unk>'] = 0
idx_to_word = {i+1: word for i, word in enumerate(no_repeat_words)}
print('all_words:',all_words)
print('no_repeat_words:',no_repeat_words)
print('idx_to_word:',idx_to_word)
print('word_to_idx:',word_to_idx)
print('total_words',total_words)


>>>all_words: ['It', 'is', 'a', 'good', 'day.', 'how', 'are', 'you?', 'I', 'want', 'to', 'study', 'the', 'nn.embedding.', 'I', 'want', 'to', 'elmate', 'my', 'pox.', 'the', 'experience', 'that', 'I', 'have', 'done', 'today', 'is', 'my', 'favriate', 'experience.']
>>>no_repeat_words: {'a', 'want', 'nn.embedding.', 'It', 'experience.', 'my', 'today', 'study', 'favriate', 'is', 'have', 'I', 'day.', 'you?', 'how', 'elmate', 'experience', 'to', 'pox.', 'the', 'that', 'good', 'done', 'are'}
>>>idx_to_word: {1: 'a', 2: 'want', 3: 'nn.embedding.', 4: 'It', 5: 'experience.', 6: 'my', 7: 'today', 8: 'study', 9: 'favriate', 10: 'is', 11: 'have', 12: 'I', 13: 'day.', 14: 'you?', 15: 'how', 16: 'elmate', 17: 'experience', 18: 'to', 19: 'pox.', 20: 'the', 21: 'that', 22: 'good', 23: 'done', 24: 'are'}
>>>word_to_idx: {'a': 1, 'want': 2, 'nn.embedding.': 3, 'It': 4, 'experience.': 5, 'my': 6, 'today': 7, 'study': 8, 'favriate': 9, 'is': 10, 'have': 11, 'I': 12, 'day.': 13, 'you?': 14, 'how': 15, 'elmate': 16, 'experience': 17, 'to': 18, 'pox.': 19, 'the': 20, 'that': 21, 'good': 22, 'done': 23, 'are': 24, '<unk>': 0}
>>>total_words: 24
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  1. WORD TO VECTOR,将句子转化成向量
word_vector = []
sentences_pad = []
print('填充前句子:',sentences)
max_len = max([len(sentence) for sentence in sentences])
  
for sentence in sentences:
    if len(sentence) < max_len:
        sentences_pad += [sentence.extend("<unk>" for _ in range(max_len-len(sentence)))]
    else:
        sentences_pad += [sentence]
for sentence in sentences:
    word_vector += [[ word_to_idx[word] for word in sentence]] 
# print('填充前的句子:',sentences_pad)
print('填充后的句子:',sentences_pad)
print('句子转化成向量:',word_vector)

>>>填充前句子: [['It', 'is', 'a', 'good', 'day.', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>'], ['how', 'are', 'you?', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>'], ['I', 'want', 'to', 'study', 'the', 'nn.embedding.', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>'], ['I', 'want', 'to', 'elmate', 'my', 'pox.', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>'], ['the', 'experience', 'that', 'I', 'have', 'done', 'today', 'is', 'my', 'favriate', 'experience.']]
>>>填充后的句子: [['It', 'is', 'a', 'good', 'day.', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>'], ['how', 'are', 'you?', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>'], ['I', 'want', 'to', 'study', 'the', 'nn.embedding.', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>'], ['I', 'want', 'to', 'elmate', 'my', 'pox.', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>'], ['the', 'experience', 'that', 'I', 'have', 'done', 'today', 'is', 'my', 'favriate', 'experience.']]
>>>句子转化成向量: [[4, 10, 1, 22, 13, 0, 0, 0, 0, 0, 0], [15, 24, 14, 0, 0, 0, 0, 0, 0, 0, 0], [12, 2, 18, 8, 20, 3, 0, 0, 0, 0, 0], [12, 2, 18, 16, 6, 19, 0, 0, 0, 0, 0], [20, 17, 21, 12, 11, 23, 7, 10, 6, 9, 5]]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  1. word_vector传入nn.Embedding()
class myEmbed(nn.Module):
    def __init__(self,total_words,embedding_dim):
        super(myEmbed,self).__init__()
        self.embed = nn.Embedding(total_words,embedding_dim)
    def forward(self,sentences_idx):
        return self.embed(sentences_idx).clone().detach()
output_emb = myEmbed(total_words = total_words, embedding_dim = 8)
word_vector = torch.tensor(word_vector, dtype=torch.long).clone().detach()
output = output_emb(word_vector)
print(output)

>>> 错误信息如‘报错详情’所示
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

报错原因

在‘3. word_vector传入nn.Embedding()’这一步出错,传入的total_words小于传入的word_vector的单词总数,造成nn.Embedding()索引溢出,简单来说num_embeddings的值设置有误,num_embeddings应当≥total_words,即为字典的单词总数。传参详情如下:

class torch.nn.Embedding(num_embeddings, embedding_dim, 
						padding_idx=None, max_norm=None,
						norm_type=2.0, scale_grad_by_freq=False, 
						sparse=False, _weight=None)
  • 1
  • 2
  • 3
  • 4

存储固定字典和大小的嵌入项的简单查找表。该模块通常用于存储词嵌入并使用索引检索它们。模块的输入是一个索引列表,输出是相应的词嵌入。

1. num_embeddings (int) – 去重后字典的单词总数;
2. embedding_dim (int) – 所设置的单词维度
3. padding_idx (int, optional) – 如果给定,当遇到索引时,用嵌入向量padding_idx(初始化为0)填充输出。(选填)
4. max_norm (float, optional) – 如果给定,则对每个范数大于max_norm的嵌入向量重新规范化,使其具有max_norm范数。(选填)
5. norm_type (float, optional) – 要为max_norm选项计算的p-norm的值。默认2。(选填)
6. scale_grad_by_freq (boolean, optional) – 如果给定,这将按小批量中单词频率的倒数来缩放梯度。默认为假。(选填)
7. sparse (bool, optional) – 若为真,则梯度w.r.t权矩阵为稀疏张量。有关稀疏梯度的更多细节,请参阅注释。(选填)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

修改后代码

// 前面部分照搬
// 1. 数据的预处理
sentences = ['It is a good day.','how are you?','I want to study the nn.embedding.','I want to elmate my pox.','the experience that I have done today is my favriate experience.']
sentences = [sentence.split() for sentence in sentences]
all_words = []
total_words = 0
for sentence in sentences:
    all_words += [ words for words in sentence ]
no_repeat_words = set(all_words)
total_words = len(no_repeat_words)  
word_to_idx = {word: i+1 for i, word in enumerate(no_repeat_words)}
word_to_idx['<unk>'] = 0
idx_to_word = {i+1: word for i, word in enumerate(no_repeat_words)}

// 2. word to vector,将句子转化成向量
word_vector = []
sentences_pad = []
max_len = max([len(sentence) for sentence in sentences])
for sentence in sentences:
    if len(sentence) < max_len:
        sentences_pad += [sentence.extend("<unk>" for _ in range(max_len-len(sentence)))]
    else:
        sentences_pad += [sentence]
for sentence in sentences:
    word_vector += [[ word_to_idx[word] for word in sentence]]

// 3.传入向量化的句子,生成字向量
total_words = len(word_to_idx)
class myEmbed(nn.Module):
    def __init__(self,total_words,embedding_dim):
        super(myEmbed,self).__init__()
        self.embed = nn.Embedding(total_words,embedding_dim)
    def forward(self,sentences_idx):
        return self.embed(sentences_idx).clone().detach()
output_emb = myEmbed(total_words = total_words, embedding_dim = 8)
word_vector = torch.tensor(word_vector, dtype=torch.long).clone().detach()
output = output_emb(word_vector)
print(output)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38

结果显示

tensor([[[-0.9028, -1.0990,  1.0646,  1.4747,  1.2577,  0.6634,  0.0188,
           0.6545],
         [-0.2176,  0.5252,  0.2574,  1.2822, -0.8745, -1.2112,  0.0584,
          -0.5189],
         [ 0.5240, -0.8862, -1.3594, -1.1795, -0.8441,  0.7830,  0.9485,
           0.5734],
         [ 1.6141,  0.2254, -0.1457,  0.7620, -1.8222,  0.4634, -0.8187,
           0.3283],
         [-0.3710,  0.8392, -0.6133,  0.6381, -1.7941,  0.2950,  0.3148,
           2.2896],
         [ 1.0500, -0.7410,  1.4759, -0.9487,  1.4232,  0.1392,  0.8788,
          -0.7190],
         [ 1.0500, -0.7410,  1.4759, -0.9487,  1.4232,  0.1392,  0.8788,
          -0.7190],
         [ 1.0500, -0.7410,  1.4759, -0.9487,  1.4232,  0.1392,  0.8788,
          -0.7190],
         [ 1.0500, -0.7410,  1.4759, -0.9487,  1.4232,  0.1392,  0.8788,
          -0.7190],
         [ 1.0500, -0.7410,  1.4759, -0.9487,  1.4232,  0.1392,  0.8788,
          -0.7190],
         [ 1.0500, -0.7410,  1.4759, -0.9487,  1.4232,  0.1392,  0.8788,
          -0.7190]],

        [[-0.1860,  1.8636, -0.6865, -0.3979,  1.1691,  1.2467,  1.5026,
           0.2586],
         [-0.9084,  0.0882, -0.0631,  0.0667,  0.9071,  1.6767, -0.1515,
           1.1327],
         [-2.6057,  0.6494,  0.0483,  0.5032,  0.5448,  0.7419,  0.8697,
          -0.4805],
         [ 1.0500, -0.7410,  1.4759, -0.9487,  1.4232,  0.1392,  0.8788,
          -0.7190],
         [ 1.0500, -0.7410,  1.4759, -0.9487,  1.4232,  0.1392,  0.8788,
          -0.7190],
         [ 1.0500, -0.7410,  1.4759, -0.9487,  1.4232,  0.1392,  0.8788,
          -0.7190],
         [ 1.0500, -0.7410,  1.4759, -0.9487,  1.4232,  0.1392,  0.8788,
          -0.7190],
         [ 1.0500, -0.7410,  1.4759, -0.9487,  1.4232,  0.1392,  0.8788,
          -0.7190],
         [ 1.0500, -0.7410,  1.4759, -0.9487,  1.4232,  0.1392,  0.8788,
          -0.7190],
         [ 1.0500, -0.7410,  1.4759, -0.9487,  1.4232,  0.1392,  0.8788,
          -0.7190],
         [ 1.0500, -0.7410,  1.4759, -0.9487,  1.4232,  0.1392,  0.8788,
          -0.7190]],

        [[-0.2740,  0.7465,  0.7614, -1.3599, -0.7212,  0.0880,  0.9135,
           1.8307],
         [ 0.3974, -0.0467, -0.8352,  0.2649,  1.9399, -2.1667,  0.3023,
          -1.7938],
         [-0.8383, -0.6372, -0.1922,  0.5328,  0.5292, -0.8630, -0.0764,
          -1.4630],
         [ 0.2232, -0.2855, -0.5257, -1.4286, -1.3177, -0.5152, -1.1457,
           0.3720],
         [-0.6988, -0.3652, -0.9142,  0.5403,  0.1923, -1.6566,  0.8366,
          -1.1495],
         [-0.1142, -1.0301,  1.1789,  0.4901, -0.2576,  0.4898,  0.4154,
           1.1342],
         [ 1.0500, -0.7410,  1.4759, -0.9487,  1.4232,  0.1392,  0.8788,
          -0.7190],
         [ 1.0500, -0.7410,  1.4759, -0.9487,  1.4232,  0.1392,  0.8788,
          -0.7190],
         [ 1.0500, -0.7410,  1.4759, -0.9487,  1.4232,  0.1392,  0.8788,
          -0.7190],
         [ 1.0500, -0.7410,  1.4759, -0.9487,  1.4232,  0.1392,  0.8788,
          -0.7190],
         [ 1.0500, -0.7410,  1.4759, -0.9487,  1.4232,  0.1392,  0.8788,
          -0.7190]],

        [[-0.2740,  0.7465,  0.7614, -1.3599, -0.7212,  0.0880,  0.9135,
           1.8307],
         [ 0.3974, -0.0467, -0.8352,  0.2649,  1.9399, -2.1667,  0.3023,
          -1.7938],
         [-0.8383, -0.6372, -0.1922,  0.5328,  0.5292, -0.8630, -0.0764,
          -1.4630],
         [-1.1177, -0.8047,  0.2185, -0.3761,  0.8753,  2.1269,  1.4648,
          -0.1830],
         [ 0.4993,  0.5043, -0.4541, -0.2609,  2.4289,  1.5842, -1.9878,
           1.4654],
         [ 1.8740, -0.1214,  0.6446, -0.4646,  0.3363, -0.3854, -0.4768,
           0.7824],
         [ 1.0500, -0.7410,  1.4759, -0.9487,  1.4232,  0.1392,  0.8788,
          -0.7190],
         [ 1.0500, -0.7410,  1.4759, -0.9487,  1.4232,  0.1392,  0.8788,
          -0.7190],
         [ 1.0500, -0.7410,  1.4759, -0.9487,  1.4232,  0.1392,  0.8788,
          -0.7190],
         [ 1.0500, -0.7410,  1.4759, -0.9487,  1.4232,  0.1392,  0.8788,
          -0.7190],
         [ 1.0500, -0.7410,  1.4759, -0.9487,  1.4232,  0.1392,  0.8788,
          -0.7190]],

        [[-0.6988, -0.3652, -0.9142,  0.5403,  0.1923, -1.6566,  0.8366,
          -1.1495],
         [ 0.4606,  0.2213, -0.6970, -0.1618, -1.8748, -0.4962,  0.5517,
          -0.4841],
         [ 0.0738,  0.8394, -1.1480, -0.3829, -0.0931,  1.1793,  0.2737,
          -0.9046],
         [-0.2740,  0.7465,  0.7614, -1.3599, -0.7212,  0.0880,  0.9135,
           1.8307],
         [ 1.2459,  0.6663,  1.6969, -0.2072, -1.9603, -1.4282,  0.8382,
          -0.3569],
         [-1.6661,  0.0275,  0.5090,  0.4771, -0.7955,  0.9199,  0.9401,
           0.8285],
         [ 0.2445,  0.0742,  1.6497, -0.0338,  1.8325,  0.1709,  0.7659,
          -0.7233],
         [-0.2176,  0.5252,  0.2574,  1.2822, -0.8745, -1.2112,  0.0584,
          -0.5189],
         [ 0.4993,  0.5043, -0.4541, -0.2609,  2.4289,  1.5842, -1.9878,
           1.4654],
         [ 0.1651, -0.1232,  1.1650, -1.3531,  0.1082,  0.1277, -1.0091,
          -1.3470],
         [-0.2381,  1.7149,  1.0614, -1.1837, -0.5192,  0.9356, -0.1343,
           0.9358]]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114

参考文章

1. pytorch nn.Embeddding()的官方文档
2. 在pytorch里面实现word embedding是通过一个函数来实现的:nn.Embedding在深度学习1这篇博客中讨论了word embeding层到底怎么实现的?
3.Pytorch中的nn.Embedding()
4.pytorch embedding层报错index out of range in self

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/522200
推荐阅读
相关标签
  

闽ICP备14008679号