当前位置:   article > 正文

水果数据集(Fruit-Dataset )+水果分类识别训练代码(支持googlenet, resnet, inception_v3, mobilenet_v2)

水果数据集

水果数据集(Fruit-Dataset )+水果分类识别训练代码(支持googlenet, resnet, inception_v3, mobilenet_v2)

目录

水果数据集(Fruit-Dataset )+水果分类识别训练代码(支持googlenet, resnet, inception_v3, mobilenet_v2)

1. 前言

2. Fruit-Dataset水果数据集

(1)Fruit-Dataset

(2)Fruits 360蔬果数据集

(3)自定义数据集

3. 水果分类识别模型训练

(1)项目安装

(2)准备Train和Test数据

(3)配置文件:​config.yaml​

(4)开始训练

(5)可视化训练过程

(6)一些优化建议

(7) 一些运行错误处理方法:

cannot import name 'load_state_dict_from_url' 

4. 水果分类识别模型测试效果

5.项目源码下载


1. 前言

本项目将采用深度学习的方法,搭建一个水果分类识别的训练和测试系统,实现一个简单的水果图像分类识别系统。目前,基于ResNet18的水果分类识别,支持262种水果分类识别,在水果数据集Fruit-Dataset上,训练集的Accuracy在95%左右,测试集的Accuracy在83%左右,骨干网络,可支持googlenet, resnet[18,34,50], inception_v3,mobilenet_v2等常用模型。如果想进一步提高准确率,可以尝试:

  1. 最重要的: 清洗数据集,水果数据集Fruits-Dataset,部分数据是通过网上爬取的,存在部分错误的图片,尽管鄙人已经清洗一部分了,但还是建议你,训练前,再次清洗数据集,不然会影响模型的识别的准确率。
  2. 减少种类:Fruit-Dataset共有262种类水果,可以剔除部分不常见的水果
  3. 使用不同backbone模型,比如resnet50或者更深的模型
  4. 增加数据增强: 已经支持: 随机裁剪,随机翻转,随机旋转,颜色变换等数据增强方式,可以尝试诸如mixup,CutMix等更复杂的数据增强方式
  5. 样本均衡: 建议进行样本均衡处理
  6. 调超参: 比如学习率调整策略,优化器(SGD,Adam等)
  7. 损失函数: 目前训练代码已经支持:交叉熵,LabelSmoothing,可以尝试FocalLoss等损失函数

【源码下载】Fruit-Dataset水果数据集+水果分类识别训练代码

【尊重原创,转载请注明出处】https://panjinquan.blog.csdn.net/article/details/126411788

2. Fruit-Dataset水果数据集

(1)Fruit-Dataset

这里分享一个水果数据集Fruit -Dataset,该数据集包含 262 种不同种类的水果,包含常见的苹果(apple ),香蕉(banana )等种类,总共有225,640 张水果图像,可满足深度学习水果种类分类识别的需求。

 Fruit-Dataset包含的262种水果,分别是:

