赞
踩
根据之前学到的知识,你可能发现,在上面的注意力打分函数中有一个问题:没有可以学习的参数。所以,我们引入"多头"的概念,将输入向量经过不同的线性映射,得到不同的结果作为“查询”、“键”和“值”。线性映射是可以学习的映射矩阵。
为了能够使多个头并行计算, 先定义下面的两个转置函数。具体来说,transpose_output函数反转了transpose_qkv函数的操作。
# 导入必要的库
import os
import math
import torch
from torch import nn
import torch.nn.functional as F
from d2l import torch as d2l
from torch.utils import data
import re
import collections
import random
# 上周实现的模块 def masked_softmax(X, valid_lens): """通过在最后一个轴上掩蔽元素来执行softmax操作""" # X:3D张量,valid_lens:1D或2D张量 if valid_lens is None: return nn.functional.softmax(X, dim=-1) else: shape = X.shape if valid_lens.dim() == 1: valid_lens = torch.repeat_interleave(valid_lens, shape[1]) else: valid_lens = valid_lens.reshape(-1) # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0 X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6) # print(nn.functional.softmax(X.reshape(shape), dim=-1)) return nn.functional.softmax(X.reshape(shape), dim=-1) class DotProductAttention(nn.Module): """缩放点积注意力""" def __init__(self, dropout, **kwargs): super(DotProductAttention, self).__init__(**kwargs) # 使用暂退法进行模型正则化 self.dropout = nn.Dropout(dropout) # queries的形状:(batch_size,查询的个数,d) # keys的形状:(batch_size,“键-值”对的个数,d) # values的形状:(batch_size,“键-值”对的个数,值的维度) # valid_lens的形状:(batch_size,)或者(batch_size,查询的个数) def forward(self, queries, keys, values, valid_lens=None): # queries,keys,values的形状: (batch_size*num_heads,q-k-v对的个数,num_hiddens/num_heads) #除最后一个维度数 d = queries.shape[-1] # 缩放点积 scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d) # scores的形状(batchsize*num_heads,q的个数,k的个数) # 计算权重 self.attention_weights = masked_softmax(scores, valid_lens) # self.attention_weights的形状(batchsize*num_heads,q的个数,k-v的个数) # 求结果 return torch.bmm(self.dropout(self.attention_weights), values),scores # (batchsize*num_heads,q的个数,value的特征维度-num_hiddens/num_heads)
def transpose_qkv(X, num_heads): # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens) # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads) X = X.reshape(X.shape[0], X.shape[1], num_heads, -1) # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads) X = X.permute(0, 2, 1, 3) # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads) return X.reshape(-1, X.shape[2], X.shape[3]) def transpose_output(X, num_heads): #(batch_size*num_heads,查询的个数,num_hiddens/num_heads) X = X.reshape(-1, num_heads, X.shape[1], X.shape[2]) X = X.permute(0, 2, 1, 3)# (batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads) return X.reshape(X.shape[0], X.shape[1], -1) # (batch_size,查询或者“键-值”对的个数, num_heads*num_hiddens/num_heads)
class MultiHeadAttention(nn.Module): def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs): super(MultiHeadAttention, self).__init__(**kwargs) self.num_heads = num_heads self.attention = DotProductAttention(dropout) self.W_q = nn.Linear(query_size, num_hiddens, bias=bias) self.W_k = nn.Linear(key_size, num_hiddens, bias=bias) self.W_v = nn.Linear(value_size, num_hiddens, bias=bias) self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias) def forward(self, queries, keys, values, valid_lens): # queries,keys,values的形状: (batch_size,查询或者“键-值”对的个数,num_hiddens) # valid_lens 的形状: (batch_size,)或(batch_size,查询的个数) queries = transpose_qkv(self.W_q(queries), self.num_heads) keys = transpose_qkv(self.W_k(keys), self.num_heads) values = transpose_qkv(self.W_v(values), self.num_heads) # 经过变换后,输出的queries,keys,values的形状: (batch_size*num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads) if valid_lens is not None: # 在轴0,将第一项(标量或者矢量)复制num_heads次,然后如此复制第二项,然后诸如此类。 valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0) # output的形状:(batch_size*num_heads,查询的个数,num_hiddens/num_heads) output,score = self.attention(queries, keys, values, valid_lens) # output_concat的形状:(batch_size,查询的个数,num_hiddens) output_concat = transpose_output(output, self.num_heads) return self.W_o(output_concat),score.reshape(-1, self.num_heads,score.shape[1], score.shape[2]).permute(1,0,2,3)
多头注意力模块输出的形状是 (batch_size,num_queries,num_hiddens)。
num_hiddens, num_heads = 100, 5
# 都用num_hiddens是为了省事,其实应该是q-k-v的特征数和hiddens
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,num_hiddens, num_heads, 0.5)
attention.eval()
MultiHeadAttention(
(attention): DotProductAttention(
(dropout): Dropout(p=0.5, inplace=False)
)
(W_q): Linear(in_features=100, out_features=100, bias=False)
(W_k): Linear(in_features=100, out_features=100, bias=False)
(W_v): Linear(in_features=100, out_features=100, bias=False)
(W_o): Linear(in_features=100, out_features=100, bias=False)
)
可以通过for循环直接打印每个头的分数,比较占篇幅,我就只打印了一个,每个头的分数形状为(batch_size,num_queries,num_kvpairs),结果确实如此
batch_size, num_queries = 2, 3 num_kvpairs, valid_lens = 6, torch.tensor([3, 2]) # X的形状-batchsize,查询个数,查询的特征数 X = torch.ones((batch_size, num_queries, num_hiddens)) # Y的形状-batchsize,键值的个数,键值的特征数 Y = torch.ones((batch_size, num_kvpairs, num_hiddens)) # 这里的键和值设置的是一样的 att,score=attention(X, Y, Y, valid_lens) print(att.shape,score.shape) # 打印每个头的打分函数 # for i in range(score.shape[0]): # print(f'heads{i}:{score[i]}') # print(f'heads{i}.shape:{score[i].shape}') print(f'heads{0}:{score[0]}') print(f'heads{0}.shape:{score[0].shape}')
torch.Size([2, 3, 100]) torch.Size([5, 2, 3, 6])
heads0:tensor([[[ 1.1521e-01, 1.1521e-01, 1.1521e-01, -1.0000e+06, -1.0000e+06,
-1.0000e+06],
[ 1.1521e-01, 1.1521e-01, 1.1521e-01, -1.0000e+06, -1.0000e+06,
-1.0000e+06],
[ 1.1521e-01, 1.1521e-01, 1.1521e-01, -1.0000e+06, -1.0000e+06,
-1.0000e+06]],
[[ 1.1521e-01, 1.1521e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06,
-1.0000e+06],
[ 1.1521e-01, 1.1521e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06,
-1.0000e+06],
[ 1.1521e-01, 1.1521e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06,
-1.0000e+06]]], grad_fn=<SelectBackward0>)
heads0.shape:torch.Size([2, 3, 6])
结果的形状比较符合预期
num_hiddens, num_heads = 500, 5
batch_size, num_queries = 3, 6
num_kvpairs, valid_lens = 8, torch.tensor([3, 2, 3])
# X的形状-batchsize,查询个数,查询的特征数
X = torch.rand((batch_size, num_queries, num_hiddens))
# Y的形状-batchsize,键值的个数,键值的特征数
Y = torch.rand((batch_size, num_kvpairs, num_hiddens))
attention1 = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,num_hiddens, num_heads, 0.5)
# 这里的键和值设置的是一样的
att,score=attention1(X, Y, Y, valid_lens)
print(f'att的形状:{att.shape},多头socre的形状:{score.shape}')
print(f'heads{0}:{score[0]}')
print(f'heads{0}.shape:{score[0].shape}')
att的形状:torch.Size([3, 6, 500]),多头socre的形状:torch.Size([5, 3, 6, 8]) heads0:tensor([[[-2.2675e-02, -2.0155e-02, -1.6584e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06], [ 3.4621e-02, 2.7170e-02, -1.0288e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06], [ 4.4613e-02, 1.7425e-02, -4.0215e-02, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06], [-2.1214e-02, -7.5245e-02, -1.2768e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06], [-6.9279e-02, -1.0662e-01, -2.2013e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06], [-9.0299e-02, -6.2465e-02, -2.1125e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06]], [[-2.0877e-01, -3.1303e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06], [-1.1978e-01, -2.2805e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06], [-1.6171e-01, -2.3794e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06], [-1.6123e-01, -3.0902e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06], [-5.1692e-02, -1.3835e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06], [-3.8984e-02, -2.4244e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06]], [[-1.0766e-01, -1.5127e-01, -1.3136e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06], [-1.9279e-01, -2.0913e-01, -1.3926e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06], [-5.1507e-02, -6.8220e-02, -1.5038e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06], [-6.9059e-02, -1.1010e-01, -1.2919e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06], [-5.1092e-02, -7.4127e-02, -8.4521e-02, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06], [-9.9756e-02, -1.8430e-01, -2.1404e-01, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06]]], grad_fn=<SelectBackward0>) heads0.shape:torch.Size([3, 6, 8])
有了注意力机制之后,我们将词元序列输入注意力汇聚中,以便同一组词元同时充当查询、键和值。 具体来说,每个查询都会关注所有的键-值对并生成一个注意力输出。由于查询、键和值来自同一组输入,因此被称为自注意力。
在处理词元序列时,循环神经网络是逐个的重复地处理词元的,而自注意力则因为并行计算而放弃了顺序操作。为了使用序列的顺序信息,我们通过在输入表示中添加位置编码(positional encoding)来注入绝对的或相对的位置信息。位置编码可以通过学习得到也可以直接固定得到。\
接下来,我们描述的是基于正弦函数和余弦函数的固定位置编码。假设输入为X,形状为(批量大小,词元个数,特征维度),位置编码使用相同形状的位置嵌入矩阵P, 输出为 X+P。\
位置嵌入矩阵P的计算公式为:
$ p_{i,2j} = \sin{(\frac{i}{10000^{2j/d}})}$ , $ p_{i,2j+1} = \cos{(\frac{i}{10000^{2j/d}})}$
class PositionalEncoding(nn.Module): """位置编码""" def __init__(self, num_hiddens, dropout, max_len=1000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(dropout) # 创建一个足够长的矩阵P self.P = torch.zeros((1, max_len, num_hiddens)) X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1 )/torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens) self.P[:, :, 0::2] = torch.sin(X) self.P[:, :, 1::2] = torch.cos(X) def forward(self, X): X = X + self.P[:, :X.shape[1], :].to(X.device) return self.dropout(X)
batch_size, encoding_dim, num_steps =2, 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()
X = pos_encoding(torch.zeros((batch_size, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
print('输出特征:', X)
print('特征的形状:', X.shape)
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])
输出特征: tensor([[[ 0.0000e+00, 1.0000e+00, 0.0000e+00, ..., 1.0000e+00, 0.0000e+00, 1.0000e+00], [ 8.4147e-01, 5.4030e-01, 5.3317e-01, ..., 1.0000e+00, 1.7783e-04, 1.0000e+00], [ 9.0930e-01, -4.1615e-01, 9.0213e-01, ..., 1.0000e+00, 3.5566e-04, 1.0000e+00], ..., [ 4.3616e-01, 8.9987e-01, 5.9521e-01, ..., 9.9984e-01, 1.0136e-02, 9.9995e-01], [ 9.9287e-01, 1.1918e-01, 9.3199e-01, ..., 9.9983e-01, 1.0314e-02, 9.9995e-01], [ 6.3674e-01, -7.7108e-01, 9.8174e-01, ..., 9.9983e-01, 1.0492e-02, 9.9994e-01]], [[ 0.0000e+00, 1.0000e+00, 0.0000e+00, ..., 1.0000e+00, 0.0000e+00, 1.0000e+00], [ 8.4147e-01, 5.4030e-01, 5.3317e-01, ..., 1.0000e+00, 1.7783e-04, 1.0000e+00], [ 9.0930e-01, -4.1615e-01, 9.0213e-01, ..., 1.0000e+00, 3.5566e-04, 1.0000e+00], ..., [ 4.3616e-01, 8.9987e-01, 5.9521e-01, ..., 9.9984e-01, 1.0136e-02, 9.9995e-01], [ 9.9287e-01, 1.1918e-01, 9.3199e-01, ..., 9.9983e-01, 1.0314e-02, 9.9995e-01], [ 6.3674e-01, -7.7108e-01, 9.8174e-01, ..., 9.9983e-01, 1.0492e-02, 9.9994e-01]]]) 特征的形状: torch.Size([2, 60, 32])
# 上周实现的模块 def masked_softmax(X, valid_lens): """通过在最后一个轴上掩蔽元素来执行softmax操作""" # X:3D张量,valid_lens:1D或2D张量 if valid_lens is None: return nn.functional.softmax(X, dim=-1) else: shape = X.shape if valid_lens.dim() == 1: valid_lens = torch.repeat_interleave(valid_lens, shape[1]) else: valid_lens = valid_lens.reshape(-1) # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0 X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6) return nn.functional.softmax(X.reshape(shape), dim=-1) def split_head_reshape(X, heads_num, head_size): batch_size, seq_len, hidden_size = X.shape X = X.reshape(batch_size, seq_len, heads_num, head_size) X = X.permute(0, 2, 1, 3) X = X.reshape(batch_size, seq_len, head_size * heads_num) return X class DotProductAttention(nn.Module): """缩放点积注意力""" def __init__(self, dropout, **kwargs): super(DotProductAttention, self).__init__(**kwargs) # 使用暂退法进行模型正则化 self.dropout = nn.Dropout(dropout) # queries的形状:(batch_size,查询的个数,d) # keys的形状:(batch_size,“键-值”对的个数,d) # values的形状:(batch_size,“键-值”对的个数,值的维度) # valid_lens的形状:(batch_size,)或者(batch_size,查询的个数) def forward(self, queries, keys, values, valid_lens=None): d = queries.shape[-1] scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d) self.attention_weights = masked_softmax(scores, valid_lens) return torch.bmm(self.dropout(self.attention_weights), values)
class MultiHeadSelfAttention(nn.Module): def __init__(self, inputs_size, heads_num, dropout=0.0): super(MultiHeadSelfAttention, self).__init__() self.inputs_size = inputs_size self.qsize, self.ksize, self.vsize = inputs_size, inputs_size, inputs_size self.heads_num = heads_num self.head_size = inputs_size // heads_num self.Q_proj = nn.Linear(self.qsize, inputs_size) self.K_proj = nn.Linear(self.ksize, inputs_size) self.V_proj = nn.Linear(self.vsize, inputs_size) self.out_proj = nn.Linear(inputs_size, inputs_size) self.attention = DotProductAttention(0.5) self.positional_encoding=PositionalEncoding(inputs_size, 0) def forward(self, X, valid_lens): # 问题: 加入位置编码 X = X + self.positional_encoding(X) self.batch_size, self.seq_len, self.hidden_size = X.shape Q, K, V = self.Q_proj(X), self.K_proj(X), self.V_proj(X) Q, K, V = [split_head_reshape(item, self.heads_num, self.head_size) for item in [Q, K ,V]] out = self.attention(Q, K, V, valid_lens) out = out.reshape(self.batch_size, self.heads_num, self.seq_len, self.head_size) out = out.permute(0, 2, 1, 3) out = out.reshape(self.batch_size, self.seq_len, out.shape[2]*out.shape[3]) out = self.out_proj(out) return out
x= torch.rand((2, 2, 4))
valid_lens = torch.tensor([3]).repeat(x.shape[0])
head = MultiHeadSelfAttention(4, 1)
context = head(x, valid_lens)
print(context.shape)
torch.Size([2, 2, 4])
我们想解决一个文本二分类的任务,每条数据是一段文字,标签值为0或1。
首先定义一个汇聚模块。一般来说,模型的输出的形状为 (批量大小, 时间步数, 特征维度), 汇聚模块将模型输出的所有位置上的特征维度进行平均,作为整段文字的表示,再送入分类器进行最后的类别判断。所以通过汇聚模块后,输出的维度减少了。
class AveragePooling(nn.Module): def __init__(self): super(AveragePooling, self).__init__() def forward(self, inputs, valid_length): valid_length = valid_length.unsqueeze(-1) print("valid_length.shape",valid_length.shape) # 有问题 max_len = inputs.shape[1] # max_len = valid_length.shape[1] mask = torch.arange(max_len) < valid_length print("mask.shape:",mask.shape,"mask:",mask) mask = mask.unsqueeze(-1) print("mask.unsq:",mask.shape) inputs = torch.multiply(inputs, mask) print("inputs.shape:",inputs.shape) mean_outputs = torch.divide(torch.sum(inputs, dim=1), valid_length) print("meao",mean_outputs.shape) return mean_outputs
接下来,实现一个LSTM与多头注意力结合的网络模型。
class lstmWithAttention(nn.Module): def __init__(self, hidden_size, embedding_size, vocab_size, n_classes, attention=None): super(lstmWithAttention, self).__init__() self.hidden_size = hidden_size # 表示词向量的维度 self.embedding_size = embedding_size # 词表的大小, 即包含词的数量 self.vocab_size = vocab_size self.n_classes = n_classes # 以下定义了一些前向计算的模块 self.embedding = nn.Embedding(self.vocab_size, self.embedding_size) # bug2:LSTM中不能少了batch_first=True self.lstm = nn.LSTM(embedding_size, hidden_size, bidirectional=True, batch_first=True) # self.lstm = nn.LSTM(embedding_size, hidden_size, bidirectional=True) 原代码 self.attention = attention self.pooling = AveragePooling() self.cls_fn = nn.Linear(self.hidden_size * 2, self.n_classes) def forward(self, inputs): valid_lens = torch.tensor([10]).repeat(inputs.shape[0]) # batchsize print("valid_lens",valid_lens.shape) # 将词索引变为词的向量 embedded_input = self.embedding(inputs) # embedded_input的形状 batchsize,序列长度,embedding_size print("embedded_input",embedded_input.shape) # last_layers_hiddens的形状 batchsize-序列长度-hidden_size*2 last_layers_hiddens, state = self.lstm(embedded_input) #有问题 print("last_layers_hiddens",last_layers_hiddens.shape) # 调用多头自注意力机制 batchsize-序列长度-hidden_size*2 last_layers_hiddens = self.attention(last_layers_hiddens, valid_lens) # bug3:汇聚层的代码有问题 last_layers_hiddens = self.pooling(last_layers_hiddens, valid_lens) print('pooling:',last_layers_hiddens.shape) logits = self.cls_fn(last_layers_hiddens) return logits
hidden_size = 32
embedding_size = 32
n_classes = 2
n_layers = 1
# vocab_size是词表的大小, 即包含词的数量,这里假设是50000
vocab_size = 50000
head = MultiHeadSelfAttention(inputs_size=64, heads_num=2)
# 定义模型
model = lstmWithAttention(hidden_size, embedding_size, vocab_size, n_classes, attention=head)
print(model)
lstmWithAttention( (embedding): Embedding(50000, 32) (lstm): LSTM(32, 32, batch_first=True, bidirectional=True) (attention): MultiHeadSelfAttention( (Q_proj): Linear(in_features=64, out_features=64, bias=True) (K_proj): Linear(in_features=64, out_features=64, bias=True) (V_proj): Linear(in_features=64, out_features=64, bias=True) (out_proj): Linear(in_features=64, out_features=64, bias=True) (attention): DotProductAttention( (dropout): Dropout(p=0.5, inplace=False) ) (positional_encoding): PositionalEncoding( (dropout): Dropout(p=0, inplace=False) ) ) (pooling): AveragePooling() (cls_fn): Linear(in_features=64, out_features=2, bias=True) )
x = torch.randint(low=0, high=50000, size=(50, 20))
x.shape
torch.Size([50, 20])
model(x).shape
valid_lens torch.Size([50]) embedded_input torch.Size([50, 20, 32]) last_layers_hiddens torch.Size([50, 20, 64]) valid_length.shape torch.Size([50, 1]) mask.shape: torch.Size([50, 20]) mask: tensor([[ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False], [ True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False]]) mask.unsq: torch.Size([50, 20, 1]) inputs.shape: torch.Size([50, 20, 64]) meao torch.Size([50, 64]) pooling: torch.Size([50, 64]) torch.Size([50, 2])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。