当前位置:   article > 正文

Pytorch 多卡并行训练教程 (DDP)_pytorch多卡训练

pytorch多卡训练

Pytorch 多卡并行训练教程 (DDP)

在使用GPU训练大模型时,往往会面临单卡显存不足的情况,这时候就希望通过多卡并行的形式来扩大显存。PyTorch主要提供了两个类来实现多卡并行分别是

  • torch.nn.DataParallel(DP)
  • torch.nn.DistributedDataParallel(DDP)

关于这两者的区别和原理也有许多博客如Pytorch 并行训练(DP, DDP)的原理和应用; DDP系列第一篇:入门教程进行总结,这里就不在赘述了。不过总结来说的话:DP 比较简单,对小白比较友好,一行代码便可以搞定。DDP 每个进程对应一个独立的训练过程,且只对梯度等少量数据进行信息交换。每个进程包含独立的解释器和 GIL。

博主能力有限,很多原理上的东西看得不是特别懂,所以理解起来也比较肤浅,但是编程的时候一直没找到一套合适的蓝本,最终参考了很多网上的博客,吭哧吭哧写了一套不会报错的代码出来,下面把我个人的理解整理出来,不当之处希望大家指出,一起交流学习。后续可能会随着自己的理解的加深持续完善。
主要参考了以下一些博客:

初始化

增加参数local_rank来确定当前进程使用哪块GPU, 用于在每个进程中指定不同的device。

def parse():
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    return args

def main():
    args = parse()
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(
        'nccl',
        init_method='env://'
    )
    device = torch.device(f'cuda:{args.local_rank}')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

其中 torch.distributed.init_process_group 用于初始化GPU通信方式(NCCL)和参数的获取方式(env代表通过环境变量)。

设置随机种子点

假如model中用到了随机数种子来保证可复现性, 那么此时不能再用固定的常数作为seed, 否则会导致DDP中的所有进程都拥有一样的seed, 进而生成同态性的数据, 因此需要在程序中显示地设置随机种子点。

 # 固定随机种子点
seed = np.random.randint(1, 10000)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
  • 1
  • 2
  • 3
  • 4
  • 5

Dataloader

对于数据加载,在初始化 data loader 的时候需要使用到 torch.utils.data.distributed.DistributedSampler 这个函数:

train_dataset = ...
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True) # 这个sampler会自动分配数据到各个gpu上

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opts.batch_size, sampler=train_sampler)
  • 1
  • 2
  • 3
  • 4

通过以上的函数便可以给每个进程一个不同的 sampler,告诉每个进程自己分别取哪些数据。

在每一个epoch开始的阶段需要为sampler重新设定eopch即:

for ep in range(total_epoch):
    train_sampler.set_epoch(ep)
  • 1
  • 2

这样做的目的是:如果在DistributedSampler设置了shuffle,DistributedSampler使用当前epoch作为随机数种子,从而使得不同epoch下有不同的shuffle结果,但是在DistributedSampler源代码中默认的epoch为0,那么每次dataloader获取的shuffle都是相同的。所以,每次 epoch 开始前都需要要调用 sampler 的 set_epoch 方法,这样才能让数据集随机 shuffle 起来。

模型初始化

对于模型的处理主要包括模型初始化,将模型加载至CUDA;加载预训练权重;或利用主进程的权重 初始化所有的进程;将模型中的BN转换为SyncBN;设置模型并行。

由于 BN 层需要基于传入模型的数据计算均值和方差,造成普通 BN 在多卡模式下实际上就是单卡模式。此时需要使用 SyncBN 利用DDP的分布式计算接口来实现真正的多卡BN。

SyncBN利用分布式通讯接口在各卡间进行通讯,传输各自进程小 batch mean 和小 batch variance,在传输少量数据的基础上利用所有数据进行BN计算。

同时由于 SyncBN 用到 all_gather 这个分布式计算接口,而使用这个接口需要先初始化DDP环境,因此 SyncBN 需要在 DDP 环境初始化后初始化,但是要在 DDP 模型前就准备好。

最后由于 SyncBN 是直接搜索 model 中每个 module,如果这个 module 是 torch.nn.modules.batchnorm._BatchNorm 的子类,就将其替换为 SyncBN。因此如果你的 Normalization 层是自己定义的特殊类,没有继承过 _BatchNorm 类,那么convert_sync_batchnorm 是不支持的,需要你自己实现一个新的SyncBN!

