赞
踩
class Residual(nn.Module): def __init__(self, in_channel, out_channel, stride=1): super(Residual, self).__init__() self.bottleneck = nn.Sequential( nn.Conv2d(in_channel, out_channel, 3, stride, padding=1, bias=False), nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True), nn.Conv2d(out_channel, out_channel, 3, padding=1, bias=False), nn.BatchNorm2d(out_channel), ) self.relu = nn.ReLU(inplace=True) self.downsample = nn.Sequential( nn.Conv2d(in_channel, out_channel, 1, stride), nn.BatchNorm2d(out_channel), ) def forward(self, x): out = self.bottleneck(x) identity = self.downsample(x) out += identity out = self.relu(out) return out
# 通道数翻倍,宽高减半
net = Residual(3, 6, 2)
x = torch.randn(4, 3, 6, 6)
net(x).shape
torch.Size([4, 6, 3, 3])
# 保持形状不变
net = Residual(3, 3)
x = torch.randn(4, 3, 6, 6)
net(x).shape
torch.Size([4, 3, 6, 6])
class Residual(nn.Module): def __init__(self, in_channel, out_channel, bottleneck_channel=None, stride=1): super(Residual, self).__init__() bottleneck_channel = in_channel / 2 if bottleneck_channel is None self.bottleneck = nn.Sequential( nn.Conv2d(in_channel, bottleneck_channel, 1, bias=False), nn.BatchNorm2d(bottleneck_channel), nn.ReLU(inplace=True), nn.Conv2d(bottleneck_channel, bottleneck_channel, 3, stride, padding=1, bias=False), nn.BatchNorm2d(bottleneck_channel), nn.ReLU(inplace=True), nn.Conv2d(bottleneck_channel, out_channel, 1, bias=False), nn.BatchNorm2d(out_channel), ) self.relu = nn.ReLU(inplace=True) self.downsample = nn.Sequential( nn.Conv2d(in_channel, out_channel, 1, stride), nn.BatchNorm2d(out_channel), ) def forward(self, x): out = self.bottleneck(x) identity = self.downsample(x) out += identity out = self.relu(out) return out
def resnet_block(in_channels, out_channels, num_residuals, first_block=False): blk = nn.Sequential() for i in range(num_residuals): blk.add_module(str(i), Residual(in_channels, out_channels, stride=2 if i == 0 and not first_block else 1)) in_channels = out_channels return blk class ResNet(nn.Module): def __init__(self, in_channel, class_num): super(ResNet, self).__init__() self.stem = nn.Sequential( nn.Conv2d(in_channel, 64, 7, stride=2, padding=3), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(3, stride=2, padding=1), ) resnet_blocks = [] in_channels = [64, 64, 128, 256, 512] for i in range(4): resnet_blocks += [resnet_block(in_channels[i], in_channels[i + 1], 2, first_block=True if i == 0 else False)] self.resnet_blocks = nn.Sequential(*resnet_blocks) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Linear(512, class_num) def forward(self, x): out = self.stem(x) out = self.resnet_blocks(out) out = self.avg_pool(out).view(-1, 512) out = self.fc(out) return out
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。