当前位置:   article > 正文

深度学习中的早停法

深度学习中的早停法

早停法(Early Stopping)是一种用于防止模型过拟合的技术,在训练过程中监视验证集(或者测试集)上的损失值。具体设立早停的限制包括两个主要参数:

  1. Patience(耐心):这是指验证集损失在连续多少个epoch没有显著改善时,才触发早停。当验证集损失连续几个epoch没有下降或者停止减少时,表示模型可能已经过拟合或者陷入局部最优点,这时候早停就会被触发。

  2. Best Loss(最佳损失):这是指在早停过程中保存的最低验证集损失值。当验证集损失值低于当前最佳损失时,更新最佳损失并重置耐心计数器。如果验证集损失连续不降,耐心计数器超过设定的耐心值时,早停就会被触发,训练过程停止。

    早停的具体设立是基于验证集上的损失值 val_loss。每次验证后,如果当前的 val_lossbest_loss 还要低,就更新 best_loss 并重置 patience_counter;否则,增加 patience_counter。当 patience_counter 达到设定的 patience 值时,早停被触发,即停止训练过程以防止模型过拟合。

    总结来说,早停的设立限制是基于耐心参数和最佳损失值,用来判断模型是否应该停止训练以避免过拟合。

  1. # 训练模型
  2. num_epochs = 200 # 总的训练轮数
  3. best_loss = float('inf') # 初始化最佳验证损失为正无穷大
  4. patience = 10 # 早停的耐心值
  5. patience_counter = 0 # 耐心计数器
  6. for epoch in range(num_epochs):
  7. model.train()
  8. for geno, pheno in train_loader:
  9. optimizer.zero_grad() # 梯度清零
  10. outputs = model(geno) # 前向传播
  11. loss = criterion(outputs.squeeze(), pheno) # 计算损失
  12. loss.backward() # 反向传播
  13. optimizer.step() # 优化模型参数
  14. model.eval()
  15. val_loss = 0
  16. with torch.no_grad(): # 不计算梯度
  17. for geno, pheno in test_loader:
  18. outputs = model(geno) # 前向传播
  19. val_loss += criterion(outputs.squeeze(), pheno).item() # 计算验证损失
  20. val_loss /= len(test_loader) # 计算平均验证损失
  21. print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}')
  22. scheduler.step(val_loss) # 更新学习率
  23. # 早停法
  24. if val_loss < best_loss:
  25. best_loss = val_loss # 更新最佳验证损失
  26. patience_counter = 0 # 重置耐心计数器
  27. else:
  28. patience_counter += 1 # 增加耐心计数器
  29. if patience_counter >= patience: # 如果耐心计数器达到设定的耐心值
  30. print("Early stopping triggered") # 触发早停
  31. break
  1. EarlyStopping
    • __init__ 方法初始化早停的参数,如 patience(耐心值)、verbose(是否打印消息)和 delta(损失改进的最小变化)。
    • __call__ 方法根据验证损失来决定是否更新 best_loss,以及是否增加计数器或者触发早停。
  2. 训练循环
    • 训练和验证过程与之前相同。
    • 每个epoch结束时,调用 early_stopping 对象,传入当前的验证损失。
    • 检查 early_stopping.early_stop 标志,如果为 True,则打印消息并停止训练。

通过使用 EarlyStopping 类,你可以更简洁和模块化地实现早停功能,使代码更易于维护和扩展。

  1. import torch
  2. import numpy as np
  3. class EarlyStopping:
  4. def __init__(self, patience=10, verbose=False, delta=0):
  5. """
  6. EarlyStopping 初始化.
  7. Args:
  8. patience (int): 当验证集损失在指定的epoch数内没有减少时触发早停.
  9. verbose (bool): 如果为True,则每次验证集损失改进时会打印一条消息.
  10. delta (float): 验证集损失改进的最小变化.
  11. """
  12. self.patience = patience
  13. self.verbose = verbose
  14. self.delta = delta
  15. self.best_loss = None
  16. self.counter = 0
  17. self.early_stop = False
  18. def __call__(self, val_loss):
  19. if self.best_loss is None:
  20. self.best_loss = val_loss
  21. elif val_loss > self.best_loss - self.delta:
  22. self.counter += 1
  23. if self.verbose:
  24. print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
  25. if self.counter >= self.patience:
  26. self.early_stop = True
  27. else:
  28. self.best_loss = val_loss
  29. self.counter = 0
  30. if self.verbose:
  31. print(f'Validation loss decreased to {self.best_loss:.6f}. Resetting counter.')
  32. # 初始化EarlyStopping对象
  33. early_stopping = EarlyStopping(patience=10, verbose=True)
  34. # 训练模型
  35. num_epochs = 200
  36. for epoch in range(num_epochs):
  37. model.train()
  38. for geno, pheno in train_loader:
  39. optimizer.zero_grad()
  40. outputs = model(geno)
  41. loss = criterion(outputs.squeeze(), pheno)
  42. loss.backward()
  43. optimizer.step()
  44. model.eval()
  45. val_loss = 0
  46. with torch.no_grad():
  47. for geno, pheno in test_loader:
  48. outputs = model(geno)
  49. val_loss += criterion(outputs.squeeze(), pheno).item()
  50. val_loss /= len(test_loader)
  51. print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}')
  52. scheduler.step(val_loss)
  53. # 检查是否触发早停
  54. early_stopping(val_loss)
  55. if early_stopping.early_stop:
  56. print("Early stopping triggered")
  57. break

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/人工智能uu/article/detail/936909
推荐阅读
相关标签
  

闽ICP备14008679号