def parse():
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument('--device', type=str, default='cuda', help='device id (i.e. 0 or 0,1 or cpu)')
    parser.add_argument('--resume', type=str, default=None, help='specified the dir of saved models for resume the training')
    args = parser.parse_args()
    return args
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
args = parse()
device = torch.device(args.device)
model = mymodel().to(device)
if args.resume:
    checkpoint = torch.load(model_save_path, map_location=device)
    model.load_state_dict(checkpoint['model'])
else:    
    save_path = 'initial_weights.pth'
    if opts.local_rank == 0:
        torch.save(model.state_dict(), save_path)
    dist.barrier()
    # 这里注意,一定要指定map_location参数,否则会导致第一块GPU占用更多资源
    model.load_state_dict(torch.load(save_path, map_location=device))

## 设置同步
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)    
## 设置模型并行
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) ## 注意要使用find_unused_parameters=True,因为有时候模型里面定义的一些模块 在forward函数里面没有调用,如果不使用find_unused_parameters=True 会报错
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

输出日志设置

在每一次需要输出或打印日志时都应该先使用opts.local_rank == 0 来判断,也就是在主进程才执行一些操作,不然日志或者打印的结果会非常混乱。

logger = None
if opts.local_rank == 0:
    log_dir = os.path.join(opts.display_dir, 'logger', opts.name)
    os.makedirs(log_dir, exist_ok=True)
    log_path = os.path.join(log_dir, 'log.txt')
    if os.path.exists(log_path):
        os.remove(log_path)
    logger = logger_config(log_path=log_path, logging_name='Timer')
    logger.info('Parameter Space: ABS: {:.1f}, REL: {:.4f}'.format(count_parameters(MPF_model), count_parameters(MPF_model) / 1024 / 1024))
    logger.info(MPF_model)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

模型保存

state = {'model':model.module.state_dict(),
         'ep':ep, 
         'total_it':total_it}
save_path = os.path.join(self.model_dir, 'model_{:0>5d}.pth'.format(ep))
torch.save(state, save_path)
  • 1
  • 2
  • 3
  • 4
  • 5

在保存模型是需要注意的是,保存的是{'model':model.module.state_dict()}, 而不是我们之前的{'model':model.state_dict()}, 因为在使用DDP后,原来的model会被封装为新的model的module属性里。

启动方式

PyTorch为提供了一个很方便的启动器 torch.distributed.lunch 用于启动文件,所以可以将运行训练代码的方式调整成下面这样:

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py
  • 1

最后附上完成了train代码和超参解析代码:

train.py

import torch.optim as optim
from create_dataset import *
from utils import *
from MPFNet_Trans_skip import MPFNet
from options import * 
from saver import Saver, resume
from time import time
from tqdm import tqdm
from optimizer import Optimizer
import datetime
import torch.distributed as dist

