当前位置:   article > 正文

esrgan_关于ESRGAN复现的问题

esrgan复现

import os

import math

import argparse

import random

import logging

import torch

import torch.distributed as dist

import torch.multiprocessing as mp

from data.data_sampler import DistIterSampler

import options.options as option

from utils import util

from data import create_dataloader, create_dataset

from models import create_model

def init_dist(backend='nccl', **kwargs):

''' initialization for distributed training'''

# if mp.get_start_method(allow_none=True) is None:

if mp.get_start_method(allow_none=True) != 'spawn':

mp.set_start_method('spawn')

rank = int(os.environ['RANK'])

num_gpus = torch.cuda.device_count()

torch.cuda.set_device(rank % num_gpus)

dist.init_process_group(backend=backend, **kwargs)

def main():

#### options

parser = argparse.ArgumentParser()

parser.add_argument('-opt', type=str, help='Path to option YMAL file.')

parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',

help='job launcher')

parser.add_argument('--local_rank', type=int, default=0)

args = parser.parse_args()

opt = option.parse(args.opt, is_train=True)

#### distributed training settings

if args.launcher == 'none':  # disabled distributed training

opt['dist'] = False

rank = -1

print('Disabled distributed training.')

else:

opt['dist'] = True

init_dist()

world_size = torch.distributed.get_world_size()

rank = torch.distributed.get_rank()

#### loading resume state if exists

if opt['path'].get('resume_state', None):

# dis

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/420489
推荐阅读
相关标签
  

闽ICP备14008679号