当前位置:   article > 正文

Pytorch学习(六) --- 模型训练的常规train函数flow及其配置_pytorch框架下train如何书写

pytorch框架下train如何书写

前几个Pytorch学习博客写了使用Pytorch的数据读取、数据增强、数据加载、模型定义,当完成上面几个步骤,就可以进行模型训练了。

使用Pytorch进行模型训练,通常可以将train过程写成一个函数,简单的train写法常规的传入参数如下:

  • 数据加载器DataLoader
  • 目标模型model
  • 损失函数criterion
  • 优化器optimizer

较为简单的train函数可以写为如下:

def train(DataLoader, model, criterion, optimizer):
	model.cuda()
	# 指定为train模式
	model.train()

	for i, (img, target) in tqdm(enumerate(DataLoader)):
		img = img.cuda()
		target = target.cuda()
		# 计算网络输出
		output = model(img)
		
		# 计算损失
		loss = criterion(output, target)
		
		
		# 计算梯度和做反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

那么,一个较为完整的使用Pytorch训练分类任务pipeline可以简单的表示如下:

1. 定义数据加载
Dataset = torchvision.Dataset(root, transform)

2. 定义模型
model = torchvision.models.xxxx(num_class)

3. 定义数据加载器
DataLoader = torch.util.data.DataLoader(Dataset, batch_size, num_workers)

4. 模型训练

# 定义优化器
optimizer = 
# 定义损失函数
criterion = 
# 定义学习率调整
scheduler = 
for i in range(epoch_number):
	# 根据epoch调整学习率
	scheduler.step()
	# 调用训练函数
	train(train_loader, model, criterion, optimizer)

	# 模型保存
	torch.save(model.state_dict(), path)
	
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

注:以上只是对于使用Pytorch中的API快速做分类任务训练的一个大框架Pipeline的简单伪代码展示,实际编写code中还有其它的一些util函数,比如计算准确率,训练到一定阶段进行验证集评估之类。。。
仅供参考!!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/101800
推荐阅读
相关标签
  

闽ICP备14008679号