赞
踩
网址:
torch.nn.parameter — PyTorch 1.11.0 documentation
torch.nn.Parameter(torch.Tensor) 是一个类,继承了torch.Tensor这个类,有两个参数:
通俗的解释:
首先可以把这个函数理解为类型转换函数,将一个不可训练的类型Tensor
转换成可以训练的类型parameter
并将这个parameter
绑定到这个module
里面(net.parameter()
中就有这个绑定的parameter
,所以在参数优化的时候可以进行优化的),所以经过类型转换就变成了模型的一部分,成为了模型中根据训练可以改动的参数了。使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化。
- self.embed = nn.Parameter(torch.FloatTensor(8, 64))
- init.normal_(self.embed, mean=0, std=0.5)
- class STL(nn.Module):
- '''
- inputs --- [N, E//2]
- '''
-
- def __init__(self,model_config):
-
- super().__init__()
- self.embed = nn.Parameter(torch.FloatTensor(model_config["gst"]["n_style_token"], model_config["gst"]["E"] // model_config["gst"]["attn_head"]))
- d_q = model_config["gst"]["E"] // 2
- d_k = model_config["gst"]["E"] // model_config["gst"]["attn_head"]
- self.attention = MultiHeadAttention(query_dim=d_q, key_dim=d_k, num_units=model_config["gst"]["E"], num_heads=model_config["gst"]["attn_head"])
-
- init.normal_(self.embed, mean=0, std=0.5)
-
- def forward(self, inputs):
- N = inputs.size(0)
- query = inputs.unsqueeze(1) # [N, 1, E//2]
- keys = F.tanh(self.embed).unsqueeze(0).expand(N, -1, -1) # [N, token_num, E // num_heads]
- style_embed = self.attention(query, keys)
-
- return style_embed
部分参考:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。