赞
踩
https://wangguisen.blog.csdn.net/article/details/126758368
原来的初始化方法是均匀分布,修改下初始化方法能够加速收敛
import torch.nn as nn
import torch.nn.functional as F
import torch
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, bidirectional, dropout):
"""
Args:
input_size: x 的特征维度
hidden_size: 隐层的特征维度
num_layers: LSTM 层数
"""
super(LSTM, self).__init__()
self.rnn = nn.LSTM(
input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional, dropout=dropout
)
self.init_params()
def init_params(self):
for i in range(self.rnn.num_layers):
nn.init.orthogonal_(getattr(self.rnn, f'weight_hh_l{i}'))
nn.init.kaiming_normal_(getattr(self.rnn, f'weight_ih_l{i}'))
nn.init.constant_(getattr(self.rnn, f'bias_hh_l{i}'), val=0)
nn.init.constant_(getattr(self.rnn, f'bias_ih_l{i}'), val=0)
getattr(self.rnn, f'bias_hh_l{i}').chunk(4)[1].fill_(1)
if self.rnn.bidirectional:
nn.init.orthogonal_(
getattr(self.rnn, f'weight_hh_l{i}_reverse'))
nn.init.kaiming_normal_(
getattr(self.rnn, f'weight_ih_l{i}_reverse'))
nn.init.constant_(
getattr(self.rnn, f'bias_hh_l{i}_reverse'), val=0)
nn.init.constant_(
getattr(self.rnn, f'bias_ih_l{i}_reverse'), val=0)
getattr(self.rnn, f'bias_hh_l{i}_reverse').chunk(4)[1].fill_(1)
def forward(self, x, lengths):
# x: [seq_len, batch_size, input_size]
# lengths: [batch_size]
packed_x = nn.utils.rnn.pack_padded_sequence(x, lengths)
# packed_x, packed_output: PackedSequence 对象
# hidden: [num_layers * bidirectional, batch_size, hidden_size]
# cell: [num_layers * bidirectional, batch_size, hidden_size]
packed_output, (hidden, cell) = self.rnn(packed_x)
# output: [real_seq_len, batch_size, hidden_size * 2]
# output_lengths: [batch_size]
output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output)
return hidden, output
import torch
import torch.nn as nn
import torch.nn.functional as F
class Linear(nn.Module):
def __init__(self, in_features, out_features):
super(Linear, self).__init__()
self.linear = nn.Linear(in_features=in_features,
out_features=out_features)
self.init_params()
def init_params(self):
nn.init.kaiming_normal_(self.linear.weight)
nn.init.constant_(self.linear.bias, 0)
def forward(self, x):
x = self.linear(x)
return x
import torch
import torch.nn as nn
import torch.nn.functional as F
class Conv1d(nn.Module):
def __init__(self, in_channels, out_channels, filter_sizes):
super(Conv1d, self).__init__()
self.convs = nn.ModuleList([
nn.Conv1d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=fs)
for fs in filter_sizes
])
self.init_params()
def init_params(self):
for m in self.convs:
nn.init.xavier_uniform_(m.weight.data)
nn.init.constant_(m.bias.data, 0.1)
def forward(self, x):
return [F.relu(conv(x)) for conv in self.convs]
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。