赞
踩
实现一个完整的图像分类任务,大致需要五个步骤:
本次实战选择pytorch开源框架,按照上述步骤实现一个基本的图像分类任务,并详细阐述其中的细节。
表面缺陷检测是生产制造过程中必不可少的一步,尤其在带钢原料钢卷的轧制工艺过程中形成的表面缺陷是造成废、次品的主要原因,因此必须加强对带钢表面缺陷检测,通过缺陷检测,对于加强轧制工艺管理,剔除废品等都有重要的意义。
本次实战选择的数据库为由东北大学(NEU)发布的热轧钢带表面缺陷数据库,收集了热轧钢带的六种典型表面缺陷,即轧制氧化皮(RS),斑块(Pa),开裂(Cr),点蚀表面( PS),内含物(In)和划痕(Sc)。该数据库包括1,800个灰度图像:六种不同类型的典型表面缺陷,每一类缺陷包含300个样本。
数据库下载地址 NEU-CLS
提取码:175m
下面展示了6中缺陷样本的图像
首先需要将数据集分类处理成pytorch可以读取的形式,即是将缺陷图像按类别放置在不同的文件夹中。代码如下:
import os import shutil ### 数据集根目录 root_dir = '数据集绝对地址' ### 数据集转移目录 shutil_dir = '处理数据集绝对地址' all_images = os.listdir(root_dir) #读取所有文件 images_classes= ['Cr', 'In', 'Pa', 'PS', 'RS', 'Sc'] for img in all_images: img_shutil_dir = os.path.join(shutil_dir, str(images_classes.index(img[0:2]))) if not os.path.isdir(img_shutil_dir): os.mkdir(img_shutil_dir) shutil.copyfile(os.path.join(root_dir, img), os.path.join(img_shutil_dir, img))
运行后,数据集形式如下:每个文件夹中放置的是同类型的缺陷图像。
在这一步,需要实现数据集的加载和数据集划分,数据集加载运用ImageFolder()
和DataLoader()
, 数据集划分运用random_spilt()
,同时实现数据集加载时的数据增强。
数据增强介绍:数据增强
Pytorch常用图像处理和数据增强方法:Pytorch
import torch.utils.data as Data import torchvision import torchvision.transforms as transforms train_transform = transforms.Compose([ transforms.RandomResizedCrop(200), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) dataset = torchvision.datasets.ImageFolder(shutil_dir, transform=train_transform) #全部训练用例 ''' 按照8 :2 比例切分数据集为训练集和验证集 train_dataset 为训练集,valid_dataset为验证集 ''' train_size = int(0.8*len(dataset)) valid_size = len(dataset)-train_size train_dataset, valid_dataset = Data.random_split(dataset, [train_size, valid_size]) train_data = Data.DataLoader(train_dataset, batch_size=1, shuffle=True) valid_data = Data.DataLoader(valid_dataset, batch_size=1, shuffle=False)
本例中的Normalize使用的参数为在ImageNet数据集上计算得到的方差和均值,实际使用时需要重新计算。参考链接:pytorch标准化。
常用的图像分类网络有VGG、ResNet、ResNext、DenseNet、Inception、ShuffleNet等,
参考链接:
图像分类:常用分类网络结构(附论文下载)
常用的分类网络
在本次实战中,主要选取了ResNet-50经典网络做为训练模型,
import torchvision import torch.nn as nn basic_model = torchvision.models.resnet50(pretrained=True) class resnet_classifier(nn.Module): def __init__(self, classnumber=21): super(resnet_classifier, self).__init__() self.features = nn.Sequential(*list(basic_model.children())[:-1]) fc_features = basic_model.fc.in_features self.classifier = nn.Linear(fc_features, classnumber, bias=False) def forward(self, x): features = self.features(x) features = torch.flatten(features, 1) classifier = self.classifier(features) return classifier
损失函数选择标准的交叉熵损失函数(详细介绍损失函数)
优化方式选择Adam优化(详细介绍优化方式)
在训练中,在网络结构中加载了预训练模型,可以加快训练速度和提升训练精度,初始学习率设置为1e-4, 在网络结构的特征层和分类层采取不同的学习率,分类层的学习率为特征层的10倍,学习率调整策略为指数衰减。(参考链接学习率调整)
model = resnet_classifier()
train_params = [{'params':model.features.parameters(), 'lr':lr},
{'params':model.classifier.parameters(),'lr':10*lr}]
optimizer = torch.optim.Adam(train_params)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.9)
训练和测试代码:
def training(self,epoch): train_loss = 0.0 self.model.train() tbar = tqdm(self.train_data) num_img_tr = len(self.train_data) for i, sample in enumerate(tbar): img, label = sample if self.cuda: img = img.cuda() self.optimizer.zero_grad() output = self.model(img) loss = self.Loss(output.cpu(), label) loss.backward() self.optimizer.step() self.scheduler.step() train_loss += loss.item() ### 记录训练过程 监控loss值 tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch) self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch) print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.batch_size + img.data.shape[0])) print('Loss: %.3f' % train_loss) def validation(self, epoch): self.model.eval() tbar = tqdm(self.valid_data, desc='\r') test_loss = 0.0 train_acc_sum = 0.0 num_img_tr = len(self.valid_data) * self.batch_size for i, sample in enumerate(tbar): img, label = sample if self.cuda: img = img.cuda() with torch.no_grad(): output = model(img) loss = self.loss(output.cpu(), label) test_loss += loss.item() tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1))) # Add batch sample into evaluator train_acc_sum += (output.cpu().argmax(dim=1) == label).sum().cpu().item() ### 监控验证过程 记录正确率 accuracy = train_acc_sum / num_img_tr self.writer.add_scalar('test/total_loss_epoch', test_loss, epoch) self.writer.add_scalar('accuracy', accuracy, epoch)
选用不同的模型和训练参数,对比训练精度,对模型或者超参数进行调整优化。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。