赞
踩
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
这个函数接收一个张量mask,并将其变换为特定的形状。输入三个参数分别为:mask:大小为[bsz, seq_len]。dtype:数据类型。tgt_len:目标序列长度。以下是函数的运行方式。
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
将掩码中0和1的位置互换。
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
inverted_mask.to(torch.bool)将反转掩码转换为布尔类型
masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)用dtype这个类型的最小值填充其中为true的位置。
返回一个经过填充处理的反转掩码张量,形状为 [bsz, 1, tgt_len, src_len],数据类型为 dtype。
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
# eps一个很小的数,用于避免除零错误
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
# weight是一个可训练的参数,初始化为一个大小为 hidden_size 的全 1 张量
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
# 存储hidden_states的数据类型
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
# hidden_states一般是batch size * sequence length * hidden size,这里的mean是按最后一维取平均
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return (self.weight * hidden_states).to(input_dtype)
这段代码定义了一个RMSNorm类,用于实现归一化。
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return (self.weight * hidden_states).to(input_dtype)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。