abiu, acai, acerola, ackee, alligator apple, ambarella, apple, apricot, araza, avocado, bael, banana, barbadine, barberry, bayberry, beach plum, bearberry, bell pepper, betel nut, bignay, bilimbi, bitter gourd, black berry, black cherry, black currant, black mullberry, black sapote, blueberry, bolwarra, bottle gourd, brazil nut, bread fruit, buddha s hand, buffaloberry, burdekin plum, burmese grape, caimito, camu camu, canistel, cantaloupe, cape gooseberry, carambola, cardon, cashew, cedar bay cherry, cempedak, ceylon gooseberry, che, chenet, cherimoya, cherry, chico, chokeberry, clementine, cloudberry, cluster fig, cocoa bean, coconut, coffee, common buckthorn, corn kernel, cornelian cherry, crab apple, cranberry, crowberry, cupuacu, custard apple, damson, date, desert fig, desert lime, dewberry, dragonfruit, durian, eggplant, elderberry, elephant apple, emblic, entawak, etrog, feijoa, fibrous satinash, fig, finger lime, galia melon, gandaria, genipap, goji, gooseberry, goumi, grape, grapefruit, greengage, grenadilla, guanabana, guarana, guava, guavaberry, hackberry, hard kiwi, hawthorn, hog plum, honeyberry, honeysuckle, horned melon, illawarra plum, indian almond, indian strawberry, ita palm, jaboticaba, jackfruit, jalapeno, jamaica cherry, jambul, japanese raisin, jasmine, jatoba, jocote, jostaberry, jujube, juniper berry, kaffir lime, kahikatea, kakadu plum, keppel, kiwi, kumquat, kundong, kutjera, lablab, langsat, lapsi, lemon, lemon aspen, leucaena, lillipilli, lime, lingonberry, loganberry, longan, loquat, lucuma, lulo, lychee, mabolo, macadamia, malay apple, mamey apple, mandarine, mango, mangosteen, manila tamarind, marang, mayhaw, maypop, medlar, melinjo, melon pear, midyim, miracle fruit, mock strawberry, monkfruit, monstera deliciosa, morinda, mountain papaya, mountain soursop, mundu, muskmelon, myrtle, nance, nannyberry, naranjilla, native cherry, native gooseberry, nectarine, neem, nungu, nutmeg, oil palm, old world sycomore, olive, orange, oregon grape, otaheite apple, papaya, passion fruit, pawpaw, pea, peanut, pear, pequi, persimmon, pigeon plum, pigface, pili nut, pineapple, pineberry, pitomba, plumcot, podocarpus, pomegranate, pomelo, prikly pear, pulasan, pumpkin, pupunha, purple apple berry, quandong, quince, rambutan, rangpur, raspberry, red mulberry, redcurrant, riberry, ridged gourd, rimu, rose hip, rose myrtle, rose-leaf bramble, saguaro, salak, salal, salmonberry, sandpaper fig, santol, sapodilla, saskatoon, sea buckthorn, sea grape, snowberry, soncoya, strawberry, strawberry guava, sugar apple, surinam cherry, sycamore fig, tamarillo, tangelo, tanjong, taxus baccata, tayberry, texas persimmon, thimbleberry, tomato, toyon, ugli fruit, vanilla, velvet tamarind, watermelon, wax gourd, white aspen, white currant, white mulberry, white sapote, wineberry, wongi, yali pear, yellow plum, yuzu, zigzag vine, zucchini
Fruit-Dataset数据说明:
  • 一个目录名代表一个标签,每个目录中该标签下的所有图像数据(图像有编号,但可能缺少数字。
  • 同一种水果的不同品种一般存放在同一个目录下(例如:青苹果、黄苹果和红苹果)。
  • 数据集中存在的水果图像可以包含水果在其生命的所有阶段,也可以包含水果切片。
  • 图像包含至少 50% 的水果信息。
  • 图像的背景可以是任何东西:单色背景、人手、水果的自然栖息地、树叶等。
  • 没有重复的图像,但有一些图像(具有相同标签)具有高度相似性。
  • 图像可能包含小水印。
  • 部分不常见的水果,数据较难采集,只有 50~100 张图像,实际工程中,可以丢弃以获得更好的平衡和更少的种类。

水果数据集Fruit-Dataset,水果数据集Fruit-Dataset,部分数据是通过网上爬取的,存在部分错误的图片,尽管鄙人已经清洗一部分了,但还是建议你,训练前,再次清洗数据集,不然会影响模型的识别的准确率。

(2)Fruits 360蔬果数据集

Fruits 360蔬果数据集包含131种不同的水果和蔬菜,共含有90483张图片,其中

  • 训练集⼤⼩:67692张图像(每张图像⼀个⽔果),
  • 测试集⼤⼩:22688张图像(每张图像⼀个⽔果)
  • ⽂件名格式:图像索引_100.jpg(例如32_100.jpg)或r_图像索引_100.jpg(例如r_32_100.jpg)或r2_图像索引_100.jpg或r3_图像索引_100.jpg。“ r”代表旋转的⽔果。“ r2”表⽰⽔果绕第三轴旋转。“100”来⾃图像尺⼨(100x100像素)。同⼀⽔果(例如苹果)的不同品种被存储为属于不同类别。

 具体有以下种类:

Apples (different varieties: Crimson Snow, Golden, Golden-Red, Granny Smith, Pink Lady, Red, Red Delicious), Apricot, Avocado, Avocado ripe, Banana (Yellow, Red, Lady Finger), Beetroot Red, Blueberry, Cactus fruit, Cantaloupe (2 varieties), Carambula, Cauliflower, Cherry (different varieties, Rainier), Cherry Wax (Yellow, Red, Black), Chestnut, Clementine, Cocos, Dates, Eggplant, Ginger Root, Granadilla, Grape (Blue, Pink, White (different varieties)), Grapefruit (Pink, White), Guava, Hazelnut, Huckleberry, Kiwi, Kaki, Kohlrabi, Kumsquats, Lemon (normal, Meyer), Lime, Lychee, Mandarine, Mango (Green, Red), Mangostan, Maracuja, Melon Piel de Sapo, Mulberry, Nectarine (Regular, Flat), Nut (Forest, Pecan), Onion (Red, White), Orange, Papaya, Passion fruit, Peach (different varieties), Pepino, Pear (different varieties, Abate, Forelle, Kaiser, Monster, Red, Williams), Pepper (Red, Green, Yellow), Physalis (normal, with Husk), Pineapple (normal, Mini), Pitahaya Red, Plum (different varieties), Pomegranate, Pomelo Sweetie, Potato (Red, Sweet, White), Quince, Rambutan, Raspberry, Redcurrant, Salak, Strawberry (normal, Wedge), Tamarillo, Tangelo, Tomato (different varieties, Maroon, Cherry Red, Yellow), Walnut.

Fruits 360蔬果数据集的图片质量特别高,很干净,几乎每张图片的水果背景都是白色的(可能是被抠出背景了),而且存在很多旋转拍照角度的高度相似的图片。因此这种水果数据集,不太符合实际业务需求,毕竟实际应用中,我们不太可能将图片抠除背景再进行水果识别,这识别成本太高了。

(3)自定义数据集

如果需要新增类别数据,或者需要自定数据集进行训练,可以如下进行处理:

  • Train和Test数据集,要求相同类别的图片,放在同一个文件夹下;且子目录文件夹命名为类别名称,如

  • 类别文件:一行一个列表:​class_name.txt​
     (最后一行,请多回车一行)
  1. A
  2. B
  3. C
  4. D
  • 修改配置文件的数据路径:​config.yaml​
  1. train_data: # 可添加多个数据集
  2. - 'data/dataset/train1'
  3. - 'data/dataset/train2'
  4. test_data: 'data/dataset/test'
  5. class_name: 'data/dataset/class_name.txt'
  6. ...
  7. ...

3. 水果分类识别模型训练

考虑到Fruits 360蔬果数据集比较简单,且不合适用于实际应用中,因此本项目以Fruit-Dataset水果数据集为训练样本。

(1)项目安装

整套工程基本框架结构如下:

  1. .
  2. ├── classifier                 # 训练模型相关工具
  3. ├── configs                    # 训练配置文件
  4. ├── data                      # 训练数据
  5. ├── libs           
  6. ├── demo.py              # 模型推理demo
  7. ├── README.md            # 项目工程说明文档
  8. ├── requirements.txt    # 项目相关依赖包
  9. └── train.py             # 训练文件

   项目依赖python包请参考requirements.txt,使用pip安装即可:

  1. numpy==1.16.3
  2. matplotlib==3.1.0
  3. Pillow==6.0.0
  4. easydict==1.9
  5. opencv-contrib-python==4.5.2.52
  6. opencv-python==4.5.1.48
  7. pandas==1.1.5
  8. PyYAML==5.3.1
  9. scikit-image==0.17.2
  10. scikit-learn==0.24.0
  11. scipy==1.5.4
  12. seaborn==0.11.2
  13. tensorboard==2.5.0
  14. tensorboardX==2.1
  15. torch==1.7.1+cu110
  16. torchvision==0.8.2+cu110
  17. tqdm==4.55.1
  18. xmltodict==0.12.0
  19. basetrainer
  20. pybaseutils==0.6.5

 项目安装教程请参考(初学者入门,麻烦先看完下面教程,配置好开发环境):

(2)准备Train和Test数据

下载水果分类数据集,Train和Test数据集,要求相同类别的图片,放在同一个文件夹下;且子目录文件夹命名为类别名称。

数据增强方式主要采用: 随机裁剪,随机翻转,随机旋转,颜色变换等处理方式

  1. import numbers
  2. import random
  3. import PIL.Image as Image
  4. import numpy as np
  5. from torchvision import transforms
  6. def image_transform(input_size, rgb_mean=[0.5, 0.5, 0.5], rgb_std=[0.5, 0.5, 0.5], trans_type="train"):
  7. """
  8. 不推荐使用:RandomResizedCrop(input_size), # bug:目标容易被crop掉
  9. :param input_size: [w,h]
  10. :param rgb_mean:
  11. :param rgb_std:
  12. :param trans_type:
  13. :return::
  14. """
  15. if trans_type == "train":
  16. transform = transforms.Compose([
  17. transforms.Resize([int(128 * input_size[1] / 112), int(128 * input_size[0] / 112)]),
  18. transforms.RandomHorizontalFlip(), # 随机左右翻转
  19. # transforms.RandomVerticalFlip(), # 随机上下翻转
  20. transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1),
  21. transforms.RandomRotation(degrees=5),
  22. transforms.RandomCrop([input_size[1], input_size[0]]),
  23. transforms.ToTensor(),
  24. transforms.Normalize(mean=rgb_mean, std=rgb_std),
  25. ])
  26. elif trans_type == "val" or trans_type == "test":
  27. transform = transforms.Compose([
  28. transforms.Resize([input_size[1], input_size[0]]),
  29. # transforms.CenterCrop([input_size[1], input_size[0]]),
  30. # transforms.Resize(input_size),
  31. transforms.ToTensor(),
  32. transforms.Normalize(mean=rgb_mean, std=rgb_std),
  33. ])
  34. else:
  35. raise Exception("transform_type ERROR:{}".format(trans_type))
  36. return transform

