当前位置:   article > 正文

python深度学习第七章_input shape(64,999)

input shape(64,999)

7.1第二种模型表示方式

  1. # coding=utf-8
  2. """
  3. __project_ = 'Python深度学习'
  4. __file_name__ = '7.1函数式API'
  5. __author__ = 'WIN10'
  6. __time__ = '2020/4/15 21:20'
  7. __product_name = PyCharm
  8. """
  9. from keras import Input, layers
  10. from keras.models import Model
  11. input_tensor = Input(shape=(64,))
  12. x = layers.Dense(32, activation='relu')(input_tensor)
  13. x = layers.Dense(32, activation='relu')(x)
  14. output_tensor = layers.Dense(10, activation='softmax')(x)
  15. model = Model(input_tensor, output_tensor)
  16. model.summary()
  17. # 双输入模型
  18. text_input = Input(shape=(None,), dtype='int32', name='text')
  19. embedded_text = layers.Embedding(10000, 64)(text_input)
  20. encoded_text = layers.LSTM(32)(embedded_text)
  21. question_input = Input(shape=(None,), dtype='int32', name='question')
  22. embedded_question = layers.Embedding(10000, 32)(question_input)
  23. encoded_question = layers.LSTM(16)(embedded_question)
  24. concatenated = layers.concatenate([encoded_text, encoded_question], axis=-1)
  25. answer = layers.Dense(500, activation='softmax')(concatenated)
  26. model = Model([text_input, question_input], answer)
  27. model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics='acc')
  28. # 多输出模型
  29. posts_input = Input(shape=(None,), dtype='int32', name='posts')
  30. embedded_posts = layers.Embedding(256, 50000)(posts_input)
  31. x = layers.Conv1D(128, 5, activation='relu')(embedded_posts)
  32. x = layers.MaxPooling1D(5)(x)
  33. x = layers.Conv1D(256, 5, activation='relu')(x)
  34. x = layers.Conv1D(256, 5, activation='relu')(x)
  35. x = layers.MaxPooling1D(5)(x)
  36. x = layers.Conv1D(256, 5, activation='relu')(x)
  37. x = layers.Conv1D(256, 5, activation='relu')(x)
  38. x = layers.GlobalMaxPool1D(x)
  39. x = layers.Dense(128, activation='relu')(x)
  40. # 输出层,必须有名字
  41. age_prediction = layers.Dense(1, name='age')(x)
  42. income_prediction = layers.Dense(10, activation='softmax', name='income')(x)
  43. gender_prediction = layers.Dense(1, activation='sigmoid', name='gender')(x)
  44. model = Model(posts_input, [age_prediction, income_prediction, gender_prediction])
  45. model.compile(optimizer='rmsprop',
  46. loss={'age': 'mse',
  47. 'income': 'categorical_crossentropy',
  48. 'gender': 'binary_crossentropy'
  49. },
  50. loss_weights={'age': 0.25,
  51. 'income': 1.,
  52. 'gender': 10.
  53. }
  54. )

7.2有向无环图

  1. # coding=utf-8
  2. """
  3. __project_ = 'Python深度学习'
  4. __file_name__ = '7.2有向无环图'
  5. __author__ = 'WIN10'
  6. __time__ = '2020/4/15 22:38'
  7. __product_name = PyCharm
  8. """
  9. from keras import layers,Input
  10. #inception模块
  11. x= Input(shape=(64,))
  12. branch_a=layers.Conv2D(128,1,activation='relu',strides=2)(x)
  13. branch_b=layers.Conv2D(128,1,activation='relu')(x)
  14. branch_b=layers.Conv2D(128,3,activation='relu',strides=2)(branch_b)
  15. branch_c=layers.AveragePooling2D(3,strides=2)(x)
  16. branch_c=layers.Conv2D(128,3,activation='relu',strides=2)(branch_c)
  17. branch_d=layers.Conv2D(128,1,activation='relu')(x)
  18. branch_b=layers.Conv2D(128,3,activation='relu')(branch_d)
  19. branch_b=layers.Conv2D(128,3,activation='relu',strides=2)(branch_d)
  20. output=layers.concatenate([branch_a,branch_b,branch_c,branch_d],axis=-1)
  21. #resnet 模块
  22. x= Input(shape=(64,))
  23. y=layers.Conv2D(128,3,activation='relu',padding='same')(x)
  24. y=layers.Conv2D(128,3,activation='relu',padding='same')(y)
  25. y=layers.Conv2D(128,3,activation='relu',padding='same')(y)
  26. y=layers.add([y,x])
  27. #resnet 模块 特征尺寸不同,使用1*1 下采样
  28. x= Input(shape=(64,))
  29. y=layers.Conv2D(128,3,activation='relu',padding='same')(x)
  30. y=layers.Conv2D(128,3,activation='relu',padding='same')(y)
  31. y=layers.Conv2D(128,3,activation='relu',padding='same')(y)
  32. residual=layers.Conv2D(128,1,strides=2,padding='same')(x)
  33. y=layers.add([y,residual])

