赞
踩
mixed_precision = True
try: # Mixed precision training https://github.com/NVIDIA/apex
from apex import amp
except:
print('Apex recommended for faster mixed precision training: https://github.com/NVIDIA/apex')
mixed_precision = False # not installed
wdir = 'weights' + os.sep # weights dir
os.makedirs(wdir, exist_ok=True)
last = wdir + 'last.pt'
best = wdir + 'best.pt'
results_file = 'results.txt'
# Configure
init_seeds(1)
with open(opt.data) as f:
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
train_path = data_dict['train']
test_path = data_dict['val']
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
# Remove previous results
for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
os.remove(f)
# Create model
model = Model(opt.cfg).to(device)
assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
model.names = data_dict['names']
assert是一个判断表达式,在assert后面成立时创建模型。
参考链接
# Image sizes
gs = int(max(model.stride)) # grid size (max stride)
imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
# Optimizer nbs = 64 # nominal batch size accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay pg0, pg1, pg2 = [], [], [] # optimizer parameter groups for k, v in model.named_parameters(): if v.requires_grad: if '.bias' in k: pg2.append(v) # biases elif '.weight' in k and '.bn' not in k: pg1.append(v) # apply weight decay else: pg0.append(v) # all else
optimizer <span class="token operator">=</span> optim<span class="token punctuation">.</span>Adam<span class="token punctuation">(</span>pg0<span class="token punctuation">,</span> lr<span class="token operator">=</span>hyp<span class="token punctuation">[</span><span class="token string">'lr0'</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token keyword">if</span> opt<span class="token punctuation">.</span>adam <span class="token keyword">else</span> \ optim<span class="token punctuation">.</span>SGD<span class="token punctuation">(</span>pg0<span class="token punctuation">,</span> lr<span class="token operator">=</span>hyp<span class="token punctuation">[</span><span class="token string">'lr0'</span><span class="token punctuation">]</span><span class="token punctuation">,</span> momentum<span class="token operator">=</span>hyp<span class="token punctuation">[</span><span class="token string">'momentum'</span><span class="token punctuation">]</span><span class="token punctuation">,</span> nesterov<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span> optimizer<span class="token punctuation">.</span>add_param_group<span class="token punctuation">(</span><span class="token punctuation">{<!-- --></span><span class="token string">'params'</span><span class="token punctuation">:</span> pg1<span class="token punctuation">,</span> <span class="token string">'weight_decay'</span><span class="token punctuation">:</span> hyp<span class="token punctuation">[</span><span class="token string">'weight_decay'</span><span class="token punctuation">]</span><span class="token punctuation">}</span><span class="token punctuation">)</span> <span class="token comment"># add pg1 with weight_decay</span> optimizer<span class="token punctuation">.</span>add_param_group<span class="token punctuation">(</span><span class="token punctuation">{<!-- --></span><span class="token string">'params'</span><span class="token punctuation">:</span> pg2<span class="token punctuation">}</span><span class="token punctuation">)</span> <span class="token comment"># add pg2 (biases)</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'Optimizer groups: %g .bias, %g conv.weight, %g other'</span> <span class="token operator">%</span> <span class="token punctuation">(</span><span class="token builtin">len</span><span class="token punctuation">(</span>pg2<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token builtin">len</span><span class="token punctuation">(</span>pg1<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token builtin">len</span><span class="token punctuation">(</span>pg0<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">del</span> pg0<span class="token punctuation">,</span> pg1<span class="token punctuation">,</span> pg2
Optimizer groups: 102 .bias, 108 conv.weight, 99 other
del并非删除数据,而是删除变量(删除指向数据的链接)参考链接
# Load Model google_utils.attempt_download(weights) start_epoch, best_fitness = 0, 0.0 if weights.endswith('.pt'): # pytorch format ckpt = torch.load(weights, map_location=device) # load checkpoint
<span class="token comment"># load model</span> <span class="token keyword">try</span><span class="token punctuation">:</span> ckpt<span class="token punctuation">[</span><span class="token string">'model'</span><span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token punctuation">{<!-- --></span>k<span class="token punctuation">:</span> v <span class="token keyword">for</span> k<span class="token punctuation">,</span> v <span class="token keyword">in</span> ckpt<span class="token punctuation">[</span><span class="token string">'model'</span><span class="token punctuation">]</span><span class="token punctuation">.</span><span class="token builtin">float</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>items<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">if</span> model<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">[</span>k<span class="token punctuation">]</span><span class="token punctuation">.</span>shape <span class="token operator">==</span> v<span class="token punctuation">.</span>shape<span class="token punctuation">}</span> <span class="token comment"># to FP32, filter</span> model<span class="token punctuation">.</span>load_state_dict<span class="token punctuation">(</span>ckpt<span class="token punctuation">[</span><span class="token string">'model'</span><span class="token punctuation">]</span><span class="token punctuation">,</span> strict<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span> <span class="token keyword">except</span> KeyError <span class="token keyword">as</span> e<span class="token punctuation">:</span> s <span class="token operator">=</span> <span class="token string">"%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s."</span> \ <span class="token operator">%</span> <span class="token punctuation">(</span>opt<span class="token punctuation">.</span>weights<span class="token punctuation">,</span> opt<span class="token punctuation">.</span>cfg<span class="token punctuation">,</span> opt<span class="token punctuation">.</span>weights<span class="token punctuation">)</span> <span class="token keyword">raise</span> KeyError<span class="token punctuation">(</span>s<span class="token punctuation">)</span> <span class="token keyword">from</span> e <span class="token comment"># load optimizer</span> <span class="token keyword">if</span> ckpt<span class="token punctuation">[</span><span class="token string">'optimizer'</span><span class="token punctuation">]</span> <span class="token keyword">is</span> <span class="token operator">not</span> <span class="token boolean">None</span><span class="token punctuation">:</span> optimizer<span class="token punctuation">.</span>load_state_dict<span class="token punctuation">(</span>ckpt<span class="token punctuation">[</span><span class="token string">'optimizer'</span><span class="token punctuation">]</span><span class="token punctuation">)</span> best_fitness <span class="token operator">=</span> ckpt<span class="token punctuation">[</span><span class="token string">'best_fitness'</span><span class="token punctuation">]</span> <span class="token comment"># load results</span> <span class="token keyword">if</span> ckpt<span class="token punctuation">.</span>get<span class="token punctuation">(</span><span class="token string">'training_results'</span><span class="token punctuation">)</span> <span class="token keyword">is</span> <span class="token operator">not</span> <span class="token boolean">None</span><span class="token punctuation">:</span> <span class="token keyword">with</span> <span class="token builtin">open</span><span class="token punctuation">(</span>results_file<span class="token punctuation">,</span> <span class="token string">'w'</span><span class="token punctuation">)</span> <span class="token keyword">as</span> <span class="token builtin">file</span><span class="token punctuation">:</span> <span class="token builtin">file</span><span class="token punctuation">.</span>write<span class="token punctuation">(</span>ckpt<span class="token punctuation">[</span><span class="token string">'training_results'</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># write results.txt</span> start_epoch <span class="token operator">=</span> ckpt<span class="token punctuation">[</span><span class="token string">'epoch'</span><span class="token punctuation">]</span> <span class="token operator">+</span> <span class="token number">1</span> <span class="token keyword">del</span> ckpt
若之前mixed_precision=False
则不会加入混合精度训练至训练中。
if mixed_precision:
model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)
opt_level=‘O1’ ,这里不是‘零1’,而是“O1”(偶1)
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1 # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
scheduler.last_epoch = start_epoch - 1 # do not move
# Initialize distributed training
if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available():
dist.init_process_group(backend='nccl', # distributed backend
init_method='tcp://127.0.0.1:9999', # init method
world_size=1, # number of nodes
rank=0) # node rank
model = torch.nn.parallel.DistributedDataParallel(model)
当满足上面三个条件(非CPU、cuda设备大于1、分布式torch可用)时,就可以进行分布式训练了。
笔者是用一张卡来训练的,不满足这个条件,没有用到分布式训练。—————————————————————————————————————————
nn.distributedataparallel()支持模型多进程并行,适用于单机或多机,每个进程都具备独立的优化器,执行自己的更新过程。
参考链接
# Trainloader dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect) mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)
<span class="token comment"># Testloader</span> testloader <span class="token operator">=</span> create_dataloader<span class="token punctuation">(</span>test_path<span class="token punctuation">,</span> imgsz_test<span class="token punctuation">,</span> batch_size<span class="token punctuation">,</span> gs<span class="token punctuation">,</span> opt<span class="token punctuation">,</span> hyp<span class="token operator">=</span>hyp<span class="token punctuation">,</span> augment<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span> cache<span class="token operator">=</span>opt<span class="token punctuation">.</span>cache_images<span class="token punctuation">,</span> rect<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span>
dataloader和testloader不同之处在于:
- testloader:没有数据增强,rect=True(大概是测试图片保留了原图的长宽比)
- dataloader:数据增强,保留了矩形框训练。
# Model parameters
hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
model.nc = nc # attach number of classes to model
model.hyp = hyp # attach hyperparameters to model
model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
# Class frequency
labels = np.concatenate(dataset.labels, 0)
c = torch.tensor(labels[:, 0]) # classes
# cf = torch.bincount(c.long(), minlength=nc) + 1.
# model._initialize_biases(cf.to(device))
if tb_writer:
plot_labels(labels)
tb_writer.add_histogram('classes', c, 0)
# Check anchors
if not opt.noautoanchor:
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
# Exponential moving average
ema = torch_utils.ModelEMA(model)
在深度学习中,经常会使用EMA(指数移动平均)这个方法对模型的参数做平均,以求提高测试指标并增加模型鲁棒。参考博客
获取开始时间,batch size数量,epochs数量,图片数量。
# Start training
t0 = time.time() # start time
nb = len(dataloader) # number of batches
n_burn = max(3 * nb, 1e3) # burn-in iterations, max(3 epochs, 1k iterations)
maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
print('Using %g dataloader workers' % dataloader.num_workers)
print('Starting training for %g epochs...' % epochs)
# torch.autograd.set_detect_anomaly(True)
加载图片权重(可选),定义进度条,设置偏差Burn-in,使用多尺度,前向传播,损失函数,反向传播,优化器,打印进度条,保存训练参数至tensorboard,计算mAP,保存结果到results.txt,保存模型(最好和最后)。
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------ model.train()
<span class="token comment"># Update image weights (optional)</span> <span class="token keyword">if</span> dataset<span class="token punctuation">.</span>image_weights<span class="token punctuation">:</span> w <span class="token operator">=</span> model<span class="token punctuation">.</span>class_weights<span class="token punctuation">.</span>cpu<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>numpy<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token operator">*</span> <span class="token punctuation">(</span><span class="token number">1</span> <span class="token operator">-</span> maps<span class="token punctuation">)</span> <span class="token operator">**</span> <span class="token number">2</span> <span class="token comment"># class weights</span> image_weights <span class="token operator">=</span> labels_to_image_weights<span class="token punctuation">(</span>dataset<span class="token punctuation">.</span>labels<span class="token punctuation">,</span> nc<span class="token operator">=</span>nc<span class="token punctuation">,</span> class_weights<span class="token operator">=</span>w<span class="token punctuation">)</span> dataset<span class="token punctuation">.</span>indices <span class="token operator">=</span> random<span class="token punctuation">.</span>choices<span class="token punctuation">(</span><span class="token builtin">range</span><span class="token punctuation">(</span>dataset<span class="token punctuation">.</span>n<span class="token punctuation">)</span><span class="token punctuation">,</span> weights<span class="token operator">=</span>image_weights<span class="token punctuation">,</span> k<span class="token operator">=</span>dataset<span class="token punctuation">.</span>n<span class="token punctuation">)</span> <span class="token comment"># rand weighted idx</span> <span class="token comment"># Update mosaic border</span> <span class="token comment"># b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)</span> <span class="token comment"># dataset.mosaic_border = [b - imgsz, -b] # height, width borders</span> mloss <span class="token operator">=</span> torch<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span><span class="token number">4</span><span class="token punctuation">,</span> device<span class="token operator">=</span>device<span class="token punctuation">)</span> <span class="token comment"># mean losses</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token string">'\n'</span> <span class="token operator">+</span> <span class="token string">'%10s'</span> <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">)</span> <span class="token operator">%</span> <span class="token punctuation">(</span><span class="token string">'Epoch'</span><span class="token punctuation">,</span> <span class="token string">'gpu_mem'</span><span class="token punctuation">,</span> <span class="token string">'GIoU'</span><span class="token punctuation">,</span> <span class="token string">'obj'</span><span class="token punctuation">,</span> <span class="token string">'cls'</span><span class="token punctuation">,</span> <span class="token string">'total'</span><span class="token punctuation">,</span> <span class="token string">'targets'</span><span class="token punctuation">,</span> <span class="token string">'img_size'</span><span class="token punctuation">)</span><span class="token punctuation">)</span> pbar <span class="token operator">=</span> tqdm<span class="token punctuation">(</span><span class="token builtin">enumerate</span><span class="token punctuation">(</span>dataloader<span class="token punctuation">)</span><span class="token punctuation">,</span> total<span class="token operator">=</span>nb<span class="token punctuation">)</span> <span class="token comment"># progress bar</span> <span class="token keyword">for</span> i<span class="token punctuation">,</span> <span class="token punctuation">(</span>imgs<span class="token punctuation">,</span> targets<span class="token punctuation">,</span> paths<span class="token punctuation">,</span> _<span class="token punctuation">)</span> <span class="token keyword">in</span> pbar<span class="token punctuation">:</span> <span class="token comment"># batch -------------------------------------------------------------</span> ni <span class="token operator">=</span> i <span class="token operator">+</span> nb <span class="token operator">*</span> epoch <span class="token comment"># number integrated batches (since train start)</span> imgs <span class="token operator">=</span> imgs<span class="token punctuation">.</span>to<span class="token punctuation">(</span>device<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">float</span><span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token operator">/</span> <span class="token number">255.0</span> <span class="token comment"># uint8 to float32, 0 - 255 to 0.0 - 1.0</span> <span class="token comment"># Burn-in</span> <span class="token keyword">if</span> ni <span class="token operator"><=</span> n_burn<span class="token punctuation">:</span> xi <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> n_burn<span class="token punctuation">]</span> <span class="token comment"># x interp</span> <span class="token comment"># model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou)</span> accumulate <span class="token operator">=</span> <span class="token builtin">max</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> np<span class="token punctuation">.</span>interp<span class="token punctuation">(</span>ni<span class="token punctuation">,</span> xi<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> nbs <span class="token operator">/</span> batch_size<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">round</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">for</span> j<span class="token punctuation">,</span> x <span class="token keyword">in</span> <span class="token builtin">enumerate</span><span class="token punctuation">(</span>optimizer<span class="token punctuation">.</span>param_groups<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token comment"># bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0</span> x<span class="token punctuation">[</span><span class="token string">'lr'</span><span class="token punctuation">]</span> <span class="token operator">=</span> np<span class="token punctuation">.</span>interp<span class="token punctuation">(</span>ni<span class="token punctuation">,</span> xi<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">0.1</span> <span class="token keyword">if</span> j <span class="token operator">==</span> <span class="token number">2</span> <span class="token keyword">else</span> <span class="token number">0.0</span><span class="token punctuation">,</span> x<span class="token punctuation">[</span><span class="token string">'initial_lr'</span><span class="token punctuation">]</span> <span class="token operator">*</span> lf<span class="token punctuation">(</span>epoch<span class="token punctuation">)</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token keyword">if</span> <span class="token string">'momentum'</span> <span class="token keyword">in</span> x<span class="token punctuation">:</span> x<span class="token punctuation">[</span><span class="token string">'momentum'</span><span class="token punctuation">]</span> <span class="token operator">=</span> np<span class="token punctuation">.</span>interp<span class="token punctuation">(</span>ni<span class="token punctuation">,</span> xi<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">0.9</span><span class="token punctuation">,</span> hyp<span class="token punctuation">[</span><span class="token string">'momentum'</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># Multi-scale</span> <span class="token keyword">if</span> opt<span class="token punctuation">.</span>multi_scale<span class="token punctuation">:</span> sz <span class="token operator">=</span> random<span class="token punctuation">.</span>randrange<span class="token punctuation">(</span>imgsz <span class="token operator">*</span> <span class="token number">0.5</span><span class="token punctuation">,</span> imgsz <span class="token operator">*</span> <span class="token number">1.5</span> <span class="token operator">+</span> gs<span class="token punctuation">)</span> <span class="token operator">//</span> gs <span class="token operator">*</span> gs <span class="token comment"># size</span> sf <span class="token operator">=</span> sz <span class="token operator">/</span> <span class="token builtin">max</span><span class="token punctuation">(</span>imgs<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">:</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># scale factor</span> <span class="token keyword">if</span> sf <span class="token operator">!=</span> <span class="token number">1</span><span class="token punctuation">:</span> ns <span class="token operator">=</span> <span class="token punctuation">[</span>math<span class="token punctuation">.</span>ceil<span class="token punctuation">(</span>x <span class="token operator">*</span> sf <span class="token operator">/</span> gs<span class="token punctuation">)</span> <span class="token operator">*</span> gs <span class="token keyword">for</span> x <span class="token keyword">in</span> imgs<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">:</span><span class="token punctuation">]</span><span class="token punctuation">]</span> <span class="token comment"># new shape (stretched to gs-multiple)</span> imgs <span class="token operator">=</span> F<span class="token punctuation">.</span>interpolate<span class="token punctuation">(</span>imgs<span class="token punctuation">,</span> size<span class="token operator">=</span>ns<span class="token punctuation">,</span> mode<span class="token operator">=</span><span class="token string">'bilinear'</span><span class="token punctuation">,</span> align_corners<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span> <span class="token comment"># Forward</span> pred <span class="token operator">=</span> model<span class="token punctuation">(</span>imgs<span class="token punctuation">)</span> <span class="token comment"># Loss</span> loss<span class="token punctuation">,</span> loss_items <span class="token operator">=</span> compute_loss<span class="token punctuation">(</span>pred<span class="token punctuation">,</span> targets<span class="token punctuation">.</span>to<span class="token punctuation">(</span>device<span class="token punctuation">)</span><span class="token punctuation">,</span> model<span class="token punctuation">)</span> <span class="token keyword">if</span> <span class="token operator">not</span> torch<span class="token punctuation">.</span>isfinite<span class="token punctuation">(</span>loss<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'WARNING: non-finite loss, ending training '</span><span class="token punctuation">,</span> loss_items<span class="token punctuation">)</span> <span class="token keyword">return</span> results <span class="token comment"># Backward</span> <span class="token keyword">if</span> mixed_precision<span class="token punctuation">:</span> <span class="token keyword">with</span> amp<span class="token punctuation">.</span>scale_loss<span class="token punctuation">(</span>loss<span class="token punctuation">,</span> optimizer<span class="token punctuation">)</span> <span class="token keyword">as</span> scaled_loss<span class="token punctuation">:</span> scaled_loss<span class="token punctuation">.</span>backward<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">else</span><span class="token punctuation">:</span> loss<span class="token punctuation">.</span>backward<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># Optimize</span> <span class="token keyword">if</span> ni <span class="token operator">%</span> accumulate <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span> optimizer<span class="token punctuation">.</span>step<span class="token punctuation">(</span><span class="token punctuation">)</span> optimizer<span class="token punctuation">.</span>zero_grad<span class="token punctuation">(</span><span class="token punctuation">)</span> ema<span class="token punctuation">.</span>update<span class="token punctuation">(</span>model<span class="token punctuation">)</span> <span class="token comment"># Print</span> mloss <span class="token operator">=</span> <span class="token punctuation">(</span>mloss <span class="token operator">*</span> i <span class="token operator">+</span> loss_items<span class="token punctuation">)</span> <span class="token operator">/</span> <span class="token punctuation">(</span>i <span class="token operator">+</span> <span class="token number">1</span><span class="token punctuation">)</span> <span class="token comment"># update mean losses</span> mem <span class="token operator">=</span> <span class="token string">'%.3gG'</span> <span class="token operator">%</span> <span class="token punctuation">(</span>torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>memory_cached<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token operator">/</span> <span class="token number">1E9</span> <span class="token keyword">if</span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>is_available<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">else</span> <span class="token number">0</span><span class="token punctuation">)</span> <span class="token comment"># (GB)</span> s <span class="token operator">=</span> <span class="token punctuation">(</span><span class="token string">'%10s'</span> <span class="token operator">*</span> <span class="token number">2</span> <span class="token operator">+</span> <span class="token string">'%10.4g'</span> <span class="token operator">*</span> <span class="token number">6</span><span class="token punctuation">)</span> <span class="token operator">%</span> <span class="token punctuation">(</span> <span class="token string">'%g/%g'</span> <span class="token operator">%</span> <span class="token punctuation">(</span>epoch<span class="token punctuation">,</span> epochs <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> mem<span class="token punctuation">,</span> <span class="token operator">*</span>mloss<span class="token punctuation">,</span> targets<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> imgs<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> pbar<span class="token punctuation">.</span>set_description<span class="token punctuation">(</span>s<span class="token punctuation">)</span> <span class="token comment"># Plot</span> <span class="token keyword">if</span> ni <span class="token operator"><</span> <span class="token number">3</span><span class="token punctuation">:</span> f <span class="token operator">=</span> <span class="token string">'train_batch%g.jpg'</span> <span class="token operator">%</span> ni <span class="token comment"># filename</span> result <span class="token operator">=</span> plot_images<span class="token punctuation">(</span>images<span class="token operator">=</span>imgs<span class="token punctuation">,</span> targets<span class="token operator">=</span>targets<span class="token punctuation">,</span> paths<span class="token operator">=</span>paths<span class="token punctuation">,</span> fname<span class="token operator">=</span>f<span class="token punctuation">)</span> <span class="token keyword">if</span> tb_writer <span class="token operator">and</span> result <span class="token keyword">is</span> <span class="token operator">not</span> <span class="token boolean">None</span><span class="token punctuation">:</span> tb_writer<span class="token punctuation">.</span>add_image<span class="token punctuation">(</span>f<span class="token punctuation">,</span> result<span class="token punctuation">,</span> dataformats<span class="token operator">=</span><span class="token string">'HWC'</span><span class="token punctuation">,</span> global_step<span class="token operator">=</span>epoch<span class="token punctuation">)</span> <span class="token comment"># tb_writer.add_graph(model, imgs) # add model to tensorboard</span> <span class="token comment"># end batch ------------------------------------------------------------------------------------------------</span> <span class="token comment"># Scheduler</span> scheduler<span class="token punctuation">.</span>step<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># mAP</span> ema<span class="token punctuation">.</span>update_attr<span class="token punctuation">(</span>model<span class="token punctuation">)</span> final_epoch <span class="token operator">=</span> epoch <span class="token operator">+</span> <span class="token number">1</span> <span class="token operator">==</span> epochs <span class="token keyword">if</span> <span class="token operator">not</span> opt<span class="token punctuation">.</span>notest <span class="token operator">or</span> final_epoch<span class="token punctuation">:</span> <span class="token comment"># Calculate mAP</span> results<span class="token punctuation">,</span> maps<span class="token punctuation">,</span> times <span class="token operator">=</span> test<span class="token punctuation">.</span>test<span class="token punctuation">(</span>opt<span class="token punctuation">.</span>data<span class="token punctuation">,</span> batch_size<span class="token operator">=</span>batch_size<span class="token punctuation">,</span> imgsz<span class="token operator">=</span>imgsz_test<span class="token punctuation">,</span> save_json<span class="token operator">=</span>final_epoch <span class="token operator">and</span> opt<span class="token punctuation">.</span>data<span class="token punctuation">.</span>endswith<span class="token punctuation">(</span>os<span class="token punctuation">.</span>sep <span class="token operator">+</span> <span class="token string">'coco.yaml'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> model<span class="token operator">=</span>ema<span class="token punctuation">.</span>ema<span class="token punctuation">,</span> single_cls<span class="token operator">=</span>opt<span class="token punctuation">.</span>single_cls<span class="token punctuation">,</span> dataloader<span class="token operator">=</span>testloader<span class="token punctuation">)</span> <span class="token comment"># Write</span> <span class="token keyword">with</span> <span class="token builtin">open</span><span class="token punctuation">(</span>results_file<span class="token punctuation">,</span> <span class="token string">'a'</span><span class="token punctuation">)</span> <span class="token keyword">as</span> f<span class="token punctuation">:</span> f<span class="token punctuation">.</span>write<span class="token punctuation">(</span>s <span class="token operator">+</span> <span class="token string">'%10.4g'</span> <span class="token operator">*</span> <span class="token number">7</span> <span class="token operator">%</span> results <span class="token operator">+</span> <span class="token string">'\n'</span><span class="token punctuation">)</span> <span class="token comment"># P, R, mAP, F1, test_losses=(GIoU, obj, cls)</span> <span class="token keyword">if</span> <span class="token builtin">len</span><span class="token punctuation">(</span>opt<span class="token punctuation">.</span>name<span class="token punctuation">)</span> <span class="token operator">and</span> opt<span class="token punctuation">.</span>bucket<span class="token punctuation">:</span> os<span class="token punctuation">.</span>system<span class="token punctuation">(</span><span class="token string">'gsutil cp results.txt gs://%s/results/results%s.txt'</span> <span class="token operator">%</span> <span class="token punctuation">(</span>opt<span class="token punctuation">.</span>bucket<span class="token punctuation">,</span> opt<span class="token punctuation">.</span>name<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token comment"># Tensorboard</span> <span class="token keyword">if</span> tb_writer<span class="token punctuation">:</span> tags <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token string">'train/giou_loss'</span><span class="token punctuation">,</span> <span class="token string">'train/obj_loss'</span><span class="token punctuation">,</span> <span class="token string">'train/cls_loss'</span><span class="token punctuation">,</span> <span class="token string">'metrics/precision'</span><span class="token punctuation">,</span> <span class="token string">'metrics/recall'</span><span class="token punctuation">,</span> <span class="token string">'metrics/mAP_0.5'</span><span class="token punctuation">,</span> <span class="token string">'metrics/F1'</span><span class="token punctuation">,</span> <span class="token string">'val/giou_loss'</span><span class="token punctuation">,</span> <span class="token string">'val/obj_loss'</span><span class="token punctuation">,</span> <span class="token string">'val/cls_loss'</span><span class="token punctuation">]</span> <span class="token keyword">for</span> x<span class="token punctuation">,</span> tag <span class="token keyword">in</span> <span class="token builtin">zip</span><span class="token punctuation">(</span><span class="token builtin">list</span><span class="token punctuation">(</span>mloss<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">+</span> <span class="token builtin">list</span><span class="token punctuation">(</span>results<span class="token punctuation">)</span><span class="token punctuation">,</span> tags<span class="token punctuation">)</span><span class="token punctuation">:</span> tb_writer<span class="token punctuation">.</span>add_scalar<span class="token punctuation">(</span>tag<span class="token punctuation">,</span> x<span class="token punctuation">,</span> epoch<span class="token punctuation">)</span> <span class="token comment"># Update best mAP</span> fi <span class="token operator">=</span> fitness<span class="token punctuation">(</span>np<span class="token punctuation">.</span>array<span class="token punctuation">(</span>results<span class="token punctuation">)</span><span class="token punctuation">.</span>reshape<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token comment"># fitness_i = weighted combination of [P, R, mAP, F1]</span> <span class="token keyword">if</span> fi <span class="token operator">></span> best_fitness<span class="token punctuation">:</span> best_fitness <span class="token operator">=</span> fi <span class="token comment"># Save model</span> save <span class="token operator">=</span> <span class="token punctuation">(</span><span class="token operator">not</span> opt<span class="token punctuation">.</span>nosave<span class="token punctuation">)</span> <span class="token operator">or</span> <span class="token punctuation">(</span>final_epoch <span class="token operator">and</span> <span class="token operator">not</span> opt<span class="token punctuation">.</span>evolve<span class="token punctuation">)</span> <span class="token keyword">if</span> save<span class="token punctuation">:</span> <span class="token keyword">with</span> <span class="token builtin">open</span><span class="token punctuation">(</span>results_file<span class="token punctuation">,</span> <span class="token string">'r'</span><span class="token punctuation">)</span> <span class="token keyword">as</span> f<span class="token punctuation">:</span> <span class="token comment"># create checkpoint</span> ckpt <span class="token operator">=</span> <span class="token punctuation">{<!-- --></span><span class="token string">'epoch'</span><span class="token punctuation">:</span> epoch<span class="token punctuation">,</span> <span class="token string">'best_fitness'</span><span class="token punctuation">:</span> best_fitness<span class="token punctuation">,</span> <span class="token string">'training_results'</span><span class="token punctuation">:</span> f<span class="token punctuation">.</span>read<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token string">'model'</span><span class="token punctuation">:</span> ema<span class="token punctuation">.</span>ema<span class="token punctuation">.</span>module <span class="token keyword">if</span> <span class="token builtin">hasattr</span><span class="token punctuation">(</span>model<span class="token punctuation">,</span> <span class="token string">'module'</span><span class="token punctuation">)</span> <span class="token keyword">else</span> ema<span class="token punctuation">.</span>ema<span class="token punctuation">,</span> <span class="token string">'optimizer'</span><span class="token punctuation">:</span> <span class="token boolean">None</span> <span class="token keyword">if</span> final_epoch <span class="token keyword">else</span> optimizer<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">}</span> <span class="token comment"># Save last, best and delete</span> torch<span class="token punctuation">.</span>save<span class="token punctuation">(</span>ckpt<span class="token punctuation">,</span> last<span class="token punctuation">)</span> <span class="token keyword">if</span> <span class="token punctuation">(</span>best_fitness <span class="token operator">==</span> fi<span class="token punctuation">)</span> <span class="token operator">and</span> <span class="token operator">not</span> final_epoch<span class="token punctuation">:</span> torch<span class="token punctuation">.</span>save<span class="token punctuation">(</span>ckpt<span class="token punctuation">,</span> best<span class="token punctuation">)</span> <span class="token keyword">del</span> ckpt <span class="token comment"># end epoch ----------------------------------------------------------------------------------------------------</span> <span class="token comment"># end training</span>
Image sizes 608 train, 608 test(设置训练和测试图片的size)
Using 8 dataloader workers(设置batch size 为8,即一次性输入8张图片训练)
Starting training for 100 epochs… (设置为100个epochs)
——————————————————————————————————————
tqdm是一个快速,可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意的迭代器 tqdm(iterator)。
参考博客
tqdm进度条
python pbar = tqdm(enumerate(dataloader), total=nb)
表示进度条,total=nb
预期的迭代次数,即你上面设置的epochs。
——————————————————————————————————————
results.txt保存结果:
0/49 6.44G 0.09249 0.07952 0.05631 0.2283 6 608 0.1107 0.1954 0.1029 0.03088 0.07504 0.06971 0.03865
epoch, best_fitness, training_results, model, optimizer, img-size, P, R, mAP, F1, test_losses=(GIoU, obj, cls)
(有点对不上,后续再补充)
n = opt.name
if len(n):
n = '_' + n if not n.isnumeric() else n
fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n
for f1, f2 in zip([wdir + 'last.pt', wdir + 'best.pt', 'results.txt'], [flast, fbest, fresults]):
if os.path.exists(f1):
os.rename(f1, f2) # rename
ispt = f2.endswith('.pt') # is *.pt
strip_optimizer(f2) if ispt else None # strip optimizer
os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket and ispt else None # upload
if not opt.evolve:
plot_results() # save as results.png
print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
dist.destroy_process_group() if device.type != 'cpu' and torch.cuda.device_count() > 1 else None
torch.cuda.empty_cache()
return results
50 epochs completed in 11.954 hours.
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。