当前位置:   article > 正文

keras剪枝-量化-推理

keras剪枝

tensorflow提供了一个优化工具tensorflow_model_optimization ,专门针对keras进行模型优化

主要可以进行剪枝、量化和权重聚类

这里主要使用前面两个

数据集使用以前的文章:mnn模型从训练-转换-预测

具体训练代码如下

注意:使用之前需要手动安装tensorflow_model_optimization,使用pip install tensorflow_model_optimization就行

  1. import tempfile
  2. import os
  3. import tensorflow as tf
  4. import numpy as np
  5. from tensorflow import keras
  6. from tensorflow.keras import layers
  7. from tensorflow.keras.models import Sequential
  8. import tensorflow_model_optimization as tfmot
  9. batch_size = 2
  10. img_height = 180
  11. img_width = 180
  12. num_classes = 5
  13. epochs = 50
  14. validation_split=0.2
  15. data_dir='flower_photos'
  16. #数据集准备
  17. train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  18. data_dir,
  19. validation_split=validation_split,
  20. subset="training",
  21. seed=123,
  22. image_size=(img_height, img_width),
  23. batch_size=batch_size)
  24. val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  25. data_dir,
  26. validation_split=validation_split,
  27. subset="validation",
  28. seed=123,
  29. image_size=(img_height, img_width),
  30. batch_size=batch_size)
  31. AUTOTUNE = tf.data.experimental.AUTOTUNE
  32. train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
  33. val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
  34. model = keras.Sequential([
  35. keras.layers.InputLayer(input_shape=(img_height, img_height,3)),
  36. keras.layers.Reshape(target_shape=(img_height, img_height, 3)),
  37. layers.Conv2D(16, 3, padding='same', activation='relu'),
  38. layers.MaxPooling2D(),
  39. layers.Conv2D(32, 3, padding='same', activation='relu'),
  40. layers.MaxPooling2D(),
  41. layers.Conv2D(64, 3, padding='same', activation='relu'),
  42. layers.MaxPooling2D(),
  43. layers.Dropout(0.2),
  44. layers.Flatten(),
  45. layers.Dense(128, activation='relu'),
  46. layers.Dense(num_classes)
  47. ])
  48. model.compile(optimizer='adam',
  49. loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  50. metrics=['accuracy'])
  51. print(model.summary())
  52. model.fit(
  53. train_ds,
  54. validation_data=val_ds,
  55. epochs=epochs
  56. )
  57. tf.keras.models.save_model(model, 'baseline_model.h5', include_optimizer=False)
  58. converter = tf.lite.TFLiteConverter.from_keras_model(model)
  59. tflite_model = converter.convert()
  60. open("baseline_model.tflite", "wb").write(tflite_model)
  61. #开始剪枝
  62. print("start pruning")
  63. prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
  64. num_images =3670
  65. end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs
  66. pruning_params = {
  67. 'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
  68. final_sparsity=0.80,
  69. begin_step=0,
  70. end_step=end_step)
  71. }
  72. model_for_pruning = prune_low_magnitude(model, **pruning_params)
  73. model_for_pruning.compile(optimizer='adam',
  74. loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  75. metrics=['accuracy'])
  76. model_for_pruning.summary()
  77. logdir = tempfile.mkdtemp()
  78. callbacks = [
  79. tfmot.sparsity.keras.UpdatePruningStep(),
  80. tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
  81. ]
  82. model_for_pruning.fit(train_ds,
  83. batch_size=batch_size, epochs=5, validation_data=val_ds,
  84. callbacks=callbacks)
  85. model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
  86. #开始量化
  87. print("start quantize")
  88. quantize_model = tfmot.quantization.keras.quantize_model
  89. q_aware_model = quantize_model(model_for_export)
  90. q_aware_model.compile(optimizer='adam',
  91. loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  92. metrics=['accuracy'])
  93. q_aware_model.summary()
  94. q_aware_model.fit(train_ds,
  95. batch_size=batch_size, epochs=5, validation_data=val_ds)
  96. converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
  97. converter.optimizations = [tf.lite.Optimize.DEFAULT]
  98. quantized_and_pruned_tflite_model = converter.convert()
  99. quantized_and_pruned_tflite_file ='pruned_and_quantized.tflite'
  100. with open(quantized_and_pruned_tflite_file, 'wb') as f:
  101. f.write(quantized_and_pruned_tflite_model)

运行结束后,我们看一下模型文件:

可以看到文件确实压缩了不少了,4倍左右

接下来试一下,推理速度

优化过的模型文件推理如下

  1. import tensorflow as tf
  2. import cv2
  3. import numpy as np
  4. import time
  5. start=time.time()
  6. image = cv2.imread('397.jpg')
  7. image=cv2.resize(image,(180,180))
  8. image=image[np.newaxis,:,:,:].astype(np.float32)
  9. print(image.shape)
  10. interpreter = tf.lite.Interpreter(model_path='pruned_and_quantized.tflite')
  11. interpreter.allocate_tensors()
  12. input_details = interpreter.get_input_details()
  13. output_details = interpreter.get_output_details()
  14. for _ in range(10):
  15. interpreter.set_tensor(input_details[0]['index'],image)
  16. interpreter.invoke()
  17. output_data = interpreter.get_tensor(output_details[0]['index'])
  18. print(output_data)
  19. print('avg infer time is %.6f s'%((time.time()-start)/10.0))

 运行结果:

原始模型推理:

  1. import tensorflow as tf
  2. import cv2
  3. import numpy as np
  4. import time
  5. start=time.time()
  6. image = cv2.imread('397.jpg')
  7. image=cv2.resize(image,(180,180))
  8. image=image[np.newaxis,:,:,:].astype(np.float32)
  9. print(image.shape)
  10. interpreter = tf.lite.Interpreter(model_path='baseline_model.tflite')
  11. interpreter.allocate_tensors()
  12. input_details = interpreter.get_input_details()
  13. output_details = interpreter.get_output_details()
  14. for _ in range(10):
  15. interpreter.set_tensor(input_details[0]['index'],image)
  16. interpreter.invoke()
  17. output_data = interpreter.get_tensor(output_details[0]['index'])
  18. print(output_data)
  19. print('avg infer time is %.6f s'%((time.time()-start)/10.0))

运行结果:

 

出乎意料的是,模型文件虽然变小了,但是速度居然还慢了,还慢了几十倍 

这里是使用的ubuntu,也试过了win10上面推理差距更大,慢更多倍

反正就是一句话,优化后居然慢了。。。。

只有是找不到tflite的benchmark工具,只能使用这样的方式进行测试,也许时间不靠谱 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/190792
推荐阅读
相关标签
  

闽ICP备14008679号