def main():
    # parse options    
    parser = TrainOptions()
    opts = parser.parse()
    # define model, optimiser and scheduler
    torch.cuda.set_device(opts.local_rank)
    torch.distributed.init_process_group('nccl', init_method='env://')

    # device = torch.device(f'cuda:{opts.local_rank}') #device 这样的设置可能会有问题
    
    device = torch.device(opts.gpu)
    # device = torch.device("cuda:{}".format(opts.gpu) if torch.cuda.is_available() else "cpu")
    # 固定随机种子
    seed = np.random.randint(1, 10000)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

     # define dataset    
    train_dataset = MSRSData(opts, is_train=True)
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True)
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=opts.batch_size,
        num_workers = opts.nThreads,
        sampler=train_sampler,
        pin_memory=False,
        )
    test_dataset = MSRSData(opts, is_train=False)
    test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=12,
        sampler=test_sampler,
        num_workers = opts.nThreads,
        )    
    ## 先加载dataloader 计算每个epoch的的迭代步数 然后计算总的迭代步数
    ep_iter = len(train_loader)
    max_iter = opts.n_ep * ep_iter
    
    if opts.local_rank == 0:
        print('Training iter: {}'.format(max_iter))    
    print(opts.local_rank)    
    ## 初始化模型
    MPF_model = MPFNet(opts.class_nb).to(device)
    momentum = 0.9
    weight_decay = 5e-4
    lr_start = 1e-3
    # max_iter = 150000
    power = 0.9
    warmup_steps = 1000
    warmup_start_lr = 1e-5
    optimizer = Optimizer(
            model = MPF_model,
            lr0 = lr_start,
            momentum = momentum,
            wd = weight_decay,
            warmup_steps = warmup_steps,
            warmup_start_lr = warmup_start_lr,
            max_iter = max_iter,
            power = power)
    if opts.resume:
        if opts.local_rank == 0:
            MPF_model, ep, total_it = resume(MPF_model, opts.resume, device)
            optimizer = Optimizer(
                model = MPF_model,
                lr0 = lr_start,
                momentum = momentum,
                wd = weight_decay,
                warmup_steps = warmup_steps,
                warmup_start_lr = warmup_start_lr,
                max_iter = max_iter,
                power = power, 
                it=total_it)
            lr = optimizer.get_lr()
            print('lr:{}'.format(lr))
    else: 
        model_dir = os.path.join(opts.result_dir, opts.name)
        os.makedirs(model_dir, exist_ok=True)
        save_path = os.path.join(model_dir, 'initial_weights.pth')
        if opts.local_rank == 0:
            torch.save(MPF_model.state_dict(), save_path)
        dist.barrier()
        # 这里注意,一定要指定map_location参数,否则会导致第一块GPU占用更多资源
        MPF_model.load_state_dict(torch.load(save_path, map_location=device))
        ep = -1
        total_it = 0
    ep += 1    

    MPF_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(MPF_model)    
    MPF_model = torch.nn.parallel.DistributedDataParallel(MPF_model, device_ids=[opts.local_rank], output_device=opts.local_rank, find_unused_parameters=True)
    # optimizer = optim.Adam(MPF_model.parameters(), lr=opts.lr)
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.9)
    logger = None
    if opts.local_rank == 0:
        log_dir = os.path.join(opts.display_dir, 'logger', opts.name)
        os.makedirs(log_dir, exist_ok=True)
        log_path = os.path.join(log_dir, 'log.txt')
        if os.path.exists(log_path):
            os.remove(log_path)
        logger = logger_config(log_path=log_path, logging_name='Timer')
        logger.info('Parameter Space: ABS: {:.1f}, REL: {:.4f}'.format(count_parameters(MPF_model), count_parameters(MPF_model) / 1024 / 1024))
        logger.info(MPF_model)
    
   
    # Train and evaluate multi-task network
    multi_task_trainer(train_loader,
                        train_sampler,
                        test_loader,
                        MPF_model,
                        device,
                        optimizer,
                        opts,
                        logger,
                        ep,
                        total_it)
    