修改配置文件数据路径:​config.yaml​

  1. # 训练数据集,可支持多个数据集
  2. train_data:
  3. - '/path/to/Fruit-Dataset/train'
  4. # 测试数据集
  5. test_data: '/path/to/Fruit-Dataset/test'
  6. # 类别文件
  7. class_name: '/path/to/Fruit-Dataset/class_name.txt'

(3)配置文件:​config.yaml​

  • 目前支持的backbone有:googlenet,resnet[18,34,50],inception_v3,mobilenet_v2等, 其他backbone可以自定义添加
  • 训练参数可以通过(configs/config.yaml)配置文件进行设置

 配置文件:​config.yaml​说明如下:

  1. # 训练数据集,可支持多个数据集
  2. train_data:
  3. - '/path/to/Fruit-Dataset/train'
  4. # 测试数据集
  5. test_data: '/path/to/Fruit-Dataset/test'
  6. # 类别文件
  7. class_name: '/path/to/Fruit-Dataset/class_name.txt'
  8. train_transform: "train" # 训练使用的数据增强方法
  9. test_transform: "val" # 测试使用的数据增强方法
  10. work_dir: "work_space/" # 保存输出模型的目录
  11. net_type: "resnet18" # 骨干网络,支持:resnet18/50,mobilenet_v2,googlenet,inception_v3
  12. width_mult: 1.0
  13. input_size: [ 224,224 ] # 模型输入大小
  14. rgb_mean: [ 0.5, 0.5, 0.5 ] # for normalize inputs to [-1, 1],Sequence of means for each channel.
  15. rgb_std: [ 0.5, 0.5, 0.5 ] # for normalize,Sequence of standard deviations for each channel.
  16. batch_size: 32
  17. lr: 0.01 # 初始学习率
  18. optim_type: "SGD" # 选择优化器,SGD,Adam
  19. loss_type: "CrossEntropyLoss" # 选择损失函数:支持CrossEntropyLoss,LabelSmoothing
  20. momentum: 0.9 # SGD momentum
  21. num_epochs: 100 # 训练循环次数
  22. num_warn_up: 3 # warn-up次数
  23. num_workers: 8 # 加载数据工作进程数
  24. weight_decay: 0.0005 # weight_decay,默认5e-4
  25. scheduler: "multi-step" # 学习率调整策略
  26. milestones: [ 20,50,80 ] # 下调学习率方式
  27. gpu_id: [ 0 ] # GPU ID
  28. log_freq: 50 # LOG打印频率
  29. progress: True # 是否显示进度条
  30. pretrained: False # 是否使用pretrained模型
  31. finetune: False # 是否进行finetune

