当前位置:   article > 正文

训练猫狗数据集(及图像增强后训练)

猫狗数据集

一、所需环境(安装附链接)

tensorflow和keras,具体版本安装看个人所需。
安装链接如下:
https://blog.csdn.net/qq_41760767/article/details/97441967?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.nonecase&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.nonecase
安装好后查看keras版本
在这里插入图片描述

二、数据集准备

下载图像数据集train,在Home目录下新建子目录"data",把下载的图像数据集train复制到"data"目录。具体如图所示:
在这里插入图片描述

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
from IPython.display import Image
root_dir = os.getcwd()
data_path = os.path.join(root_dir,'data')
#根据目录路径
root_dir = os.getcwd()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
#存放数据集的目录
data_path = os.path.join(root_dir,'data')
import os,shutil

#原始的数据集目录
original_dataset_dir = os.path.join(data_path,'train')

#存储小数据集的目录
base_dir = os.path.join(data_path,'cats_and_dogs_small')
if not os.path.exists(base_dir):
    os.mkdir(base_dir)

#训练图像的目录
train_dir = os.path.join(base_dir,'train')
if not os.path.exists(train_dir):
    os.mkdir(train_dir)
#验证图像的目录
validation_dir = os.path.join(base_dir,'validation')
if not os.path.exists(validation_dir):
    os.mkdir(validation_dir)
#测试资料的目录
test_dir = os.path.join(base_dir,'test')
if not os.path.exists(test_dir):
    os.mkdir(test_dir)

#猫的图片的训练资料的目录
train_cats_dir = os.path.join(train_dir,'cats')
ifnot os.path.exists(train_cats_dir):
    os.mkdir(train_cats_dir)
#狗的图片的训练资料的目录
train_dogs_dir = os.path.join(train_dir,'dogs')
if not os.path.exists(train_dogs_dir):
    os.mkdir(train_dogs_dir)

#猫的图片的测试数据集目录
test_cats_dir = os.path.join(test_dir,'cats')
if not os.path.exists(test_cats_dir):
    os.mkdir(test_cats_dir)
#狗的图片的测试数据集目录
test_dogs_dir = os.path.join(test_dir,'dogs')
if not os.path.exists(test_dogs_dir):
    os.mkdir(test_dogs_dir)
#复制前600个猫的图片到train_cats_dir
fnames = ['cat.{}.jpg'.format(i) for i in range(600)]
for fname in fnames:
    src = os.path.join(original_dataset_dir,fname)
    dst = os.path.join(train_cats_dir,fname)
    if not os.path.exists(dst):
        shutil.copyfile(src,dst)
print("Copy next 600 cat images to train_cats_dir complete!")
#复制后面400个猫的图片到validation_cats_dir
fnames = ['cat.{}.jpg'.format(i) for i in range(1000,1400)]
for fname in fnames:
    src = os.path.join(original_dataset_dir,fname)
    dst = os.path.join(validation_cats_dir,fname)
    if not os.path.exists(dst):
        shutil.copyfile(src,dst)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/385261
推荐阅读
相关标签
  

闽ICP备14008679号