当前位置:   article > 正文

pytorch搭建训练自己数据集的模型(预处理、读取自己的图片、进行训练和测试、保存模型、加载模型和测试)_pycharm跑数据模型

pycharm跑数据模型

第一阶段:读取图片并保存为.txt

import os
import random

#把训练集和测试集分为8:2
train_ratio = 0.8
test_ratio = 1 - train_ratio

rootdata = '/home/hsy/PycharmProjects/数据集/5月下旬'

train_list, test_list = [], []

data_list = []

#图片的标签
class_flag = -1

'''
要取得该文件夹下的所有文件,可以使用 for(root,dirs,files) in walk(roots)函数
roots:代表需要便利的根文件夹
root: 表示正在遍历的文件夹的名字
dirs:记录正在遍历的文件夹中的文件
'''
for root, dirs, files in os.walk(rootdata):

    for i in range(len(files)):
        '''
        os.path.join()函数:连接两个或者更多的路径名组价你
        1.如果各组件首字母不包含'/',则函数会自动加上
        2.如果一个组件是一个绝对路径,则在它之前的所有组件均会被舍弃
        3.如果最后一个组件为空,则成一个路径以一个'/'分隔符结尾

        root='/home/hsy/PycharmProjects/数据集/5月下旬/train/鱼腥草'
        files[i]='yuxingcao_1.jpg'

        os.path.join(root,files[i])='/home/hsy/PycharmProjects/数据集/5月下旬/train/鱼腥草/yuxingcao_1.jpg'
        '''

        data_list.append(os.path.join(root, files[i]))

    for i in range(0, int(len(files) * train_ratio)):
        train_data = os.path.join(root, files[i]) + '\t' + str(class_flag) + '\n'
        train_list.append(train_data)

    for i in range(int(len(files) * train_ratio), len(files)):
        test_data = os.path.join(root, files[i]) + '\t' + str(class_flag) + '\n'
        test_list.append(test_data)

    class_flag += 1

# print(train_list)

# 将数据打乱
random.shuffle(train_list)
random.shuffle(test_list)


# 保存到txt
with open('../data/train.txt', 'w', encoding='UTF-8') as f:
    for train in train_list:
        f.write(train)

with open('../data/test.txt', 'w', encoding='UTF-8') as f:
    for test in test_list:
        f.write(test)


print(test_list)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67

在这里插入图片描述
train.txt

/home/hsy/PycharmProjects/数据集/5月下旬/瞿麦/qumai_109.jpg	16
/home/hsy/PycharmProjects/数据集/5月下旬/洋金花/yangjinhua_33.jpg	4
/home/hsy/PycharmProjects/数据集/5月下旬/萱草/xuancao_1.jpg	19
/home/hsy/PycharmProjects/数据集/5月下旬/紫苏/zisu_137.jpg	12
/home/hsy/PycharmProjects/数据集/5月下旬/香加皮/xiangjiapi_50.jpg	17
/home/hsy/PycharmProjects/数据集/5月下旬/紫苏/zisu_117.jpg	12
/home/hsy/PycharmProjects/数据集/5月下旬/紫苏/zisu_136.jpg	12
/home/hsy/PycharmProjects/数据集/5月下旬/洋金花/yangjinhua_28.jpg	4
/home/hsy/PycharmProjects/数据集/5月下旬/金芥麦/jinjiemai_107.jpg	6
/home/hsy/PycharmProjects/数据集/5月下旬/何首乌/heshouwu_42.jpg	3
	.......
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

test.txt

/home/hsy/PycharmProjects/数据集/5月下旬/垂盆草/chuipencao_7.jpg	18
/home/hsy/PycharmProjects/数据集/5月下旬/夏枯草/xiakucao_124.jpg	2
/home/hsy/PycharmProjects/数据集/5月下旬/车前草/cheqiancao_106.jpg	8
/home/hsy/PycharmProjects/数据集/5月下旬/京大戟/jingdaji_39.jpg	7
/home/hsy/PycharmProjects/数据集/5月下旬/射干/shegan_76.jpg	5
/home/hsy/PycharmProjects/数据集/5月下旬/夏枯草/xiakucao_151.jpg	2
/home/hsy/PycharmProjects/数据集/5月下旬/牛蒡子/niubangzi_184.jpg	1
/home/hsy/PycharmProjects/数据集/5月下旬/决明子/juemingzi_100.jpg	10
/home/hsy/PycharmProjects/数据集/5月下旬/瞿麦/qumai_23.jpg	16
/home/hsy/PycharmProjects/数据集/5月下旬/紫苏/zisu_105.jpg	12
/home/hsy/PycharmProjects/数据集/5月下旬/决明子/juemingzi_92.jpg	10
/home/hsy/PycharmProjects/数据集/5月下旬/鱼腥草/yuxingcao_45.jpg	0
/home/hsy/PycharmProjects/数据集/5月下旬/紫苏/zisu_24.jpg	12
/home/hsy/PycharmProjects/数据集/5月下旬/金芥麦/jinjiemai_98.jpg	6
.......
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

