当前位置:   article > 正文

Jittor MNIST图片识别模型训练+测试_如何用jittor识别文字

如何用jittor识别文字

1、先创建一个model.py

  1. import jittor as jt
  2. from jittor import nn, Module
  3. import numpy as np
  4. import sys, os
  5. import random
  6. import math
  7. from jittor import init
  8. class Model (Module):
  9. def __init__ (self):
  10. super (Model, self).__init__()
  11. self.conv1 = nn.Conv (3, 32, 3, 1) # no padding
  12. self.conv2 = nn.Conv (32, 64, 3, 1)
  13. self.bn = nn.BatchNorm(64)
  14. self.max_pool = nn.Pool (2, 2)
  15. self.relu = nn.Relu()
  16. self.fc1 = nn.Linear (64 * 12 * 12, 256)
  17. self.fc2 = nn.Linear (256, 10)
  18. def execute (self, x) :
  19. x = self.conv1 (x)
  20. x = self.relu (x)
  21. x = self.conv2 (x)
  22. x = self.bn (x)
  23. x = self.relu (x)
  24. x = self.max_pool (x)
  25. x = jt.reshape (x, [x.shape[0], -1])
  26. x = self.fc1 (x)
  27. x = self.relu(x)
  28. x = self.fc2 (x)
  29. return x

2、创建train.py来进行训练、保存模型

  1. import jittor as jt
  2. from jittor import nn, Module
  3. import numpy as np
  4. import sys, os
  5. import random
  6. import math
  7. from jittor import init
  8. from model import Model
  9. from jittor.dataset.mnist import MNIST
  10. import jittor.transform as trans
  11. import pylab as pl
  12. jt.flags.use_cuda = 0 # if jt.flags.use_cuda = 1 will use gpu
  13. def train(model, train_loader, optimizer, epoch, losses, losses_idx):
  14. model.train()
  15. lens = len(train_loader)
  16. for batch_idx, (inputs, targets) in enumerate(train_loader):
  17. outputs = model(inputs)
  18. loss = nn.cross_entropy_loss(outputs, targets)
  19. optimizer.step (loss)
  20. losses.append(loss.data[0])
  21. losses_idx.append(epoch * lens + batch_idx)
  22. if batch_idx % 10 == 0:
  23. print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
  24. epoch, batch_idx, len(train_loader),
  25. 100. * batch_idx / len(train_loader), loss.data[0]))
  26. def test(model, val_loader, epoch):
  27. model.eval()
  28. test_loss = 0
  29. correct = 0
  30. total_acc = 0
  31. total_num = 0
  32. for batch_idx, (inputs, targets) in enumerate(val_loader):
  33. batch_size = inputs.shape[0]
  34. outputs = model(inputs)
  35. pred = np.argmax(outputs.data, axis=1)
  36. acc = np.sum(targets.data==pred)
  37. total_acc += acc
  38. total_num += batch_size
  39. acc = acc / batch_size
  40. print('Test Epoch: {} [{}/{} ({:.0f}%)]\tAcc: {:.6f}'.format(epoch, \
  41. batch_idx, len(val_loader),100. * float(batch_idx) / len(val_loader), acc))
  42. print ('Total test acc =', total_acc / total_num)
  43. def main ():
  44. batch_size = 64
  45. learning_rate = 0.1
  46. momentum = 0.9
  47. weight_decay = 1e-4
  48. epochs = 5
  49. losses = []
  50. losses_idx = []
  51. train_loader = MNIST(train=True, transform=trans.Resize(28)).set_attrs(batch_size=batch_size, shuffle=True)
  52. val_loader = MNIST(train=False, transform=trans.Resize(28)) .set_attrs(batch_size=1, shuffle=False)
  53. model = Model ()
  54. optimizer = nn.SGD(model.parameters(), learning_rate, momentum, weight_decay)
  55. for epoch in range(epochs):
  56. train(model, train_loader, optimizer, epoch, losses, losses_idx)
  57. test(model, val_loader, epoch)
  58. pl.plot(losses_idx, losses)
  59. pl.xlabel('Iterations')
  60. pl.ylabel('Train_loss')
  61. pl.show()
  62. model_path = '/home/root/Python_Demo/JittorMNISTImageClassification/mnist_model.pkl'
  63. model.save(model_path)
  64. if __name__ == '__main__':
  65. main()

3、运行train.py文件结果如图(testdata为自己创建的测试数据文件夹、mnist_model.pkl为模型文件)

 4、创建test.py,加载本地图片对模型进行测试。

  1. from datetime import date
  2. from matplotlib import pyplot as plt
  3. import jittor as jt
  4. from numpy.core.fromnumeric import shape
  5. from numpy.lib.type_check import imag
  6. from model import Model
  7. from jittor.dataset.mnist import MNIST
  8. import jittor.transform as trans
  9. import numpy as np
  10. import cv2
  11. import os
  12. from PIL import Image
  13. import matplotlib.pyplot as plt
  14. """
  15. #加载MNIST模型库中的图片
  16. model_path = '/home/lizhi528/Python_Demo/JittorMNISTImageClassification/mnist_model.pkl'
  17. new_model = Model()
  18. new_model.load_parameters(jt.load(model_path))
  19. val_loader = MNIST(train=False, transform=trans.Resize(28)).set_attrs(batch_size=1, shuffle=False)
  20. data_iter = iter(val_loader)
  21. val_data, val_label = next(data_iter)#
  22. outputs = new_model(val_data)
  23. prediction = np.argmax(outputs.data, axis=1)
  24. print(val_label.data)
  25. print(prediction)
  26. """
  27. def ImageClassification(imagPath,model):
  28. img_path=imagPath
  29. # 得到一个 HxWx3 的 array(224, 225, 3)
  30. image = cv2.imread(img_path)
  31. cv2.imshow("img",image)
  32. cv2.waitKey(0)
  33. # 把图像缩放到 28x28 个像素(28, 28, 3)
  34. image = cv2.resize(image, (28, 28))
  35. print(image.shape)
  36. image = image / 255.0 # 把图像的 RGB 值从 [0, 255] 变为 [0, 1]
  37. image = image.transpose(2, 0, 1) # 把输入格式从 HWC 改为 CHW
  38. image = jt.float32(image) # 变为 Jittor Var
  39. image = image.unsqueeze(dim=0) # 加入 batch 维度,变为 [1, C, H, W]
  40. outputs = model(image)
  41. prediction = np.argmax(outputs.data, axis=1)
  42. print(prediction)
  43. """
  44. img = Image.open(imagPath)
  45. plt.figure("Image")
  46. plt.imshow(img)
  47. plt.show()
  48. """
  49. def main():
  50. #加载模型
  51. model_path = '/home/lizhi528/Python_Demo/JittorMNISTImageClassification/mnist_model.pkl'
  52. model = Model()
  53. model.load_parameters(jt.load(model_path))
  54. #加载本地图片
  55. img_path='/home/lizhi528/Python_Demo/JittorMNISTImageClassification/testdata/0.jpg'
  56. ImageClassification(img_path,model)
  57. if __name__ == '__main__':
  58. main()

5、测试结果如图:

 6、用cv2进行图片展示只能选择停留>0的时间,否则关闭弹窗程序无法集训运行,而用plt对图片展示就没有这个问题,关闭弹窗程序继续运行。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/263969
推荐阅读
相关标签
  

闽ICP备14008679号