def multi_task_trainer(train_loader, train_sampler, test_loader, multi_task_model, device, optimizer, opt, logger=None, start_ep=0, total_it=0):
    total_epoch = opt.n_ep
    saver = Saver(opt)    
    ## 计算分割损失相关的设计
    score_thres = 0.7
    ignore_idx = 255
    n_min = 8 * 256 * 256 // 8
    criteria = OhemCELoss(
        thresh=score_thres, n_min=n_min, device=device, ignore_lb=ignore_idx)    
    binary_class_weight = np.array([1.4548, 19.8962])    
    binary_class_weight = torch.tensor(binary_class_weight).float().to(device)    
    binary_class_weight = binary_class_weight.unsqueeze(0)
    binary_class_weight = binary_class_weight.unsqueeze(2)
    binary_class_weight = binary_class_weight.unsqueeze(2)
    
    lb_ignore = [255]
    if opt.resume:
        best_mIou = multi_task_tester(test_loader, multi_task_model, device, opt)
    else:
        best_mIou = 0.0
    if opt.local_rank == 0:
        print('best mIoU: {:.4f}'.format(best_mIou))
    start = glob_st = time()
    for ep in range(start_ep, total_epoch): ## 每一个epoch 计算一次动态权重
        train_sampler.set_epoch(ep)
        multi_task_model.train()        
        seg_metric = SegmentationMetric(opt.class_nb, device=device)   ## 这里可能会有问题       
        for it, (img_ir, img_vi, label, bi, bd, mask) in enumerate(train_loader):
            total_it += 1
            img_ir = img_ir.to(device)
            img_vi = img_vi.to(device)
            label = label.to(device)
            bi = bi.to(device).squeeze(1)
            bd = bd.to(device).squeeze(1)            
            vi_Y, vi_Cb, vi_Cr = RGB2YCrCb(img_vi)
            vi_Y = vi_Y.to(device)
            vi_Cb = vi_Cb.to(device)
            vi_Cr = vi_Cr.to(device)
            mask = mask.to(device)
            seg_pred, bi_pred, bd_pred, fused_img, re_vi, re_ir = multi_task_model(img_vi, img_ir)            
            # seg_pred = F.softmax(seg_pred, dim=1) 
            # seg_pred = multi_task_model(img_vi, img_ir)
            optimizer.zero_grad()
            seg_loss = Seg_loss(seg_pred, label, device, criteria)
            bd = F.one_hot(bd,num_classes=2)
            bd=bd.permute(0,3,1,2).float()
            bi = F.one_hot(bi,num_classes=2)
            bi= bi.permute(0,3,1,2).float()
            bd_loss = F.binary_cross_entropy_with_logits(bd_pred, bd) 
            bi_loss = F.binary_cross_entropy_with_logits(bi_pred, bi, pos_weight=binary_class_weight)
            seg_results = torch.argmax(seg_pred, dim=1, keepdim=True) ## print(seg_result.shape())
            train_seg_loss = 10 * seg_loss + 5 * bi_loss + 5 * bd_loss

            ## reconstruction-related loss            
            fusion_loss, ssim_loss, grad_loss, int_loss = Fusion_loss(img_ir, vi_Y, fused_img, mask)            
            vi_re_loss, vi_int_loss, vi_grad_loss = Re_loss(re_vi, vi_Y, mask=mask, ir_flag=False)
            ir_re_loss, ir_int_loss, ir_grad_loss = Re_loss(re_ir, img_ir, mask=mask, ir_flag=True)
            
            train_loss = 1 * train_seg_loss + 1 * fusion_loss + 0.5 * vi_re_loss + 0.5 * ir_re_loss
            train_loss.backward()
            optimizer.step()
            seg_metric.addBatch(seg_results, label, lb_ignore)
            # dist.destroy_process_group()
        if opt.local_rank == 0:            
            lr = optimizer.get_lr()
            mIoU = np.array(seg_metric.meanIntersectionOverUnion().item())
            Acc = np.array(seg_metric.pixelAccuracy().item())
            end = time()
            training_time, glob_t_intv = end - start, end - glob_st
            now_it = total_it+1
            eta = int((total_epoch * len(train_loader) - now_it) * (glob_t_intv / (now_it)))
            eta = str(datetime.timedelta(seconds=eta))
            logger.info('ep: [{}/{}], learning rate: {:.6f}, time consuming: {:.2f}s, segmentation loss: {:.4f}, fusion loss: {:.4f}, vi rec loss: {:.4f}, ir rec loss: {:.4f}'.format(ep+1, total_epoch, lr, training_time, seg_loss.item(), fusion_loss.item(), vi_re_loss.item(), ir_re_loss.item()))
            logger.info('ssim loss: [{:.4f}], grad loss: [{:.4f}], int loss: [{:.4f}], segmentation loss: {:.4f}, mIou: {:.4f}, Acc: {:.4f}, Eta: {}\n'.format(ssim_loss.item(), grad_loss.item(), int_loss.item(), seg_loss.item(), mIoU, Acc, eta))
            start = time()

        ## save Visualization results
        if (ep + 1) % opt.img_save_freq == 0 and opt.local_rank == 0:
            input = [img_ir, img_vi, fused_img, label]
            fused_rgb = YCbCr2RGB(fused_img, vi_Cb, vi_Cr)
            vi_rgb = YCbCr2RGB(re_vi, vi_Cb, vi_Cr)
            output = [re_ir, vi_rgb, fused_rgb, seg_results]
            saver.write_img(ep, input, output)
        ## save model
        if (ep + 1) % opt.model_save_freq == 0 and opt.local_rank == 0:
            test_mIoU = multi_task_tester(test_loader, multi_task_model, device, opt)            
            logger.info('test mIoU: {:.4f}, best mIoU:{:.4f}'.format(test_mIoU, best_mIou))
            if test_mIoU > best_mIou:
                best_mIou = test_mIoU
                saver.write_model(ep, total_it, multi_task_model, optimizer.optim, best_mIou, device)

def multi_task_tester(test_loader, multi_task_model, device, opts):
    multi_task_model.eval()
    test_bar= tqdm(test_loader)
    seg_metric = SegmentationMetric(opts.class_nb, device=device)
    lb_ignore = [255]
    ## define save dir
    with torch.no_grad():  # operations inside don't track history        
        for it, (img_ir, img_vi, label, img_names) in enumerate(test_bar):
            img_ir = img_ir.to(device)
            img_vi = img_vi.to(device)
            label = label.to(device)           
            Seg_pred, _, _, fused_img, re_vi, re_ir = multi_task_model(img_vi, img_ir)            
            seg_result = torch.argmax(Seg_pred, dim=1, keepdim=True) ## print(seg_result.shape())
            seg_metric.addBatch(seg_result, label, lb_ignore)        
    mIoU = np.array(seg_metric.meanIntersectionOverUnion().item())
    return mIoU
  