参数类型参考值说明
train_datastr, list-训练数据文件,可支持多个文件
test_datastr, list-测试数据文件,可支持多个文件
class_namestr-类别文件
work_dirstrwork_space训练输出工作空间
net_typestrresnet18
backbone类型,{resnet18/50,mobilenet_v2,googlenet,inception_v3}
input_sizelist[128,128]模型输入大小[W,H]
batch_sizeint32batch size
lrfloat0.1初始学习率大小
optim_typestrSGD优化器,{SGD,Adam}
loss_typestrCELoss损失函数
schedulerstrmulti-step学习率调整策略,{multi-step,cosine}
milestoneslist[30,80,100]降低学习率的节点,仅仅scheduler=multi-step有效
momentumfloat0.9SGD动量因子
num_epochsint120循环训练的次数
num_warn_upint3warn_up的次数
num_workersint12DataLoader开启线程数
weight_decayfloat5e-4权重衰减系数
gpu_idlist[ 0 ]指定训练的GPU卡号,可指定多个
log_freqin20显示LOG信息的频率
finetunestrmodel.pthfinetune的模型
progressboolTrue是否显示进度条
distributedboolFalse是否使用分布式训练

(4)开始训练

整套训练代码非常简单操作,用户只需要将相同类别的数据放在同一个目录下,并填写好对应的数据路径,即可开始训练了。

