当前位置:   article > 正文

加速samout和loss 更低

加速samout和loss 更低
import torch
import numpy as np


class MaxState(torch.nn.Module):
    def __init__(self, hidden_dim, heads, win):
        super(MaxState, self).__init__()

        assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."

        self.head_size = hidden_dim // heads
        self.head = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.state = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.head_num = heads
        self.win = win
        self.hidden = hidden_dim
        self.mask = torch.triu(torch.ones([win, win])).to(device)
        self.layer_nor = torch.nn.LayerNorm(hidden_dim)

    def forward(self, input_data, state=None):
        # self.head.to(device)
        b, s, k, h, w = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size, self.win

        window = torch.ones([1, w]).to(device)

        out = self.head(input_data)

        out = out.unsqueeze(-1) @ window

        out = out.permute([0, 2, 1, 3])

        one_list = []
        if state is None:
            state = torch.ones([out.shape[0], out.shape[1], 1, 1]) * float("-inf")
            state = state.to(device)
        for i in range(0, s, w):

            state.reshape([state.shape[0], -1])
            j = w + i
            one = out[:, :, i:j]
            _, _, r, c = one.shape
            if r != self.win:

                one = torch.where(self.mask[:r, :] == 1, one, torch.Tensor([-float('inf')]).to(device))

            else:
                one = torch.where(self.mask == 1, one, torch.Tensor([-float('inf')]).to(device))

            if i == 0:

                one = torch.concat([one, state @ window], axis=2)
                state, _ = torch.max(one, axis=2, keepdim=True)


            else:

                state1, _ = torch.max(one, axis=2, keepdim=True)

                # state = torch.sin(self.state(state1.reshape([state1.shape[0], -1]))*state.reshape([state.shape[0], -1]))
                state1 = self.state(state1.permute([0, 3, 1, 2]).reshape([state1.shape[0], -1, state1.shape[1]]))
                state = state1.permute([0, 2, 1]).unsqueeze(-2) + state
                # state = state.reshape(state1.shape)

                one = torch.concat([one, state], axis=2)
                state, _ = torch.max(one, axis=2, keepdim=True)

            one = state.reshape([b, k, h, w])

            state = state[..., -1:]
            if r != self.win:
                one = one[..., :r]

            one = one.permute([0, 3, 1, 2])
            one_list.append(one)

        out = torch.concat(one_list, 1)

        out = out.reshape([b, s, -1])

        return out, state


class FeedForward(torch.nn.Module):
    def __init__(self, hidden_size):
        super(FeedForward, self).__init__()

        self.ffn1 = torch.nn.Linear(hidden_size, hidden_size * 2)
        self.ffn2 = torch.nn.Linear(hidden_size * 2, hidden_size)
        self.gate = torch.nn.Linear(hidden_size, hidden_size * 2)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x1 = self.ffn1(x)

        x2 = self.relu(self.gate(x))

        x = x1 * x2

        x = self.ffn2(x)
        return x


class MemoryBlock(torch.nn.Module):
    def __init__(self, hidden_dim):
        super(MemoryBlock, self).__init__()

        # 使用Xavier初始化权重
        self.fc = torch.nn.Parameter(torch.empty(hidden_dim, hidden_dim))
        torch.nn.init.xavier_uniform_(self.fc)

        # self.mem = torch.nn.Parameter(torch.empty(hidden_dim, hidden_dim))
        self.mem = torch.eye(hidden_dim).to(device)
        # torch.nn.init.xavier_uniform_(self.mem)

        # self.sig = torch.nn.Sigmoid()

    def forward(self, x):
        x = x @ (self.fc + self.mem)
        # x = self.sig(x)
        return x


