当前位置:   article > 正文

计算机视觉(十五):综合案例:垃圾分类_垃圾分类计算机视觉项目

垃圾分类计算机视觉项目

计算机视觉笔记总目录


一、垃圾分类之模型构建

1 垃圾分类介绍

这里不做详细介绍了,有兴趣的可以看看:百度百科

2 华为垃圾分类比赛介绍

官网:https://competition.huaweicloud.com/information/1000007620/introduction

本次比赛选取40种生活中常见的垃圾,选手根据公布的数据集进行模型训练,将训练好的模型发布到华为ModelArts平台上,在线预测华为的私有数据集,采用识别准确率作为评价指标。这次比赛中有很多容易混淆的类,比如饮料瓶和调料瓶、筷子和牙签、果皮和果肉等外形极为相似的垃圾,因此此次竞赛也可看作是细粒度图像分类任务。

{
    "0": "其他垃圾/一次性快餐盒",
    "1": "其他垃圾/污损塑料",
    "2": "其他垃圾/烟蒂",
    "3": "其他垃圾/牙签",
    "4": "其他垃圾/破碎花盆及碟碗",
    "5": "其他垃圾/竹筷",
    "6": "厨余垃圾/剩饭剩菜",
    "7": "厨余垃圾/大骨头",
    "8": "厨余垃圾/水果果皮",
    "9": "厨余垃圾/水果果肉",
    "10": "厨余垃圾/茶叶渣",
    "11": "厨余垃圾/菜叶菜根",
    "12": "厨余垃圾/蛋壳",
    "13": "厨余垃圾/鱼骨",
    "14": "可回收物/充电宝",
    "15": "可回收物/包",
    "16": "可回收物/化妆品瓶",
    "17": "可回收物/塑料玩具",
    "18": "可回收物/塑料碗盆",
    "19": "可回收物/塑料衣架",
    "20": "可回收物/快递纸袋",
    "21": "可回收物/插头电线",
    "22": "可回收物/旧衣服",
    "23": "可回收物/易拉罐",
    "24": "可回收物/枕头",
    "25": "可回收物/毛绒玩具",
    "26": "可回收物/洗发水瓶",
    "27": "可回收物/玻璃杯",
    "28": "可回收物/皮鞋",
    "29": "可回收物/砧板",
    "30": "可回收物/纸板箱",
    "31": "可回收物/调料瓶",
    "32": "可回收物/酒瓶",
    "33": "可回收物/金属食品罐",
    "34": "可回收物/锅",
    "35": "可回收物/食用油桶",
    "36": "可回收物/饮料瓶",
    "37": "有害垃圾/干电池",
    "38": "有害垃圾/软膏",
    "39": "有害垃圾/过期药物"
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42

重要:比赛或者项目解题思路

  • 1、拿到数据后,首先做数据分析。统计数据样本分布,尺寸分布,图片形态等,基于分析可以做一些针对性的数据预处理算法,对后期的模型训练会有很大的帮助
  • 2、选择好的baseline。需要不断的尝试各种现有的网络结构,进行结果对比,挑选出适合该网络的模型结构,然后基于该模型进行不断的调参,调试出性能较好的参数
  • 3、做结果验证,将上述模型在验证集上做结果验证,找出错误样本,分析出错原因,然后针对性的调整网络和数据。
  • 4、基于新数据和模型,再次进行模型调优

比赛表现前5团队效果

名次准确率推理(inference)时间(ms)
第一名0.969636102.8
第二名0.9625195.43
第三名0.96204597.25
第四名0.96173582.99
第五名0.957397108.49

2.1 赛题分析

  • 1、问题描述

    经典图像分类问题。采用深圳市垃圾分类标准,输出该物品属于可回收物、 厨余垃圾、有害垃圾和其他垃圾中的二级分类,共40个类别

  • 2、评价指标

    识别准确率 = 识别正确的图片数 / 图片总数

  • 3、挑战

    官方训练集有19459张图片,数据量小;
    类别较多(40),且各类样本不平衡;
    图片大小、分辨率不一,垃圾物品有多种尺度;
    垃圾分类是细粒度、粗粒度兼有的一种分类问题,轮廓、纹理、对象位置分 布都需要考察

2.2 对策

  • 1、数据集分析和选择
  • 2、模型选择
  • 3、图像分类问题常见trick(优化)

3 项目构建

3.1 项目模块图

在这里插入图片描述

  • data:存放数据的目录
  • data_gen目录:批次数据预处理代码,包括数据增强、标签平滑、mixup功能
  • deploy:模型导出以及部署模块
  • efficientnet:efficientnet模型源码存放位置
  • utils:封装的工具类,如warmup以及余弦退火学习率
  • train.y与eval.py:训练网络部分包括数据流获取、网络构建、优化器

3.2 步骤以及知识点应用分析

  • 1、数据读取以及预处理模块

    数据获取
    数据增强
    归一化
    随机擦除
    Mixup

  • 2、模型网络结构实现

    efficientnet模型介绍
    垃圾分类模型修改
    模型学习率优化-warmup与余弦退火学习率
    模型优化器-Adam优化器改进RAdam/NRdam

  • 3、模型训练保存与预测

    模型完整训练过程实现
    预估流程实现

  • 4、模型导出以及部署

    tf.saved_model模块使用
    TensorFlow serving模块使用

4 processing_data.py代码展示

import math
import os
import random
import numpy as np
import PIL
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical, Sequence
from sklearn.model_selection import train_test_split

from random_eraser import get_random_eraser


class GarbageDataSequence(Sequence):
    """
    数据流生成器,每次迭代返回一个batch
    可直接用于fit_generator的generator参数,能保证在多进程下的一个epoch中不会重复取相同的样本
    """

    def __init__(self, img_paths, labels, batch_size, img_size, use_aug):
        # 训练之与目标值合并结果 , [batch_size, 1], [batch_size, 40] -> [batch_size, 41]
        self.x_y = np.hstack(((np.array(img_paths).reshape(len(img_paths), 1)), np.array(labels)))

        self.batch_size = batch_size
        self.img_size = img_size  # [300,300]
        self.use_aug = use_aug
        self.alpha = 0.2
        # 随机擦除方法
        self.eraser = get_random_eraser(s_h=0.3, pixel_level=True)

    def __len__(self):
        return math.ceil(len(self.x_y) / self.batch_size)

    @staticmethod
    def center_img(img, size=None, fill_value=255):
        """
        改变图片尺寸到300x300,并且做填充使得图像处于中间位置
        """
        h, w = img.shape[:2]
        if size is None:
            size = max(h, w)
        shape = (size, size) + img.shape[2:]
        background = np.full(shape, fill_value, np.uint8)
        center_x = (size - w) // 2
        center_y = (size - h) // 2
        background[center_y:center_y + h, center_x:center_x + w] = img
        return background

    def preprocess_img(self, img_path):
        """
        图片的处理流程函数,数据增强、center_img处理
        """

        # 图像读取,[180 , 200]->[300/200 * 180, 300/200 * 200]
        # 这样做为了不使图形直接变形,后续再统一长宽
        img = PIL.Image.open(img_path)  # [180, 200, 3]
        resize_scale = self.img_size[0] / max(img.size[:2])
        img = img.resize((int(img.size[0] * resize_scale), int(img.size[1] * resize_scale)))
        img = img.convert('RGB')
        img = np.array(img)

        # 数据增强:如果是训练集进行数据增强操作
        if self.use_aug:
            # 先随机擦除,然后翻转
            img = self.eraser(img)
            datagen = ImageDataGenerator(
                width_shift_range=0.05,
                height_shift_range=0.05,
                horizontal_flip=True,
                vertical_flip=True,
            )
            # 由于变换的种类很多,这里是随机使用某一种变换在图像上面
            img = datagen.random_transform(img)

        # 把图片大小调整到[300, 300, 3],调整的方式为直接填充小的坐标。为了模型需要
        img = self.center_img(img, self.img_size[0])
        return img

    def mixup(self, batch_x, batch_y):
        """
        数据混合mixup
        :param batch_x: 要mixup的batch_X
        :param batch_y: 要mixup的batch_y
        :return: mixup后的数据
        """
        size = self.batch_size
        l = np.random.beta(self.alpha, self.alpha, size)

        X_l = l.reshape(size, 1, 1, 1)
        y_l = l.reshape(size, 1)

        X1 = batch_x
        Y1 = batch_y
        X2 = batch_x[::-1]
        Y2 = batch_y[::-1]

        X = X1 * X_l + X2 * (1 - X_l)
        Y = Y1 * y_l + Y2 * (1 - y_l)

        return X, Y

    @staticmethod
    def preprocess_input(x):
        """归一化处理样本特征值
        :param x:
        :return:
        """
        assert x.ndim in (3, 4)
        assert x.shape[-1] == 3

        MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255]
        STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255]

        x = x - np.array(MEAN_RGB)
        x = x / np.array(STDDEV_RGB)

        return x

    def __getitem__(self, idx):
        # 处理图片大小、数据增强等过程
        batch_x = self.x_y[idx * self.batch_size: (idx + 1) * self.batch_size, 0]
        batch_y = self.x_y[idx * self.batch_size: (idx + 1) * self.batch_size, 1:]

        batch_x = np.array([self.preprocess_img(img_path) for img_path in batch_x])
        batch_y = np.array(batch_y).astype(np.float32)

        # 2、mixup进行构造新的样本分布数据
        batch_x, batch_y = self.mixup(batch_x, batch_y)

        # 3、输入模型的归一化数据
        batch_x = self.preprocess_input(batch_x)
        return batch_x, batch_y

    def on_epoch_end(self):
        np.random.shuffle(self.x_y)


