当前位置:   article > 正文

从语言模型的hidden_states到logit,经历了什么变换_causallmoutputwithcrossattentions

causallmoutputwithcrossattentions

以gpt2为例

导入模型,并推理。

  1. from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
  2. import torch
  3. config = GPT2Config.from_pretrained("../model/gpt2")
  4. model = GPT2LMHeadModel.from_pretrained("../model/gpt2")
  5. tokenizer = GPT2Tokenizer.from_pretrained("../model/gpt2")
  6. prompt = "I thought this movie was glorious, I appreciated it. Conclusion: This movie is"
  7. inputs = tokenizer(prompt, return_tensors="pt")
  8. output = model(inputs.input_ids, output_hidden_states=True)

output输出的内容是什么

查看modeling_gpt2的源代码,在import部分:

  1. from ...modeling_outputs import (
  2. BaseModelOutputWithPastAndCrossAttentions,
  3. CausalLMOutputWithCrossAttentions,
  4. QuestionAnsweringModelOutput,
  5. SequenceClassifierOutputWithPast,
  6. TokenClassifierOutput,
  7. )

再进一步查看modeling_outputs.py文件,可以看到output的类

  1. class CausalLMOutputWithCrossAttentions(ModelOutput):
  2. """
  3. Base class for causal language model (or autoregressive) outputs.
  4. Args:
  5. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  6. Language modeling loss (for next-token prediction).
  7. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  8. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  9. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  10. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  11. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  12. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  13. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  14. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  15. sequence_length)`.
  16. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  17. heads.
  18. cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  19. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  20. sequence_length)`.
  21. Cross attentions weights after the attention softmax, used to compute the weighted average in the
  22. cross-attention heads.
  23. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  24. Tuple of `torch.FloatTensor` tuples of length `config.n_layers`, with each tuple containing the cached key,
  25. value states of the self-attention and the cross-attention layers if model is used in encoder-decoder
  26. setting. Only relevant if `config.is_decoder = True`.
  27. Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
  28. `past_key_values` input) to speed up sequential decoding.
  29. """
  30. loss: Optional[torch.FloatTensor] = None
  31. logits: torch.FloatTensor = None
  32. past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
  33. hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
  34. attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
  35. cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None

因此,output可以访问loss、logits、hidden_states(需要在model()加一个参数:output_hidden_states=True或者在config设置:config.output_hidden_states=True)等等。

hidden_states和logits有什么关系

hidden_states包含了每一个transformer block的输出结果,因此可以通过hidden_states[-1]来访问最后一层的输出结果,再经过一个线性变换,即可得到logits。

以gpt2(d_model = 768) 和上述prompt为例(18个token),先放上gpt2的结构,详见gpt2结构-CSDN博客

  1. GPT2LMHeadModel(
  2. (transformer): GPT2Model(
  3. (wte): Embedding(50257, 768)
  4. (wpe): Embedding(1024, 768)
  5. (drop): Dropout(p=0.1, inplace=False)
  6. (h): ModuleList(
  7. (0-11): 12 x GPT2Block(
  8. (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  9. (attn): GPT2Attention(
  10. (c_attn): Conv1D()
  11. (c_proj): Conv1D()
  12. (attn_dropout): Dropout(p=0.1, inplace=False)
  13. (resid_dropout): Dropout(p=0.1, inplace=False)
  14. )
  15. (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  16. (mlp): GPT2MLP(
  17. (c_fc): Conv1D()
  18. (c_proj): Conv1D()
  19. (act): NewGELUActivation()
  20. (dropout): Dropout(p=0.1, inplace=False)
  21. )
  22. )
  23. )
  24. (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  25. )
  26. (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  27. )

可以看到最后一行的线性层(lm_head): Linear(in_features=768, out_features=50257, bias=False)
)即为hidden_states[-1]通向logits的变换。

hidden_states[-1]的维度是[1,18,768],线性层model.lm_head.weight的维度是[50257,768],将这两个矩阵相乘

logits2 = torch.matmul(output.hidden_states[-1], model.lm_head.weight.transpose(0, 1))

得出的logits2和model.logits一摸一样,维度是[1,18,50257],50257是词表的大小。

总结:hidden_states[-1]通向logits,只需要一个线性变换。

从logits到token

得到logits之后,找到分数最大的,对应词表中的单词就是next token

  1. # 得到logits后
  2. probs = torch.softmax(logits, dim=-1)
  3. print(probs.size()) #[50257]
  4. # 选择最可能的词的索引
  5. next_token_index = torch.argmax(probs, dim=-1)
  6. # 使用tokenizer将索引转换为单词
  7. next_token = tokenizer.decode(next_token_index.tolist()[0])
  8. print(next_token) #"a"

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/在线问答5/article/detail/780752
推荐阅读
相关标签
  

闽ICP备14008679号