第二阶段:改写Dataset,保证下阶段读取自己的数据集

import torch
from  PIL import  Image
import os
from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
from torchvision import  transforms,utils,datasets
import numpy as np


#图像标准化
# transform_BN=transforms.Normalize((0.485,0.456,0.406),(0.226,0.224,0.225))


class LoadData(Dataset):
    def __init__(self,txt_path,train_flag=True):
        self.imgs_info=self.get_imags(txt_path)
        self.train_flag=train_flag

        self.transform_train=transforms.Compose([

            # #随机水平翻转
            # transforms.RandomHorizontalFlip(),
            # #随机垂直翻转
            # transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])

        ])

        self.transform_test=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
        ])

    def get_imags(self, txt_path):
        with open(txt_path,'r',encoding='UTF-8') as f:
            imgs_info=f.readlines()
            imgs_info=list(map(lambda x:x.strip().split('\t'),imgs_info))

        return imgs_info


    def __getitem__(self, index):
        img_path,label=self.imgs_info[index]

        img=Image.open(img_path)

        img=img.convert("RGB")

        if self.train_flag:
            img=self.transform_train(img)
        else:
            img=self.transform_test(img)

        label=int(label)

        #返回打开的图片和它的标签
        return img,label

    def __len__(self):
        return len(self.imgs_info)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61

第三阶段:读取自己的数据集并训练和测试

from torch import optim
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import time

from data.CreateDataloader import LoadData

def load_dataset(batch_size):

    train_set=LoadData("../data/train.txt",True)
    test_set=LoadData("../data/test.txt",False)

    train_iter=torch.utils.data.DataLoader(
        dataset=train_set,batch_size=batch_size,shuffle=True,num_workers=4
    )

    test_iter=torch.utils.data.DataLoader(
        dataset=test_set,batch_size=batch_size,shuffle=True,num_workers=4
    )

    return train_iter,test_iter

def get_cur_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']


def learning_curve(record_train,record_test=None):
    plt.style.use('ggplot')

    plt.plot(range(1,len(record_train)+1),record_train,label='train acc')
    if record_test is not None:
        plt.plot(range(1,len(record_test)+1),record_test,label="test acc")

    plt.legend(loc=4)
    plt.title("learning curve")
    plt.xticks(range(0,len(record_train)+1,5))
    plt.yticks(range(0,101,5))
    plt.xlabel("epoch")
    plt.ylabel("accuracy")

    plt.show()

