当前位置:   article > 正文

《python深度学习》笔记(七):多分类问题_python 多特征数据分类

python 多特征数据分类

本节使用路透社数据集,它包含许多新闻及其对应的主题,由路透社在1986年发布。它是一个简单的、广泛使用的文本分类数据集。包括46个主题:某些主题的样本更多,但训练集中每个主题都至少10个样本。

因为有多个类别,所以这是多分类问题。因为每个数据点只能划分到一个类别,所以这又是单标签、多分类问题。

完整代码实现:

  1. from keras.datasets import reuters
  2. import numpy as np
  3. # 第一步:加载数据
  4. (train_data, train_labels), (test_data, test_labels) = reuters.load_data(num_words=10000)
  5. # 第二步:编码数据,输入数据
  6. def vectorize_sequences(sequences, dimension=10000): # 定义一个向量化序列函数,将所有评论都向量化成一样的维度
  7. results = np.zeros((len(sequences), dimension)) # 创建一个零矩阵,二维张量
  8. for i, sequence in enumerate(sequences): # 将一个可遍历的数据对象组合成一个索引序列。
  9. results[i, sequence] = 1.
  10. return results # 例如,序列[3,5,6]将被转化为1000维向量,只有索引为3,5,6的元素才为1,其他都为0.
  11. # 数据向量化
  12. x_train = vectorize_sequences(train_data)
  13. x_test = vectorize_sequences(test_data)
  14. def to_one_hot(labels, dimension=46): # 标签有46
  15. results = np.zeros((len(labels), dimension))
  16. for i, label in enumerate(labels):
  17. results[i, label] = 1.
  18. return results
  19. # 标签向量化
  20. one_hot_train_labels = to_one_hot(train_labels)
  21. one_hot_test_labels = to_one_hot(test_labels)
  22. from keras.utils.np_utils import to_categorical
  23. one_hot_train_labels = to_categorical(train_labels)
  24. one_hot_test_labels = to_categorical(test_labels)
  25. # 第三步:构建网络模型
  26. # 定义模型
  27. from keras import models
  28. from keras import layers
  29. model = models.Sequential()
  30. """
  31. Q: 为什么此处输入单元数要使用64,为什么不使用电影评论分类时使用的16?
  32. A:16维空间对于这个例子来说太小了,无法学会区分46个不同的类别。
  33. 这种维度较小的层可能成为信息瓶颈,永久地丢失相关信息。
  34. 如果是三分类,四分类问题你依然可以使用16个隐藏单元
  35. Q:我能不能设置为640个单元?
  36. A:单元数不是越大越好,网络容量越大,网络就越容易记住训练过的数据。
  37. 网络会在训练过的数据上表现优异,但是在没有见过的数据上的表现则不容乐观。
  38. 因此单元数不是越大越好,需要在欠拟合与过拟合之间找到一个平衡点。
  39. """
  40. model.add(layers.Dense(64, activation='relu', input_shape=(10000, )))
  41. model.add(layers.Dense(4, activation='relu'))
  42. model.add(layers.Dense(46, activation='softmax'))
  43. # 编译模型
  44. model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
  45. # 第四步:训练模型,绘制图像
  46. # 留出验证集
  47. x_val = x_train[:1000]
  48. partial_x_train = x_train[1000:]
  49. y_val = one_hot_train_labels[:1000]
  50. partial_y_train = one_hot_train_labels[1000:]
  51. # 训练模型
  52. history = model.fit(partial_x_train, partial_y_train, epochs=20, batch_size=512, validation_data=(x_val, y_val))
  53. # 绘制训练损失和验证损失
  54. import matplotlib.pyplot as plt
  55. loss = history.history['loss']
  56. val_loss = history.history['val_loss']
  57. epochs = range(1, len(loss)+1)
  58. plt.plot(epochs, loss, 'bo', label='Training loss')
  59. plt.plot(epochs, val_loss, 'b', label='Validation loss')
  60. plt.title('Training and validation loss')
  61. plt.xlabel('Epochs')
  62. plt.ylabel('Loss')
  63. plt.legend()
  64. plt.show()
  65. # 绘制训练精度和验证精度
  66. plt.clf()
  67. acc = history.history['accuracy']
  68. val_acc = history.history['val_accuracy']
  69. plt.plot(epochs, acc, 'bo', label='Training acc')
  70. plt.plot(epochs, val_acc, 'b', label='Validation acc')
  71. plt.title('Training and validation accuracy')
  72. plt.xlabel('Epochs')
  73. plt.ylabel('Accuracy')
  74. plt.legend()
  75. plt.show()
  76. # 重新训练模型
  77. model = models.Sequential()
  78. model.add(layers.Dense(64, activation='relu', input_shape=(10000, )))
  79. model.add(layers.Dense(4, activation='relu'))
  80. model.add(layers.Dense(46, activation='softmax'))
  81. model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
  82. history = model.fit(partial_x_train, partial_y_train, epochs=8, batch_size=512, validation_data=(x_val, y_val))
  83. # 观察在测试集上表现
  84. results = model.evaluate(x_test, one_hot_test_labels)
  85. print(results)
  86. # [0.9868815943054715, 0.7862867116928101]
  87. # 80%左右的精度
  88. # 采取随机预测的方式
  89. import copy
  90. test_labels_copy = copy.copy(test_labels)
  91. np.random.shuffle(test_labels_copy)
  92. hits_array = np.array(test_labels) == np.array(test_labels_copy)
  93. print(float(np.sum(hits_array)) / len(test_labels))
  94. # 0.18788958147818344
  95. # 20%的精度,可以看出模型的预测效果好得多

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小桥流水78/article/detail/901646
推荐阅读
相关标签
  

闽ICP备14008679号