当前位置:   article > 正文

用PyTorch实现基于注意机制的小样本故障诊断模型的训练与验证_基于注意力机制的小样本故障诊断

基于注意力机制的小样本故障诊断

在本文中,我们将介绍如何使用PyTorch实现一个深度学习模型的训练和验证,并绘制验证集的准确率和损失曲线。我们将使用一个自定义的神经网络模型来进行实验。

导入相关库

首先,我们需要导入本项目中所需的相关库:

  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. from datasave import train_loader, test_loader
  5. from early_stopping import EarlyStopping
  6. from label_smoothing import LSR
  7. from oneD_Meta_ACON import MetaAconC
  8. import time
  9. import matplotlib.pyplot as plt
  10. from torchsummary import summary
  11. from adabn import reset_bn, fix_bn

设置随机数种子

为了保证实验的可复现性,我们需要设置一个固定的随机数种子:

  1. def setup_seed(seed):
  2. torch.manual_seed(seed)
  3. torch.cuda.manual_seed_all(seed)
  4. np.random.seed(seed)
  5. torch.backends.cudnn.deterministic = True
  6. setup_seed(2)

定义神经网络模型

接下来,我们定义一个自定义的神经网络模型Net

  1. class Net(nn.Module):
  2. # 省略详细的网络结构代码...

实例化模型并查看网络结构

我们实例化上面定义的神经网络模型,并使用torchsummary库查看网络结构:

  1. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  2. model = Net().to(device)
  3. summary(model, input_size=(1, 1024))

定义损失函数与优化器

接下来,我们使用自定义的标签平滑损失函数LSR,并根据模型参数创建一个自定义的优化器AdamP

  1. criterion = LSR()
  2. # 省略对模型参数的处理代码...
  3. optimizer = AdamP(model.parameters(), lr=0.001, weight_decay=0.0001, nesterov=True, amsgrad=True)

初始化变量

我们需要初始化一些变量来保存训练和验证的损失值和准确率:

  1. losses = []
  2. acces = []
  3. eval_losses = []
  4. eval_acces = []
  5. early_stopping = EarlyStopping(patience=10, verbose=True)

开始训练与验证

接下来,我们使用for循环进行神经网络模型的训练和验证,并将损失值和准确率存储到相应的变量中:

  1. starttime = time.time()
  2. for epoch in range(150):
  3. # 省略具体的训练与验证代码...
  4. print('epoch: {}, Train Loss: {:.4f}, Train Acc: {:.4f}, Test Loss: {:.4f}, Test Acc: {:.4f}'
  5. .format(epoch, train_loss / len
  1. 在训练过程中,我们使用了早停策略(Early Stopping)来防止模型过拟合。如果验证集损失在连续10个周期内没有改善,我们将停止训练并保存最佳模型。
  2. #### 保存模型参数
  3. 训练完成后,我们将模型参数保存到一个文件中:
  4. ```python
  5. torch.save(model.state_dict(), 'B0503_LSTM.pt')

绘制验证集准确率和损失曲线

我们将使用matplotlib库绘制验证集的准确率和损失曲线:

  1. import pandas as pd
  2. pd.set_option('display.max_columns', None)
  3. pd.set_option('display.max_rows', None)
  4. plt.plot(eval_acces, label='Validation Accuracy', linestyle='-', marker='o')
  5. plt.plot(eval_losses, label='Validation Loss', linestyle='--', marker='o')
  6. plt.title('Validation Accuracy and Loss')
  7. plt.xlabel('Epochs')
  8. plt.ylabel('Accuracy / Loss')
  9. plt.legend()
  10. plt.show()

至此,我们已经成功实现了一个使用PyTorch的故障诊断模型的训练与验证过程,并绘制了验证集准确率和损失曲线。可以根据实际需求修改网络结构和训练参数以达到更好的性能。

【全部代码及数据集,可直接运行】

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/263707
推荐阅读
相关标签
  

闽ICP备14008679号