当前位置:   article > 正文

Python电能质量扰动信号分类(三)基于Transformer的一维信号分类模型_transformer对信号进行分类

transformer对信号进行分类

目录

引言

1 数据集制作与加载

1.1 导入数据

1.2 制作数据集

2 Transformer分类模型和超参数选取

2.1 定义Transformer分类模型

2.2 定义模型参数

3 Transformer模型训练与评估

3.1 模型训练

3.2 模型评估

代码、数据如下:


往期精彩内容:

电能质量扰动信号数据介绍与分类-Python实现-CSDN博客

Python电能质量扰动信号分类(一)基于LSTM模型的一维信号分类-CSDN博客

Python电能质量扰动信号分类(二)基于CNN模型的一维信号分类-CSDN博客

引言

本文基于Python仿真的电能质量扰动信号,先经过数据预处理进行数据集的制作和加载,然后通过Pytorch实现Transformer模型对扰动信号的分类。Python仿真电能质量扰动信号的详细介绍可以参考下文(文末附10分类数据集):

电能质量扰动信号数据介绍与分类-Python实现-CSDN博客

部分扰动信号类型波形图如下所示:

1 数据集制作与加载

1.1 导入数据

在参考IEEE Std1159-2019电能质量检测标准与相关文献的基础上构建了扰动信号的模型,生成包括正常信号在内的10中单一信号和多种复合扰动信号。参考之前的文章,进行扰动信号10分类的预处理:

第一步,按照公式模型生成单一信号

单一扰动信号可视化:

第二步,导入十分类数据

  1. import pandas as pd
  2. import numpy as np
  3. # 样本时长0.2s 样本步长1024 每个信号生成500个样本 噪声0DB
  4. window_step = 1024
  5. samples = 500
  6. noise = 0
  7. split_rate = [0.7, 0.2, 0.1] # 训练集、验证集、测试集划分比例
  8. # 读取已处理的 CSV 文件
  9. dataframe_10c = pd.read_csv('PDQ_10c_Clasiffy_data.csv' )
  10. dataframe_10c.shape

1.2 制作数据集

第一步,定义制作数据集函数

第二步,制作数据集与分类标签

  1. from joblib import dump, load
  2. # 生成数据
  3. train_dataframe, val_dataframe, test_dataframe = make_data(dataframe_10c, split_rate)
  4. # 制作标签
  5. train_xdata, train_ylabel = make_data_labels(train_dataframe)
  6. val_xdata, val_ylabel = make_data_labels(val_dataframe)
  7. test_xdata, test_ylabel = make_data_labels(test_dataframe)
  8. # 保存数据
  9. dump(train_xdata, 'TrainX_1024_0DB_10c')
  10. dump(val_xdata, 'ValX_1024_0DB_10c')
  11. dump(test_xdata, 'TestX_1024_0DB_10c')
  12. dump(train_ylabel, 'TrainY_1024_0DB_10c')
  13. dump(val_ylabel, 'ValY_1024_0DB_10c')
  14. dump(test_ylabel, 'TestY_1024_0DB_10c')

2 Transformer分类模型和超参数选取

2.1 定义Transformer分类模型

注意:输入数据进行了堆叠 ,把一个1*1024 的序列 进行划分堆叠成形状为 32 * 32, 就使输入序列的长度降下来了。

2.2 定义模型参数

  1. # 模型参数
  2. input_dim = 32 # 输入维度
  3. hidden_dim = 512 # 注意力维度
  4. output_dim = 10 # 输出维度
  5. num_layers = 4 # 编码器层数
  6. num_heads = 8 # 多头注意力头数
  7. batch_size = 64
  8. # 模型
  9. model = TransformerModel(input_dim, output_dim, hidden_dim, num_layers, num_heads, batch_size)
  10. model = model.to(device)
  11. loss_function = nn.CrossEntropyLoss(reduction='sum') # loss
  12. learn_rate = 0.0003
  13. optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate) # 优化器

3 Transformer模型训练与评估

3.1 模型训练

训练结果

100个epoch,准确率将近90%,Transformer模型分类效果良好,参数过拟合了,适当调整模型参数,降低模型复杂度,还可以进一步提高分类准确率。

注意调整参数:

  • 可以适当增加 Transformer层数和隐藏层维度数,微调学习率;

  • 增加更多的 epoch (注意防止过拟合)

  • 可以改变一维信号堆叠的形状(设置合适的长度和维度)

3.2 模型评估

  1. # 模型 测试集 验证
  2. import torch.nn.functional as F
  3. device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 有GPU先用GPU训练
  4. # 加载模型
  5. model =torch.load('best_model_transformer.pt')
  6. # 将模型设置为评估模式
  7. model.eval()
  8. # 使用测试集数据进行推断
  9. with torch.no_grad():
  10. correct_test = 0
  11. test_loss = 0
  12. for test_data, test_label in test_loader:
  13. test_data, test_label = test_data.to(device), test_label.to(device)
  14. test_output = model(test_data)
  15. probabilities = F.softmax(test_output, dim=1)
  16. predicted_labels = torch.argmax(probabilities, dim=1)
  17. correct_test += (predicted_labels == test_label).sum().item()
  18. loss = loss_function(test_output, test_label)
  19. test_loss += loss.item()
  20. test_accuracy = correct_test / len(test_loader.dataset)
  21. test_loss = test_loss / len(test_loader.dataset)
  22. print(f'Test Accuracy: {test_accuracy:4.4f} Test Loss: {test_loss:10.8f}')
  23. Test Accuracy: 0.9070 Test Loss: 0.22114271

代码、数据如下:

对数据集和代码感兴趣的,可以关注最后一行

  1. # 加载数据
  2. import torch
  3. from joblib import dump, load
  4. import torch.utils.data as Data
  5. import numpy as np
  6. import pandas as pd
  7. import torch
  8. import torch.nn as nn
  9. # 参数与配置
  10. torch.manual_seed(100) # 设置随机种子,以使实验结果具有可重复性
  11. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  12. #代码和数据集:https://mbd.pub/o/bread/ZZiZmphq

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/887016
推荐阅读
相关标签
  

闽ICP备14008679号