当前位置:   article > 正文

pytorch模型转到TensorFlow lite:pytorch->onnx->tensorflow->tensorflow lite_tf_rep.run

tf_rep.run

现在很多算法都是用pytorch框架训练的,但是在移动端部署很多又使用TensorFlow lite,因此需要将pytorch模型转换到TensorFlow lite。

将pytorch模型转到TensorFlow lite的流程是pytorch->onnx->tensorflow->tensorflow lite,本文记录一下踩坑的过程。

1、pytorch转onnx

这一步比较简单,使用pytorch自带接口就行。不过有一点需要注意的,就是opset版本,可能会影响后续的转换。

  1. os.environ['CUDA_VISIBLE_DEVICES']='0'
  2. model_path = 'model.pth'
  3. model = architecture.IMDN_RTC(upscale=2).cuda()
  4. model_dict = utils.load_state_dict(model_path)
  5. model.load_state_dict(model_dict, strict=True)
  6. model.eval()
  7. for k, v in model.named_parameters():
  8. v.requires_grad = False
  9. dummy_input = torch.rand(1, 3, 224, 224).cuda()
  10. input_names = ["input"]
  11. #output_names = ["output1", "output2", "output3"]
  12. output_names = ["output"]
  13. #使用pytorch的onnx模块来进行转换
  14. #opset 10转换后,使用onnxruntime运行,在pixelshuffle处会出错
  15. torch.onnx.export(model, dummy_input, "model.onnx", opset_version=11, verbose=True,
  16. input_names=input_names, output_names=output_names,
  17. dynamic_axes={'input': [0, 2, 3], 'output': [0, 2, 3]})
  18. session = onnxruntime.InferenceSession("model.onnx")
  19. input_name = session.get_inputs()[0].name
  20. #output_name = session.get_outputs()[0].name
  21. output_names = [s.name for s in session.get_outputs()]
  22. input_shape = session.get_inputs()[0].shape
  23. img = cv2.imread('babyx2.bmp')[:,:,::-1]
  24. img = np.transpose(img, (2, 0, 1)) / 255.
  25. img = torch.from_numpy(img).unsqueeze(0).float()
  26. res = session.run(output_names, {input_name: img.cpu().numpy()})
  27. tmp = res[0]
  28. tmp = np.clip(tmp[0], 0, 1)
  29. img = np.array(tmp*255, dtype=np.uint8)
  30. img = np.transpose(img, (1, 2, 0))[:,:,::-1]
  31. cv2.imwrite('tmp.jpg', img)

torch.onnx.export后,就得到了onnx模型,后面的代码是使用onnxruntime测试转换后的onnx模型。建议每一步转换后,都测试一下转换后模型的结果,确保每一步都是正确的。

2、onnx转TensorFlow

需要安装onnx-tensorflow进行转换。

  1. from onnx_tf.backend import prepare
  2. import onnx
  3. import tensorflow as tf
  4. if __name__ == '__main__':
  5. onnx_model = onnx.load("model.onnx") # load onnx model
  6. tf_rep = prepare(onnx_model) # prepare tf representation
  7. tf_rep.export_graph("model.tf") # export the model
  8. img = cv2.imread('babyx2.bmp')[:,:,::-1]
  9. img = np.transpose(img, (2, 0, 1)) / 255.
  10. img = torch.from_numpy(img).unsqueeze(0).float()
  11. input = img.numpy()
  12. if 0:
  13. output = tf_rep.run(input) # run the loaded model
  14. res = output.output[0]
  15. res = np.clip(res, 0, 1)
  16. im = np.array(res*255, dtype=np.uint8)
  17. im1 = np.transpose(im, (1, 2, 0))[:,:,::-1]
  18. cv2.imwrite('tfres.jpg', im1)
  19. else:
  20. saved_model = tf.saved_model.load("model.tf")
  21. detect_fn = saved_model.signatures["serving_default"]
  22. output = detect_fn(tf.constant(input))
  23. tmp = np.array(output['output'])[0]
  24. res = np.clip(tmp, 0, 1)
  25. im = np.array(res*255, dtype=np.uint8)
  26. im1 = np.transpose(im, (1, 2, 0))[:,:,::-1]
  27. cv2.imwrite('savedmodelres.jpg', im1)

转换部分就是前三行代码,后面是对TensorFlow模型的测试,确保转换结果没有问题。

3、TensorFlow转TensorFlow lite

没想到这一步是比较坑的,换了几个TensorFlow版本,最终使用tf2.5,转换成功了,参考issue

  1. import tensorflow as tf
  2. if __name__ == '__main__':
  3. # Convert the model
  4. converter = tf.lite.TFLiteConverter.from_saved_model('model.tf') # path to the SavedModel directory
  5. converter.target_spec.supported_ops = [
  6. tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
  7. tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
  8. ]
  9. tflite_model = converter.convert()
  10. # Save the model.
  11. with open('model.tflite', 'wb') as f:
  12. f.write(tflite_model)
  13. # test tflite model
  14. interpreter = tf.lite.Interpreter(model_path='model.tflite')
  15. #my_signature = interpreter.get_signature_runner()
  16. img = cv2.imread('babyx2.bmp')[:,:,::-1]
  17. img = np.transpose(img, (2, 0, 1)) / 255.
  18. img = img[np.newaxis, :]
  19. #output = my_signature(tf.constant(img))
  20. print()
  21. interpreter.resize_tensor_input(0, [1, 3, 256, 256])
  22. interpreter.allocate_tensors()
  23. # Get input and output tensors.
  24. input_details = interpreter.get_input_details()
  25. output_details = interpreter.get_output_details()
  26. # Test the model on random input data.
  27. #input_shape = input_details[0]['shape_signature']
  28. #input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
  29. input_data = img.astype(np.float32)
  30. interpreter.set_tensor(input_details[0]['index'], input_data)
  31. interpreter.invoke()
  32. # The function `get_tensor()` returns a copy of the tensor data.
  33. # Use `tensor()` in order to get a pointer to the tensor.
  34. output_data = interpreter.get_tensor(output_details[0]['index'])
  35. res = np.clip(output_data[0], 0, 1)
  36. im = np.array(res*255, dtype=np.uint8)
  37. im1 = np.transpose(im, (1, 2, 0))[:,:,::-1]
  38. cv2.imwrite('tfliteres.jpg', im1)

这里需要注意一点,converter.target_spec.supported_ops这个需要加上,不然有些op在TensorFlow lite中不支持,转换不成功。

 

 

 

 

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/154247
推荐阅读
相关标签
  

闽ICP备14008679号