当前位置:   article > 正文

2022李宏毅机器学习深度学习学习笔记第十二周--Dataset类代码实战_李宏毅深度学习代码

李宏毅深度学习代码


前言

在学习Dataset类代码实战之前,先了解python 的基础知识,比如初始化方法,类的继承以及self的使用;了解dataset如何获取数据,以及dataset两个重要的方法__getitem__方法和__len__方法。


一、python 基础

给对象增加属性:只需要在类的外部的代码中直接通过 . 设置一个属性即可。
比如Cat()类;创建一个Tom对象,Tom=Cat();为Tom对象增加一个name属性即可通过Tom.name="Tom"设置,name的定义必须在调用方法之前定义,否则会出错。这种方式虽然简单,但是不推荐,他只是在类的外部设置一个属性,对象属性的封装应该封装在类的内部。

class Cat:
    def eat(self):
        print("%s 爱吃鱼"% self.name)
    def drink(self):
        print("小猫要喝水")
tom = Cat()
tom.name = "Tom"
tom.eat()
tom.drink()
print(tom)
lazy_cat = Cat()
lazy_cat.name = "大懒猫"
lazy_cat.eat()
lazy_cat.drink()
print(lazy_cat)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

从控制台中可以看到哪一个对象调用的方法,self就是哪一个对象的引用。在调用方法的时候,不需要传递self参数;在方法内部,可以通过self.访问对象的属性,也可以通过self.调用其他的对象方法。

初始化方法
当使用类名()创建对象时,会自动执行以下操作:
1.为对象在内存中分配空间–创建对象
2.为对象的属性设置初始值–初始化方法(init)
这个初始化方法就是__init__方法,__init__是对象的内置方法
__init__方法是专门用来定义一个类具有哪些属性的方法。
在初始化方法内部定义属性:
在__init__方法内部使用self.属性名=属性的初始值就可以定义属性;在定义属性之后,在使用Cat类船舰对象,都会拥有该属性。

class Cat:
    def __init__(self):
        self.name="tom"
tom = Cat()
print(tom.name)  #输出tom
  • 1
  • 2
  • 3
  • 4
  • 5

初始化的同时设置初始值
如果希望在创建对象的同时就设置对象的属性,可以对__init__方法进行改造,把希望设置的属性值定义成__init__方法的参数,在方法内部使用self.属性=形参接受外部传递的参数,在创建对象时,使用类名(属性1,属性2.。。。)调用。在python中,使用print输出对象变量,默认会输出这个变量引用的对象是由哪一个类创建的对象,以及在内存中的地址(十六进制)。如果在开发中希望print输出对象变量时,能够打印自定义的内容,就可以利用__str__这个内置方法了。__str__方法必须返回一个字符串。

继承的语法:子类继承自父类,直接享受父类中封装好的方法,还可以封装子类特有的属性和方法。
class 类名(父类名):
pass
专业术语:狗类是动物类的派生类,动物类是狗类的基类,狗类从动物类派生。
子类拥有父类以及父类的父亲中封装的所有属性和方法。

二、Dataset类代码实战

Dataset 主要是提供一种方式去获取数据及其label。
Dataset 是一个抽象类,所有的数据集都需要继承这个类,所以的子类都需要重写 __getitem__方法,这个方法主要是获取每个数据及其对应的label,还可以重写__len__方法,返回数据集的大小。
实例:
先创建一个类MyData,继承自Dataset类

class MyData(Dataset):
  • 1

初始化
使用以下方法来读取图片,

from PIL import Image
  • 1

在这个例子中一般蚂蚁/蜜蜂的图片是数据或者input,而ants/bees是labels,想要获取图片,先提供图片的地址。
获取图片的地址(注意使用两个斜杠表示转义)

img_path="E:\\学习笔记\\learn_pytorch\\dataset\\train\\ants\\0013035.jpg"
  • 1

读取图片通过img=Image.open(img_path)
在Python控制台看到读取出来的 img,是一个JpegImageFile类的对象,img里有很多属性,可以通过img.进行查看。
在这里插入图片描述
img.show()可以查看图片。
获取所有图片
从数据集路径中,获取所有文件的名字,存储到一个列表中;listdir方法会将路径下的所有文件名(包括后缀名)组成一个列表

import os
dir_path = "dataset/train/ants"
img_path_list = os.listdir(dir_path)
  • 1
  • 2
  • 3

用索引去访问列表中的每个文件名

img_path_list[0]
Out[14]: '0013035.jpg'
root_dir = "dataset/train"
child_dir = "ants"
path = os.path.join(root_dir, child_dir)
  • 1
  • 2
  • 3
  • 4
  • 5

使用os.path.join方法,将两个路径拼接起来,就得到了ants子数据集的相对路径;有了这个数据集的路径后,就可以使用之前所讲的listdir方法,获取这个路径中所有文件的文件名,存储到一个列表中。

img_path_list = os.listdir(path)
idx = 0
img_path_list[idx]
Out[20]: '0013035.jpg'     
  • 1
  • 2
  • 3
  • 4

可以将这个文件名与路径进行组合,然后使用PIL获取具体的图像img对象。

img_name = img_path_list[idx]
img_item_path = os.path.join(root_dir, child_dir, img_name)
img = Image.open(img_item_path)
  • 1
  • 2
  • 3

__init__方法:获取根目录路径、子目录路径;然后将两个路径进行组合,就得到了目标数据集的路径;这个路径作为参数传入listdir函数,从而让img_path_list中存储该目录下所有文件名;通过索引就可以轻松获取每个文件名。
__getitem__方法:在__getitem__方法中,默认会有一个 item 参数,常命名为 idx,这个参数是一个索引编号,用于对我们初始化中得到的文件名列表进行索引访问,我们就得到了具体的文件名,然后与根目录、子目录再次组装,得到具体数据的相对路径,我们可以通过这个路径获取到索引编号对应的数据对象本身。

我们还可以将两个数据集对象进行组合,组合成一个大的数据集对象。

train_dataset = ants_dataset + bees_dataset
  • 1

我们在python Console中看这三个数据集对象的大小:

len1 = len(ants_dataset)
len2 = len(bees_dataset)
len3 = len(train_dataset)
  • 1
  • 2
  • 3

输出:124 121 245 ;刚好124+121=245。

完整的代码:

from torch.utils.data import Dataset
from PIL import Image
# 读取图片
import os
# 创建类MyData继承Dataset类
class MyData(Dataset):

    def __init__(self,root_dir,label_dir):
        self.root_dir=root_dir
        self.label_dir=label_dir
        self.path=os.path.join(self.root_dir,self.label_dir)
        #获取图片的路径地址
        self.img_path=os.listdir(self.path)
        #获得图片列表
    def __getitem__(self, idx):
        img_name=self.img_path[idx]
        #图片名称
        img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
        #每张图片的地址
        img = Image.open(img_item_path)
        label = self.label_dir
        return img, label
    def __len__(self):
        return len(self.img_path)

root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset+bees_dataset

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/263716
推荐阅读
  

闽ICP备14008679号