当前位置:   article > 正文

DataLoader与Dataset_if (epoch + 1) % val_interval == 0:

if (epoch + 1) % val_interval == 0:

1、DataLoader

torch.utils.data.DataLoader

功能:构建可迭代的数据装载器

  • dataset: Dataset类,决定数据从哪读取及如何读取
  • batchsize : 批大小
  • num _works: 是否多进程读取数据(设置多线程,但是一般机器如果没有更多GPU,会出错,尽量设置num_works=0)
  • shuffle: 每个epo ch是否乱序(random.shuffle有什么区别呢?)
  • drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据

 Epoch:所有训练样本都已经输入到模型中,称为一个Epoch

Iteration:一批样本输入到模型中,称之为一个Iteration

Batchsize:批大小,决定一个Epoch有多少个Iteration

样本总数:80,Batchsize=8,1 Epoch = 10 Iteration

当样本总数为:87,Batchsize=8,一般默认都是drop_last=False

drop_last=True1 Epoch = 10 Iteration------>正确的!!!多余的7个被丢弃
drop_last=False1 Epoch = 10 Iteration------>错误的!!!

 

 

 

2、Dataset

torch.utils.data.Dataset

功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写:__getitem__()

__getitem__():接收一个索引,返回一个样本

3、人民币二分类

数据集主要分为:一元和100元:每一个类单独文件夹存放!

程序主要按照下面步骤进行:

1、数据集处理

对数据集进行划分:原始图像每个类100张

train80
valid10
test10
  1. import torch
  2. import os
  3. # shutil:高级的 文件、文件夹、压缩包 处理模块
  4. import shutil
  5. import random
  6. # 显示当前路径
  7. # BASE_DIR = os.path.dirname(os.path.abspath('__file__'))
  8. # print(BASE_DIR)
  9. # 创建新的目录
  10. def makedir(new_dir):
  11. if not os.path.exists(new_dir):
  12. os.makedirs(new_dir)
  13. # 当模块被直接运行时,以下代码块将被运行,当模块是被导入时,代码块不被运行
  14. if __name__ == '__main__':
  15. dataset_dir = os.path.join("data","RMB_data")
  16. split_dir = os.path.join("data","rmb_split")
  17. # 将每个类别进行划分Train/valid/test三个部分
  18. train_dir = os.path.join(split_dir,"train")
  19. valid_dir = os.path.join(split_dir,"valid")
  20. test_dir = os.path.join(split_dir,"test")
  21. # 判断目录是否存在,不要也可以
  22. # if not os.path.exists(dataset_dir):
  23. # raise Exception("\n{} 不存在重新下载放到 {}下,并解压即可".format(
  24. # dataset_dir, os.path.dirname(dataset_dir)))
  25. # 数据集划分比例
  26. train_pct = 0.8
  27. valid_pct = 0.1
  28. test_pct = 0.1
  29. for root,dirs,files in os.walk(dataset_dir):
  30. for sub_dir in dirs:
  31. imgs = os.listdir(os.path.join(root,sub_dir))
  32. imgs = list(filter(lambda x: x.endswith('.jpg'),imgs))
  33. random.shuffle(imgs)
  34. img_count = len(imgs)
  35. train_point = int(img_count*train_pct)
  36. valid_point = int(img_count*(train_pct + valid_pct))
  37. for i in range(img_count):
  38. if i < train_point:
  39. out_dir = os.path.join(train_dir,sub_dir)
  40. elif i < valid_point:
  41. out_dir = os.path.join(valid_dir,sub_dir)
  42. else:
  43. out_dir = os.path.join(test_dir,sub_dir)
  44. makedir(out_dir)
  45. target_path = os.path.join(out_dir,imgs[i])
  46. src_path = os.path.join(dataset_dir,sub_dir,imgs[i])
  47. # 复制文件从源文件到目标文件
  48. shutil.copy(src_path,target_path)
  49. print('Class:{}, train:{}, valid:{}, test:{}'.format(sub_dir, train_point, valid_point-train_point,
  50. img_count-valid_point))
  51. print("已在 {} 创建划分好的数据\n".format(out_dir))
Class:1, train:80, valid:10, test:10
已在 data\rmb_split\test\1 创建划分好的数据

Class:100, train:80, valid:10, test:10
已在 data\rmb_split\test\100 创建划分好的数据

