赞
踩
import os import random #把训练集和测试集分为8:2 train_ratio = 0.8 test_ratio = 1 - train_ratio rootdata = '/home/hsy/PycharmProjects/数据集/5月下旬' train_list, test_list = [], [] data_list = [] #图片的标签 class_flag = -1 ''' 要取得该文件夹下的所有文件,可以使用 for(root,dirs,files) in walk(roots)函数 roots:代表需要便利的根文件夹 root: 表示正在遍历的文件夹的名字 dirs:记录正在遍历的文件夹中的文件 ''' for root, dirs, files in os.walk(rootdata): for i in range(len(files)): ''' os.path.join()函数:连接两个或者更多的路径名组价你 1.如果各组件首字母不包含'/',则函数会自动加上 2.如果一个组件是一个绝对路径,则在它之前的所有组件均会被舍弃 3.如果最后一个组件为空,则成一个路径以一个'/'分隔符结尾 root='/home/hsy/PycharmProjects/数据集/5月下旬/train/鱼腥草' files[i]='yuxingcao_1.jpg' os.path.join(root,files[i])='/home/hsy/PycharmProjects/数据集/5月下旬/train/鱼腥草/yuxingcao_1.jpg' ''' data_list.append(os.path.join(root, files[i])) for i in range(0, int(len(files) * train_ratio)): train_data = os.path.join(root, files[i]) + '\t' + str(class_flag) + '\n' train_list.append(train_data) for i in range(int(len(files) * train_ratio), len(files)): test_data = os.path.join(root, files[i]) + '\t' + str(class_flag) + '\n' test_list.append(test_data) class_flag += 1 # print(train_list) # 将数据打乱 random.shuffle(train_list) random.shuffle(test_list) # 保存到txt with open('../data/train.txt', 'w', encoding='UTF-8') as f: for train in train_list: f.write(train) with open('../data/test.txt', 'w', encoding='UTF-8') as f: for test in test_list: f.write(test) print(test_list)
train.txt
/home/hsy/PycharmProjects/数据集/5月下旬/瞿麦/qumai_109.jpg 16
/home/hsy/PycharmProjects/数据集/5月下旬/洋金花/yangjinhua_33.jpg 4
/home/hsy/PycharmProjects/数据集/5月下旬/萱草/xuancao_1.jpg 19
/home/hsy/PycharmProjects/数据集/5月下旬/紫苏/zisu_137.jpg 12
/home/hsy/PycharmProjects/数据集/5月下旬/香加皮/xiangjiapi_50.jpg 17
/home/hsy/PycharmProjects/数据集/5月下旬/紫苏/zisu_117.jpg 12
/home/hsy/PycharmProjects/数据集/5月下旬/紫苏/zisu_136.jpg 12
/home/hsy/PycharmProjects/数据集/5月下旬/洋金花/yangjinhua_28.jpg 4
/home/hsy/PycharmProjects/数据集/5月下旬/金芥麦/jinjiemai_107.jpg 6
/home/hsy/PycharmProjects/数据集/5月下旬/何首乌/heshouwu_42.jpg 3
.......
test.txt
/home/hsy/PycharmProjects/数据集/5月下旬/垂盆草/chuipencao_7.jpg 18
/home/hsy/PycharmProjects/数据集/5月下旬/夏枯草/xiakucao_124.jpg 2
/home/hsy/PycharmProjects/数据集/5月下旬/车前草/cheqiancao_106.jpg 8
/home/hsy/PycharmProjects/数据集/5月下旬/京大戟/jingdaji_39.jpg 7
/home/hsy/PycharmProjects/数据集/5月下旬/射干/shegan_76.jpg 5
/home/hsy/PycharmProjects/数据集/5月下旬/夏枯草/xiakucao_151.jpg 2
/home/hsy/PycharmProjects/数据集/5月下旬/牛蒡子/niubangzi_184.jpg 1
/home/hsy/PycharmProjects/数据集/5月下旬/决明子/juemingzi_100.jpg 10
/home/hsy/PycharmProjects/数据集/5月下旬/瞿麦/qumai_23.jpg 16
/home/hsy/PycharmProjects/数据集/5月下旬/紫苏/zisu_105.jpg 12
/home/hsy/PycharmProjects/数据集/5月下旬/决明子/juemingzi_92.jpg 10
/home/hsy/PycharmProjects/数据集/5月下旬/鱼腥草/yuxingcao_45.jpg 0
/home/hsy/PycharmProjects/数据集/5月下旬/紫苏/zisu_24.jpg 12
/home/hsy/PycharmProjects/数据集/5月下旬/金芥麦/jinjiemai_98.jpg 6
.......
import torch from PIL import Image import os from torch.utils.data import DataLoader,Dataset import matplotlib.pyplot as plt from torchvision import transforms,utils,datasets import numpy as np #图像标准化 # transform_BN=transforms.Normalize((0.485,0.456,0.406),(0.226,0.224,0.225)) class LoadData(Dataset): def __init__(self,txt_path,train_flag=True): self.imgs_info=self.get_imags(txt_path) self.train_flag=train_flag self.transform_train=transforms.Compose([ # #随机水平翻转 # transforms.RandomHorizontalFlip(), # #随机垂直翻转 # transforms.RandomVerticalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225]) ]) self.transform_test=transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225]) ]) def get_imags(self, txt_path): with open(txt_path,'r',encoding='UTF-8') as f: imgs_info=f.readlines() imgs_info=list(map(lambda x:x.strip().split('\t'),imgs_info)) return imgs_info def __getitem__(self, index): img_path,label=self.imgs_info[index] img=Image.open(img_path) img=img.convert("RGB") if self.train_flag: img=self.transform_train(img) else: img=self.transform_test(img) label=int(label) #返回打开的图片和它的标签 return img,label def __len__(self): return len(self.imgs_info)
from torch import optim from torch.utils.data import DataLoader from matplotlib import pyplot as plt import time from data.CreateDataloader import LoadData def load_dataset(batch_size): train_set=LoadData("../data/train.txt",True) test_set=LoadData("../data/test.txt",False) train_iter=torch.utils.data.DataLoader( dataset=train_set,batch_size=batch_size,shuffle=True,num_workers=4 ) test_iter=torch.utils.data.DataLoader( dataset=test_set,batch_size=batch_size,shuffle=True,num_workers=4 ) return train_iter,test_iter def get_cur_lr(optimizer): for param_group in optimizer.param_groups: return param_group['lr'] def learning_curve(record_train,record_test=None): plt.style.use('ggplot') plt.plot(range(1,len(record_train)+1),record_train,label='train acc') if record_test is not None: plt.plot(range(1,len(record_test)+1),record_test,label="test acc") plt.legend(loc=4) plt.title("learning curve") plt.xticks(range(0,len(record_train)+1,5)) plt.yticks(range(0,101,5)) plt.xlabel("epoch") plt.ylabel("accuracy") plt.show() ''' model.train() 在使用pytorch构建神经网络的时候,训练过程中会在程序上方添加一句model.train() 作用是启动batch.normalize和dropout model.eval() 测试过程中会使用model.eval(),这时神经网络会沿用batch normalization的值,并不使用dropou ''' def train(model,train_iter,criterion,optimizer,device,num_print,lr_scheduler=None): model.train() total,correct,train_loss=0,0,0 start=time.time() for i,(inputs,labels) in enumerate(train_iter): inputs,labels=inputs.to(device),labels.to(device) output=model(inputs) # print(inputs.shape) loss=criterion(output,labels) optimizer.zero_grad() loss.backward() optimizer.step() train_loss+=loss.item() total+=labels.size(0) correct+=torch.eq(output.argmax(dim=1),labels).sum().item() train_acc=100*correct/total # print(train_acc) if (i + 1) % num_print == 0: print("step: [{}/{}], train_loss: {:.3f} | train_acc: {:6.3f}% | lr: {:.6f}" \ .format(i + 1, len(train_iter), train_loss / (i + 1), \ train_acc, get_cur_lr(optimizer))) if lr_scheduler is not None: lr_scheduler.step() print("-----cost time:{:.4f}s----".format(time.time()-start)) # if test_iter is not None: # record_test.append(test(model,test_iter,criterion,device)) return train_acc def test(model, test_iter, criterion, device,test_num): j=0 total,correct=0,0 caoyao_list = ['鱼腥草', '牛蒡子', '夏枯草', '何首乌', '洋金花', '射干', '金芥麦', '京大戟', '车前草', '千金子', '决明子', '红花', '紫苏', '白勺', '薄荷', '当归', '瞿麦', '香加皮', '垂盆草', '萱草' ] model.eval() with torch.no_grad(): print("*************************test***************************") for inputs,labels in test_iter: inputs,labels=inputs.to(device),labels.to(device) output=model(inputs) loss=criterion(output,labels) total+=labels.size(0) # print("labels.shape",labels.shape,labels.size(0)) correct+=torch.eq(output.argmax(dim=1),labels).sum().item() test_acc=100.0*correct/total print("test_loss:{:.3} | test_acc:{:6.3f}%"\ .format(loss.item(),test_acc) ) print("*************************************************************") # model.train() return test_acc from model.VggNet import * from model.VGG11 import * from model.ResNet18 import * batch_size=14 num_epochs=30 num_class=20 learning_rate=0.001 momentum=0.9 weight_decay=0.0005 num_print=40 test_num=0 device="cuda" if torch.cuda.is_available() else "cpu" def main(): #这里需要更改为自己的网络模型 model=RestNet18_Net().to(device) train_iter,test_iter=load_dataset(batch_size) criterion=nn.CrossEntropyLoss() optimizer=optim.SGD( model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay, nesterov=True ) lr_scheduler=optim.lr_scheduler.StepLR(optimizer,step_size=8,gamma=0.1) train_acc=list() test_acc=list() test_num=0 for epoch in range(num_epochs): test_num+=1 print('=================epoch:[{}/{}]======================'.format(epoch+1,num_epochs)) record_train=train(model,train_iter,criterion,optimizer,device,num_print,lr_scheduler) record_test=test(model,test_iter,criterion,device,test_num) train_acc.append(record_train) test_acc.append(record_test) print("Finished Training") #保存训练好的模型 torch.save(model, '../save_model/ResNet18/1.pth') torch.save(model.state_dict(), '../save_model/ResNet18/1_params.pth') learning_curve(train_acc,test_acc) if __name__ == '__main__': main()
如果这段代码看不懂可以看:https://blog.csdn.net/m0_50127633/article/details/117045008,在这里我有比较详细的注释。
import torch import torchvision import torchvision.transforms as transforms from PIL import Image def pridict(): device="cuda" if torch.cuda.is_available() else "cpu" path='../save_model/ResNet18/1.pth' model = torch.load(path) model=model.to(device) model.eval() img=Image.open('/home/hsy/PycharmProjects/数据集/5月下旬/当归/danggui_49.jpg') transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.226, 0.224, 0.225]) ]) img = img.convert("RGB") # 如果是标准的RGB格式,则可以不加 img = transform(img) img = img.unsqueeze(0) img = img.to(device) with torch.no_grad(): py = model(img) ''' torch.max()这个函数返回的是两个值,第一个值是具体的value(我们用下划线_表示),第二个值是value所在的index 下划线_ 表示的就是具体的value,也就是输出的最大值。 数字1其实可以写为dim=1,这里简写为1,python也可以自动识别,dim=1表示输出所在行的最大值 ''' _,predicted = torch.max(py, 1) # 获取分类结果 #预测结果的标签 classIndex = predicted.item() print("预测结果",classIndex) if __name__ == '__main__': pridict()
这是根据我自己的数据集进行写的,如果你要训练自己数据的话需要进行改写,欢迎指出不足。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。