赞
踩
早停法(Early Stopping)是一种用于防止模型过拟合的技术,在训练过程中监视验证集(或者测试集)上的损失值。具体设立早停的限制包括两个主要参数:
Patience(耐心):这是指验证集损失在连续多少个epoch没有显著改善时,才触发早停。当验证集损失连续几个epoch没有下降或者停止减少时,表示模型可能已经过拟合或者陷入局部最优点,这时候早停就会被触发。
Best Loss(最佳损失):这是指在早停过程中保存的最低验证集损失值。当验证集损失值低于当前最佳损失时,更新最佳损失并重置耐心计数器。如果验证集损失连续不降,耐心计数器超过设定的耐心值时,早停就会被触发,训练过程停止。
早停的具体设立是基于验证集上的损失值 val_loss
。每次验证后,如果当前的 val_loss
比 best_loss
还要低,就更新 best_loss
并重置 patience_counter
;否则,增加 patience_counter
。当 patience_counter
达到设定的 patience
值时,早停被触发,即停止训练过程以防止模型过拟合。
总结来说,早停的设立限制是基于耐心参数和最佳损失值,用来判断模型是否应该停止训练以避免过拟合。
- # 训练模型
- num_epochs = 200 # 总的训练轮数
- best_loss = float('inf') # 初始化最佳验证损失为正无穷大
- patience = 10 # 早停的耐心值
- patience_counter = 0 # 耐心计数器
-
- for epoch in range(num_epochs):
- model.train()
- for geno, pheno in train_loader:
- optimizer.zero_grad() # 梯度清零
- outputs = model(geno) # 前向传播
- loss = criterion(outputs.squeeze(), pheno) # 计算损失
- loss.backward() # 反向传播
- optimizer.step() # 优化模型参数
-
- model.eval()
- val_loss = 0
- with torch.no_grad(): # 不计算梯度
- for geno, pheno in test_loader:
- outputs = model(geno) # 前向传播
- val_loss += criterion(outputs.squeeze(), pheno).item() # 计算验证损失
- val_loss /= len(test_loader) # 计算平均验证损失
- print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}')
-
- scheduler.step(val_loss) # 更新学习率
-
- # 早停法
- if val_loss < best_loss:
- best_loss = val_loss # 更新最佳验证损失
- patience_counter = 0 # 重置耐心计数器
- else:
- patience_counter += 1 # 增加耐心计数器
- if patience_counter >= patience: # 如果耐心计数器达到设定的耐心值
- print("Early stopping triggered") # 触发早停
- break
EarlyStopping
类:
__init__
方法初始化早停的参数,如 patience
(耐心值)、verbose
(是否打印消息)和 delta
(损失改进的最小变化)。__call__
方法根据验证损失来决定是否更新 best_loss
,以及是否增加计数器或者触发早停。early_stopping
对象,传入当前的验证损失。early_stopping.early_stop
标志,如果为 True
,则打印消息并停止训练。通过使用 EarlyStopping
类,你可以更简洁和模块化地实现早停功能,使代码更易于维护和扩展。
- import torch
- import numpy as np
-
- class EarlyStopping:
- def __init__(self, patience=10, verbose=False, delta=0):
- """
- EarlyStopping 初始化.
- Args:
- patience (int): 当验证集损失在指定的epoch数内没有减少时触发早停.
- verbose (bool): 如果为True,则每次验证集损失改进时会打印一条消息.
- delta (float): 验证集损失改进的最小变化.
- """
- self.patience = patience
- self.verbose = verbose
- self.delta = delta
- self.best_loss = None
- self.counter = 0
- self.early_stop = False
-
- def __call__(self, val_loss):
- if self.best_loss is None:
- self.best_loss = val_loss
- elif val_loss > self.best_loss - self.delta:
- self.counter += 1
- if self.verbose:
- print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
- if self.counter >= self.patience:
- self.early_stop = True
- else:
- self.best_loss = val_loss
- self.counter = 0
- if self.verbose:
- print(f'Validation loss decreased to {self.best_loss:.6f}. Resetting counter.')
-
- # 初始化EarlyStopping对象
- early_stopping = EarlyStopping(patience=10, verbose=True)
-
- # 训练模型
- num_epochs = 200
- for epoch in range(num_epochs):
- model.train()
- for geno, pheno in train_loader:
- optimizer.zero_grad()
- outputs = model(geno)
- loss = criterion(outputs.squeeze(), pheno)
- loss.backward()
- optimizer.step()
-
- model.eval()
- val_loss = 0
- with torch.no_grad():
- for geno, pheno in test_loader:
- outputs = model(geno)
- val_loss += criterion(outputs.squeeze(), pheno).item()
- val_loss /= len(test_loader)
- print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}')
-
- scheduler.step(val_loss)
-
- # 检查是否触发早停
- early_stopping(val_loss)
- if early_stopping.early_stop:
- print("Early stopping triggered")
- break
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。