def smooth_labels(y, smooth_factor=0.1):
    assert len(y.shape) == 2
    if 0 <= smooth_factor <= 1:
        y *= 1 - smooth_factor
        y += smooth_factor / y.shape[1]
    else:
        raise Exception(
            'Invalid label smoothing factor: ' + str(smooth_factor))
    return y


def data_from_sequence(train_data_dir, batch_size, num_classes, input_size):
    """
    读取本地图片和标签数据,处理成sequence数据类型
    :param train_data_dir: 训练数据目录
    :param batch_size: 批次大小
    :param num_classes: 垃圾分类总类别数
    :param input_size: 输入模型的图片大小(300, 300)
    :return:
    """
    # 1、获取txt文件,打乱一次文件
    label_files = []

    for filename in os.listdir(train_data_dir):
        if filename.endswith('.txt'):
            label_files.append(os.path.join(train_data_dir, filename))

    random.seed(2)
    random.shuffle(label_files)

    # 解析txt文件当中特征值以及目标值
    img_paths = []
    labels = []

    for _, file_path in enumerate(label_files):
        with open(file_path, 'r') as f:
            line = f.readline()
        line_split = line.strip().split(', ')
        if len(line_split) != 2:
            print('%s 文件中格式错误' % file_path)
            continue
        img_name = line_split[0]
        label = int(line_split[1])
        img_paths.append(os.path.join(train_data_dir, img_name))
        labels.append(label)

    # 进行标签类别one-hot处理,以及标签平滑
    labels = to_categorical(labels, num_classes)
    labels = smooth_labels(labels)

    # 进行所有数据的分割,训练集和验证集
    train_img_paths, validation_img_paths, train_labels, validation_labels = \
        train_test_split(img_paths, labels, test_size=0.15, random_state=9)
    print('总共样本数: %d, 训练样本数: %d, 验证样本数据: %d' % (
        len(img_paths), len(train_img_paths), len(validation_img_paths)))

    # sequence序列数据制作
    train_sequence = GarbageDataSequence(train_img_paths, train_labels, batch_size,
                                         [input_size, input_size], use_aug=True)

    validation_sequence = GarbageDataSequence(validation_img_paths, validation_labels, batch_size,
                                              [input_size, input_size], use_aug=False)

    return train_sequence, validation_sequence


