当前位置:   article > 正文

Python迁移学习:用Torchvision、Pytorch进行交通标志图像分类|附代码数据

Python迁移学习:用Torchvision、Pytorch进行交通标志图像分类|附代码数据

原文链接:https://tecdat.cn/?p=36539

本研究旨在探索如何应用迁移学习技术对交通标志图像进行分类。通过构建适用于Torchvision的图像数据集,并利用预训练模型进行微调,我们实现了对原始像素的交通标志图像的分类点击文末“阅读原文”获取完整代码数据)。

相关视频

此外,我们还引入了一个新的“未知”类别,并对模型进行了重新训练,以提高其在实际应用中的泛化能力。

随着深度学习技术的快速发展,图像分类在交通管理、自动驾驶等领域的应用日益广泛。然而,对于特定的图像分类任务,如交通标志识别,从头开始训练一个深度学习模型往往需要大量的时间和计算资源。因此,迁移学习技术应运而生,它通过利用在大型数据集上预训练的模型,可以大大加快模型的训练速度并提高分类性能。

方法

在本研究中,我们采用了以下步骤来构建和训练交通标志图像分类模型:

  1. 交通标志图像数据集概述:我们首先对所使用的交通标志图像数据集进行了概述,包括数据集的来源、规模、类别分布等信息。

  2. 构建数据集:我们将原始图像数据转换为适用于Torchvision的数据集格式,并进行了必要的数据预处理和增强操作,以提高模型的泛化能力。

  3. 使用Torchvision的预训练模型:我们选择了一个在大型数据集上预训练的深度学习模型作为起点,通过对其进行微调,使其适应交通标志图像的分类任务。

  4. 添加新的“未知”类别并重新训练模型:为了处理实际应用中可能出现的未知类别的图像,我们在数据集中添加了一个新的“未知”类别,并对模型进行了重新训练。通过这种方法,模型可以在遇到未知类别的图像时给出相应的预测结果。

fc20b9cd4165bbd9ffb671447e74aa6c.png

8f6039662f576b04188c19cdd348a6dc.png

配置

  1. %reload_ext watermark
  2. %watermark -v -p numpy,pandas,torch,torchvision

d75bcc097da662eb6231089c2a4f340e.png

  1. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  2. device

967fcf2e8bd92b5cf03c40e82ea423e4.png

交通标志识别

德国交通标志识别基准(GTSRB)包含了超过50,000张带有40多种交通标志注释的图像。给定一张图像,您需要识别出其中的交通标志。

!unzip -qq GTSRB_Final_Training_Images.zip

444c42eb8eb184d8b8128c1a8a23f01a.png

代码模拟

让我们先来了解一下数据。每个交通标志的图像都存储在一个单独的目录中。我们有多少个?

len(train_folders)

7942ea36bf214a5054be62a557cf492a.png

我们将创建 3 个辅助函数,使用 OpenCV 和 Torchvision 来加载和显示图像:

  1. def load_image(img_path, resize=True):
  2. img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
  3. if resize:
  4. img = cv2.resize(img, (64, 64), interpolation = cv2.INTER_AREA)

让我们看看每个交通标志的一些示例:

  1. sample_images = [np.random.choice(glob(f'{tf}/*ppm')) for tf in train_folders]
  2. show_sign_grid(sample_images)

626bac7adf04734a992d6a6f255607ec.png

这里有一个标志:

  1. img_path = glob(f'{train_folders[16]}/*ppm')[1]
  2. show_image(img_path)

9c677ad71764d5dce87c57adc29d5d78.png


点击标题查阅往期内容

f6ff76fd591eaa300e182820e3e18b1f.jpeg

R语言基于Keras的小数据集深度学习图像分类

outside_default.png

左右滑动查看更多

outside_default.png

01

fbf3ba37358b86e1f0c8aa21532284dd.png

02

a4dfc58e5badc3152eb9ac5ea04a4f66.png

03

f65c2b2445af81e2543552c14d251331.png

04

10a893423998a5ec51c5e2cb09f0ed52.png

建立数据集

为了简单起见,我们将重点对一些最常用的交通标志进行分类:

  1. class_names = ['priority_road', 'give_way', 'stop', 'no_entry']
  2. class_indices = [12, 13, 14, 17]

