当前位置:   article > 正文

昇思25天学习打卡营第13天 | ShuffleNet图像分类

昇思25天学习打卡营第13天 | ShuffleNet图像分类
内容介绍:

ShuffleNetV1是旷视科技提出的一种计算高效的CNN模型,和MobileNet, SqueezeNet等一样主要应用在移动端,所以模型的设计目标就是利用有限的计算资源来达到最好的模型精度。ShuffleNetV1的设计核心是引入了两种操作:Pointwise Group Convolution和Channel Shuffle,这在保持精度的同时大大降低了模型的计算量。因此,ShuffleNetV1和MobileNet类似,都是通过设计更高效的网络结构来实现模型的压缩和加速。

ShuffleNet最显著的特点在于对不同通道进行重排来解决Group Convolution带来的弊端。通过对ResNet的Bottleneck单元进行改进,在较小的计算量的情况下达到了较高的准确率。

Pointwise Group Convolution

Group Convolution(分组卷积)原理如下图所示,相比于普通的卷积操作,分组卷积的情况下,每一组的卷积核大小为in_channels/g*k*k,一共有g组,所有组共有(in_channels/g*k*k)*out_channels个参数,是正常卷积参数的1/g。分组卷积中,每个卷积核只处理输入特征图的一部分通道,其优点在于参数量会有所降低,但输出通道数仍等于卷积核的数量

具体内容:

1. 导包

  1. from mindspore import nn
  2. import mindspore.ops as ops
  3. from mindspore import Tensor
  4. from download import download
  5. import mindspore as ms
  6. from mindspore.dataset import Cifar10Dataset
  7. from mindspore.dataset import vision, transforms
  8. import time
  9. import mindspore
  10. import numpy as np
  11. from mindspore import Tensor, nn
  12. from mindspore.train import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor, Model, Top1CategoricalAccuracy, Top5CategoricalAccuracy
  13. from mindspore import load_checkpoint, load_param_into_net
  14. import mindspore
  15. import matplotlib.pyplot as plt
  16. import mindspore.dataset as ds

