当前位置:   article > 正文

pyTorch入门(五)——训练自己的数据集_pytorch训练自己的数据集

pytorch训练自己的数据集

学更好的别人,

做更好的自己。

——《微卡智享》

419c279068cd36a2704020162893d3c7.jpeg

本文长度为1749,预计阅读5分钟

前言

前面四篇将Minist数据集的训练及OpenCV的推理都介绍完了,在实际应用项目中,往往需要用自己的数据集进行训练,所以本篇就专门介绍一下pyTorch怎么训练自己的数据集。

57242f16008f15a727e298f868261a8b.png

微卡智享

生成自己的训练图片

上一篇《pyTorch入门(四)——导出Minist模型,C++ OpenCV DNN进行识别》中使用VS Studio实现了OpenCV的推理,介绍过在推理前需要将图片进行预处理,包括灰度、二值化,查找及排序轮廓都已经处理了,所以只要对上面的代码进行改造一下,将提取的信息保存出来,就是我们想要训练的数据了。先上源码:

  1. #pragma once
  2. #include<iostream>
  3. #include<chrono>
  4. #include<time.h>
  5. #include<opencv2/opencv.hpp>
  6. #include<opencv2/dnn/dnn.hpp>
  7. using namespace cv;
  8. using namespace std;
  9. //参数iType 0-提取图片保存 1-使用DNN推理
  10. int iType = 1;
  11. dnn::Net net;
  12. //排序矩形
  13. void SortRect(vector<Rect>& inputrects) {
  14. for (int i = 0; i < inputrects.size(); ++i) {
  15. for (int j = i; j < inputrects.size(); ++j) {
  16. //说明顺序在上方,这里不用变
  17. if (inputrects[i].y + inputrects[i].height < inputrects[i].y) {
  18. }
  19. //同一排
  20. else if (inputrects[i].y <= inputrects[j].y + inputrects[j].height) {
  21. if (inputrects[i].x > inputrects[j].x) {
  22. swap(inputrects[i], inputrects[j]);
  23. }
  24. }
  25. //下一排
  26. else if (inputrects[i].y > inputrects[j].y + inputrects[j].height) {
  27. swap(inputrects[i], inputrects[j]);
  28. }
  29. }
  30. }
  31. }
  32. //处理DNN检测的MINIST图像,防止长方形图像直接转为28*28扁了
  33. void DealInputMat(Mat& src, int row = 28, int col = 28, int tmppadding = 5) {
  34. int w = src.cols;
  35. int h = src.rows;
  36. //看图像的宽高对比,进行处理,先用padding填充黑色,保证图像接近正方形,这样缩放28*28比例不会失衡
  37. if (w > h) {
  38. int tmptopbottompadding = (w - h) / 2 + tmppadding;
  39. copyMakeBorder(src, src, tmptopbottompadding, tmptopbottompadding, tmppadding, tmppadding,
  40. BORDER_CONSTANT, Scalar(0));
  41. }
  42. else {
  43. int tmpleftrightpadding = (h - w) / 2 + tmppadding;
  44. copyMakeBorder(src, src, tmppadding, tmppadding, tmpleftrightpadding, tmpleftrightpadding,
  45. BORDER_CONSTANT, Scalar(0));
  46. }
  47. resize(src, src, Size(row, col));
  48. }
  49. // 获取当时系统时间
  50. const string GetCurrentSystemTime()
  51. {
  52. auto t = chrono::system_clock::to_time_t(std::chrono::system_clock::now());
  53. struct tm ptm { 60, 59, 23, 31, 11, 1900, 6, 365, -1 };
  54. _localtime64_s(&ptm, &t);
  55. char date[60] = { 0 };
  56. sprintf_s(date, "%d%02d%02d%02d%02d%02d",
  57. (int)ptm.tm_year + 1900, (int)ptm.tm_mon + 1, (int)ptm.tm_mday,
  58. (int)ptm.tm_hour, (int)ptm.tm_min, (int)ptm.tm_sec);
  59. return move(std::string(date));
  60. }
  61. int main(int argc, char** argv) {
  62. //定义onnx文件
  63. string onnxfile = "D:/Business/DemoTEST/CPP/OpenCVMinistDNN/torchminist/ResNet.onnx";
  64. //测试图片文件
  65. string testfile = "D:/Business/DemoTEST/CPP/OpenCVMinistDNN/testpic/test3.png";
  66. //提取的图片保存位置
  67. string savefile = "D:/Business/DemoTEST/CPP/OpenCVMinistDNN/findcontoursMat";
  68. if (iType == 1) {
  69. net = dnn::readNetFromONNX(onnxfile);
  70. if (net.empty()) {
  71. cout << "加载Onnx文件失败!" << endl;
  72. return -1;
  73. }
  74. }
  75. //读取图片,灰度,高斯模糊
  76. Mat src = imread(testfile);
  77. //备份源图
  78. Mat backsrc;
  79. src.copyTo(backsrc);
  80. cvtColor(src, src, COLOR_BGR2GRAY);
  81. GaussianBlur(src, src, Size(3, 3), 0.5, 0.5);
  82. //二值化图片,注意用THRESH_BINARY_INV改为黑底白字,对应MINIST
  83. threshold(src, src, 0, 255, THRESH_BINARY_INV | THRESH_OTSU);
  84. //做彭账处理,防止手写的数字没有连起来,这里做了3次膨胀处理
  85. Mat kernel = getStructuringElement(MORPH_RECT, Size(3, 3));
  86. //加入开运算先去燥点
  87. morphologyEx(src, src, MORPH_OPEN, kernel, Point(-1, -1));
  88. morphologyEx(src, src, MORPH_DILATE, kernel, Point(-1, -1), 3);
  89. imshow("src", src);
  90. vector<vector<Point>> contours;
  91. vector<Vec4i> hierarchy;
  92. vector<Rect> rects;
  93. //查找轮廓
  94. findContours(src, contours, hierarchy, RETR_EXTERNAL, CHAIN_APPROX_NONE);
  95. for (int i = 0; i < contours.size(); ++i) {
  96. RotatedRect rect = minAreaRect(contours[i]);
  97. Rect outrect = rect.boundingRect();
  98. //插入到矩形列表中
  99. rects.push_back(outrect);
  100. }
  101. //按从左到右,从上到下排序
  102. SortRect(rects);
  103. //要输出的图像参数
  104. for (int i = 0; i < rects.size(); ++i) {
  105. Mat tmpsrc = src(rects[i]);
  106. DealInputMat(tmpsrc);
  107. if (iType == 1) {
  108. //Mat inputBlob = dnn::blobFromImage(tmpsrc, 0.3081, Size(28, 28), Scalar(0.1307), false, false);
  109. Mat inputBlob = dnn::blobFromImage(tmpsrc, 1, Size(28, 28), Scalar(), false, false);
  110. //输入参数值
  111. net.setInput(inputBlob, "input");
  112. //预测结果
  113. Mat output = net.forward("output");
  114. //查找出结果中推理的最大值
  115. Point maxLoc;
  116. minMaxLoc(output, NULL, NULL, NULL, &maxLoc);
  117. cout << "预测值:" << maxLoc.x << endl;
  118. //画出截取图像位置,并显示识别的数字
  119. rectangle(backsrc, rects[i], Scalar(255, 0, 255));
  120. putText(backsrc, to_string(maxLoc.x), Point(rects[i].x, rects[i].y), FONT_HERSHEY_PLAIN, 5, Scalar(255, 0, 255), 1, -1);
  121. }
  122. else {
  123. string filename = savefile + "/" + GetCurrentSystemTime() + "-" + to_string(i) + ".jpg";
  124. cout << filename << endl;
  125. imwrite(filename, tmpsrc);
  126. }
  127. }
  128. imshow("backsrc", backsrc);
  129. waitKey(0);
  130. return 0;
  131. }

