当前位置:   article > 正文

Tensorflow pb模型转tflite,并量化_pb转tflite

pb转tflite

一、tensorflow2.x版本pb模型转换tflite及量化

1、h5模型转tflite,不进行量化

import tensorflow as tf
import numpy as np
from pathlib import Path
print("TensorFlow version: ", tf.__version__)

model = tf.keras.models.load_model('model.h5')

### 不量化
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
tflite_model_file = Path("mnist_model_null.tflite")
tflite_model_file.write_bytes(tflite_model)

interpreter = tf.lite.Interpreter(model_content=tflite_model)
input_type = interpreter.get_input_details()[0]['dtype']
print('input: ', input_type)
output_type = interpreter.get_output_details()[0]['dtype']
print('output: ', output_type)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

`

2、h5模型转tflite,进行动态范围量化 (官方参考代码)

import tensorflow as tf
import numpy as np
from pathlib import Path
print("TensorFlow version: ", tf.__version__)

model = tf.keras.models.load_model('model.h5')
### 动态范围量化
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model_dynamic = converter.convert()
tflite_model_file = Path("mnist_model_dynamic.tflite")
tflite_model_file.write_bytes(tflite_model_dynamic)

interpreter = tf.lite.Interpreter(model_content=tflite_model_dynamic)
input_type = interpreter.get_input_details()[0]['dtype']
print('input: ', input_type)
output_type = interpreter.get_output_details()[0]['dtype']
print('output: ', output_type)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

`

3、h5模型转tflite,进行int8整型量化 (官方参考代码)

import tensorflow as tf
import numpy as np
from pathlib import Path
print("TensorFlow version: ", tf.__version__)

model = tf.keras.models.load_model('model.h5')
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# Normalize the input image so that each pixel value is between 0 to 1.
print(type(train_images), train_images.shape)
train_images = train_images.astype(np.float32) / 255.0
def representative_data_gen():
  for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):
    yield [input_value]

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_model_int8 = converter.convert()
tflite_model_file = Path("mnist_model_int8.tflite")
tflite_model_file.write_bytes(tflite_model_int8)

interpreter = tf.lite.Interpreter(model_content=tflite_model_int8)
input_type = interpreter.get_input_details()[0]['dtype']
print('input: ', input_type)
output_type = interpreter.get_output_details()[0]['dtype']
print('output: ', output_type)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30

`

4、h5模型转tflite,进行float16量化 (官方参考代码)

import tensorflow as tf
import numpy as np
from pathlib import Path
print("TensorFlow version: ", tf.__version__)

model = tf.keras.models.load_model('model.h5')

# float16量化
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_model_float16 = converter.convert()
tflite_model_file = Path("mnist_model_float16.tflite")
tflite_model_file.write_bytes(tflite_model_float16)

interpreter = tf.lite.Interpreter(model_content=tflite_model_float16)
input_type = interpreter.get_input_details()[0]['dtype']
print('input: ', input_type)
output_type = interpreter.get_output_details()[0]['dtype']
print('output: ', output_type)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

`

二、tensorflow2.x版本调用1.x(.compat.v1)pb模型转换tflite及量化 (官方api)

1、pb模型转tflite,不进行量化

converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
        graph_def_file = '0824.pb',
        input_arrays = ['x_img_g', 'is_training'],
        input_shapes = {'x_img_g' : [1, 256, 512, 3], 'is_training' : [1]},
        output_arrays = ['encoder_generator/classifier/SINET_output/BiasAdd']
)
tflite_model = converter.convert()
open("model_null.tflite", "wb").write(tflite_model)
interpreter = tf.lite.Interpreter(model_content=tflite_model)
input = interpreter.get_input_details()
print(input)
output = interpreter.get_output_details()
print(output)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

2、pb模型转tflite,进行动态范围量化

#  动态量化
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
        graph_def_file = '0824.pb',
        input_arrays = ['x_img_g', 'is_training'],
        input_shapes = {'x_img_g' : [1, 256, 512, 3], 'is_training' : [1]},
        output_arrays = ['encoder_generator/classifier/SINET_output/BiasAdd']
)
converter.quantized_input_stats = {"x_img_g": (0., 1.), "is_training": (0., 1.)}
tflite_model = converter.convert()
open("model_dynamic.tflite", "wb").write(tflite_model)

