当前位置:   article > 正文

解决过拟合问题的代码_过拟合解决代码

过拟合解决代码

1、使用更大的数据集

2、数据增强

数据增强包括翻转(水平或垂直)、任意角度旋转、随机缩放、随机裁剪、中心裁剪、添加噪声等。代码放在构建的dataloader中,相当于在训练中读取数据的时候进行随机的数据增强。

这里列出了翻转、旋转和裁剪。因为是显著性检测任务,image和label要一同变换。

  1. from torch.utils.data import Dataset
  2. from skimage import io, transform
  3. import numpy as np
  4. import scipy.io as scio
  5. from sklearn.decomposition import PCA
  6. import random
  7. import torchvision.transforms as tf
  8. from math import *
  9. import cv2
  10. def flip(image,mask): #水平翻转和垂直翻转
  11. if random.random()>0.5:
  12. image = np.flip(image, 1)
  13. mask = np.flip(mask, 1)
  14. if random.random()<0.5:
  15. image = np.flip(image, 2)
  16. mask = np.flip(mask, 2)
  17. return image, mask
  18. def rotate_image(image, angle, center=None, scale=1.0):
  19. # grab the dimensions of the image
  20. (h, w) = image.shape[:2]
  21. # if the center is None, initialize it as the center of
  22. if center is None:
  23. center = (w // 2, h // 2)
  24. # perform the rotation
  25. M = cv2.getRotationMatrix2D(center, angle, scale)
  26. rotated = cv2.warpAffine(image, M, (w, h))
  27. # return the rotated image
  28. return rotated
  29. def rotate(image,mask):
  30. angle = tf.RandomRotation.get_params([-180, 180]) # -180~180随机选一个角度旋转
  31. for i in range(image.shape[0]):
  32. image[i] = rotate_image(image[i], angle)
  33. for i in range(mask.shape[0]):
  34. mask[i] = rotate_image(mask[i], angle)
  35. return image, mask
  36. def scale_up(image, mask):
  37. i_b, i_h, i_w = image.shape
  38. m_b, m_h, m_w = mask.shape
  39. sch = random.uniform(1.01,1.5)
  40. scw = random.uniform(1.01,1.5)
  41. ih = int(i_h * sch)
  42. iw = int(i_w * scw)
  43. image = transform.resize(image, (i_b, ih, iw))
  44. mask = transform.resize(mask, (m_b, ih, iw))
  45. rh = random.randrange(0,ih-i_h,1)
  46. rw = random.randrange(0,iw-i_w,1)
  47. image = image[:,rh:(rh+i_h),rw:(rw+i_w)]
  48. mask = mask[:,rh:(rh+i_h),rw:(rw+i_w)]
  49. return image, mask
  50. def scale_down(image, mask):
  51. i_b, i_h, i_w = image.shape
  52. m_b, m_h, m_w = mask.shape
  53. sch = random.uniform(0.5,0.99)
  54. scw = random.uniform(0.5,0.99)
  55. ih = int(i_h * sch)
  56. iw = int(i_w * scw)
  57. image = transform.resize(image, (i_b, ih, iw))
  58. mask = transform.resize(mask, (m_b, ih, iw))
  59. rh = random.randrange(0,i_h-ih,1)
  60. rw = random.randrange(0,i_w-iw,1)
  61. image_ = np.zeros((i_b, i_h, i_w))
  62. mask_ = np.zeros((i_b, i_h, i_w))
  63. image_[:,rh:(rh+ih),rw:(rw+iw)] = image
  64. mask_[:,rh:(rh+ih),rw:(rw+iw)] = mask
  65. return image_, mask_
  66. def scale(image, mask):
  67. if random.random()>0.5:
  68. image, mask = scale_down(image, mask)
  69. if random.random()<0.5:
  70. image, mask = scale_up(image, mask)
  71. return image, mask
  72. def augmentation(img, label):
  73. img, label = scale(img, label)
  74. img, label = flip(img, label)
  75. img, label = rotate(img, label)
  76. return img, label

3、使用batchnorm,增大batchsize

增大batchsize可以使用SyncBatchNorm

在多卡下使用sync可以增大计算batchnrom时的size,以DDP为例

  1. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  2. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],output_device=args.local_rank,find_unused_parameters=True)