2、开始训练:没有对数据集进行增强操作!!!

  1. import os
  2. # BASE_DIR = os.path.dirname(os.path.abspath('__file__'))/
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. from torch.utils.data import DataLoader
  7. import torchvision.transforms as transforms
  8. import torch.optim as optim
  9. from matplotlib import pyplot as plt
  10. # lenet存放在model文件夹中
  11. path_lenet = os.path.join("model","lenet.py")
  12. # print(path_lenet)
  13. # common_tools.py存放路径
  14. path_tools = os.path.join("tools","common_tools.py")
  15. # print(path_tools)
  16. # 下面是进行判断的操作
  17. # 不添加也可以
  18. # assert os.path.exists(path_lenet), "{}不存在,请将lenet.py文件放到 {}".format(path_lenet, os.path.dirname(path_lenet))
  19. # assert os.path.exists(path_tools), "{}不存在,请将common_tools.py文件放到 {}".format(path_tools, os.path.dirname(path_tools))
  20. # import sys
  21. # hello_pytorch_DIR = os.path.abspath(os.path.dirname('__file__')+os.path.sep+".."+os.path.sep+"..")
  22. # print(hello_pytorch_DIR)
  23. # sys.path.append(hello_pytorch_DIR)
  24. from model.lenet import LeNet
  25. from tools.my_dataset import RMBDataset
  26. from tools.common_tools import set_seed
  27. set_seed()
  28. rmb_label = {"1":0,"100":1}
  29. # 参数设置
  30. MAX_EPOCH = 10
  31. BATCH_SIZE = 16
  32. LR = 0.01
  33. log_interval = 10
  34. val_interval = 1
  35. # 简单的拼接地址
  36. split_dir = os.path.join("data","rmb_split")
  37. # print(split_dir)
  38. # 训练集地址
  39. train_dir = os.path.join(split_dir,"train")
  40. # 验证集地址
  41. valid_dir = os.path.join(split_dir,"valid")
  42. # print(train_dir)
  43. # print(valid_dir)
  44. # 数据增强
  45. norm_mean = [0.485,0.456,0.406]
  46. norm_std = [0.229,0.224,0.225]
  47. train_transform = transforms.Compose([
  48. transforms.Resize((32,32)),
  49. transforms.RandomCrop(32,padding=4),
  50. transforms.ToTensor(),
  51. transforms.Normalize(norm_mean,norm_std),
  52. ])
  53. valid_transform = transforms.Compose([
  54. transforms.Resize((32,32)),
  55. transforms.ToTensor(),
  56. transforms.Normalize(norm_mean,norm_std),
  57. ])
  58. # 构建MyDataset实例
  59. train_data = RMBDataset(data_dir=train_dir,transform=train_transform)
  60. valid_data = RMBDataset(data_dir=valid_dir,transform=valid_transform)
  61. # 构建DataLoader
  62. train_loader = DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)
  63. valid_loader = DataLoader(dataset=valid_data,batch_size=BATCH_SIZE)
  64. # 模型
  65. net = LeNet(classes=2)
  66. # 初始化权重参数
  67. net.initialize_weights()
  68. # 损失函数
  69. criterion = nn.CrossEntropyLoss()
  70. # 优化器
  71. optimizer = optim.SGD(net.parameters(),lr=LR,momentum=0.9)
  72. # 设置学习率下降策略
  73. scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=10,gamma=0.1)
  74. train_curve = list()
  75. valid_curve = list()
  76. for epoch in range(MAX_EPOCH):
  77. loss_mean = 0.
  78. correct = 0.
  79. total = 0.
  80. net.train()
  81. for i,data in enumerate(train_loader):
  82. #forward
  83. inputs,labels = data
  84. outputs = net(inputs)
  85. # backward
  86. optimizer.zero_grad()
  87. loss = criterion(outputs,labels)
  88. loss.backward()
  89. # update权重
  90. optimizer.step()
  91. # 统计分类情况
  92. _,predicted = torch.max(outputs.data,1)
  93. total += labels.size(0)
  94. correct += (predicted==labels).squeeze().sum().numpy()
  95. # 打印训练信息
  96. loss_mean += loss.item()
  97. train_curve.append(loss.item())
  98. if (i+1) % log_interval == 0:
  99. loss_mean = loss_mean / log_interval
  100. print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
  101. epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
  102. loss_mean = 0.
  103. # 更新学习率
  104. scheduler.step()
  105. # 验证模型
  106. if (epoch+1) % val_interval == 0:
  107. correct_val = 0.
  108. total_val = 0.
  109. loss_val = 0.
  110. net.eval()
  111. # 测试的话就不需要对梯度进行更新了
  112. with torch.no_grad():
  113. for j,data in enumerate(valid_loader):
  114. inputs,labels = data
  115. outputs = net(inputs)
  116. loss = criterion(outputs,labels)
  117. _,predicted = torch.max(outputs.data,1)
  118. total_val += labels.size(0)
  119. correct_val += (predicted==labels).squeeze().sum().numpy()
  120. loss_val += loss.item()
  121. loss_val_epoch = loss_val / len(valid_loader)
  122. valid_curve.append(loss_val_epoch)
  123. print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
  124. epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_epoch, correct_val / total_val))
  125. train_x = range(len(train_curve))
  126. train_y = train_curve
  127. train_iters = len(train_loader)
  128. valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
  129. valid_y = valid_curve
  130. plt.plot(train_x, train_y, label='Train')
  131. plt.plot(valid_x, valid_y, label='Valid')
  132. plt.legend(loc='upper right')
  133. plt.ylabel('loss value')
  134. plt.xlabel('Iteration')
  135. plt.show()