if __name__ == '__main__':
    batch_size = 32
    train_data_dir = '../data/garbage_classify/train_data'

    train_sequence, validation_sequence = data_from_sequence(train_data_dir, batch_size, num_classes=40, input_size=300)

    # for i in range(100):
    #     print("第 %d 批次数据" % i)
    #     batch_x, bacth_y = train_sequence.__getitem__(i)
    #     print(batch_x, bacth_y)
    #     batch_x, bacth_y = validation_sequence.__getitem__(i)
    #     print(batch_x, bacth_y)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215

二、垃圾分类之训练、推理

使用EfficientNet模型,论文地址:https://arxiv.org/pdf/1905.11946.pdf

1 lr_scheduler.py代码实现

import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"


def cosine_decay_with_warmup(global_step,
                             learning_rate_base,
                             total_steps,
                             warmup_learning_rate=0.0,
                             warmup_steps=0,
                             hold_base_rate_steps=0):
    """
    每批次带有warmup余弦退火学习率计算
    :param global_step: 当前到达的步数
    :param learning_rate_base: warmup之后的基础学习率
    :param total_steps: 总需要批次数
    :param warmup_learning_rate: warmup开始的学习率
    :param warmup_steps: warmup学习率 步数
    :param hold_base_rate_steps: 预留总步数和warmup步数间隔
    :return: lr
    """
    if total_steps < warmup_steps:
        raise ValueError('总步数必须大于warmup')

    # 余弦退火学习率计算,从warmup结束之后计算
    lr = 0.5 * learning_rate_base * (1 + np.cos(np.pi * (global_step - warmup_steps - hold_base_rate_steps)
                                                / float(total_steps - warmup_steps - hold_base_rate_steps)))
    # warmup之后的学习率计算
    # 如果预留大于0,判断目前步数是否 > warmup步数+预留步数
    # 是的话返回刚才上面计算的学习率,不是的话使用warmup之后的基础学习率
    if hold_base_rate_steps > 0:
        lr = np.where(global_step > warmup_steps + hold_base_rate_steps, lr, learning_rate_base)
    # warmup步数是大于0的
    if warmup_steps > 0:
        if learning_rate_base < warmup_learning_rate:
            raise ValueError('warmup后学习率必须大于warmup开始学习率')
        # warmup的斜率
        slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
        # 计算warmup第global_step的学习率
        warmup_lr = slope * global_step + warmup_learning_rate
        # 判断global_step小于warmup_steps的话,返回这个warmup当时的学习率,否则直接返回余弦退火计算的
        lr = np.where(global_step < warmup_steps, warmup_lr, lr)

    # 如果最后当前到达的步数大于总步数,则归0
    # 否则返回当前的计算出来的学习率(可能是warmup学习率也可能是余弦衰减结果)
    return np.where(global_step > total_steps, 0.0, lr)


