当前位置:   article > 正文

神经网络load_state_dict()进阶使用

load_state_dict

 很多时候我们需要提前加载预训练的模型,一般情况下直接使用model.load_state_dict(torch.load(state_path) )就行了,但是有些时候预训练的模型可能和要训练的模型之间不是所有参数都能对应上的。

分为三种情况,1.某个参数预训练模型有但是目标模型没有。2某个参数预训练模型没有但是目标模型有。3参数预训练模型和目标模型都有但是参数对不上(例如形状不同)。前两者不会影响加载,将strice设置为False就可以了,同时会输出missing_keys和unexpected_keys,说明哪些参数缺失。

但是如果参数的名称对的上但是值的形状对不上就会有问题,无法加载,这个时候我们可以通过将del state_dict(wrong_key)的方法来消除问题

  1. missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict)

 完整代码如下,map_location是用来将数据所在的设备进行重定向的

  1. device = torch.cuda.current_device()
  2. state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage.cuda(device))
  3. src_state_dict = state_dict['net']
  4. target_state_dict = model.state_dict()
  5. skip_keys = []
  6. # skip mismatch size tensors in case of pretraining
  7. for k in src_state_dict.keys():
  8. if k not in target_state_dict:
  9. continue
  10. if src_state_dict[k].size() != target_state_dict[k].size():
  11. skip_keys.append(k)
  12. for k in skip_keys:
  13. del src_state_dict[k]
  14. missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=strict)
  15. if skip_keys:
  16. logger.info(
  17. f'removed keys in source state_dict due to size mismatch: {", ".join(skip_keys)}')
  18. if missing_keys:
  19. logger.info(f'missing keys in source state_dict: {", ".join(missing_keys)}')
  20. if unexpected_keys:
  21. logger.info(f'unexpected key in source state_dict: {", ".join(unexpected_keys)}')

实际上在训练过程中我们最好还保存模型的优化器的参数,这样使得我们的训练即使被中断也可以继续训练,因为一般优化器的参数会随着训练的进行自动进行调节,所以保存优化器的参数也是很重要的。此外就是训练的轮数,方便我们知道训练了多少轮。

  1. # load optimizer
  2. if optimizer is not None:
  3. assert 'optimizer' in state_dict
  4. optimizer.load_state_dict(state_dict['optimizer'])
  5. if 'epoch' in state_dict:
  6. epoch = state_dict['epoch']
  7. else:
  8. epoch = 0

加载的完整代码

  1. def load_checkpoint(checkpoint, logger, model, optimizer=None, strict=False):
  2. device = torch.cuda.current_device()
  3. state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage.cuda(device))
  4. src_state_dict = state_dict['net']
  5. target_state_dict = model.state_dict()
  6. skip_keys = []
  7. # skip mismatch size tensors in case of pretraining
  8. for k in src_state_dict.keys():
  9. if k not in target_state_dict:
  10. continue
  11. if src_state_dict[k].size() != target_state_dict[k].size():
  12. skip_keys.append(k)
  13. for k in skip_keys:
  14. del src_state_dict[k]
  15. missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=strict)
  16. if skip_keys:
  17. logger.info(
  18. f'removed keys in source state_dict due to size mismatch: {", ".join(skip_keys)}')
  19. if missing_keys:
  20. logger.info(f'missing keys in source state_dict: {", ".join(missing_keys)}')
  21. if unexpected_keys:
  22. logger.info(f'unexpected key in source state_dict: {", ".join(unexpected_keys)}')
  23. # load optimizer
  24. if optimizer is not None:
  25. assert 'optimizer' in state_dict
  26. optimizer.load_state_dict(state_dict['optimizer'])
  27. if 'epoch' in state_dict:
  28. epoch = state_dict['epoch']
  29. else:
  30. epoch = 0
  31. return epoch + 1

顺便把配套的保存训练结果的代码也方进来

其中的

  1. def checkpoint_save(epoch, model, optimizer, work_dir, save_freq=16):
  2. f = os.path.join(work_dir, f'epoch_{epoch}.pth')
  3. checkpoint = {
  4. 'net': model.state_dict(),
  5. 'optimizer': optimizer.state_dict(),
  6. 'epoch': epoch
  7. }
  8. torch.save(checkpoint, f)
  9. #如果已经存在一个latest最新的pth文件,将其移除,然后利用ln -s命令将epoch_xxx.pth和 latest.pth链接起来
  10. if os.path.exists(f'{work_dir}/latest.pth'):
  11. os.remove(f'{work_dir}/latest.pth')
  12. os.system(f'cd {work_dir}; ln -s {osp.basename(f)} latest.pth')
  13. #除非epoch为2的某个指数的值或者是save_freq的某个倍数,否则移除该pth文件,避免保存过多的pth文件
  14. epoch = epoch - 1
  15. f = os.path.join(work_dir, f'epoch_{epoch}.pth')
  16. if os.path.isfile(f):
  17. if not is_multiple(epoch, save_freq) and not is_power2(epoch):
  18. os.remove(f)
  19. def is_power2(num):
  20. return num != 0 and ((num & (num - 1)) == 0)
  21. def is_multiple(num, multiple):
  22. return num != 0 and num % multiple == 0

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/103413?site
推荐阅读
相关标签
  

闽ICP备14008679号