赞
踩
源代码地址:Swin-Transformer
预训练模型链接:swinv2_base_patch4_window12_192_22k.pth
本机为Ubuntu系统,为了训练自己的数据集,在原代码的基础上做了一点小调整:
── imagenet ├── train │ ├── class1 │ │ ├── cat0001.jpg │ │ ├── cat0002.jpg │ │ └── ... │ ├── class2 │ │ ├── dog0001.jpg │ │ ├── dog0002.jpg │ │ └── ... │ └── class3 │ ├── bird0001.jpg │ ├── bird0002.jpg │ └── ... └── val ├── class1 ├── class2 └── class3
以swinv2_base_patch4_window12_192_22k.yaml
为例
DATA: # 为了配合上方的数据集存放格式,DATASET的value需设置为imagenet DATASET: imagenet IMG_SIZE: 384 # NAME_CLASSES是自己增加的,在推理阶段可视化时使用 NAME_CLASSES: ["cat", "dog", "bird"] MODEL: TYPE: swinv2 NAME: swinv2_base_patch4_window12_192_22k DROP_PATH_RATE: 0.2 # NUM_CLASSES是增加进来的默认是1000 NUM_CLASSES: 3 SWINV2: EMBED_DIM: 128 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 4, 8, 16, 32 ] WINDOW_SIZE: 12 TRAIN: EPOCHS: 90 WARMUP_EPOCHS: 5 WEIGHT_DECAY: 0.1 BASE_LR: 1.25e-4 # 4096 batch-size WARMUP_LR: 1.25e-7 MIN_LR: 1.25e-6
针对上方的调整相应地需要修改config.py
文件
_C.DATA = CN()
# 增加NAME_CLASSES字段的默认值
_C.DATA.NAME_CLASSES = []
main.py
if __name__ == '__main__':
args, config = parse_option()
# 训练环境为本地单机单卡,手动写入环境变量中一些字段
os.environ['WORLD_SIZE'] = '1'
os.environ['RANK'] = '0'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
# ...
if config.TRAIN.AUTO_RESUME:
resume_file = auto_resume_helper(config.OUTPUT, get_best=True)
# 原代码中计算acc时输出的是top-1 acc和top-5 acc,但我自己的数据集只有3个类别 # 所以调整为输出top-1 acc和top-2 acc # 增加了每个类别的acc的输出 def validate(config, data_loader, model): criterion = torch.nn.CrossEntropyLoss() model.eval() batch_time = AverageMeter() loss_meter = AverageMeter() acc1_meter = AverageMeter() acc2_meter = AverageMeter() cla_num_meter = np.zeros(config.MODEL.NUM_CLASSES) pre_num_meter = np.zeros(config.MODEL.NUM_CLASSES) end = time.time() for idx, (images, target) in enumerate(data_loader): images = images.cuda(non_blocking=True) target = target.cuda(non_blocking=True) # compute output with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): output = model(images) # measure accuracy and record loss loss = criterion(output, target) acc1, acc2 = accuracy(output, target, topk=(1, 2)) cla_num, pre_num = cla_accuracy(output, target, config.MODEL.NUM_CLASSES) cla_num_meter += cla_num pre_num_meter += pre_num acc1 = reduce_tensor(acc1) acc2 = reduce_tensor(acc2) loss = reduce_tensor(loss) loss_meter.update(loss.item(), target.size(0)) acc1_meter.update(acc1.item(), target.size(0)) acc2_meter.update(acc2.item(), target.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if idx % config.PRINT_FREQ == 0: memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) logger.info( f'Test: [{idx}/{len(data_loader)}]\t' f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' f'Acc@2 {acc2_meter.val:.3f} ({acc2_meter.avg:.3f})\t' f'Mem {memory_used:.0f}MB') logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@2 {acc2_meter.avg:.3f}') ans = '' acc_each_class = [pre_num_meter[i] / cla_num_meter[i] for i in range(config.MODEL.NUM_CLASSES)] for i in range(config.MODEL.NUM_CLASSES): ans += f'Acc of {config.DATA.NAME_CLASSES[i]}: {acc_each_class[i]}\t' logger.info(ans) return acc1_meter.avg, acc2_meter.avg, loss_meter.avg def cla_accuracy(output, target, num_class): # 计算每个类别的实际数目和识别正确数目 _, pred = output.topk(1, 1, True, True) pred = pred.t()[0] sam_nums = np.zeros(num_class) pre_cor_nums = np.zeros(num_class) for i in range(len(target)): sam_nums[int(target[i])] += 1 if int(target[i]) == int(pred[i]): pre_cor_nums[int(target[i])] += 1 return sam_nums, pre_cor_nums
# 原代码每个epoch保存一个模型,调整为只保存best_ckpt.pth和last_epoch_ckpt.pth for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): data_loader_train.sampler.set_epoch(epoch) train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler) acc1, acc2, loss = validate(config, data_loader_val, model) if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): if acc1 > max_accuracy: ckpt_name = "best_ckpt" else: ckpt_name = "last_epoch_ckpt" save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger, ckpt_name) logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") max_accuracy = max(max_accuracy, acc1) logger.info(f'Max accuracy: {max_accuracy:.2f}%')
data/build.py
def build_loader(config):
config.defrost()
# 原代码为dataset_train, config.MODEL.NUM_CLASSES =
# 我们在config文件中已经指明了数据集类别数
dataset_train, _ = build_dataset(is_train=True, config=config)
utils.py
# 修改代码resume时调用的是best_ckpt.pth def auto_resume_helper(output_dir, get_best=False): checkpoints = os.listdir(output_dir) checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] print(f"All checkpoints founded in {output_dir}: {checkpoints}") # 原本的代码是采用时间最近的模型,调整为读取best_ckpt.pth if len(checkpoints) > 0 and not get_best: latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) print(f"The latest checkpoint founded: {latest_checkpoint}") resume_file = latest_checkpoint elif get_best and "best_ckpt.pth" in checkpoints: print(f"The best checkpoint founded: {os.path.join(output_dir, 'best_ckpt.pth')}") resume_file = os.path.join(output_dir, 'best_ckpt.pth') else: resume_file = None return resume_file
def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger, ckpt_name):
save_state = {'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'max_accuracy': max_accuracy,
'scaler': loss_scaler.state_dict(),
'epoch': epoch,
'config': config}
save_path = os.path.join(config.OUTPUT, f'{ckpt_name}.pth')
logger.info(f"{save_path} saving......")
torch.save(save_state, save_path)
logger.info(f"{save_path} saved !!!")
python main.py --cfg configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml --batch-size 4 --data-path imagenet --pretrained swinv2_base_patch4_window12_192_22k.pth --local_rank 0
python main.py --eval --cfg configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml --resume output/swinv2_base_patch4_window12_192_22k/default/best_ckpt.pth --data-path imagenet --local_rank 0
评估阶段的终端输出:
原作者没有提供inference代码,根据evaluate流程写一个简单的推理脚本。
import os import argparse from torch.autograd import Variable import cv2 import torch from torchvision import transforms from config import get_config from models import build_model from PIL import Image from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD try: from torchvision.transforms import InterpolationMode def _pil_interp(method): if method == 'bicubic': return InterpolationMode.BICUBIC elif method == 'lanczos': return InterpolationMode.LANCZOS elif method == 'hamming': return InterpolationMode.HAMMING else: # default bilinear, do we want to allow nearest? return InterpolationMode.BILINEAR import timm.data.transforms as timm_transforms timm_transforms._pil_interp = _pil_interp except: from timm.data.transforms import _pil_interp def parse_option(): parser = argparse.ArgumentParser('Swin Transformer inference script', add_help=False) parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) parser.add_argument( "--opts", help="Modify config options by adding 'KEY VALUE' pairs. ", default=None, nargs='+', ) # easy config modification parser.add_argument('--batch-size', type=int, help="batch size for single GPU") parser.add_argument('--data-path', type=str, help='path to dataset') parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset') parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], help='no: no cache, ' 'full: cache all data, ' 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') parser.add_argument('--pretrained', help='pretrained weight from checkpoint, could be imagenet22k pretrained weight') parser.add_argument('--resume', help='resume from checkpoint') parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") parser.add_argument('--use-checkpoint', action='store_true', help="whether to use gradient checkpointing to save memory") parser.add_argument('--disable_amp', action='store_true', help='Disable pytorch amp') parser.add_argument('--amp-opt-level', type=str, choices=['O0', 'O1', 'O2'], help='mixed precision opt level, if O0, no amp is used (deprecated!)') parser.add_argument('--output', default='output', type=str, metavar='PATH', help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)') parser.add_argument('--tag', help='tag of experiment') parser.add_argument('--eval', action='store_true', help='Perform evaluation only') parser.add_argument('--throughput', action='store_true', help='Test throughput only') # distributed training parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') # for acceleration parser.add_argument('--fused_window_process', action='store_true', help='Fused window shift & window partition, similar for reversed part.') parser.add_argument('--fused_layernorm', action='store_true', help='Use fused layernorm.') ## overwrite optimizer in config (*.yaml) if specified, e.g., fused_adam/fused_lamb parser.add_argument('--optim', type=str, help='overwrite optimizer if provided, can be adamw/sgd/fused_adam/fused_lamb.') args, unparsed = parser.parse_known_args() config = get_config(args) return args, config if __name__ == '__main__': args, config = parse_option() transform_test = transforms.Compose( [transforms.Resize( (config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), interpolation=_pil_interp(config.DATA.INTERPOLATION)), transforms.ToTensor(), transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) ] ) classes = config.DATA.NAME_CLASSES DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = build_model(config) checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu') model.load_state_dict(checkpoint['model'], strict=False) model.eval() model.to(DEVICE) path = config.DATA.DATA_PATH testList = os.listdir(path) for file in testList: img = Image.open(os.path.join(path + file)) img = transform_test(img) img.unsqueeze_(0) img = Variable(img).to(DEVICE) out = model(img) _,pred = torch.max(out.data, 1) ori_img = cv2.imread(os.path.join(path + file)) text = 'ImageName:{}, predict:{}'.format(file, classes[pred.data.item()]) font = cv2.FONT_HERSHEY_SIMPLEX txt_size = cv2.getTextSize(text, font, 0.7, 1)[0] x0 = int(ori_img.shape[1] / 2.0) cv2.putText(ori_img, text, (x0 - int(txt_size[0] / 2.0), int(0 + txt_size[1])), font, 0.7, (0, 0, 255), thickness=1) cv2.imshow(os.path.join(path, file), ori_img) cv2.waitKey(0) cv2.destroyAllWindows()
python inference.py --cfg configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml --data-path images/ --pretrained output/swinv2_base_patch4_window12_192_22k/default/best_ckpt.pth --local_rank 0
查看我的下一篇文章:使用grad-cam对swin transformer的特征进行可视化
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。