python train.py -c configs/config.yaml 

(5)可视化训练过程

训练过程可视化工具是使用Tensorboard,使用方法,在终端输入:
  1. # 基本方法
  2. tensorboard --logdir=path/to/log/
  3. # 例如
  4. tensorboard --logdir=work_space/mobilenet_v2_1.0_CrossEntropyLoss/log

可视化效果 

(6)一些优化建议

训练完成后,在水果数据集Fruit-Dataset上,训练集的Accuracy在95%左右,测试集的Accuracy在83%左右,骨干网络,可支持googlenet, resnet[18,34,50], inception_v3,mobilenet_v2等常用模型。如果想进一步提高准确率,可以尝试:

  1. 最重要的: 清洗数据集,水果数据集Fruit-Dataset,部分数据是通过网上爬取的,存在部分错误的图片,尽管鄙人已经清洗一部分了,但还是建议你,训练前,再次清洗数据集,不然会影响模型的识别的准确率。
  2. 使用不同backbone模型,比如resnet50或者更深的模型
  3. 增加数据增强: 已经支持: 随机裁剪,随机翻转,随机旋转,颜色变换等数据增强方式,可以尝试诸如mixup,CutMix等更复杂的数据增强方式
  4. 样本均衡: 建议进行样本均衡处理
  5. 调超参: 比如学习率调整策略,优化器(SGD,Adam等)
  6. 损失函数: 目前训练代码已经支持:交叉熵,LabelSmoothing,可以尝试FocalLoss等损失函数