class WarmUpCosineDecayScheduler(tf.keras.callbacks.Callback):
    """
    带有warmup的余弦退火学习率调度
    """

    def __init__(self,
                 learning_rate_base,
                 total_steps,
                 global_step_init=0,
                 warmup_learning_rate=0.0,
                 warmup_steps=0,
                 hold_base_rate_steps=0,
                 verbose=0):
        """
        初始化参数
        :param learning_rate_base: 基础学习率
        :param total_steps: 总共迭代的批次步数 epoch * num_samples / batch_size
        :param global_step_init: 初始
        :param warmup_learning_rate: 预热学习率默认0.0
        :param warmup_steps:预热的步数默认0
        :param hold_base_rate_steps: 预留步数
        :param verbose:每次训练结束是狗打印学习率
        """
        super(WarmUpCosineDecayScheduler, self).__init__()
        self.learning_rate_base = learning_rate_base
        self.total_steps = total_steps
        self.global_step = global_step_init
        self.warmup_learning_rate = warmup_learning_rate
        self.warmup_steps = warmup_steps
        self.hold_base_rate_steps = hold_base_rate_steps
        # 是否在每次训练结束打印学习率
        self.verbose = verbose
        # 记录所有批次下来的每次准确的学习率,可以用于打印显示
        self.learning_rates = []

    def on_batch_end(self, batch, logs=None):
        self.global_step = self.global_step + 1
        lr = K.get_value(self.model.optimizer.lr)
        self.learning_rates.append(lr)

    def on_batch_begin(self, batch, logs=None):
        lr = cosine_decay_with_warmup(
            global_step=self.global_step,
            learning_rate_base=self.learning_rate_base,
            total_steps=self.total_steps,
            warmup_learning_rate=self.warmup_learning_rate,
            warmup_steps=self.warmup_steps,
            hold_base_rate_steps=self.hold_base_rate_steps)

        K.set_value(self.model.optimizer.lr, lr)
        if self.verbose > 0:
            print('\n批次数 %05d: 设置学习率为 %s.' % (self.global_step + 1, lr))