interpreter = tf.lite.Interpreter(model_content=tflite_model)
input = interpreter.get_input_details()
print(input)
output = interpreter.get_output_details()
print(output)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

3、pb模型转tflite,进行int8整型量化

 # 整型量化
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
        graph_def_file = '0824.pb',
        input_arrays = ['x_img_g', 'is_training'],
        input_shapes = {'x_img_g' : [1, 256, 512, 3], 'is_training' : [1]},
        output_arrays = ['encoder_generator/classifier/SINET_output/BiasAdd']
)
converter.quantized_input_stats = {"x_img_g": (0., 1.), "is_training": (0., 1.)}
converter.inference_type = tf.int8
tflite_model = converter.convert()
open("model_int8.tflite", "wb").write(tflite_model)

interpreter = tf.lite.Interpreter(model_content=tflite_model)
input = interpreter.get_input_details()
print(input)
output = interpreter.get_output_details()
print(output)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

4、pb模型转tflite,进行float16量化

#  float16量化
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
        graph_def_file = '0824.pb',
        input_arrays = ['x_img_g', 'is_training'],
        input_shapes = {'x_img_g' : [1, 256, 512, 3], 'is_training' : [1]},
        output_arrays = ['encoder_generator/classifier/SINET_output/BiasAdd']
)
converter.quantized_input_stats = {"x_img_g": (0., 1.), "is_training": (0., 1.)}
converter.inference_type = tf.float16
tflite_model = converter.convert()
open("model_float16.tflite", "wb").write(tflite_model)

interpreter = tf.lite.Interpreter(model_content=tflite_model)
input = interpreter.get_input_details()
print(input)
output = interpreter.get_output_details()
print(output)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

·

三、调用tflite

import os
import cv2
import time
import numpy as np
import tensorflow as tf
from PIL import Image

#os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# A helper function to evaluate the TF Lite model using "test" dataset.
def evaluate_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for test_image in test_images:
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  # Compare prediction results with ground truth labels to calculate accuracy.
  accurate_count = 0
  for index in range(len(prediction_digits)):
    if prediction_digits[index] == test_labels[index]:
      accurate_count += 1
  accuracy = accurate_count * 1.0 / len(prediction_digits)

  return accuracy


# interpreter = tf.compat.v1.lite.Interpreter(model_path="model_null.tflite")
interpreter = tf.compat.v1.lite.Interpreter(model_path="model_int8.tflite")
# interpreter = tf.compat.v1.lite.Interpreter(model_path="model_float16.tflite")
# interpreter = tf.compat.v1.lite.Interpreter(model_path="model_dynamic.tflite")
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
print(output_details)

test_image = cv2.imread('test.png')                             # (1080, 1920, 3)
r_w, r_h = 512, 256
img_data =  cv2.resize(test_image, (r_w, r_h))                  # (256, 512, 3)
img_data = np.expand_dims(img_data, axis=0).astype(np.int8)

interpreter.set_tensor(input_details[0]['index'], img_data)
interpreter.set_tensor(input_details[1]['index'], [False])
t1 = time.time()
interpreter.invoke()
t2 = time.time()
prediction = interpreter.get_tensor(output_details[0]['index'])
print(t2-t1)

print(prediction.shape)
prediction = prediction[0]
print(prediction.shape)

prediction1 = prediction[:,:,0]
print(prediction1.shape)
print(np.max(prediction1),np.min(prediction1))
img = Image.fromarray(prediction1)
img.show()

prediction2 = prediction[:,:,1]
print(prediction2.shape)
print(np.max(prediction2),np.min(prediction2))
img = Image.fromarray(prediction2)
img.show()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82

`

四、参考

1、官方转换教程参考
2、tensorflow 将.pb文件量化操作为.tflite
3、tensorflow2转tflite提示OP不支持的解决方案
4、Tensorflow2 lite 模型量化

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/154312
推荐阅读
相关标签
  

闽ICP备14008679号