(7) 一些运行错误处理方法:

  • 项目不要出现含有中文字符的目录文件或路径,否则会出现很多异常!!!!!!!!

  • cannot import name 'load_state_dict_from_url' 

由于一些版本升级,会导致部分接口函数不能使用,请确保版本对应

torch==1.7.1

torchvision==0.8.2

或者将对应python文件将

from torchvision.models.resnet import model_urls, load_state_dict_from_url

修改为:

  1. from torch.hub import load_state_dict_from_url
  2. model_urls = {
  3. 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
  4. 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
  5. 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
  6. 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
  7. 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
  8. 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
  9. 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
  10. 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
  11. 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
  12. 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
  13. }

4. 水果分类识别模型测试效果

 demo.py文件用于推理和测试模型的效果,填写好配置文件,模型文件以及测试图片即可运行测试了

  1. def get_parser():
  2. # 配置文件
  3. config_file = "data/pretrained/resnet18_1.0_CrossEntropyLoss_20220826100725/config.yaml"
  4. # 模型文件
  5. model_file = "data/pretrained/resnet18_1.0_CrossEntropyLoss_20220826100725/model/best_model_098_83.5305.pth"
  6. # 待测试图片目录
  7. image_dir = "data/test_images/fruit"
  8. parser = argparse.ArgumentParser(description="Inference Argument")
  9. parser.add_argument("-c", "--config_file", help="configs file", default=config_file, type=str)
  10. parser.add_argument("-m", "--model_file", help="model_file", default=model_file, type=str)
  11. parser.add_argument("--device", help="cuda device id", default="cuda:0", type=str)
  12. parser.add_argument("--image_dir", help="image file or directory", default=image_dir, type=str)
  13. return parser
  1. #!/usr/bin/env bash
  2. # Usage:
  3. # python demo.py -c "path/to/config.yaml" -m "path/to/model.pth" --image_dir "path/to/image_dir"
  4. python demo.py -c "data/pretrained/resnet18_1.0_CrossEntropyLoss_20220826100725/config.yaml" -m "data/pretrained/resnet18_1.0_CrossEntropyLoss_20220826100725/model/best_model_098_83.5305.pth" --image_dir "data/test_images/fruit"

运行测试结果: 

pred_index:['apple'],pred_score:[0.9730666]

 pred_index:['apple'],pred_score:[0.8644004]

 pred_index:['banana'],pred_score:[0.9996606]

 pred_index:['banana'],pred_score:[0.99923694]


5.项目源码下载

整套项目源码内容包含:Fruit-Dataset水果数据集+水果分类识别训练代码

  • Fruit-Dataset水果数据集: 该数据集包含 262 种不同种类的水果,包含常见的苹果(apple),香蕉(banana)等种类,总共有225,640 张水果图像,可满足深度学习水果种类分类识别的需求
  • Fruits 360蔬果数据集: 包含131种不同的水果和蔬菜,共含有90483张图片
  • 支持自定义数据集训练
  • 整套水果分类训练代码和测试代码(Pytorch版本), 支持的backbone骨干网络模型有:googlenet,resnet[18,34,50],inception_v3,mobilenet_v2等, 其他backbone可以自定义添加

 【源码下载】Fruit-Dataset水果数据集+水果分类识别训练代码

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/678107
推荐阅读
  

闽ICP备14008679号