当前位置:   article > 正文

pytorch 中 torch.nn.Parameter()_nn.parameter expand

nn.parameter expand

一、官方教程:

网址:

torch.nn.parameter — PyTorch 1.11.0 documentation

二、代码解读:

torch.nn.Parameter(torch.Tensor) 是一个类,继承了torch.Tensor这个类,有两个参数:

  • data(Tensor): 给定一个tensor;
  • requires_grad: 指定是否需要梯度,默认为True;

通俗的解释:

首先可以把这个函数理解为类型转换函数,将一个不可训练的类型Tensor转换成可以训练的类型parameter并将这个parameter绑定到这个module里面(net.parameter()中就有这个绑定的parameter,所以在参数优化的时候可以进行优化的),所以经过类型转换就变成了模型的一部分,成为了模型中根据训练可以改动的参数了。使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化

三、实际应用:

  • 比如,在GST中,需要定义多个Token当作计算attention的K、V,这里就用到了torch.nn.Parameter(),作为模型的一部分不断地修改优化。
  • 主要是在模型类的 __init__()中,声明一下并标准化。
  1. self.embed = nn.Parameter(torch.FloatTensor(8, 64))
  2. init.normal_(self.embed, mean=0, std=0.5)
  • 具体如下: 
  1. class STL(nn.Module):
  2. '''
  3. inputs --- [N, E//2]
  4. '''
  5. def __init__(self,model_config):
  6. super().__init__()
  7. self.embed = nn.Parameter(torch.FloatTensor(model_config["gst"]["n_style_token"], model_config["gst"]["E"] // model_config["gst"]["attn_head"]))
  8. d_q = model_config["gst"]["E"] // 2
  9. d_k = model_config["gst"]["E"] // model_config["gst"]["attn_head"]
  10. self.attention = MultiHeadAttention(query_dim=d_q, key_dim=d_k, num_units=model_config["gst"]["E"], num_heads=model_config["gst"]["attn_head"])
  11. init.normal_(self.embed, mean=0, std=0.5)
  12. def forward(self, inputs):
  13. N = inputs.size(0)
  14. query = inputs.unsqueeze(1) # [N, 1, E//2]
  15. keys = F.tanh(self.embed).unsqueeze(0).expand(N, -1, -1) # [N, token_num, E // num_heads]
  16. style_embed = self.attention(query, keys)
  17. return style_embed
  • 此外,在使用attention时,如果需要自定义并随机初始化一个Q,也是同样的道理。 

部分参考:

torch.nn.Parameter()_chenzy_hust的博客-CSDN博客_nn.parameter()

PyTorch里面的torch.nn.Parameter() - 简书 (jianshu.com)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/码创造者/article/detail/1007164
推荐阅读
相关标签
  

闽ICP备14008679号