当前位置:   article > 正文

昇思25天学习打卡营第12天 | ShuffleNet图像分类

昇思25天学习打卡营第12天 | ShuffleNet图像分类

ShuffleNet网络介绍

ShuffleNetV1是旷视科技提出的一种计算高效的CNN模型,和MobileNet, SqueezeNet等一样主要应用在移动端,所以模型的设计目标就是利用有限的计算资源来达到最好的模型精度。ShuffleNetV1的设计核心是引入了两种操作:Pointwise Group Convolution和Channel Shuffle,这在保持精度的同时大大降低了模型的计算量。因此,ShuffleNetV1和MobileNet类似,都是通过设计更高效的网络结构来实现模型的压缩和加速。

如下图所示,ShuffleNet在保持不低的准确率的前提下,将参数量几乎降低到了最小,因此其运算速度较快,单位参数量对模型准确率的贡献非常高。

shufflenet1

模型架构

ShuffleNet最显著的特点在于对不同通道进行重排来解决Group Convolution带来的弊端。通过对ResNet的Bottleneck单元进行改进,在较小的计算量的情况下达到了较高的准确率。

Pointwise Group Convolution

Group Convolution(分组卷积)原理如下图所示,相比于普通的卷积操作,分组卷积的情况下,每一组的卷积核大小为in_channels/g*k*k,一共有g组,所有组共有(in_channels/g*k*k)*out_channels个参数,是正常卷积参数的1/g。分组卷积中,每个卷积核只处理输入特征图的一部分通道,其优点在于参数量会有所降低,但输出通道数仍等于卷积核的数量

shufflenet2

Depthwise Convolution(深度可分离卷积)将组数g分为和输入通道相等的in_channels,然后对每一个in_channels做卷积操作,每个卷积核只处理一个通道,记卷积核大小为1*k*k,则卷积核参数量为:in_channels*k*k,得到的feature maps通道数与输入通道数相等