2. 模型

  1. class GroupConv(nn.Cell):
  2. def __init__(self, in_channels, out_channels, kernel_size,
  3. stride, pad_mode="pad", pad=0, groups=1, has_bias=False):
  4. super(GroupConv, self).__init__()
  5. self.groups = groups
  6. self.convs = nn.CellList()
  7. for _ in range(groups):
  8. self.convs.append(nn.Conv2d(in_channels // groups, out_channels // groups,
  9. kernel_size=kernel_size, stride=stride, has_bias=has_bias,
  10. padding=pad, pad_mode=pad_mode, group=1, weight_init='xavier_uniform'))
  11. def construct(self, x):
  12. features = ops.split(x, split_size_or_sections=int(len(x[0]) // self.groups), axis=1)
  13. outputs = ()
  14. for i in range(self.groups):
  15. outputs = outputs + (self.convs[i](features[i].astype("float32")),)
  16. out = ops.cat(outputs, axis=1)
  17. return out
  1. class ShuffleV1Block(nn.Cell):
  2. def __init__(self, inp, oup, group, first_group, mid_channels, ksize, stride):
  3. super(ShuffleV1Block, self).__init__()
  4. self.stride = stride
  5. pad = ksize // 2
  6. self.group = group
  7. if stride == 2:
  8. outputs = oup - inp
  9. else:
  10. outputs = oup
  11. self.relu = nn.ReLU()
  12. branch_main_1 = [
  13. GroupConv(in_channels=inp, out_channels=mid_channels,
  14. kernel_size=1, stride=1, pad_mode="pad", pad=0,
  15. groups=1 if first_group else group),
  16. nn.BatchNorm2d(mid_channels),
  17. nn.ReLU(),
  18. ]
  19. branch_main_2 = [
  20. nn.Conv2d(mid_channels, mid_channels, kernel_size=ksize, stride=stride,
  21. pad_mode='pad', padding=pad, group=mid_channels,
  22. weight_init='xavier_uniform', has_bias=False),
  23. nn.BatchNorm2d(mid_channels),
  24. GroupConv(in_channels=mid_channels, out_channels=outputs,
  25. kernel_size=1, stride=1, pad_mode="pad", pad=0,
  26. groups=group),
  27. nn.BatchNorm2d(outputs),
  28. ]
  29. self.branch_main_1 = nn.SequentialCell(branch_main_1)
  30. self.branch_main_2 = nn.SequentialCell(branch_main_2)
  31. if stride == 2:
  32. self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same')
  33. def construct(self, old_x):
  34. left = old_x
  35. right = old_x
  36. out = old_x
  37. right = self.branch_main_1(right)
  38. if self.group > 1:
  39. right = self.channel_shuffle(right)
  40. right = self.branch_main_2(right)
  41. if self.stride == 1:
  42. out = self.relu(left + right)
  43. elif self.stride == 2:
  44. left = self.branch_proj(left)
  45. out = ops.cat((left, right), 1)
  46. out = self.relu(out)
  47. return out
  48. def channel_shuffle(self, x):
  49. batchsize, num_channels, height, width = ops.shape(x)
  50. group_channels = num_channels // self.group
  51. x = ops.reshape(x, (batchsize, group_channels, self.group, height, width))
  52. x = ops.transpose(x, (0, 2, 1, 3, 4))
  53. x = ops.reshape(x, (batchsize, num_channels, height, width))
  54. return x
  1. class ShuffleNetV1(nn.Cell):
  2. def __init__(self, n_class=1000, model_size='2.0x', group=3):
  3. super(ShuffleNetV1, self).__init__()
  4. print('model size is ', model_size)
  5. self.stage_repeats = [4, 8, 4]
  6. self.model_size = model_size
  7. if group == 3:
  8. if model_size == '0.5x':
  9. self.stage_out_channels = [-1, 12, 120, 240, 480]
  10. elif model_size == '1.0x':
  11. self.stage_out_channels = [-1, 24, 240, 480, 960]
  12. elif model_size == '1.5x':
  13. self.stage_out_channels = [-1, 24, 360, 720, 1440]
  14. elif model_size == '2.0x':
  15. self.stage_out_channels = [-1, 48, 480, 960, 1920]
  16. else:
  17. raise NotImplementedError
  18. elif group == 8:
  19. if model_size == '0.5x':
  20. self.stage_out_channels = [-1, 16, 192, 384, 768]
  21. elif model_size == '1.0x':
  22. self.stage_out_channels = [-1, 24, 384, 768, 1536]
  23. elif model_size == '1.5x':
  24. self.stage_out_channels = [-1, 24, 576, 1152, 2304]
  25. elif model_size == '2.0x':
  26. self.stage_out_channels = [-1, 48, 768, 1536, 3072]
  27. else:
  28. raise NotImplementedError
  29. input_channel = self.stage_out_channels[1]
  30. self.first_conv = nn.SequentialCell(
  31. nn.Conv2d(3, input_channel, 3, 2, 'pad', 1, weight_init='xavier_uniform', has_bias=False),
  32. nn.BatchNorm2d(input_channel),
  33. nn.ReLU(),
  34. )
  35. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
  36. features = []
  37. for idxstage in range(len(self.stage_repeats)):
  38. numrepeat = self.stage_repeats[idxstage]
  39. output_channel = self.stage_out_channels[idxstage + 2]
  40. for i in range(numrepeat):
  41. stride = 2 if i == 0 else 1
  42. first_group = idxstage == 0 and i == 0
  43. features.append(ShuffleV1Block(input_channel, output_channel,
  44. group=group, first_group=first_group,
  45. mid_channels=output_channel // 4, ksize=3, stride=stride))
  46. input_channel = output_channel
  47. self.features = nn.SequentialCell(features)
  48. self.globalpool = nn.AvgPool2d(7)
  49. self.classifier = nn.Dense(self.stage_out_channels[-1], n_class)
  50. def construct(self, x):
  51. x = self.first_conv(x)
  52. x = self.maxpool(x)
  53. x = self.features(x)
  54. x = self.globalpool(x)
  55. x = ops.reshape(x, (-1, self.stage_out_channels[-1]))
  56. x = self.classifier(x)
  57. return x

3. 数据处理

  1. def get_dataset(train_dataset_path, batch_size, usage):
  2. image_trans = []
  3. if usage == "train":
  4. image_trans = [
  5. vision.RandomCrop((32, 32), (4, 4, 4, 4)),
  6. vision.RandomHorizontalFlip(prob=0.5),
  7. vision.Resize((224, 224)),
  8. vision.Rescale(1.0 / 255.0, 0.0),
  9. vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
  10. vision.HWC2CHW()
  11. ]
  12. elif usage == "test":
  13. image_trans = [
  14. vision.Resize((224, 224)),
  15. vision.Rescale(1.0 / 255.0, 0.0),
  16. vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
  17. vision.HWC2CHW()
  18. ]
  19. label_trans = transforms.TypeCast(ms.int32)
  20. dataset = Cifar10Dataset(train_dataset_path, usage=usage, shuffle=True)
  21. dataset = dataset.map(image_trans, 'image')
  22. dataset = dataset.map(label_trans, 'label')
  23. dataset = dataset.batch(batch_size, drop_remainder=True)
  24. return dataset
  25. dataset = get_dataset("./dataset/cifar-10-batches-bin", 128, "train")
  26. batches_per_epoch = dataset.get_dataset_size()

4. 模型训练

  1. def train():
  2. mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target="Ascend")
  3. net = ShuffleNetV1(model_size="2.0x", n_class=10)
  4. loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)
  5. min_lr = 0.0005
  6. base_lr = 0.05
  7. lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,
  8. base_lr,
  9. batches_per_epoch*250,
  10. batches_per_epoch,
  11. decay_epoch=250)
  12. lr = Tensor(lr_scheduler[-1])
  13. optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.00004, loss_scale=1024)
  14. loss_scale_manager = ms.amp.FixedLossScaleManager(1024, drop_overflow_update=False)
  15. model = Model(net, loss_fn=loss, optimizer=optimizer, amp_level="O3", loss_scale_manager=loss_scale_manager)
  16. callback = [TimeMonitor(), LossMonitor()]
  17. save_ckpt_path = "./"
  18. config_ckpt = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=5)
  19. ckpt_callback = ModelCheckpoint("shufflenetv1", directory=save_ckpt_path, config=config_ckpt)
  20. callback += [ckpt_callback]
  21. print("============== Starting Training ==============")
  22. start_time = time.time()
  23. # 由于时间原因,epoch = 5,可根据需求进行调整
  24. model.train(5, dataset, callbacks=callback)
  25. use_time = time.time() - start_time
  26. hour = str(int(use_time // 60 // 60))
  27. minute = str(int(use_time // 60 % 60))
  28. second = str(int(use_time % 60))
  29. print("total time:" + hour + "h " + minute + "m " + second + "s")
  30. print("============== Train Success ==============")
  31. if __name__ == '__main__':
  32. train()

5. 模型评估

  1. def test():
  2. mindspore.set_context(mode=mindspore.GRAPH_MODE, device_target="Ascend")
  3. dataset = get_dataset("./dataset/cifar-10-batches-bin", 128, "test")
  4. net = ShuffleNetV1(model_size="2.0x", n_class=10)
  5. param_dict = load_checkpoint("shufflenetv1-5_390.ckpt")
  6. load_param_into_net(net, param_dict)
  7. net.set_train(False)
  8. loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)
  9. eval_metrics = {'Loss': nn.Loss(), 'Top_1_Acc': Top1CategoricalAccuracy(),
  10. 'Top_5_Acc': Top5CategoricalAccuracy()}
  11. model = Model(net, loss_fn=loss, metrics=eval_metrics)
  12. start_time = time.time()
  13. res = model.eval(dataset, dataset_sink_mode=False)
  14. use_time = time.time() - start_time
  15. hour = str(int(use_time // 60 // 60))
  16. minute = str(int(use_time // 60 % 60))
  17. second = str(int(use_time % 60))
  18. log = "result:" + str(res) + ", ckpt:'" + "./shufflenetv1-5_390.ckpt" \
  19. + "', time: " + hour + "h " + minute + "m " + second + "s"
  20. print(log)
  21. filename = './eval_log.txt'
  22. with open(filename, 'a') as file_object:
  23. file_object.write(log + '\n')
  24. if __name__ == '__main__':
  25. test()

6. 模型预测

  1. net = ShuffleNetV1(model_size="2.0x", n_class=10)
  2. show_lst = []
  3. param_dict = load_checkpoint("shufflenetv1-5_390.ckpt")
  4. load_param_into_net(net, param_dict)
  5. model = Model(net)
  6. dataset_predict = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")
  7. dataset_show = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")
  8. dataset_show = dataset_show.batch(16)
  9. show_images_lst = next(dataset_show.create_dict_iterator())["image"].asnumpy()
  10. image_trans = [
  11. vision.RandomCrop((32, 32), (4, 4, 4, 4)),
  12. vision.RandomHorizontalFlip(prob=0.5),
  13. vision.Resize((224, 224)),
  14. vision.Rescale(1.0 / 255.0, 0.0),
  15. vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
  16. vision.HWC2CHW()
  17. ]
  18. dataset_predict = dataset_predict.map(image_trans, 'image')
  19. dataset_predict = dataset_predict.batch(16)
  20. class_dict = {0:"airplane", 1:"automobile", 2:"bird", 3:"cat", 4:"deer", 5:"dog", 6:"frog", 7:"horse", 8:"ship", 9:"truck"}
  21. # 推理效果展示(上方为预测的结果,下方为推理效果图片)
  22. plt.figure(figsize=(16, 5))
  23. predict_data = next(dataset_predict.create_dict_iterator())
  24. output = model.predict(ms.Tensor(predict_data['image']))
  25. pred = np.argmax(output.asnumpy(), axis=1)
  26. index = 0
  27. for image in show_images_lst:
  28. plt.subplot(2, 8, index+1)
  29. plt.title('{}'.format(class_dict[pred[index]]))
  30. index += 1
  31. plt.imshow(image)
  32. plt.axis("off")
  33. plt.show()

ShuffleNet作为一种轻量级的神经网络结构,其独特的数据重排和分组卷积设计,使得它在保持较高精度的同时,大大降低了模型的计算量和参数量。这种设计思路让我意识到,在追求模型性能的同时,也需要考虑模型的轻量化和效率问题。这对于我们来说是一个非常重要的启示,特别是在移动设备和嵌入式系统上的深度学习应用中。

在使用MindSpore和ShuffleNet进行图像分类的实践过程中,我也遇到了一些挑战和困难。例如,如何调整超参数以优化模型的性能、如何处理数据的不平衡问题等。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/黑客灵魂/article/detail/782924
推荐阅读
相关标签
  

闽ICP备14008679号