赞
踩
1、先创建一个model.py
- import jittor as jt
- from jittor import nn, Module
- import numpy as np
- import sys, os
- import random
- import math
- from jittor import init
-
- class Model (Module):
- def __init__ (self):
- super (Model, self).__init__()
- self.conv1 = nn.Conv (3, 32, 3, 1) # no padding
- self.conv2 = nn.Conv (32, 64, 3, 1)
- self.bn = nn.BatchNorm(64)
-
- self.max_pool = nn.Pool (2, 2)
- self.relu = nn.Relu()
- self.fc1 = nn.Linear (64 * 12 * 12, 256)
- self.fc2 = nn.Linear (256, 10)
- def execute (self, x) :
- x = self.conv1 (x)
- x = self.relu (x)
-
- x = self.conv2 (x)
- x = self.bn (x)
- x = self.relu (x)
-
- x = self.max_pool (x)
- x = jt.reshape (x, [x.shape[0], -1])
- x = self.fc1 (x)
- x = self.relu(x)
- x = self.fc2 (x)
- return x
![](https://csdnimg.cn/release/blogv2/dist/pc/img/newCodeMoreWhite.png)
2、创建train.py来进行训练、保存模型
- import jittor as jt
- from jittor import nn, Module
- import numpy as np
- import sys, os
- import random
- import math
- from jittor import init
- from model import Model
- from jittor.dataset.mnist import MNIST
- import jittor.transform as trans
- import pylab as pl
-
- jt.flags.use_cuda = 0 # if jt.flags.use_cuda = 1 will use gpu
-
- def train(model, train_loader, optimizer, epoch, losses, losses_idx):
- model.train()
- lens = len(train_loader)
- for batch_idx, (inputs, targets) in enumerate(train_loader):
- outputs = model(inputs)
- loss = nn.cross_entropy_loss(outputs, targets)
- optimizer.step (loss)
- losses.append(loss.data[0])
- losses_idx.append(epoch * lens + batch_idx)
- if batch_idx % 10 == 0:
- print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
- epoch, batch_idx, len(train_loader),
- 100. * batch_idx / len(train_loader), loss.data[0]))
-
- def test(model, val_loader, epoch):
- model.eval()
- test_loss = 0
- correct = 0
- total_acc = 0
- total_num = 0
- for batch_idx, (inputs, targets) in enumerate(val_loader):
- batch_size = inputs.shape[0]
- outputs = model(inputs)
- pred = np.argmax(outputs.data, axis=1)
- acc = np.sum(targets.data==pred)
- total_acc += acc
- total_num += batch_size
- acc = acc / batch_size
- print('Test Epoch: {} [{}/{} ({:.0f}%)]\tAcc: {:.6f}'.format(epoch, \
- batch_idx, len(val_loader),100. * float(batch_idx) / len(val_loader), acc))
- print ('Total test acc =', total_acc / total_num)
-
- def main ():
- batch_size = 64
- learning_rate = 0.1
- momentum = 0.9
- weight_decay = 1e-4
- epochs = 5
- losses = []
- losses_idx = []
- train_loader = MNIST(train=True, transform=trans.Resize(28)).set_attrs(batch_size=batch_size, shuffle=True)
- val_loader = MNIST(train=False, transform=trans.Resize(28)) .set_attrs(batch_size=1, shuffle=False)
- model = Model ()
- optimizer = nn.SGD(model.parameters(), learning_rate, momentum, weight_decay)
- for epoch in range(epochs):
- train(model, train_loader, optimizer, epoch, losses, losses_idx)
- test(model, val_loader, epoch)
-
- pl.plot(losses_idx, losses)
- pl.xlabel('Iterations')
- pl.ylabel('Train_loss')
- pl.show()
-
- model_path = '/home/root/Python_Demo/JittorMNISTImageClassification/mnist_model.pkl'
- model.save(model_path)
-
- if __name__ == '__main__':
- main()
![](https://csdnimg.cn/release/blogv2/dist/pc/img/newCodeMoreWhite.png)
3、运行train.py文件结果如图(testdata为自己创建的测试数据文件夹、mnist_model.pkl为模型文件)
4、创建test.py,加载本地图片对模型进行测试。
- from datetime import date
- from matplotlib import pyplot as plt
- import jittor as jt
- from numpy.core.fromnumeric import shape
- from numpy.lib.type_check import imag
- from model import Model
- from jittor.dataset.mnist import MNIST
- import jittor.transform as trans
- import numpy as np
- import cv2
- import os
- from PIL import Image
- import matplotlib.pyplot as plt
-
- """
- #加载MNIST模型库中的图片
- model_path = '/home/lizhi528/Python_Demo/JittorMNISTImageClassification/mnist_model.pkl'
- new_model = Model()
- new_model.load_parameters(jt.load(model_path))
- val_loader = MNIST(train=False, transform=trans.Resize(28)).set_attrs(batch_size=1, shuffle=False)
- data_iter = iter(val_loader)
- val_data, val_label = next(data_iter)#
- outputs = new_model(val_data)
- prediction = np.argmax(outputs.data, axis=1)
- print(val_label.data)
- print(prediction)
- """
-
-
- def ImageClassification(imagPath,model):
-
- img_path=imagPath
- # 得到一个 HxWx3 的 array(224, 225, 3)
- image = cv2.imread(img_path)
- cv2.imshow("img",image)
- cv2.waitKey(0)
- # 把图像缩放到 28x28 个像素(28, 28, 3)
- image = cv2.resize(image, (28, 28))
- print(image.shape)
- image = image / 255.0 # 把图像的 RGB 值从 [0, 255] 变为 [0, 1]
- image = image.transpose(2, 0, 1) # 把输入格式从 HWC 改为 CHW
- image = jt.float32(image) # 变为 Jittor Var
- image = image.unsqueeze(dim=0) # 加入 batch 维度,变为 [1, C, H, W]
- outputs = model(image)
- prediction = np.argmax(outputs.data, axis=1)
- print(prediction)
-
- """
- img = Image.open(imagPath)
- plt.figure("Image")
- plt.imshow(img)
- plt.show()
- """
-
- def main():
- #加载模型
- model_path = '/home/lizhi528/Python_Demo/JittorMNISTImageClassification/mnist_model.pkl'
- model = Model()
- model.load_parameters(jt.load(model_path))
- #加载本地图片
- img_path='/home/lizhi528/Python_Demo/JittorMNISTImageClassification/testdata/0.jpg'
- ImageClassification(img_path,model)
-
- if __name__ == '__main__':
- main()
![](https://csdnimg.cn/release/blogv2/dist/pc/img/newCodeMoreWhite.png)
5、测试结果如图:
6、用cv2进行图片展示只能选择停留>0的时间,否则关闭弹窗程序无法集训运行,而用plt对图片展示就没有这个问题,关闭弹窗程序继续运行。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。