当前位置:   article > 正文

Pytorch实现ResNet_pytorch resnet

pytorch resnet

Pytorch实现ResNet

一、ResNet网络介绍

  1. ResNet在2015年被提出,在ImageNet比赛classification任务上获得第一名,目标检测第一名。获得COCO数据集中目标检测第一名,图像分割第一名。由于它“简单与实用”并存,之后很多方法都建立在ResNet50或者ResNet101的基础上完成的,检测,分割,识别等领域里得到广泛的应用。

  2. ResNet残差结构图:

在这里插入图片描述

  1. ResNet网络结构参数列表:

在这里插入图片描述

  1. ResNet网络的高点

    • 提出residual结构(残差结构)
    • 拱建超深的网络结构(突破1000层)
    • 使用Batch Normalization加速训练(丢弃dropout)

二、ResNet网络的中心——残差学习

  1. 残差

    残差是指对每层的输入做一个reference(X), 学习形成残差函数。

  2. 残差学习block的分支

    • identity mapping:指的是图(一、2)右边那条弯的曲线。顾名思义,identity mapping指的就是本身的映射,也就是 x 自身
    • residual mapping:指的是另一条分支,也就是 F(x) 部分,这部分称为残差映射
  3. 残差学习的定义公式

    y = F ( x, { Wi }) + x