Training:Epoch[000/010] Iteration[010/013] Loss: 0.6107 Acc:69.38%
Valid:	 Epoch[000/010] Iteration[003/003] Loss: 0.7187 Acc:57.89%
Training:Epoch[001/010] Iteration[010/013] Loss: 0.7202 Acc:65.00%
Valid:	 Epoch[001/010] Iteration[003/003] Loss: 0.3015 Acc:100.00%
Training:Epoch[002/010] Iteration[010/013] Loss: 0.1356 Acc:100.00%
Valid:	 Epoch[002/010] Iteration[003/003] Loss: 0.0002 Acc:100.00%
Training:Epoch[003/010] Iteration[010/013] Loss: 0.0214 Acc:99.38%
Valid:	 Epoch[003/010] Iteration[003/003] Loss: 0.0001 Acc:100.00%
Training:Epoch[004/010] Iteration[010/013] Loss: 0.0008 Acc:100.00%
Valid:	 Epoch[004/010] Iteration[003/003] Loss: 0.0000 Acc:100.00%
Training:Epoch[005/010] Iteration[010/013] Loss: 0.0000 Acc:100.00%
Valid:	 Epoch[005/010] Iteration[003/003] Loss: 0.0000 Acc:100.00%
Training:Epoch[006/010] Iteration[010/013] Loss: 0.0001 Acc:100.00%
Valid:	 Epoch[006/010] Iteration[003/003] Loss: 0.0000 Acc:100.00%
Training:Epoch[007/010] Iteration[010/013] Loss: 0.0001 Acc:100.00%
Valid:	 Epoch[007/010] Iteration[003/003] Loss: 0.0000 Acc:100.00%
Training:Epoch[008/010] Iteration[010/013] Loss: 0.0001 Acc:100.00%
Valid:	 Epoch[008/010] Iteration[003/003] Loss: 0.0000 Acc:100.00%
Training:Epoch[009/010] Iteration[010/013] Loss: 0.0000 Acc:100.00%
Valid:	 Epoch[009/010] Iteration[003/003] Loss: 0.0000 Acc:100.00%

 

3、测试训练的模型 

  1. # 使用一张图像进行测试
  2. # BASE_DIR = os.path.dirname(os.path.abspath('__file__'))
  3. test_dir = os.path.join("test_data")
  4. test_data = RMBDataset(data_dir=test_dir, transform=valid_transform)
  5. valid_loader = DataLoader(dataset=test_data, batch_size=1)
  6. # 进行测试
  7. for i, data in enumerate(valid_loader):
  8. # forward
  9. inputs, labels = data
  10. outputs = net(inputs)
  11. _, predicted = torch.max(outputs.data, 1)
  12. rmb = 1 if predicted.numpy()[0] == 0 else 100
  13. print("模型获得{}元".format(rmb))

 结果:模型获得100元

附加文件:

