赞
踩
前几个Pytorch学习博客写了使用Pytorch的数据读取、数据增强、数据加载、模型定义,当完成上面几个步骤,就可以进行模型训练了。
使用Pytorch进行模型训练,通常可以将train
过程写成一个函数,简单的train
写法常规的传入参数如下:
DataLoader
model
criterion
optimizer
较为简单的train
函数可以写为如下:
def train(DataLoader, model, criterion, optimizer): model.cuda() # 指定为train模式 model.train() for i, (img, target) in tqdm(enumerate(DataLoader)): img = img.cuda() target = target.cuda() # 计算网络输出 output = model(img) # 计算损失 loss = criterion(output, target) # 计算梯度和做反向传播 optimizer.zero_grad() loss.backward() optimizer.step()
那么,一个较为完整的使用Pytorch训练分类任务pipeline可以简单的表示如下:
1. 定义数据加载 Dataset = torchvision.Dataset(root, transform) 2. 定义模型 model = torchvision.models.xxxx(num_class) 3. 定义数据加载器 DataLoader = torch.util.data.DataLoader(Dataset, batch_size, num_workers) 4. 模型训练 # 定义优化器 optimizer = # 定义损失函数 criterion = # 定义学习率调整 scheduler = for i in range(epoch_number): # 根据epoch调整学习率 scheduler.step() # 调用训练函数 train(train_loader, model, criterion, optimizer) # 模型保存 torch.save(model.state_dict(), path)
注:以上只是对于使用Pytorch中的API快速做分类任务训练的一个大框架Pipeline的简单伪代码展示,实际编写code中还有其它的一些util函数,比如计算准确率,训练到一定阶段进行验证集评估之类。。。
仅供参考!!
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。