三、ResNet网络代码实现

  1. ResNet网络模型

    import torch
    import torch.nn as nn
    
    class BaicsBlock(nn.Module):
        # 主分支的卷积个数的倍数
        def expansion(self):
            expansion = 1
            return expansion
    
        def __init__(self, in_channel, out_channel, stride=1, downsample=None):
            super(BaicsBlock, self).__init__()
            self.conv1 = nn.Conv2d(in_channels=in_channel,
                                   out_channels=out_channel,
                                   kernel_size=3,
                                   stride=stride,
                                   padding=1,
                                   bias=False)  # 不使用偏置,bias=False
            self.bn1 = nn.BatchNorm2d(out_channel)
            self.relu =nn.ReLU(inplace=True)
            self.conv2 = nn.Conv2d(in_channels=out_channel,
                                   out_channels=out_channel,
                                   kernel_size=3,
                                   stride=stride,
                                   padding=1,
                                   bias=False)
            self.bn2 = nn.BatchNorm2d(out_channel)
            self.downsample = downsample    # 下采样参数,虚线的残差结构
    
        def forward(self, x):
            # 捷径分支下采样参数保存变量
            identity = x
            if self.downsample is not None:
                identity = self.downsample(x)
    
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
    
            x = self.conv2(x)
            x = self.bn2(x)
            x +=identity
            x = self.relu(x)
    
            return x
    
    class Bottleneck(nn.Module):
        def expansion(self):
            expansion = 4
            return expansion
    
        def __bool__(self, in_channel, out_channel, stride=1, downsample=None):
            super(Bottleneck, self).__init__()
            self.conv1 = nn.Conv2d(in_channels=in_channel,
                                   out_channels=out_channel,
                                   kernel_size=1,
                                   stride=1,
                                   padding=1,
                                   bias=False)
            self.bn1 = nn.BatchNorm2d(out_channel)
            self.relu = nn.ReLU(inplace=True)
            self.conv2 = nn.Conv2d(in_channels=out_channel,
                                   out_channels=out_channel,
                                   kernel_size=3,
                                   stride=stride,
                                   padding=1,
                                   bias=False)
            self.bn2 = nn.BatchNorm2d(out_channel)
            self.conv3 = nn.Conv2d(in_channels=out_channel,
                                   out_channels=out_channel,
                                   kernel_size=1,
                                   stride=1,
                                   padding=1,
                                   bias=False)
            self.bn3 = nn.BatchNorm2d(out_channel*self.expansion())
            self.downsample = downsample
    
        def forward(self, x):
            identity = x
            if self.downsample is not None:
                identity = self.downsample(x)
    
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
    
            x = self.conv2(x)
            x = self.bn2(x)
            x = self.relu(x)
    
            x = self.conv3(x)
            x = self.bn3(x)
            x += identity
            x = self.relu(x)
    
            return x
    
    class ResNet(nn.Module):
        def __init__(self, block, block_list, num_classes=1000, include_top=True):
            super(ResNet, self).__init__()
            self.include_top = include_top
            self.in_channel = 64
    
            self.conv1 = nn.Conv2d(in_channels=3,
                                   out_channels=self.in_channel,
                                   kernel_size=7,
                                   stride=2,
                                   padding=3,
                                   bias=False)
            self.bn1 = nn.BatchNorm2d(self.in_channel)
            self.relu = nn.ReLU(inplace=True)
            self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
    
            self.layer_1 = self.make_layer(block, 64, block_list[0])
            self.layer_2 = self.make_layer(block, 128, block_list[1], stride=2)
            self.layer_3 = self.make_layer(block, 256, block_list[2], stride=2)
            self.layer_4 = self.make_layer(block, 512, block_list[3], stride=2)
            if self.include_top:
                self.avgpool = nn.AdaptiveAvgPool1d((1,1))
                self.fc = nn.Linear(512 * block.expansion(self), num_classes)
    
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    
        def make_layer(self, block, channel, block_list, stride=1):
            downsample = None
            if stride != 1 or self.in_channel != channel * block.expansion(self):
                downsample = nn.Sequential(
                    nn.Conv2d(self.in_channel, channel * block.expansion(self), kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(channel * block.expansion(self)))
    
            layers = []
            layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))
            self.in_channel = channel * block.expansion(self)
    
            for _ in range(1, block_list):
                layers.append(block(self.in_channel, channel))
    
            return nn.Sequential(*layers)
    
        def forward(self, x):
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
            x = self.maxpool(x)
    
            x = self.layer_1(x)
            x = self.layer_2(x)
            x = self.layer_3(x)
            x = self.layer_4(x)
    
            if self.include_top:
                x = self.avgpool(x)
                x = torch.flatten(x, 1)
                x = self.fc(x)
    
            return x
    def ResNet18(num_classes=1000, include_top=True):
        return ResNet(BaicsBlock, [2, 2, 2, 2], num_classes=num_classes, include_top=include_top)
    def ResNet34(num_classes=1000, include_top=True):
        return ResNet(BaicsBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)
    def ResNet50(num_classes=1000, include_top=True):
        return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)
    def ResNet101(num_classes=1000, include_top=True):
        return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)
    def ResNet152(num_classes=1000, include_top=True):
        return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, include_top=include_top)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
  2. ResNet网络训练(5分类的花分类)

    from nlp.task.CIFAR10_try.ResNet import ResNet34	# ResNet本地导入,就是上面的网络模型导入
    import os
    import json
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torchvision import transforms, datasets
    from tqdm import tqdm
    def main():
        # 判断cuda? cpu?
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        print("using {} device.".format(device))
        
        # 数据处理
        data_transform = {
            "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                         transforms.RandomHorizontalFlip(),
                                         transforms.ToTensor(),
                                         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
            "val": transforms.Compose([transforms.Resize(256),
                                       transforms.CenterCrop(224),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
    	# 设置路径
        data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
        image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
        assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
        # 数据导入
        train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                             transform=data_transform["train"])
        train_num = len(train_dataset)
    
        # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
        flower_list = train_dataset.class_to_idx
        cla_dict = dict((val, key) for key, val in flower_list.items())
        # write dict into json file
        json_str = json.dumps(cla_dict, indent=4)
        with open('class_indices.json', 'w') as json_file:
            json_file.write(json_str)
    
        batch_size = 16
        nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
        print('Using {} dataloader workers every process'.format(nw))
    
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=batch_size, shuffle=True,
                                                   num_workers=nw)
    
        validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                                transform=data_transform["val"])
        val_num = len(validate_dataset)
        validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                      batch_size=batch_size, shuffle=False,
                                                      num_workers=nw)
    
        print("using {} images for training, {} images for validation.".format(train_num,
                                                                               val_num))
    
        net = ResNet34()
        model_weight_path = "./resnet34-pre.pth"	# 官网权重:https://download.pytorch.org/models/resnet34-333f7ec4.pth
        assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
        net.load_state_dict(torch.load(model_weight_path, map_location=device))
        in_channel = net.fc.in_features
        net.fc = nn.Linear(in_channel, 5)
        net.to(device)
    
        # 优化器
        loss_function = nn.CrossEntropyLoss()
        params = [p for p in net.parameters() if p.requires_grad]
        optimizer = optim.Adam(params, lr=0.0001)
    
        epochs = 5
        best_acc = 0.0
        save_path = './resNet34.pth'
        train_steps = len(train_loader)
        for epoch in range(epochs):
            # train
            net.train()
            running_loss = 0.0
            train_bar = tqdm(train_loader)
            for step, data in enumerate(train_bar):
                images, labels = data
                optimizer.zero_grad()
                logits = net(images.to(device))
                loss = loss_function(logits, labels.to(device))
                loss.backward()
                optimizer.step()
    
                # print statistics
                running_loss += loss.item()
    
                train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                         epochs,
                                                                         loss)
    
            # validate
            net.eval()
            acc = 0.0  # accumulate accurate number / epoch
            with torch.no_grad():
                val_bar = tqdm(validate_loader)
                for val_data in val_bar:
                    val_images, val_labels = val_data
                    outputs = net(val_images.to(device))
                    # loss = loss_function(outputs, test_labels)
                    predict_y = torch.max(outputs, dim=1)[1]
                    acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
    
                    val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                               epochs)
    
            val_accurate = acc / val_num
            print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
                  (epoch + 1, running_loss / train_steps, val_accurate))
    
            if val_accurate > best_acc:
                best_acc = val_accurate
                torch.save(net.state_dict(), save_path)
    
        print('Finished Training')
    
    if __name__ == '__main__':
        main()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122

**Tips:**Batch Normalization:是使我们的一批(Batch)feature map 满足均值为0,方差为1的分布规律。在使用BN时,训练时将training设置为True,在验证时将training设置为False。将BN层放在conv层与ReLU层之间,并且conv层不能使用偏置。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bxGTYEI2-1618642831083)(image/image-20210413124110903.png)]

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/211736
推荐阅读
相关标签
  

闽ICP备14008679号