4、加入Dropout

dropout和batchsize一起使用会出现问题,eval()远低于train()下效果

需要使用均匀分布dropout,才能和batchnorm搭配使用,详情见我的另外一篇总结Dropout、高斯Dropout、均匀分布Dropout(Uout)_天明月落的博客-CSDN博客

5、加入正则化

正则化分为L1和L2,就是将模型参数量也考虑到优化里面,趋向于最简单的模型

L2正则化最简单,直接在定义网络optimizer的时候对weight_decay参数赋值,例如

optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.1)

值越大考虑的程度越深

L1正则化需要自己写

  1. class Regularization(torch.nn.Module):
  2. def __init__(self,model,weight_decay,p=1):
  3. '''
  4. :param model 模型
  5. :param weight_decay:正则化参数
  6. :param p: 当p=2为L2正则化,p=1为L1正则化
  7. '''
  8. super(Regularization, self).__init__()
  9. if weight_decay <= 0:
  10. print("param weight_decay can not <=0")
  11. exit(0)
  12. self.model=model
  13. self.weight_decay=weight_decay
  14. self.p=p
  15. self.weight_list=self.get_weight(model)
  16. #self.weight_info(self.weight_list)
  17. def to(self,device):
  18. '''
  19. 指定运行模式
  20. :param device: cude or cpu
  21. :return:
  22. '''
  23. self.device=device
  24. super().to(device)
  25. return self
  26. def forward(self, model):
  27. self.weight_list=self.get_weight(model)#获得最新的权重
  28. reg_loss = self.regularization_loss(self.weight_list, self.weight_decay, p=self.p)
  29. return reg_loss
  30. def get_weight(self,model):
  31. '''
  32. 获得模型的权重列表
  33. :param model:
  34. :return:
  35. '''
  36. weight_list = []
  37. for name, param in model.named_parameters():
  38. if 'weight' in name:
  39. weight = (name, param)
  40. weight_list.append(weight)
  41. return weight_list
  42. def regularization_loss(self,weight_list, weight_decay, p=1):
  43. '''
  44. 计算张量范数
  45. :param weight_list:
  46. :param p: 范数计算中的幂指数值,默认求2范数
  47. :param weight_decay:
  48. :return:
  49. '''
  50. # weight_decay=Variable(torch.FloatTensor([weight_decay]).to(self.device),requires_grad=True)
  51. # reg_loss=Variable(torch.FloatTensor([0.]).to(self.device),requires_grad=True)
  52. # weight_decay=torch.FloatTensor([weight_decay]).to(self.device)
  53. # reg_loss=torch.FloatTensor([0.]).to(self.device)
  54. reg_loss=0
  55. for name, w in weight_list:
  56. l2_reg = torch.norm(w, p=p)
  57. reg_loss = reg_loss + l2_reg
  58. reg_loss=weight_decay*reg_loss
  59. return reg_loss
  60. def weight_info(self,weight_list):
  61. '''
  62. 打印权重列表信息
  63. :param weight_list:
  64. :return:
  65. '''
  66. print("---------------regularization weight---------------")
  67. for name ,w in weight_list:
  68. print(name)
  69. print("---------------------------------------------------")
  1. if args.L1decay>0:
  2. reg_loss=Regularization(model, args.L1decay, p=1).to(device)
  3. else:
  4. print("no regularization")
  1. loss = fuse_loss(out, labels_v)
  2. loss = loss + reg_loss(model).item()
  3. loss.backward()

注释:先定义一个正则化类,改变参数可以分别去算L1、L2正则化。然后在计算loss时,把正则化loss和模型loss加起来。正则化loss可以理解为模型参数量的范数和,L1、L2对应两种范数计算方式。在反向传播的时候就会趋向于参数量小的模型达到防治过拟合的目的。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Li_阴宅/article/detail/829839
推荐阅读
相关标签
  

闽ICP备14008679号