赞
踩
import torch import numpy as np class MaxState(torch.nn.Module): def __init__(self, hidden_dim, heads, win): super(MaxState, self).__init__() assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads." self.head_size = hidden_dim // heads self.head = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) self.state = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) self.head_num = heads self.win = win self.hidden = hidden_dim self.mask = torch.triu(torch.ones([win, win])).to(device) self.layer_nor = torch.nn.LayerNorm(hidden_dim) def forward(self, input_data, state=None): # self.head.to(device) b, s, k, h, w = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size, self.win window = torch.ones([1, w]).to(device) out = self.head(input_data) out = out.unsqueeze(-1) @ window out = out.permute([0, 2, 1, 3]) one_list = [] if state is None: state = torch.ones([out.shape[0], out.shape[1], 1, 1]) * float("-inf") state = state.to(device) for i in range(0, s, w): state.reshape([state.shape[0], -1]) j = w + i one = out[:, :, i:j] _, _, r, c = one.shape if r != self.win: one = torch.where(self.mask[:r, :] == 1, one, torch.Tensor([-float('inf')]).to(device)) else: one = torch.where(self.mask == 1, one, torch.Tensor([-float('inf')]).to(device)) if i == 0: one = torch.concat([one, state @ window], axis=2) state, _ = torch.max(one, axis=2, keepdim=True) else: state1, _ = torch.max(one, axis=2, keepdim=True) # state = torch.sin(self.state(state1.reshape([state1.shape[0], -1]))*state.reshape([state.shape[0], -1])) state1 = self.state(state1.permute([0, 3, 1, 2]).reshape([state1.shape[0], -1, state1.shape[1]])) state = state1.permute([0, 2, 1]).unsqueeze(-2) + state # state = state.reshape(state1.shape) one = torch.concat([one, state], axis=2) state, _ = torch.max(one, axis=2, keepdim=True) one = state.reshape([b, k, h, w]) state = state[..., -1:] if r != self.win: one = one[..., :r] one = one.permute([0, 3, 1, 2]) one_list.append(one) out = torch.concat(one_list, 1) out = out.reshape([b, s, -1]) return out, state class FeedForward(torch.nn.Module): def __init__(self, hidden_size): super(FeedForward, self).__init__() self.ffn1 = torch.nn.Linear(hidden_size, hidden_size * 2) self.ffn2 = torch.nn.Linear(hidden_size * 2, hidden_size) self.gate = torch.nn.Linear(hidden_size, hidden_size * 2) self.relu = torch.nn.ReLU() def forward(self, x): x1 = self.ffn1(x) x2 = self.relu(self.gate(x)) x = x1 * x2 x = self.ffn2(x) return x class MemoryBlock(torch.nn.Module): def __init__(self, hidden_dim): super(MemoryBlock, self).__init__() # 使用Xavier初始化权重 self.fc = torch.nn.Parameter(torch.empty(hidden_dim, hidden_dim)) torch.nn.init.xavier_uniform_(self.fc) # self.mem = torch.nn.Parameter(torch.empty(hidden_dim, hidden_dim)) self.mem = torch.eye(hidden_dim).to(device) # torch.nn.init.xavier_uniform_(self.mem) # self.sig = torch.nn.Sigmoid() def forward(self, x): x = x @ (self.fc + self.mem) # x = self.sig(x) return x class DecoderLayer(torch.nn.Module): def __init__(self, hidden_size, num_heads): super(DecoderLayer, self).__init__() # self.self_attention = MaskMultiHeadAttention(hidden_size, num_heads) self.self_attention = MaxState(hidden_size, num_heads, 8) self.ffn = FeedForward(hidden_size) self.layer_norm = torch.nn.LayerNorm(hidden_size) def forward(self, x, state=None, seq_len=None): x1, state = self.self_attention(x, state) x = self.layer_norm(self.ffn(x1) + x) return x, state class SamOut(torch.nn.Module): def __init__(self, voc_size, hidden_size, num_heads, num_layers): super(SamOut, self).__init__() self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3) self.pos = torch.nn.Embedding(1024, hidden_size) self.decoder_layers = torch.nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)]) self.head = torch.nn.Linear(hidden_size, voc_size, False) # self.head_state = torch.nn.Linear(hidden_size, num_layers, False) self.layer_nor = torch.nn.LayerNorm(hidden_size) self.down = torch.nn.ModuleList( [torch.nn.Linear(2 * hidden_size, hidden_size, False) for _ in range(num_layers)]) self.layer_norm = torch.nn.LayerNorm(hidden_size) def state_forward(self, state, pos, x): if state is None: state = [None] * len(self.decoder_layers) i = 0 for ii, decoder_layer in enumerate(self.decoder_layers): x = self.down[i](torch.concat([torch.zeros([x.shape[0], 1, 1]).to(device) + pos, x], -1)) x1, state[i] = decoder_layer(x, state[i]) x = x1 + x i += 1 return x, state def pos_forward(self, x): if x.shape[1] >= 1024: pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) // 1024).unsqueeze(0) pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) % 1024).unsqueeze(0) + pos else: pos = self.pos(torch.arange(0, x.shape[1]).long().to(device)).unsqueeze(0) return pos def forward(self, x0, x1): x0, state0 = self.one_forward(x0) x1, state1 = self.one_forward(x1) # x2, state2 = self.one_forward(x2) return x0, x1 def one_forward(self, x, state=None, seq_len=None): x = self.em(x) pos = self.pos_forward(x) x, state = self.state_forward(state, pos, x) return self.head(x), state device = "cuda" if __name__ == '__main__': net = SamOut(235, 256, 16, 4) net.to(device) net(torch.randint(0, 200, [2, 8 * 13]).to(device), torch.randint(0, 200, [2, 4 * 13]).to(device), torch.randint(0, 200, [2, 2 * 13]).to(device)) #
这是一个用于实现Self-Attention机制的PyTorch模型。它包含了许多不同的类和函数来实现不同的功能。
MaxState类:定义了一个用于处理输入数据的自注意力模块。它根据输入数据的大小、头数和窗口大小来初始化。它使用线性层来对输入数据进行变换,并使用最大化函数来计算注意力权重。它还包含一些辅助函数来处理输入数据的形状和尺寸。
FeedForward类:定义了一个前馈神经网络模块。它使用线性层和ReLU激活函数来实现非线性变换。
MemoryBlock类:定义了一个记忆块模块。它使用参数化的线性变换和一个固定的记忆矩阵来实现输入数据的记忆功能。
DecoderLayer类:定义了一个解码器层模块。它包含一个自注意力模块和一个前馈神经网络模块,以及一些规范化层用于调整输入数据的尺寸。
SamOut类:定义了一个自注意力模型。它包含一个嵌入层、一个位置编码层和多个解码器层。它使用这些层来处理输入数据,并最终输出预测结果。
然后,我们在主函数中实例化了SamOut类,并进行了简单的前向传播操作,以测试模型的功能。
请注意,代码中可能存在一些错误或不完整的部分,因此如果您想将其用于实际应用,请进行适当的修改和完善。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。