当前位置:   article > 正文

用python读取mat格式的元胞数组(图片训练集),训练一个卷积神经网络二分类模型_python读取matlab元胞数组

python读取matlab元胞数组

应用背景

matlab在图像处理上有着非常大的优势,而在某些情况下预处理完的数据需要在python上训练。

在数据导入的环节往往会出现许多问题。

数据来源

 训练数据集是我自己拍摄的红外三通道图像,经预处理后为131张单通道的规模为1430*1072的图像,保存在train.mat中,标签名为B

测试集是后60张图片,保存在test.mat。

训练标签是保存在Excel里的分类目标

 python代码(数据读入):

  1. from scipy.io import loadmat
  2. import pandas as pd
  3. import numpy
  4. train_labels = pd.read_excel('训练.xlsx')
  5. print(train_labels)
  6. test_labels = pd.read_excel('测试.xlsx')
  7. train_labels = train_labels.to_numpy()
  8. test_labels = test_labels.to_numpy()
  9. path = r"train.mat" # mat文件路径
  10. data1 = loadmat(path) # 读取mat文件
  11. path = r"test.mat" # mat文件路径
  12. data2 = loadmat(path)
  13. #print(data1.keys()) # 查看mat文件中包含的变量
  14. #Out:
  15. #dict_keys(['__header__', '__version__', '__globals__', 'A', 'C', 'n', 's'])
  16. train_images = data1['B']
  17. train_images = numpy.stack(train_images[:, 0], axis=0)
  18. train_images = numpy.expand_dims(train_images, axis=-1)
  19. test_images = data2['C']
  20. test_images = numpy.stack(test_images[:, 0], axis=0)
  21. test_images = numpy.expand_dims(test_images, axis=-1)

解释:

        xlsx文件标签集的读取选择使用pandas库,读入后的数据结构为dataframe,为了格式统一将其转换成数组形式;

  1. train_labels = train_labels.to_numpy()
  2. test_labels = test_labels.to_numpy()

         mat文件的读取使用scipy.io中的loadmat,他会返回一个字典,指定对应的标签可获得一个数组。但是这里我希望的数组应该是一个三维图片数组(131,维度,维度),而返回的仅仅为(131,1),因此需要对其中的数据堆叠成能够适用于模型的大小;

  1. train_images = data1['B']
  2. train_images = numpy.stack(train_images[:, 0], axis=0)
  3. train_images = numpy.expand_dims(train_images, axis=-1)

python代码(搭建神经网络):

  1. # 构建卷积神经网络模型
  2. model = models.Sequential()
  3. model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(1430,1072,1)))
  4. model.add(layers.MaxPooling2D((2, 2)))
  5. model.add(layers.Conv2D(64, (3, 3), activation='relu'))
  6. # 添加全连接层
  7. model.add(layers.Flatten())
  8. model.add(layers.Dense(1, activation='sigmoid'))
  9. # 编译模型
  10. model.compile(optimizer='adam',
  11. loss=tf.keras.losses.BinaryCrossentropy(),
  12. metrics=['accuracy'])
  13. # 训练模型
  14. history = model.fit(train_images, train_labels, epochs=10, batch_size=10,
  15. validation_data=(test_images, test_labels))
  16. predictions = model.predict(test_images)
  17. predictions = numpy.round(predictions).astype(int)
  18. print(predictions)

一个简单的卷积神经网络代码。

全部代码如下:

  1. import tensorflow as tf
  2. from tensorflow.keras import datasets, layers, models
  3. from scipy.io import loadmat
  4. import pandas as pd
  5. import numpy
  6. train_labels = pd.read_excel('训练.xlsx')
  7. print(train_labels)
  8. test_labels = pd.read_excel('测试.xlsx')
  9. train_labels = train_labels.to_numpy()
  10. test_labels = test_labels.to_numpy()
  11. path = r"train.mat" # mat文件路径
  12. data1 = loadmat(path) # 读取mat文件
  13. path = r"test.mat" # mat文件路径
  14. data2 = loadmat(path)
  15. #print(data1.keys()) # 查看mat文件中包含的变量
  16. #Out:
  17. #dict_keys(['__header__', '__version__', '__globals__', 'A', 'C', 'n', 's'])
  18. train_images = data1['B']
  19. train_images = numpy.stack(train_images[:, 0], axis=0)
  20. train_images = numpy.expand_dims(train_images, axis=-1)
  21. test_images = data2['C']
  22. test_images = numpy.stack(test_images[:, 0], axis=0)
  23. test_images = numpy.expand_dims(test_images, axis=-1)
  24. # 构建卷积神经网络模型
  25. model = models.Sequential()
  26. model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(1430,1072,1)))
  27. model.add(layers.MaxPooling2D((2, 2)))
  28. model.add(layers.Conv2D(64, (3, 3), activation='relu'))
  29. # 添加全连接层
  30. model.add(layers.Flatten())
  31. model.add(layers.Dense(1, activation='sigmoid'))
  32. # 编译模型
  33. model.compile(optimizer='adam',
  34. loss=tf.keras.losses.BinaryCrossentropy(),
  35. metrics=['accuracy'])
  36. # 训练模型
  37. history = model.fit(train_images, train_labels, epochs=10, batch_size=10,
  38. validation_data=(test_images, test_labels))
  39. predictions = model.predict(test_images)
  40. predictions = numpy.round(predictions).astype(int)
  41. print(predictions)

结果:

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/509473
推荐阅读
相关标签
  

闽ICP备14008679号