if __name__ == '__main__':
    main()
    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240

options.py

import argparse
class TrainOptions():
  def __init__(self):
    self.parser = argparse.ArgumentParser()

    # data loader related
    self.parser.add_argument('--dataroot', type=str, default='/data/timer/Idea/mtan/dataset/MSRS', help='path of data')
    self.parser.add_argument('--phase', type=str, default='train', help='phase for dataloading')
    self.parser.add_argument('--batch_size', type=int, default=12
    , help='batch size')
    self.parser.add_argument('--nThreads', type=int, default=16, help='# of threads for data loader')    
    

    # training related
    self.parser.add_argument('--lr', default=1e-3, type=int, help='Initial learning rate for training model')
    self.parser.add_argument('--weight', default='dwa', type=str, help='multi-task weighting: equal, uncert, dwa')
    self.parser.add_argument('--n_ep', type=int, default=1500, help='number of epochs') # 400 * d_iter
    self.parser.add_argument('--n_ep_decay', type=int, default=1000, help='epoch start decay learning rate, set -1 if no decay') # 200 * d_iter
    self.parser.add_argument('--resume', type=str, default=None, help='specified the dir of saved models for resume the training')
     # 不要改该参数,系统会自动分配
    self.parser.add_argument('--gpu', type=str, default='cuda', help='device id (i.e. 0 or 0,1 or cpu)')
    self.parser.add_argument('--temp', default=2.0, type=float, help='temperature for DWA (must be positive)')    
    
    # ouptput related
    self.parser.add_argument('--name', type=str, default='MPF-Trans-skip_DDP', help='folder name to save outputs')
    self.parser.add_argument('--class_nb', type=int, default=9, help='class number for segmentation model')
    self.parser.add_argument('--display_dir', type=str, default='/data/timer/Idea/mtan/logs', help='path for saving display results')
    self.parser.add_argument('--result_dir', type=str, default='/data/timer/Idea/mtan/results', help='path for saving result images and models')
    self.parser.add_argument('--display_freq', type=int, default=10, help='freq (iteration) of display')
    self.parser.add_argument('--img_save_freq', type=int, default=10, help='freq (epoch) of saving images')
    self.parser.add_argument('--model_save_freq', type=int, default=10, help='freq (epoch) of saving models')
    
    # DDP related
    self.parser.add_argument('--local_rank', type=int, default=0, help='Specifying the default GPU')
    
  def parse(self):
    self.opt = self.parser.parse_args()
    args = vars(self.opt)
    print('\n--- load options ---')
    for name, value in sorted(args.items()):
      print('%s: %s' % (str(name), str(value)))
    return self.opt

class TestOptions():
  def __init__(self):
    self.parser = argparse.ArgumentParser()

    # data loader related
    self.parser.add_argument('--dataroot', type=str, default='/data/timer/Idea/mtan/dataset/MSRS', help='path of data')
    self.parser.add_argument('--phase', type=str, default='test', help='phase for dataloading')
    self.parser.add_argument('--batch_size', type=int, default=16, help='batch size')
    self.parser.add_argument('--nThreads', type=int, default=16, help='# of threads for data loader')    
    
    ## mode related
    self.parser.add_argument('--class_nb', type=int, default=9, help='class number for segmentation model')
    self.parser.add_argument('--resume', type=str, default='/data/timer/Idea/mtan/results/MPF-skip/best_model.pth', help='specified the dir of saved models for resume the training')
    self.parser.add_argument('--gpu', type=int, default=0, help='GPU id')
    
    # results related
    self.parser.add_argument('--name', type=str, default='MPF_skip', help='folder name to save outputs')
    self.parser.add_argument('--result_dir', type=str, default='/data/timer/Idea/mtan/test', help='path for saving result images and models')
    
  def parse(self):
    self.opt = self.parser.parse_args()
    args = vars(self.opt)
    print('\n--- load options ---')
    for name, value in sorted(args.items()):
      print('%s: %s' % (str(name), str(value)))
    return self.opt

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70

一些主要的操作都在train.py文件里有所涉及,因为是第一次系统的使用DDP,还有很多地方理解的不够透彻,不当之处希望大家指出一起交流。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/708085
推荐阅读
相关标签
  

闽ICP备14008679号