7.3回调函数

  1. # coding=utf-8
  2. """
  3. __project_ = 'Python深度学习'
  4. __file_name__ = '7.3回调函数'
  5. __author__ = 'WIN10'
  6. __time__ = '2020/4/15 22:54'
  7. __product_name = PyCharm
  8. """
  9. import keras
  10. x,y,x_val,y_val=[]
  11. callbacks_list=[
  12. #监控验证精度,如果多余一轮时间精度不改善,就中断训练
  13. keras.callbacks.EarlyStopping(
  14. monitor='acc',
  15. patience=1,
  16. ),
  17. #如果val_loss没有改善,不需要覆盖模型
  18. keras.callbacks.ModelCheckpoint(
  19. filepath='my_model.h5',
  20. monitor='val_loss',
  21. save_best_only=True
  22. )
  23. ]
  24. keras.models.Model.compile(optimizer='rmsprop',
  25. loss='binary_crossentropy',
  26. metrics=['acc']
  27. )
  28. keras.models.Model.fit(x,y,
  29. epochs=10,
  30. callbacks=callbacks_list,
  31. validation_data=(x_val,y_val))
  32. #降低学习率
  33. callbacks_list=[
  34. #监控val_loss 10轮回没改善 ,学习率除以10
  35. keras.callbacks.ReduceLROnPlateau(
  36. monitor='val_loss',
  37. factor=0.1,
  38. patience=10
  39. )
  40. ]

7.4深度可分离卷积神经网络

  1. # coding=utf-8
  2. """
  3. __project_ = 'Python深度学习'
  4. __file_name__ = '7.4深度可分离卷积神经网络'
  5. __author__ = 'WIN10'
  6. __time__ = '2020/4/15 23:17'
  7. __product_name = PyCharm
  8. """
  9. from keras.preprocessing.image import ImageDataGenerator
  10. from keras.models import Sequential
  11. from keras import layers
  12. from keras import optimizers
  13. import matplotlib.pyplot as plt
  14. def DataGen(dir_path, img_row, img_col, batch_size, is_train):
  15. if is_train:
  16. datagen = ImageDataGenerator(rescale=1. / 255,
  17. zoom_range=0.2,
  18. rotation_range=40.,
  19. shear_range=0.2,
  20. width_shift_range=0.2,
  21. height_shift_range=0.2,
  22. horizontal_flip=True,
  23. fill_mode='nearest')
  24. else:
  25. datagen = ImageDataGenerator(rescale=1. / 255)
  26. generator = datagen.flow_from_directory(
  27. dir_path, target_size=(img_row, img_col),
  28. batch_size=batch_size,
  29. # class_mode='binary',
  30. shuffle=is_train)
  31. return generator
  32. # 数据准备
  33. image_size = 65
  34. image_class=3
  35. batch_size = 128
  36. epochs=100
  37. train_image_path='G:\\DL\\MyData\\MattingImages\\train'
  38. test_image_path='G:\\DL\\MyData\\MattingImages\\val'
  39. train_generator = DataGen(train_image_path, image_size, image_size, batch_size, True)
  40. validation_generator = DataGen(test_image_path, image_size, image_size, batch_size, False)
  41. model=Sequential()
  42. model.add(layers.SeparableConv2D(32,3,activation='relu',input_shape=(65,65,3,)))
  43. model.add(layers.SeparableConv2D(64,3,activation='relu'))
  44. model.add(layers.MaxPooling2D(2))
  45. model.add(layers.SeparableConv2D(64,3,activation='relu'))
  46. model.add(layers.SeparableConv2D(128,3,activation='relu'))
  47. model.add(layers.MaxPooling2D(2))
  48. model.add(layers.SeparableConv2D(64,3,activation='relu'))
  49. model.add(layers.SeparableConv2D(128,3,activation='relu'))
  50. model.add(layers.GlobalAveragePooling2D())
  51. model.add(layers.Dense(32,activation='relu'))
  52. model.add(layers.Dense(3,activation='softmax'))
  53. #编译 需要3个参数 ,损失函数、优化器、训练和测试过程中的键控指标
  54. model.compile(optimizer=optimizers.RMSprop(lr=1e-4),
  55. loss='categorical_crossentropy',
  56. metrics=['accuracy'])
  57. #训练
  58. history=model.fit_generator(
  59. train_generator,
  60. steps_per_epoch=2,
  61. epochs=epochs,
  62. validation_data=validation_generator,
  63. validation_steps=1)
  64. model.save('my_model_split.h5')
  65. #绘制训练损失和验证损失
  66. history_dict=history.history
  67. acc=history_dict['accuracy']
  68. val_acc=history_dict['val_accuracy']
  69. loss=history_dict['loss']
  70. val_loss=history_dict['val_loss']
  71. epochs=range(1,len(acc)+1)
  72. plt.plot(epochs,acc,'bo',label='Training acc')
  73. plt.plot(epochs,val_acc,'b',label='Validation acc')
  74. plt.title('training and val acc')
  75. plt.legend()
  76. plt.figure()
  77. plt.plot(epochs,loss,'bo',label='Training loss')
  78. plt.plot(epochs,val_loss,'b',label='Validation loss')
  79. plt.title('train and val loss')
  80. plt.legend()
  81. plt.show()

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/901652
推荐阅读
相关标签
  

闽ICP备14008679号