划重点

696b35749fc2d08092aaea17c12bcf3e.png

加了一个参数,设置的时候0为提取保存的图片,1是上一篇的推理。

9b73b9524203bf40b821f0230a283166.png

增加了一个获取当前时间的函数,主要作用就是保存图片的时候在文件名加上时间。

a06bdbdd0839f168cd0eb8e90323b9e3.png

增加了一个保存图片的位置

634a67e4e50c2621215c694fa2b13718.png

根据上面的参数,设置为1时还是原来的DNN推理,0时通过imwrite将图片进行保存。

832dc656540202bd25a5f88deea31459.png

接下来我们自己做点数据集,用画图工具在上面写上数字,将0--9的数字分别做了10张图出来。

9663be5306a98fa4547d066a5315b660.png

6e3194a29168ced8937551dfed883c7a.png

79f8857319913e02716ae435ad402ca6.png

c6e6bd5458548371d1f7c98cbd0e86bd.png

运行的效果如下:

fa0b279295cfed140faa2d03a12cecfb.png

可以看出上图中我们将数字9的图片分开截取并保存到指定的目录了。

db8689d46c2755d651e916d14e92e4b1.png

同时在Dataset下创建mydata目录,并创建出train训练的目录,在目录下创建了0-9的文件夹,这样做的目录是在pyTorch调用时会直接根据train下不同的文件夹目录设置对应的label标签了,不用我们在每个进行对照,相应的,提取出的数字图片也要放到对应的目录中