(1)、lenet.py存放再model文件夹中

  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. class LeNet(nn.Module):
  4. def __init__(self, classes):
  5. super(LeNet, self).__init__()
  6. self.conv1 = nn.Conv2d(3, 6, 5)
  7. self.conv2 = nn.Conv2d(6, 16, 5)
  8. self.fc1 = nn.Linear(16*5*5, 120)
  9. self.fc2 = nn.Linear(120, 84)
  10. self.fc3 = nn.Linear(84, classes)
  11. def forward(self, x):
  12. out = F.relu(self.conv1(x))
  13. out = F.max_pool2d(out, 2)
  14. out = F.relu(self.conv2(out))
  15. out = F.max_pool2d(out, 2)
  16. out = out.view(out.size(0), -1)
  17. out = F.relu(self.fc1(out))
  18. out = F.relu(self.fc2(out))
  19. out = self.fc3(out)
  20. return out
  21. def initialize_weights(self):
  22. for m in self.modules():
  23. if isinstance(m, nn.Conv2d):
  24. nn.init.xavier_normal_(m.weight.data)
  25. if m.bias is not None:
  26. m.bias.data.zero_()
  27. elif isinstance(m, nn.BatchNorm2d):
  28. m.weight.data.fill_(1)
  29. m.bias.data.zero_()
  30. elif isinstance(m, nn.Linear):
  31. nn.init.normal_(m.weight.data, 0, 0.1)
  32. m.bias.data.zero_()
  33. class LeNet2(nn.Module):
  34. def __init__(self, classes):
  35. super(LeNet2, self).__init__()
  36. self.features = nn.Sequential(
  37. nn.Conv2d(3, 6, 5),
  38. nn.ReLU(),
  39. nn.MaxPool2d(2, 2),
  40. nn.Conv2d(6, 16, 5),
  41. nn.ReLU(),
  42. nn.MaxPool2d(2, 2)
  43. )
  44. self.classifier = nn.Sequential(
  45. nn.Linear(16*5*5, 120),
  46. nn.ReLU(),
  47. nn.Linear(120, 84),
  48. nn.ReLU(),
  49. nn.Linear(84, classes)
  50. )
  51. def forward(self, x):
  52. x = self.features(x)
  53. x = x.view(x.size()[0], -1)
  54. x = self.classifier(x)
  55. return x
  56. class LeNet_bn(nn.Module):
  57. def __init__(self, classes):
  58. super(LeNet_bn, self).__init__()
  59. self.conv1 = nn.Conv2d(3, 6, 5)
  60. self.bn1 = nn.BatchNorm2d(num_features=6)
  61. self.conv2 = nn.Conv2d(6, 16, 5)
  62. self.bn2 = nn.BatchNorm2d(num_features=16)
  63. self.fc1 = nn.Linear(16 * 5 * 5, 120)
  64. self.bn3 = nn.BatchNorm1d(num_features=120)
  65. self.fc2 = nn.Linear(120, 84)
  66. self.fc3 = nn.Linear(84, classes)
  67. def forward(self, x):
  68. out = self.conv1(x)
  69. out = self.bn1(out)
  70. out = F.relu(out)
  71. out = F.max_pool2d(out, 2)
  72. out = self.conv2(out)
  73. out = self.bn2(out)
  74. out = F.relu(out)
  75. out = F.max_pool2d(out, 2)
  76. out = out.view(out.size(0), -1)
  77. out = self.fc1(out)
  78. out = self.bn3(out)
  79. out = F.relu(out)
  80. out = F.relu(self.fc2(out))
  81. out = self.fc3(out)
  82. return out
  83. def initialize_weights(self):
  84. for m in self.modules():
  85. if isinstance(m, nn.Conv2d):
  86. nn.init.xavier_normal_(m.weight.data)
  87. if m.bias is not None:
  88. m.bias.data.zero_()
  89. elif isinstance(m, nn.BatchNorm2d):
  90. m.weight.data.fill_(1)
  91. m.bias.data.zero_()
  92. elif isinstance(m, nn.Linear):
  93. nn.init.normal_(m.weight.data, 0, 1)
  94. m.bias.data.zero_()

(2)、common_tool.py存放在tools中

  1. import torch
  2. import random
  3. import psutil
  4. import numpy as np
  5. from PIL import Image
  6. import torchvision.transforms as transforms
  7. def transform_invert(img_, transform_train):
  8. """
  9. 将data 进行反transfrom操作
  10. :param img_: tensor
  11. :param transform_train: torchvision.transforms
  12. :return: PIL image
  13. """
  14. if 'Normalize' in str(transform_train):
  15. norm_transform = list(filter(lambda x: isinstance(x, transforms.Normalize), transform_train.transforms))
  16. mean = torch.tensor(norm_transform[0].mean, dtype=img_.dtype, device=img_.device)
  17. std = torch.tensor(norm_transform[0].std, dtype=img_.dtype, device=img_.device)
  18. img_.mul_(std[:, None, None]).add_(mean[:, None, None])
  19. img_ = img_.transpose(0, 2).transpose(0, 1) # C*H*W --> H*W*C
  20. if 'ToTensor' in str(transform_train):
  21. img_ = img_.detach().numpy() * 255
  22. if img_.shape[2] == 3:
  23. img_ = Image.fromarray(img_.astype('uint8')).convert('RGB')
  24. elif img_.shape[2] == 1:
  25. img_ = Image.fromarray(img_.astype('uint8').squeeze())
  26. else:
  27. raise Exception("Invalid img shape, expected 1 or 3 in axis 2, but got {}!".format(img_.shape[2]) )
  28. return img_
  29. def set_seed(seed=1):
  30. random.seed(seed)
  31. np.random.seed(seed)
  32. torch.manual_seed(seed)
  33. torch.cuda.manual_seed(seed)
  34. def get_memory_info():
  35. virtual_memory = psutil.virtual_memory()
  36. used_memory = virtual_memory.used/1024/1024/1024
  37. free_memory = virtual_memory.free/1024/1024/1024
  38. memory_percent = virtual_memory.percent
  39. memory_info = "Usage Memory:{:.2f} G,Percentage: {:.1f}%,Free Memory:{:.2f} G".format(
  40. used_memory, memory_percent, free_memory)
  41. return memory_info

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/326375
推荐阅读
相关标签
  

闽ICP备14008679号