Pointwise Group Convolution(逐点分组卷积)在分组卷积的基础上,令每一组的卷积核大小为 1×11×1,卷积核参数量为(in_channels/g*1*1)*out_channels。

  1. from mindspore import nn
  2. import mindspore.ops as ops
  3. from mindspore import Tensor
  4. class GroupConv(nn.Cell):
  5. def __init__(self, in_channels, out_channels, kernel_size,
  6. stride, pad_mode="pad", pad=0, groups=1, has_bias=False):
  7. super(GroupConv, self).__init__()
  8. self.groups = groups
  9. self.convs = nn.CellList()
  10. for _ in range(groups):
  11. self.convs.append(nn.Conv2d(in_channels // groups, out_channels // groups,
  12. kernel_size=kernel_size, stride=stride, has_bias=has_bias,
  13. padding=pad, pad_mode=pad_mode, group=1, weight_init='xavier_uniform'))
  14. def construct(self, x):
  15. features = ops.split(x, split_size_or_sections=int(len(x[0]) // self.groups), axis=1)
  16. outputs = ()
  17. for i in range(self.groups):
  18. outputs = outputs + (self.convs[i](features[i].astype("float32")),)
  19. out = ops.cat(outputs, axis=1)
  20. return out

Channel Shuffle

Group Convolution的弊端在于不同组别的通道无法进行信息交流,堆积GConv层后一个问题是不同组之间的特征图是不通信的,这就好像分成了g个互不相干的道路,每一个人各走各的,这可能会降低网络的特征提取能力。这也是Xception,MobileNet等网络采用密集的1x1卷积(Dense Pointwise Convolution)的原因。

为了解决不同组别通道“近亲繁殖”的问题,ShuffleNet优化了大量密集的1x1卷积(在使用的情况下计算量占用率达到了惊人的93.4%),引入Channel Shuffle机制(通道重排)。这项操作直观上表现为将不同分组通道均匀分散重组,使网络在下一层能处理不同组别通道的信息。

shufflenet3

如下图所示,对于g组,每组有n个通道的特征图,首先reshape成g行n列的矩阵,再将矩阵转置成n行g列,最后进行flatten操作,得到新的排列。这些操作都是可微分可导的且计算简单,在解决了信息交互的同时符合了ShuffleNet轻量级网络设计的轻量特征。

shufflenet4

为了阅读方便,将Channel Shuffle的代码实现放在下方ShuffleNet模块的代码中。

ShuffleNet模块

如下图所示,ShuffleNet对ResNet中的Bottleneck结构进行由(a)到(b), (c)的更改:

  1. 将开始和最后的1×11×1卷积模块(降维、升维)改成Point Wise Group Convolution;

  2. 为了进行不同通道的信息交流,再降维之后进行Channel Shuffle;

  3. 降采样模块中,3×33×3 Depth Wise Convolution的步长设置为2,长宽降为原来的一般,因此shortcut中采用步长为2的3×33×3平均池化,并把相加改成拼接。

shufflenet5

  1. class ShuffleV1Block(nn.Cell):
  2. def __init__(self, inp, oup, group, first_group, mid_channels, ksize, stride):
  3. super(ShuffleV1Block, self).__init__()
  4. self.stride = stride
  5. pad = ksize // 2
  6. self.group = group
  7. if stride == 2:
  8. outputs = oup - inp
  9. else:
  10. outputs = oup
  11. self.relu = nn.ReLU()
  12. branch_main_1 = [
  13. GroupConv(in_channels=inp, out_channels=mid_channels,
  14. kernel_size=1, stride=1, pad_mode="pad", pad=0,
  15. groups=1 if first_group else group),
  16. nn.BatchNorm2d(mid_channels),
  17. nn.ReLU(),
  18. ]
  19. branch_main_2 = [
  20. nn.Conv2d(mid_channels, mid_channels, kernel_size=ksize, stride=stride,
  21. pad_mode='pad', padding=pad, group=mid_channels,
  22. weight_init='xavier_uniform', has_bias=False),
  23. nn.BatchNorm2d(mid_channels),
  24. GroupConv(in_channels=mid_channels, out_channels=outputs,
  25. kernel_size=1, stride=1, pad_mode="pad", pad=0,
  26. groups=group),
  27. nn.BatchNorm2d(outputs),
  28. ]
  29. self.branch_main_1 = nn.SequentialCell(branch_main_1)
  30. self.branch_main_2 = nn.SequentialCell(branch_main_2)
  31. if stride == 2:
  32. self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same')
  33. def construct(self, old_x):
  34. left = old_x
  35. right = old_x
  36. out = old_x
  37. right = self.branch_main_1(right)
  38. if self.group > 1:
  39. right = self.channel_shuffle(right)
  40. right = self.branch_main_2(right)
  41. if self.stride == 1:
  42. out = self.relu(left + right)
  43. elif self.stride == 2:
  44. left = self.branch_proj(left)
  45. out = ops.cat((left, right), 1)
  46. out = self.relu(out)
  47. return out
  48. def channel_shuffle(self, x):
  49. batchsize, num_channels, height, width = ops.shape(x)
  50. group_channels = num_channels // self.group
  51. x = ops.reshape(x, (batchsize, group_channels, self.group, height, width))
  52. x = ops.transpose(x, (0, 2, 1, 3, 4))
  53. x = ops.reshape(x, (batchsize, num_channels, height, width))
  54. return x

构建ShuffleNet网络

ShuffleNet网络结构如下图所示,以输入图像224×224224×224,组数3(g = 3)为例,首先通过数量24,卷积核大小为3×33×3,stride为2的卷积层,输出特征图大小为112×112112×112,channel为24;然后通过stride为2的最大池化层,输出特征图大小为56×5656×56,channel数不变;再堆叠3个ShuffleNet模块(Stage2, Stage3, Stage4),三个模块分别重复4次、8次、4次,其中每个模块开始先经过一次下采样模块(上图(c)),使特征图长宽减半,channel翻倍(Stage2的下采样模块除外,将channel数从24变为240);随后经过全局平均池化,输出大小为1×1×9601×1×960,再经过全连接层和softmax,得到分类概率。

shufflenet6

  1. class ShuffleNetV1(nn.Cell):
  2. def __init__(self, n_class=1000, model_size='2.0x', group=3):
  3. super(ShuffleNetV1, self).__init__()
  4. print('model size is ', model_size)
  5. self.stage_repeats = [4, 8, 4]
  6. self.model_size = model_size
  7. if group == 3:
  8. if model_size == '0.5x':
  9. self.stage_out_channels = [-1, 12, 120, 240, 480]
  10. elif model_size == '1.0x':
  11. self.stage_out_channels = [-1, 24, 240, 480, 960]
  12. elif model_size == '1.5x':
  13. self.stage_out_channels = [-1, 24, 360, 720, 1440]
  14. elif model_size == '2.0x':
  15. self.stage_out_channels = [-1, 48, 480, 960, 1920]
  16. else:
  17. raise NotImplementedError
  18. elif group == 8:
  19. if model_size == '0.5x':
  20. self.stage_out_channels = [-1, 16, 192, 384, 768]
  21. elif model_size == '1.0x':
  22. self.stage_out_channels = [-1, 24, 384, 768, 1536]
  23. elif model_size == '1.5x':
  24. self.stage_out_channels = [-1, 24, 576, 1152, 2304]
  25. elif model_size == '2.0x':
  26. self.stage_out_channels = [-1, 48, 768, 1536, 3072]
  27. else:
  28. raise NotImplementedError
  29. input_channel = self.stage_out_channels[1]
  30. self.first_conv = nn.SequentialCell(
  31. nn.Conv2d(3, input_channel, 3, 2, 'pad', 1, weight_init='xavier_uniform', has_bias=False),
  32. nn.BatchNorm2d(input_channel),
  33. nn.ReLU(),
  34. )
  35. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
  36. features = []
  37. for idxstage in range(len(self.stage_repeats)):
  38. numrepeat = self.stage_repeats[idxstage]
  39. output_channel = self.stage_out_channels[idxstage + 2]
  40. for i in range(numrepeat):
  41. stride = 2 if i == 0 else 1
  42. first_group = idxstage == 0 and i == 0
  43. features.append(ShuffleV1Block(input_channel, output_channel,
  44. group=group, first_group=first_group,
  45. mid_channels=output_channel // 4, ksize=3, stride=stride))
  46. input_channel = output_channel
  47. self.features = nn.SequentialCell(features)
  48. self.globalpool = nn.AvgPool2d(7)
  49. self.classifier = nn.Dense(self.stage_out_channels[-1], n_class)
  50. def construct(self, x):
  51. x = self.first_conv(x)
  52. x = self.maxpool(x)
  53. x = self.features(x)
  54. x = self.globalpool(x)
  55. x = ops.reshape(x, (-1, self.stage_out_channels[-1]))
  56. x = self.classifier(x)
  57. return x

模型训练和评估

采用CIFAR-10数据集对ShuffleNet进行预训练。

训练集准备与加载

采用CIFAR-10数据集对ShuffleNet进行预训练。CIFAR-10共有60000张32*32的彩色图像,均匀地分为10个类别,其中50000张图片作为训练集,10000图片作为测试集。如下示例使用mindspore.dataset.Cifar10Dataset接口下载并加载CIFAR-10的训练集。目前仅支持二进制版本(CIFAR-10 binary version)。

  1. from download import download
  2. url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz"
  3. download(url, "./dataset", kind="tar.gz", replace=True)
  1. import mindspore as ms
  2. from mindspore.dataset import Cifar10Dataset
  3. from mindspore.dataset import vision, transforms
  4. def get_dataset(train_dataset_path, batch_size, usage):
  5. image_trans = []
  6. if usage == "train":
  7. image_trans = [
  8. vision.RandomCrop((32, 32), (4, 4, 4, 4)),
  9. vision.RandomHorizontalFlip(prob=0.5),
  10. vision.Resize((224, 224)),
  11. vision.Rescale(1.0 / 255.0, 0.0),
  12. vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
  13. vision.HWC2CHW()
  14. ]
  15. elif usage == "test":
  16. image_trans = [
  17. vision.Resize((224, 224)),
  18. vision.Rescale(1.0 / 255.0, 0.0),
  19. vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
  20. vision.HWC2CHW()
  21. ]
  22. label_trans = transforms.TypeCast(ms.int32)
  23. dataset = Cifar10Dataset(train_dataset_path, usage=usage, shuffle=True)
  24. dataset = dataset.map(image_trans, 'image')
  25. dataset = dataset.map(label_trans, 'label')
  26. dataset = dataset.batch(batch_size, drop_remainder=True)
  27. return dataset
  28. dataset = get_dataset("./dataset/cifar-10-batches-bin", 128, "train")
  29. batches_per_epoch = dataset.get_dataset_size()

模型训练

本节用随机初始化的参数做预训练。首先调用ShuffleNetV1定义网络,参数量选择"2.0x",并定义损失函数为交叉熵损失,学习率经过4轮的warmup后采用余弦退火,优化器采用Momentum。最后用train.model中的Model接口将模型、损失函数、优化器封装在model中,并用model.train()对网络进行训练。将ModelCheckpointCheckpointConfigTimeMonitorLossMonitor传入回调函数中,将会打印训练的轮数、损失和时间,并将ckpt文件保存在当前目录下。

  1. import time
  2. import mindspore
  3. import numpy as np
  4. from mindspore import Tensor, nn
  5. from mindspore.train import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor, Model, Top1CategoricalAccuracy, Top5CategoricalAccuracy
  6. def train():
  7. mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target="Ascend")
  8. net = ShuffleNetV1(model_size="2.0x", n_class=10)
  9. loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)
  10. min_lr = 0.0005
  11. base_lr = 0.05
  12. lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,
  13. base_lr,
  14. batches_per_epoch*250,
  15. batches_per_epoch,
  16. decay_epoch=250)
  17. lr = Tensor(lr_scheduler[-1])
  18. optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.00004, loss_scale=1024)
  19. loss_scale_manager = ms.amp.FixedLossScaleManager(1024, drop_overflow_update=False)
  20. model = Model(net, loss_fn=loss, optimizer=optimizer, amp_level="O3", loss_scale_manager=loss_scale_manager)
  21. callback = [TimeMonitor(), LossMonitor()]
  22. save_ckpt_path = "./"
  23. config_ckpt = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=5)
  24. ckpt_callback = ModelCheckpoint("shufflenetv1", directory=save_ckpt_path, config=config_ckpt)
  25. callback += [ckpt_callback]
  26. print("============== Starting Training ==============")
  27. start_time = time.time()
  28. # 由于时间原因,epoch = 5,可根据需求进行调整
  29. model.train(5, dataset, callbacks=callback)
  30. use_time = time.time() - start_time
  31. hour = str(int(use_time // 60 // 60))
  32. minute = str(int(use_time // 60 % 60))
  33. second = str(int(use_time % 60))
  34. print("total time:" + hour + "h " + minute + "m " + second + "s")
  35. print("============== Train Success ==============")
  36. if __name__ == '__main__':
  37. train()

模型评估

在CIFAR-10的测试集上对模型进行评估。

设置好评估模型的路径后加载数据集,并设置Top 1, Top 5的评估标准,最后用model.eval()接口对模型进行评估。

  1. from mindspore import load_checkpoint, load_param_into_net
  2. def test():
  3. mindspore.set_context(mode=mindspore.GRAPH_MODE, device_target="Ascend")
  4. dataset = get_dataset("./dataset/cifar-10-batches-bin", 128, "test")
  5. net = ShuffleNetV1(model_size="2.0x", n_class=10)
  6. param_dict = load_checkpoint("shufflenetv1-5_390.ckpt")
  7. load_param_into_net(net, param_dict)
  8. net.set_train(False)
  9. loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)
  10. eval_metrics = {'Loss': nn.Loss(), 'Top_1_Acc': Top1CategoricalAccuracy(),
  11. 'Top_5_Acc': Top5CategoricalAccuracy()}
  12. model = Model(net, loss_fn=loss, metrics=eval_metrics)
  13. start_time = time.time()
  14. res = model.eval(dataset, dataset_sink_mode=False)
  15. use_time = time.time() - start_time
  16. hour = str(int(use_time // 60 // 60))
  17. minute = str(int(use_time // 60 % 60))
  18. second = str(int(use_time % 60))
  19. log = "result:" + str(res) + ", ckpt:'" + "./shufflenetv1-5_390.ckpt" \
  20. + "', time: " + hour + "h " + minute + "m " + second + "s"
  21. print(log)
  22. filename = './eval_log.txt'
  23. with open(filename, 'a') as file_object:
  24. file_object.write(log + '\n')
  25. if __name__ == '__main__':
  26. test()

模型预测

在CIFAR-10的测试集上对模型进行预测,并将预测结果可视化。

  1. import mindspore
  2. import matplotlib.pyplot as plt
  3. import mindspore.dataset as ds
  4. net = ShuffleNetV1(model_size="2.0x", n_class=10)
  5. show_lst = []
  6. param_dict = load_checkpoint("shufflenetv1-5_390.ckpt")
  7. load_param_into_net(net, param_dict)
  8. model = Model(net)
  9. dataset_predict = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")
  10. dataset_show = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")
  11. dataset_show = dataset_show.batch(16)
  12. show_images_lst = next(dataset_show.create_dict_iterator())["image"].asnumpy()
  13. image_trans = [
  14. vision.RandomCrop((32, 32), (4, 4, 4, 4)),
  15. vision.RandomHorizontalFlip(prob=0.5),
  16. vision.Resize((224, 224)),
  17. vision.Rescale(1.0 / 255.0, 0.0),
  18. vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
  19. vision.HWC2CHW()
  20. ]
  21. dataset_predict = dataset_predict.map(image_trans, 'image')
  22. dataset_predict = dataset_predict.batch(16)
  23. class_dict = {0:"airplane", 1:"automobile", 2:"bird", 3:"cat", 4:"deer", 5:"dog", 6:"frog", 7:"horse", 8:"ship", 9:"truck"}
  24. # 推理效果展示(上方为预测的结果,下方为推理效果图片)
  25. plt.figure(figsize=(16, 5))
  26. predict_data = next(dataset_predict.create_dict_iterator())
  27. output = model.predict(ms.Tensor(predict_data['image']))
  28. pred = np.argmax(output.asnumpy(), axis=1)
  29. index = 0
  30. for image in show_images_lst:
  31. plt.subplot(2, 8, index+1)
  32. plt.title('{}'.format(class_dict[pred[index]]))
  33. index += 1
  34. plt.imshow(image)
  35. plt.axis("off")
  36. plt.show()

代码实现:

为什么这两天总是报错qaq

那我弄到本地吧,这个云平台qaq

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/爱喝兽奶帝天荒/article/detail/782933
推荐阅读
相关标签
  

闽ICP备14008679号