4ce412583b43388e6f667e7fe32ea462.png

将刚才生成的数字9的图片都剪切到9的文件夹下,其余的数字也是用同样方式。

5cc7864cdaeb7777d1deca672abf98ab.png

test测试集也用相同的方式处理,只不过我们拷过来后删了一大部分,就做别的处理。做完这些,提取图片的准备工作就完成了,接下来就是通过pyTorch训练。

3e1769d38670b88898e73c7cf0744cb1.png

微卡智享

pyTorch训练自己数据集

420c8516d113837d9c52dde07fee8ab4.png

新建了一个trainmydata.py的文件,训练的流程其实和原来差不多,只不过我们是在原来的基础上进行再训练,所以这些的模型是先加载原来的训练模型后,再进行训练,还是先上代码

  1. import torch
  2. import time
  3. from torchvision import datasets
  4. from torch.utils.data import DataLoader
  5. from torchvision import transforms
  6. import torch.optim as optim
  7. import matplotlib.pyplot as plt
  8. from pylab import mpl
  9. import trainModel as tm
  10. ##训练轮数
  11. epoch_times = 15
  12. ##设置初始预测率,用于判断高于当前预测率的保存模型
  13. toppredicted = 0.0
  14. ##设置学习率
  15. learnrate = 0.01
  16. ##设置动量值,如果上一次的momentnum与本次梯度方向是相同的,梯度下降幅度会拉大,起到加速迭代的作用
  17. momentnum = 0.5
  18. ##自己训练的模型前面加个my
  19. savemodel_name = "my" + tm.savemodel_name
  20. ##生成图用的数组
  21. ##预测值
  22. predict_list = []
  23. ##训练轮次值
  24. epoch_list = []
  25. ##loss值
  26. loss_list = []
  27. transform = transforms.Compose([
  28. transforms.Grayscale(num_output_channels=1),
  29. transforms.ToTensor(),
  30. transforms.Normalize(mean=(0.1307,), std=(0.3081,))
  31. ]) ##Normalize 里面两个值0.1307是均值mean, 0.3081是标准差std,计算好的直接用了
  32. ##训练数据集位置
  33. train_mydata = datasets.ImageFolder(
  34. root = '../datasets/mydata/train',
  35. transform = transform
  36. )
  37. train_mydataloader = DataLoader(train_mydata, batch_size=64, shuffle=True, num_workers=0)
  38. ##测试数据集位置
  39. test_mydata = datasets.ImageFolder(
  40. root = '../datasets/mydata/test',
  41. transform = transform
  42. )
  43. test_mydataloader = DataLoader(test_mydata, batch_size=1, shuffle=True, num_workers=0)
  44. ##加载已经训练好的模型
  45. model = tm.Net(tm.train_name)
  46. model.load_state_dict(torch.load(tm.savemodel_name))
  47. ##加入判断是CPU训练还是GPU训练
  48. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  49. model.to(device)
  50. ##优化器
  51. optimizer = optim.SGD(model.parameters(), lr= learnrate, momentum= momentnum)
  52. ##训练函数
  53. def train(epoch):
  54. model.train()
  55. for batch_idx, data in enumerate(train_mydataloader, 0):
  56. inputs, target = data
  57. ##加入CPU和GPU选择
  58. inputs, target = inputs.to(device), target.to(device)
  59. optimizer.zero_grad()
  60. #前馈,反向传播,更新
  61. outputs = model(inputs)
  62. loss = model.criterion(outputs, target)
  63. loss.backward()
  64. optimizer.step()
  65. loss_list.append(loss.item())
  66. print("progress:", epoch, 'loss=', loss.item())
  67. def test():
  68. correct = 0
  69. total = 0
  70. model.eval()
  71. ##with这里标记是不再计算梯度
  72. with torch.no_grad():
  73. for data in test_mydataloader:
  74. inputs, labels = data
  75. ##加入CPU和GPU选择
  76. inputs, labels = inputs.to(device), labels.to(device)
  77. outputs = model(inputs)
  78. ##预测返回的是两列,第一列是下标就是0-9的值,第二列为预测值,下面的dim=1就是找维度1(第二列)最大值输出
  79. _, predicted = torch.max(outputs.data, dim=1)
  80. total += labels.size(0)
  81. correct += (predicted == labels).sum().item()
  82. currentpredicted = (100 * correct / total)
  83. ##用global声明toppredicted,用于在函数内部修改在函数外部声明的全局变量,否则报错
  84. global toppredicted
  85. ##当预测率大于原来的保存模型
  86. if currentpredicted > toppredicted:
  87. toppredicted = currentpredicted
  88. torch.save(model.state_dict(), savemodel_name)
  89. print(savemodel_name+" saved, currentpredicted:%d %%" % currentpredicted)
  90. predict_list.append(currentpredicted)
  91. print('Accuracy on test set: %d %%' % currentpredicted)
  92. ##开始训练
  93. timestart = time.time()
  94. for epoch in range(epoch_times):
  95. train(epoch)
  96. test()
  97. timeend = time.time() - timestart
  98. print("use time: {:.0f}m {:.0f}s".format(timeend // 60, timeend % 60))
  99. ##设置画布显示中文字体
  100. mpl.rcParams["font.sans-serif"] = ["SimHei"]
  101. ##设置正常显示符号
  102. mpl.rcParams["axes.unicode_minus"] = False
  103. ##创建画布
  104. fig, (axloss, axpredict) = plt.subplots(nrows=1, ncols=2, figsize=(8,6))
  105. #loss画布
  106. axloss.plot(range(epoch_times), loss_list, label = 'loss', color='r')
  107. ##设置刻度
  108. axloss.set_xticks(range(epoch_times)[::1])
  109. axloss.set_xticklabels(range(epoch_times)[::1])
  110. axloss.set_xlabel('训练轮数')
  111. axloss.set_ylabel('数值')
  112. axloss.set_title(tm.train_name+' 损失值')
  113. #添加图例
  114. axloss.legend(loc = 0)
  115. #predict画布
  116. axpredict.plot(range(epoch_times), predict_list, label = 'predict', color='g')
  117. ##设置刻度
  118. axpredict.set_xticks(range(epoch_times)[::1])
  119. axpredict.set_xticklabels(range(epoch_times)[::1])
  120. # axpredict.set_yticks(range(100)[::5])
  121. # axpredict.set_yticklabels(range(100)[::5])
  122. axpredict.set_xlabel('训练轮数')
  123. axpredict.set_ylabel('预测值')
  124. axpredict.set_title(tm.train_name+' 预测值')
  125. #添加图例
  126. axpredict.legend(loc = 0)
  127. #显示图像
  128. plt.show()

划重点

22cb244fdcf4fa43784ef000cdc6e357.png

自己训练的模型文件前面加上一个my,用于不覆盖原来的训练模型。

加载训练集和测试集

e9f3aebe4640a7592188864523459b18.png

在transform中,增加了一行transforms.Grayscale(num_output_channels=1),主要原因是在OpenCV中使用imwrite保存的文件,虽然是二值化的图片,但是是3通道的,而在pyTorch我们的训练数据都是1X28X28,即是单通道的图像,所以这里加上这一句是将读取的图片设置为单通道。

使用datasets.ImageFolder直接读取train目录下的数据,自动将图像及对应的标签加载进来了。

加载已训练的模型

71ae4d6479603740bce6094f709f95c2.png

这里的model模型直接通过load_state_dict加载进来,然后再训练自己的数据,下面的训练方式和原来train都一样了。

e00c31aa1abff71781568e86146dd6ac.png

4aa80b9bb402d525f1e20c6a8e52c462.png

因为我这边保存的数据很少,而且测试集的图片和训练集的一样,只训练了15轮,所以训练到第3轮的时候已经就到100%了。简单的训练自己的数据集就完成了。

f7d7a0741f48d3512b9cf069c4706a1d.png

922088850941b926acc627227ca65eee.png

往期精彩回顾

 

d8f73635e4ed1490e4abaca17a1c26c2.jpeg

pyTorch入门(四)——导出Minist模型,C++ OpenCV DNN进行识别

 

 

8a9bd2f215a3d2935dbf77425f29f973.jpeg

pyTorch入门(三)——GoogleNet和ResNet训练

 

 

1ed72c8a20a8e0f7295b23e5f0478060.jpeg

pyTorch入门(二)——常用网络层函数及卷积神经网络训练

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/262465
推荐阅读
相关标签
  

闽ICP备14008679号