当前位置:   article > 正文

python模型训练_python训练模型

python训练模型

目录

1、新建模型   train_model.py

2、运行模型

(1)首先会下载data文件库

(2)完成之后会开始训练模型(10次)

3、 训练好之后,进入命令集

 4、输入命令:python -m tensorboard.main --logdir="C:\Users\15535\Desktop\day6\train"

(1)目录的绝对路径获得方法

 5、打开网页可视化图形

(1)运行完之后会自动有一个网址,点进去

 (2)显示


1、新建模型   train_model.py

  1. import torch
  2. import torchvision.transforms
  3. from torch.utils.tensorboard import SummaryWriter
  4. from torchvision import datasets
  5. from torch.utils.data import DataLoader
  6. import torch.nn as nn
  7. from torch.nn import CrossEntropyLoss
  8. #step1.下载数据集
  9. train_data=datasets.CIFAR10('./data',train=True,\
  10. transform=torchvision.transforms.ToTensor(),
  11. download=True)
  12. test_data=datasets.CIFAR10('./data',train=False,\
  13. transform=torchvision.transforms.ToTensor(),
  14. download=True)
  15. print(len(train_data))
  16. print(len(test_data))
  17. #step2.数据集打包
  18. train_data_loader=DataLoader(train_data,batch_size=64,shuffle=False)
  19. test_data_loader=DataLoader(test_data,batch_size=64,shuffle=False)
  20. #step3.搭建网络模型
  21. class My_Module(nn.Module):
  22. def __init__(self):
  23. super(My_Module,self).__init__()
  24. #64*32*32*32
  25. self.conv1=nn.Conv2d(in_channels=3,out_channels=32,\
  26. kernel_size=5,padding=2)
  27. #64*32*16*16
  28. self.maxpool1=nn.MaxPool2d(2)
  29. #64*32*16*16
  30. self.conv2=nn.Conv2d(in_channels=32,out_channels=32,\
  31. kernel_size=5,padding=2)
  32. #64*32*8*8
  33. self.maxpool2=nn.MaxPool2d(2)
  34. #64*64*8*8
  35. self.conv3=nn.Conv2d(in_channels=32,out_channels=64,\
  36. kernel_size=5,padding=2)
  37. #64*64*4*4
  38. self.maxpool3=nn.MaxPool2d(2)
  39. #线性化
  40. self.flatten=nn.Flatten()
  41. self.linear1=nn.Linear(in_features=1024,out_features=64)
  42. self.linear2=nn.Linear(in_features=64,out_features=10)
  43. def forward(self,input):
  44. #input:64,3,32,32
  45. output1=self.conv1(input)
  46. output2=self.maxpool1(output1)
  47. output3=self.conv2(output2)
  48. output4=self.maxpool2(output3)
  49. output5=self.conv3(output4)
  50. output6=self.maxpool3(output5)
  51. output7=self.flatten(output6)
  52. output8=self.linear1(output7)
  53. output9=self.linear2(output8)
  54. return output9
  55. my_model=My_Module()
  56. # print(my_model)
  57. loss_func=CrossEntropyLoss()#衡量模型训练的过程(输入输出之间的差值)
  58. #优化器,lr越大模型就越“聪明”
  59. optim = torch.optim.SGD(my_model.parameters(),lr=0.001)
  60. writer=SummaryWriter('./train')
  61. #################################训练###############################
  62. for looptime in range(10): #模型训练的次数:10
  63. print("------looptime:{}------".format(looptime+1))
  64. num=0
  65. loss_all=0
  66. for data in (train_data_loader):
  67. num+=1
  68. #前向
  69. imgs, targets = data
  70. output = my_model(imgs)
  71. loss_train = loss_func(output,targets)
  72. loss_all=loss_all+loss_train
  73. if num%100==0:
  74. print(loss_train)
  75. #后向backward 三步法 获取最小的损失函数
  76. optim.zero_grad()
  77. loss_train.backward()
  78. optim.step()
  79. # print(output.shape)
  80. loss_av=loss_all/len(test_data_loader)
  81. print(loss_av)
  82. writer.add_scalar('train_loss',loss_av,looptime)
  83. writer.close()
  84. #################################验证#########################
  85. with torch.no_grad():
  86. accuracy=0
  87. test_loss_all=0
  88. for data in test_data_loader:
  89. imgs,targets = data
  90. output = my_model(imgs)
  91. loss_test = loss_func(output,targets)
  92. #output.argmax(1)---输出标签
  93. accuracy=(output.argmax(1)==targets).sum()
  94. test_loss_all = test_loss_all+loss_test
  95. test_loss_av = test_loss_all/len(test_data_loader)
  96. acc_av = accuracy/len(test_data_loader)
  97. print("测试集的平均损失{},测试集的准确率{}".format(test_loss_av,acc_av))
  98. writer.add_scalar('test_loss',test_loss_av,looptime)
  99. writer.add_scalar('acc',acc_av,looptime)
  100. writer.close()

2、运行模型

(1)首先会下载data文件库

(2)完成之后会开始训练模型(10次)

3、 训练好之后,进入命令集

 4、输入命令:python -m tensorboard.main --logdir="C:\Users\15535\Desktop\day6\train"

(1)目录的绝对路径获得方法

执行下面的操作自动复制

 

 

 5、打开网页可视化图形

(1)运行完之后会自动有一个网址,点进去

 (2)显示

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/运维做开发/article/detail/761311
推荐阅读
相关标签
  

闽ICP备14008679号