if __name__ == '__main__':
    # 1、创建模型
    model = Sequential()
    model.add(Dense(32, activation='relu', input_dim=100))
    model.add(Dense(10, activation='softmax'))
    model.compile(optimizer='rmsprop',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

    # 2、参数设置
    sample_count = 1000  # 样本数
    epochs = 4  # 总迭代次数
    warmup_epoch = 3  # warmup 迭代次数
    batch_size = 16  # 批次大小
    learning_rate_base = 0.0001  # warmup后的初始学习率
    total_steps = int(epochs * sample_count / batch_size)  # 总迭代批次步数
    warmup_steps = int(warmup_epoch * sample_count / batch_size)  # warmup总批次数

    # 3、创建测试数据
    data = np.random.random((sample_count, 100))
    labels = np.random.randint(10, size=(sample_count, 1))
    # 转换目标类别
    one_hot_labels = tf.keras.utils.to_categorical(labels, num_classes=10)

    # 5、创建余弦warmup调度器
    warm_up_lr = WarmUpCosineDecayScheduler(learning_rate_base=learning_rate_base,
                                            total_steps=total_steps,
                                            warmup_learning_rate=4e-06,  # warmup开始学习率
                                            warmup_steps=warmup_steps,
                                            hold_base_rate_steps=0,
                                            verbose=0
                                            )

    # 训练模型
    model.fit(data, one_hot_labels, epochs=epochs, batch_size=batch_size, verbose=0, callbacks=[warm_up_lr])

    print(warm_up_lr.learning_rates)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145

2 train.py代码实现

import multiprocessing
import numpy as np
import argparse
import tensorflow as tf
from tensorflow.keras.callbacks import TensorBoard, Callback
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam, RMSprop

from efficientnet import model as EfficientNet
from data_gen import data_from_sequence
from utils.lr_scheduler import WarmUpCosineDecayScheduler
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
# 注意关闭默认的eager模式
tf.compat.v1.disable_eager_execution()

parser = argparse.ArgumentParser()
parser.add_argument("data_url", type=str, default='./data/garbage_classify/train_data', help="data dir", nargs='?')
parser.add_argument("train_url", type=str, default='./garbage_ckpt/', help="save model dir", nargs='?')
parser.add_argument("num_classes", type=int, default=40, help="num_classes", nargs='?')
parser.add_argument("input_size", type=int, default=300, help="input_size", nargs='?')
parser.add_argument("batch_size", type=int, default=16, help="batch_size", nargs='?')
# parser.add_argument("batch_size", type=int, default=64, help="batch_size", nargs='?')
parser.add_argument("learning_rate", type=float, default=0.001, help="learning_rate", nargs='?')
parser.add_argument("max_epochs", type=int, default=3, help="max_epochs", nargs='?')
parser.add_argument("deploy_script_path", type=str, default='', help="deploy_script_path", nargs='?')
parser.add_argument("test_data_url", type=str, default='', help="test_data_url", nargs='?')


def model_fn(param):
    """修改符合垃圾分类的模型
    :param param: 命令行参数
    :return:
    """
    base_model = EfficientNet.EfficientNetB3(include_top=False,
                                             input_shape=(param.input_size, param.input_size, 3),
                                             classes=param.num_classes)

    x = base_model.output
    # 自定义修改40个分类的后面基层
    x = GlobalAveragePooling2D(name='avg_pool')(x)
    predictions = Dense(param.num_classes, activation='softmax')(x)

    model = Model(inputs=base_model.input, outputs=predictions)
    return model


def train_model(param):
    """训练模型逻辑
    :param param: 各种参数命令行
    :return:
    """
    # 1、读取sequence数据
    train_sequence, validation_sequence = data_from_sequence(param.data_url, param.batch_size, param.num_classes, param.input_size)

    # 2、建立模型,修改模型指定训练相关参数
    model = model_fn(param)

    optimizer = Adam(lr=param.learning_rate)
    objective = 'categorical_crossentropy'
    metrics = ['acc']
    # 模型修改
    # 模型训练优化器指定
    model.compile(loss=objective, optimizer=optimizer, metrics=metrics)
    model.summary()

    # 3、指定相关回调函数
    # Tensorboard
    tensorboard = tf.keras.callbacks.TensorBoard(log_dir='./graph', histogram_freq=1, write_graph=True, write_images=True)

    # modelcheckpoint
    # (3)模型保存相关参数
    check = tf.keras.callbacks.ModelCheckpoint(param.train_url + 'weights_{epoch:02d}-{val_acc:.2f}.h5',
                                               monitor='val_acc',
                                               save_best_only=True,
                                               save_weights_only=False,
                                               mode='auto',
                                               period=1)

    # 余弦退回warmup
    # 得到总样本数
    # sample_count = len(train_sequence)* param.batch_size
    sample_count = len(train_sequence)
    batch_size = param.batch_size

    # 第二阶段学习率以及总步数
    learning_rate_base = param.learning_rate
    total_steps = int(param.max_epochs * sample_count) / batch_size
    # 计算第一阶段的步数需要多少 warmup_steps
    warmup_epoch = 1
    warmup_steps = int(warmup_epoch * sample_count) / batch_size


    warm_lr = WarmUpCosineDecayScheduler(learning_rate_base=learning_rate_base,
                                        total_steps=total_steps,
                                        warmup_learning_rate=0,
                                        warmup_steps=warmup_steps,
                                        hold_base_rate_steps=0,)

    # 4、训练步骤
    model.fit_generator(
        train_sequence,
        steps_per_epoch=int(sample_count / batch_size),  # 一个epoch需要多少步 , 1epoch sample_out 140000多样本, 140000 / 16 = 步数
        epochs=param.max_epochs,
        verbose=1,
        callbacks=[check, tensorboard, warm_lr],
        validation_data=validation_sequence,
        max_queue_size=10,
        workers=int(multiprocessing.cpu_count() * 0.7),
        use_multiprocessing=True,
        shuffle=True
    )

    return None


if __name__ == '__main__':
    args = parser.parse_args()
    train_model(args)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/374096
推荐阅读
相关标签
  

闽ICP备14008679号