'''
model.train()
在使用pytorch构建神经网络的时候,训练过程中会在程序上方添加一句model.train()
作用是启动batch.normalize和dropout

model.eval()
测试过程中会使用model.eval(),这时神经网络会沿用batch normalization的值,并不使用dropou
'''
def train(model,train_iter,criterion,optimizer,device,num_print,lr_scheduler=None):

    model.train()


    total,correct,train_loss=0,0,0
    start=time.time()

    for i,(inputs,labels) in enumerate(train_iter):
        inputs,labels=inputs.to(device),labels.to(device)

        output=model(inputs)
        # print(inputs.shape)
        loss=criterion(output,labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss+=loss.item()
        total+=labels.size(0)
        correct+=torch.eq(output.argmax(dim=1),labels).sum().item()

        train_acc=100*correct/total
        # print(train_acc)


        if (i + 1) % num_print == 0:
            print("step: [{}/{}], train_loss: {:.3f} | train_acc: {:6.3f}% | lr: {:.6f}" \
                .format(i + 1, len(train_iter), train_loss / (i + 1), \
                train_acc, get_cur_lr(optimizer)))


    if lr_scheduler is not None:
        lr_scheduler.step()

    print("-----cost time:{:.4f}s----".format(time.time()-start))

    # if test_iter is not None:
    #     record_test.append(test(model,test_iter,criterion,device))


    return train_acc


def test(model, test_iter, criterion, device,test_num):
    j=0
    total,correct=0,0
    caoyao_list = ['鱼腥草', '牛蒡子', '夏枯草', '何首乌', '洋金花', '射干', '金芥麦', '京大戟', '车前草', '千金子',
                   '决明子', '红花', '紫苏', '白勺', '薄荷', '当归', '瞿麦', '香加皮', '垂盆草', '萱草'
                   ]

    model.eval()


    with torch.no_grad():
        print("*************************test***************************")

        for inputs,labels in test_iter:
            inputs,labels=inputs.to(device),labels.to(device)

            output=model(inputs)
            loss=criterion(output,labels)

            total+=labels.size(0)
            # print("labels.shape",labels.shape,labels.size(0))
            correct+=torch.eq(output.argmax(dim=1),labels).sum().item()


    test_acc=100.0*correct/total
    print("test_loss:{:.3} | test_acc:{:6.3f}%"\
          .format(loss.item(),test_acc)
          )


    print("*************************************************************")
    # model.train()


    return  test_acc


from model.VggNet import *
from model.VGG11 import *
from model.ResNet18 import *

batch_size=14
num_epochs=30
num_class=20
learning_rate=0.001
momentum=0.9
weight_decay=0.0005
num_print=40
test_num=0
device="cuda" if torch.cuda.is_available() else "cpu"
def main():
	#这里需要更改为自己的网络模型
    model=RestNet18_Net().to(device)

    train_iter,test_iter=load_dataset(batch_size)

    criterion=nn.CrossEntropyLoss()
    optimizer=optim.SGD(
        model.parameters(),
        lr=learning_rate,
        momentum=momentum,
        weight_decay=weight_decay,
        nesterov=True

    )

    lr_scheduler=optim.lr_scheduler.StepLR(optimizer,step_size=8,gamma=0.1)

    train_acc=list()
    test_acc=list()
    test_num=0
    for epoch in range(num_epochs):
        test_num+=1
        print('=================epoch:[{}/{}]======================'.format(epoch+1,num_epochs))
        record_train=train(model,train_iter,criterion,optimizer,device,num_print,lr_scheduler)
        record_test=test(model,test_iter,criterion,device,test_num)

        train_acc.append(record_train)
        test_acc.append(record_test)


    print("Finished Training")
	#保存训练好的模型
    torch.save(model, '../save_model/ResNet18/1.pth')
    torch.save(model.state_dict(), '../save_model/ResNet18/1_params.pth')
    
    learning_curve(train_acc,test_acc)


if __name__ == '__main__':
    main()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188

如果这段代码看不懂可以看:https://blog.csdn.net/m0_50127633/article/details/117045008,在这里我有比较详细的注释。

第四阶段:模型加载并进行测试

import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image




def pridict():

    device="cuda" if torch.cuda.is_available() else "cpu"

    path='../save_model/ResNet18/1.pth'

    model = torch.load(path)
    model=model.to(device)

    model.eval()

    img=Image.open('/home/hsy/PycharmProjects/数据集/5月下旬/当归/danggui_49.jpg')
    transform = transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.226, 0.224, 0.225])
                                    ])

    img = img.convert("RGB")  # 如果是标准的RGB格式,则可以不加
    img = transform(img)
    img = img.unsqueeze(0)
    img = img.to(device)

    with torch.no_grad():
        py = model(img)
    '''
    torch.max()这个函数返回的是两个值,第一个值是具体的value(我们用下划线_表示),第二个值是value所在的index
    下划线_ 表示的就是具体的value,也就是输出的最大值。
    数字1其实可以写为dim=1,这里简写为1,python也可以自动识别,dim=1表示输出所在行的最大值
    '''
    _,predicted = torch.max(py, 1)  # 获取分类结果
    #预测结果的标签
    classIndex = predicted.item()


    print"预测结果",classIndex)


if __name__ == '__main__':
    pridict()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47

在这里插入图片描述
这是根据我自己的数据集进行写的,如果你要训练自己数据的话需要进行改写,欢迎指出不足。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/628517
推荐阅读
相关标签
  

闽ICP备14008679号