我们将把图像文件复制到一个新的目录中,以便于使用 Torchvision 的数据集助手。让我们从每个类的目录开始:

  1. for ds in DATASETS:
  2. for cls in class_names:
  3. (DATA_DIR / ds / cls).mkdir(parents=True,

我们将为每个类别保留 80% 的图像用于训练,10% 用于验证,10% 用于测试。将把每张图片复制到正确的数据集目录下:

  1. for i, cls_index in enumerate(class_indices):
  2. image_paths = np.array(glob(f'{train_folders[cls_index]}/*.ppm'))
  3. class_name = class_names[i]

fe5aef2633f719239dc88d07fe6f057f.png

我们的类别不平衡,但并不严重。我们可以忽略它。

我们将应用一些图像增强技术,人为地增加训练数据集的大小:

  1. transforms = {'train': T.Compose([
  2. T.RandomResizedCrop(size=256),
  3. T.RandomRotation(degrees=15),
  4. T.RandomHorizontalFlip(),

我们会随机调整大小、旋转和水平翻转。最后,我们使用每个通道的预设值对张量进行归一化处理。

这是 Torchvision 中预训练模型的要求。

我们将为每个图像数据集文件夹和数据加载器创建一个 PyTorch 数据集,以方便训练:

我们还将存储每个数据集中的示例数量和类名,以备日后使用:

dataset_sizes = {d: len(image_datasets[d]) for d in DATASETS

77e95d00f0485e10fc4a9c0e30998e11.png

让我们来看看一些应用了转换的图像示例。我们还需要反转归一化并重新排列颜色通道,以获得正确的图像数据:

  1. def imshow(inp, title=None):
  2. inp = inp.numpy().transpose((1, 2, 0))
  3. mean = np.array([mean_nums])

c0ebfc46d17e5559357c48fe153b124f.png

使用预训练模型:

我们的模型将接收原始图像像素,并尝试将它们分类为四个交通标志之一。这有多难?试试从头开始建立一个模型。

在这里,我们将使用迁移学习 复制非常流行的ResNet 模型的架构。此外,我们还将使用在 ImageNet 数据集 上训练时学习到的模型权重。Torchvision 让所有这些都变得简单易用:

  1. def create_model(n_classes):
  2. model = models.resnet34(pretrained=True)

除了输出层的变化,我们几乎重复使用了所有内容。这是因为我们数据集中的类数与 ImageNet 不同。

让我们创建一个模型实例:

bf774b8c3fa28b84279cdc0e3fba25e4.png

训练

我们将编写 3 个辅助函数来封装训练和评估逻辑。首先是 train_epoch

  1. loss.backward()
  2. optimizer.step()
  3. optimizer.zero_grad()
  4. scheduler.step()

首先,我们将模型调至训练模式,然后查看数据。在得到预测结果后,我们会得到概率最大的类别以及损失,这样我们就能计算出历时损失和准确率。

请注意,我们还使用了学习率调度器。

  1. losses.append(loss.item())
  2. return correct_predictions.double() / n_examples, np.mean(losses)

除了不进行梯度计算外,对模型的评估非常相似。

让我们把所有东西放在一起:

  1. model.load_state_dict(torch.load('best_model_state.bin'))
  2. return model, history

我们做了大量的字符串格式化和训练历史记录工作。困难的工作会委托给前面的辅助函数。我们还希望获得最佳模型,因此在训练过程中会存储最准确模型的权重。

让我们来训练第一个模型:

b8cd9036bfa40d5ffab075d9408572af.png

这里有一个小辅助函数,可以将训练历史可视化:

  1. plot_training_history(history):
  2. fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))

94df00e0ea65b40292d2cd26bfcb2eed.png

预先训练好的模型非常出色,我们在 3 个历时后获得了非常高的准确率和较低的损失。遗憾的是,我们的验证集太小,无法从中获得一些有意义的指标。

评估

让我们看看测试集中对交通标志的预测:

  1. def show_predictions(model, class_names, n_images=6):
  2. model = model.eval()
  3. images_handeled = 0

8af291fa4db739544eac84fc423dabdf.png

即使是几乎看不见的优先道路标志也能正确分类。让我们再深入一点。

我们先从模型中获取预测结果:

80ababd6193160fa5bd0ab3e5c5e3002.png

show_confusion_matrix(cm, class_names)

355d0844acfe56f764b7f844bf7e5b1d.jpeg

没有错误。

未见图像分类

好了,但当我们面对真实世界的图像时,我们的模型会有多好呢?让我们来看看:

d3ded46af6ab4616957ad1de07055d62.png

show_image('stop-sign.jpg')

0b1e50270a26a5f7094d9664ccfcf13c.png

为此,我们将查看每个类别的置信度。让我们从模型中获取:

predict_proba(base_model, 'stop-sign.jpg')

70afd34279603987df35e7378b7ff455.png

这有点难以理解。让我们来绘制一下:

  1. })
  2. sns.barplot(x='values', y='class_names', data=pred_df, orient='h')
  3. plt.xlim([0, 1]);

9f2dd24af934410199e5c1e59fac0a31.png

我们的模型再次表现出色!对正确的交通标志非常有信心!

分类未知交通标志

我们的模型面临的最后一个挑战是从未见过的交通标志:

40d69b6752bc4ca1a2cc7c9e0cf5647d.png

show_image('unknown-sign.jpg')

d0a10bfd817081f464c3c3fccdf7982c.png

让我们来预测一下:

predict_proba(base_model, 'unknown-sign.jpg')

e14223eba0ca40fad26452e165b0acda.png

d6666116a9ea0c293c313c5007002009.png

我们的模型非常确定(超过 95% 的置信度)这是一个让路信号。这显然是错误的。如何才能让你的模型看到这一点呢?

添加 "未知 "类

虽然有多种方法可以处理这种情况,但我们要做的事情更简单。

我们将获取原始数据集中未包含的所有交通标志的索引:

我们将为未知类创建一个新文件夹,并在其中复制一些图像:

  1. for ds, images in dataset_data:
  2. for img_path in images:
  3. shutil.copy(img_path, f'{DATA_DIR}/{ds}/unknown/')

接下来的步骤与我们已经做的完全相同:

  1. class_names = image_datasets['train'].classes
  2. dataset_sizes

81f027dc5237582df6feb98c91dfb3cc.png

e7186d60c1d2ca4fdd08abb6a27ea0f6.png

raining_history(history)

306e2b8dbb1efe809d449ef2ead18319.png

同样,我们的模型学习速度非常快。让我们再来看看样本图像:

b4db3715e0787b174cd7ebb3ece897a6.png

prediction_confidence(pred, class_names)

36a9c2c171330151ece4fc3aa3d36c27.png

很好,这个模型并不重视任何已知类别。它不知道这是一个双向符号,但却承认它是未知的。

让我们看看新数据集的一些例子:

076dc0ffcaf43ebd728e8b3619e1a25d.png

让我们来了解一下这款新车型的性能:

report(y_test, y_pred, target_names=clas

452e08d1a21142483c006ab725c3eb95.png

dff8fc57fb5f50b55aa0db5c41ecb9ff.png

我们的模型依然完美。

总结

您训练了两种不同的模型,用于根据原始像素对交通标志进行分类。

以下是所学到的内容:

  • 交通标志图像数据集概述

  • 建立数据集

  • 使用 Torchvision 预先训练的模型

  • 添加新的未知类并重新训练模型


资料获取

在公众号后台回复“领资料”,可免费获取数据分析、机器学习、深度学习等学习资料。

955fcc39a1065f35614da0802fcd5d9d.png

点击文末“阅读原文”

获取全文完整代码数据资料。

本文选自《Python迁移学习:用Torchvision、Pytorch进行交通标志图像分类》。

点击标题查阅往期内容

【视频讲解】卷积神经网络CNN肿瘤图像识别3实例附代码数据

Python对商店数据进行lstm和xgboost销售量时间序列建模预测分析

Matlab用深度学习长短期记忆(LSTM)神经网络对文本数据进行分类

RNN循环神经网络 、LSTM长短期记忆网络实现时间序列长期利率预测

结合新冠疫情COVID-19股票价格预测:ARIMA,KNN和神经网络时间序列分析

深度学习:Keras使用神经网络进行简单文本分类分析新闻组数据

用PyTorch机器学习神经网络分类预测银行客户流失模型

PYTHON用LSTM长短期记忆神经网络的参数优化方法预测时间序列洗发水销售数据

Python用Keras神经网络序列模型回归拟合预测、准确度检查和结果可视化

Python用LSTM长短期记忆神经网络对不稳定降雨量时间序列进行预测分析

R语言中的神经网络预测时间序列:多层感知器(MLP)和极限学习机(ELM)数据分析报告

R语言深度学习:用keras神经网络回归模型预测时间序列数据

Matlab用深度学习长短期记忆(LSTM)神经网络对文本数据进行分类

R语言KERAS深度学习CNN卷积神经网络分类识别手写数字图像数据(MNIST)

MATLAB中用BP神经网络预测人体脂肪百分比数据

Python中用PyTorch机器学习神经网络分类预测银行客户流失模型

R语言实现CNN(卷积神经网络)模型进行回归数据分析

SAS使用鸢尾花(iris)数据集训练人工神经网络(ANN)模型

【视频】R语言实现CNN(卷积神经网络)模型进行回归数据分析

Python使用神经网络进行简单文本分类

R语言用神经网络改进Nelson-Siegel模型拟合收益率曲线分析

R语言基于递归神经网络RNN的温度时间序列预测

R语言神经网络模型预测车辆数量时间序列

R语言中的BP神经网络模型分析学生成绩

matlab使用长短期记忆(LSTM)神经网络对序列数据进行分类

R语言实现拟合神经网络预测和结果可视化

用R语言实现神经网络预测股票实例

使用PYTHON中KERAS的LSTM递归神经网络进行时间序列预测

python用于NLP的seq2seq模型实例:用Keras实现神经网络机器翻译

用于NLP的Python:使用Keras的多标签文本LSTM神经网络分类

70cfce7feb9f0d0341e8483e05a6c4b6.jpeg

24faf2e135caa8ad3bb5708a67737271.png

557f13c938eb4e8612dfe2fe2eb0ddef.png

5cba757e4b47e4b08ee47d63bf98c683.jpeg

1e53b55f1e7ac99d88cf1178f49d868f.png

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小舞很执着/article/detail/803185
推荐阅读
相关标签
  

闽ICP备14008679号