class DecoderLayer(torch.nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(DecoderLayer, self).__init__()
        # self.self_attention = MaskMultiHeadAttention(hidden_size, num_heads)
        self.self_attention = MaxState(hidden_size, num_heads, 8)
        self.ffn = FeedForward(hidden_size)
        self.layer_norm = torch.nn.LayerNorm(hidden_size)

    def forward(self, x, state=None, seq_len=None):
        x1, state = self.self_attention(x, state)
        x = self.layer_norm(self.ffn(x1) + x)

        return x, state


class SamOut(torch.nn.Module):
    def __init__(self, voc_size, hidden_size, num_heads, num_layers):
        super(SamOut, self).__init__()
        self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3)
        self.pos = torch.nn.Embedding(1024, hidden_size)

        self.decoder_layers = torch.nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])
        self.head = torch.nn.Linear(hidden_size, voc_size, False)
        # self.head_state = torch.nn.Linear(hidden_size, num_layers, False)
        self.layer_nor = torch.nn.LayerNorm(hidden_size)
        self.down = torch.nn.ModuleList(
            [torch.nn.Linear(2 * hidden_size, hidden_size, False) for _ in range(num_layers)])
        self.layer_norm = torch.nn.LayerNorm(hidden_size)

    def state_forward(self, state, pos, x):
        if state is None:
            state = [None] * len(self.decoder_layers)
        i = 0
        for ii, decoder_layer in enumerate(self.decoder_layers):
            x = self.down[i](torch.concat([torch.zeros([x.shape[0], 1, 1]).to(device) + pos, x], -1))
            x1, state[i] = decoder_layer(x, state[i])
            x = x1 + x
            i += 1
        return x, state

    def pos_forward(self, x):
        if x.shape[1] >= 1024:
            pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) // 1024).unsqueeze(0)
            pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) % 1024).unsqueeze(0) + pos

        else:
            pos = self.pos(torch.arange(0, x.shape[1]).long().to(device)).unsqueeze(0)
        return pos

    def forward(self, x0, x1):
        x0, state0 = self.one_forward(x0)
        x1, state1 = self.one_forward(x1)
        # x2, state2 = self.one_forward(x2)
        return x0, x1

    def one_forward(self, x, state=None, seq_len=None):
        x = self.em(x)

        pos = self.pos_forward(x)

        x, state = self.state_forward(state, pos, x)

        return self.head(x), state


device = "cuda"
if __name__ == '__main__':
    net = SamOut(235, 256, 16, 4)
    net.to(device)
    net(torch.randint(0, 200, [2, 8 * 13]).to(device), torch.randint(0, 200, [2, 4 * 13]).to(device),
        torch.randint(0, 200, [2, 2 * 13]).to(device))
    #

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195

这是一个用于实现Self-Attention机制的PyTorch模型。它包含了许多不同的类和函数来实现不同的功能。

  • MaxState类:定义了一个用于处理输入数据的自注意力模块。它根据输入数据的大小、头数和窗口大小来初始化。它使用线性层来对输入数据进行变换,并使用最大化函数来计算注意力权重。它还包含一些辅助函数来处理输入数据的形状和尺寸。

  • FeedForward类:定义了一个前馈神经网络模块。它使用线性层和ReLU激活函数来实现非线性变换。

  • MemoryBlock类:定义了一个记忆块模块。它使用参数化的线性变换和一个固定的记忆矩阵来实现输入数据的记忆功能。

  • DecoderLayer类:定义了一个解码器层模块。它包含一个自注意力模块和一个前馈神经网络模块,以及一些规范化层用于调整输入数据的尺寸。

  • SamOut类:定义了一个自注意力模型。它包含一个嵌入层、一个位置编码层和多个解码器层。它使用这些层来处理输入数据,并最终输出预测结果。

然后,我们在主函数中实例化了SamOut类,并进行了简单的前向传播操作,以测试模型的功能。

请注意,代码中可能存在一些错误或不完整的部分,因此如果您想将其用于实际应用,请进行适当的修改和完善。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/代码探险家/article/detail/1019505
推荐阅读
相关标签
  

闽ICP备14008679号