赞
踩
首先,在我看来,transformer就是nlp的编码解码思想(seq2seq model)加上注意力机制,下面简单介绍一下seq2seq model
Seq2Seq(Sequence to Sequence)模型是一种用于处理序列数据的深度学习模型,特别适用于处理不定长度的输入序列和输出序列的任务。该模型最初被设计用于机器翻译,但后来被广泛应用于其他自然语言处理任务,语音识别,文本生成等领域。
Seq2Seq模型由两个主要部分组成:编码器(Encoder)和解码器(Decoder)。
1. 编码器(Encoder):
编码器负责将输入序列转换为一个固定长度的表示(context向量),这个表示将包含输入序列的信息。常用的编码器是循环神经网络(RNN)或者长短时记忆网络(LSTM)。编码器的输出通常是输入序列全体时间步的隐藏状态或者最终隐藏状态。
2. 解码器(Decoder):
解码器接收编码器的输出(context向量)并生成输出序列。解码器同样可以是RNN或者LSTM。在生成序列的过程中,解码器逐步地生成一个元素,同时利用先前生成的元素作为上下文来指导生成下一个元素。Seq2Seq模型的训练过程通常使用Teacher Forcing方法。在训练时,解码器的输入是已知的目标序列(ground truth),而在推理阶段(生成阶段),解码器的输入是其自己先前生成的元素。在Seq2Seq的训练中,损失函数通常使用交叉熵损失函数来衡量生成序列与目标序列之间的差异。
3. 应用场景:
机器翻译: 将一个语言的句子翻译成另一个语言的句子。
文本摘要: 生成输入文本的摘要。
对话系统: 应用于生成对话系统的回复。
语音识别: 将语音信号转换为文本。
总体而言,Seq2Seq模型是一种强大的框架,可用于处理输入和输出序列不定长的任务,并在自然语言处理等领域取得了许多成功的应用。
项目地址:GitHub - bai-shang/crnn_seq2seq_ocr_pytorch: Extremely simple implement for Chinese OCR by PyTorch.
数据集地址:GitHub - senlinuc/caffe_ocr: 主流ocr算法研究实验性的项目,目前实现了CNN+BLSTM+CTC架构
项目简单介绍:该项目实现了卷积循环神经网络 (CRNN),它是 CNN 和序列到序列模型的组合,关注基于图像的序列识别任务,例如场景文本识别和 OCR
编码器的主要作用是将输入的图像信息进行特征提取和编码,以便后续解码器进行处理。整个编码器由两个部分组成:卷积神经网络(CNN)和循环神经网络(RNN)。
卷积层 1:
输入通道:3(RGB图像)
输出通道:64
卷积核大小:(3, 3)
激活函数:ReLU
池化层(MaxPooling)
卷积层 2:
输入通道:64
输出通道:128
卷积核大小:(3, 3)
激活函数:ReLU
池化层(MaxPooling)
卷积层 3:
输入通道:128
输出通道:256
卷积核大小:(3, 3)
激活函数:ReLU
池化层(MaxPooling)
卷积层 4:
输入通道:256
输出通道:256
卷积核大小:(3, 3)
激活函数:ReLU
池化层(MaxPooling)
卷积层 5:
输入通道:256
输出通道:512
卷积核大小:(3, 3)
激活函数:ReLU
池化层(MaxPooling)
卷积层 6:
输入通道:512
输出通道:512
卷积核大小:(3, 3)
激活函数:ReLU
池化层(MaxPooling)
卷积层 7:
输入通道:512
输出通道:512
卷积核大小:(2, 2) #这是一个较小的卷积层,用于进一步提取图像的局部特征。
激活函数:ReLU
一个双向的循环神经网络(BidirectionalLSTM)。它的作用是对来自卷积神经网络的特征进行更深层次的编码和建模。这里使用了两个双向 LSTM 层,每个层的输出维度都是 256。每个 LSTM 层之后都跟随一个线性层(Linear),用于映射隐藏状态到更高维度,以增强模型的表达能力。
LSTM(Long Short-Term Memory):LSTM 是一种专门用于处理序列数据的循环神经网络结构,可以有效地捕捉长距离依赖关系。
双向操作:双向 LSTM 允许模型同时考虑输入序列的过去和未来信息,从而更好地捕捉上下文信息。
Linear 层:线性层用于将 LSTM 层的隐藏状态映射到更高维度的特征表示,以供解码器使用。
- class Encoder(nn.Module):
- def __init__(self, channel_size, hidden_size):
- super(Encoder, self).__init__()
- self.cnn = CNN(channel_size)
- self.rnn = nn.Sequential(
- BidirectionalLSTM(512, hidden_size, hidden_size),
- BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
-
- def forward(self, input):
- # conv features
- conv = self.cnn(input)
- b, c, h, w = conv.size()
- assert h == 1, "the height of conv must be 1"
-
- # rnn feature
- conv = conv.squeeze(2) # [b, c, 1, w] -> [b, c, w]
- conv = conv.permute(2, 0, 1) # [b, c, w] -> [w, b, c]
- output = self.rnn(conv)
- return output
-
- class CNN(nn.Module):
-
- def __init__(self, channel_size):
- super(CNN, self).__init__()
- self.cnn = nn.Sequential(
- nn.Conv2d(channel_size, 64, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d(2, 2),
- nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d(2, 2),
- nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True),
- nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d((2,2), (2,1), (0,1)),
- nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),
- nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d((2,2), (2,1), (0,1)),
- nn.Conv2d(512, 512, 2, 1, 0), nn.BatchNorm2d(512), nn.ReLU(True))
-
- def forward(self, input):
- # [n, channel_size, 32, 280] -> [n, 512, 1, 71]
- conv = self.cnn(input)
- return conv
-
- class BidirectionalLSTM(nn.Module):
-
- def __init__(self, input_size, hidden_size, output_size):
- super(BidirectionalLSTM, self).__init__()
-
- self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True)
- self.embedding = nn.Linear(hidden_size * 2, output_size)
-
- def forward(self, input):
- recurrent, _ = self.rnn(input)
- T, b, h = recurrent.size()
- t_rec = recurrent.view(T * b, h)
-
- output = self.embedding(t_rec) # [T * b, output_size]
- output = output.view(T, b, -1)
- return output
解码器将编码器提取的特征信息解码成最终的文本序列。在这里,解码器采用了带有注意力机制(AttnDecoderRNN)的循环神经网络(GRU)。
Embedding 层(Embedding):将输入的标签索引映射为密集的词嵌入(word embedding)。这一层的输出将作为后续的输入提供给 GRU。
注意力层(attn):注意力机制用于动态地关注输入序列中的不同部分。这里通过一个线性层将解码器当前的隐藏状态和编码器的输出进行结合,产生注意力分布。
注意力结合层(attn_combine):将注意力权重应用于编码器的输出,以加权求和的方式结合编码器的输出和当前解码器的输入。
Dropout 层(dropout):用于在训练过程中随机丢弃一些神经元,以防止过拟合。
GRU 层(GRU):GRU 是一种循环神经网络结构,用于处理序列数据。它接受当前时刻的输入和先前时刻的隐藏状态,并生成当前时刻的输出和新的隐藏状态。
线性输出层(out):将 GRU 层的输出映射到最终的输出空间,这里是对应词汇表的大小(5992),用于预测下一个标签的概率分布。
解码器具体代码实现
- class Decoder(nn.Module):
-
- def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=71):
- super(Decoder, self).__init__()
- self.hidden_size = hidden_size
- self.decoder = AttnDecoderRNN(hidden_size, output_size, dropout_p, max_length)
-
- def forward(self, input, hidden, encoder_outputs):
- return self.decoder(input, hidden, encoder_outputs)
-
- def initHidden(self, batch_size):
- result = torch.autograd.Variable(torch.zeros(1, batch_size, self.hidden_size))
- return result
-
- class AttnDecoderRNN(nn.Module):
-
- def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=71):
- super(AttnDecoderRNN, self).__init__()
- self.hidden_size = hidden_size
- self.output_size = output_size
- self.dropout_p = dropout_p
- self.max_length = max_length
-
- self.embedding = nn.Embedding(self.output_size, self.hidden_size)
- self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
- self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
- self.dropout = nn.Dropout(self.dropout_p)
- self.gru = nn.GRU(self.hidden_size, self.hidden_size)
- self.out = nn.Linear(self.hidden_size, self.output_size)
-
-
- def forward(self, input, hidden, encoder_outputs):
- embedded = self.embedding(input)
- embedded = self.dropout(embedded)
-
- attn_weights = F.softmax(self.attn(torch.cat((embedded, hidden[0]), 1)), dim=1)
- attn_applied = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs.permute(1, 0, 2))
-
- output = torch.cat((embedded, attn_applied.squeeze(1)), 1)
- output = self.attn_combine(output).unsqueeze(0)
-
- output = F.relu(output)
- output, hidden = self.gru(output, hidden)
-
- output = F.log_softmax(self.out(output[0]), dim=1)
- return output, hidden, attn_weights
-
- def initHidden(self):
- return torch.zeros(1, 1, self.hidden_size, device=device)
开始训练
这个项目运行过程中会有一些问题,报错可以评论,如果我会的话,可以帮助解决...........
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。