赞
踩
现在很多算法都是用pytorch框架训练的,但是在移动端部署很多又使用TensorFlow lite,因此需要将pytorch模型转换到TensorFlow lite。
将pytorch模型转到TensorFlow lite的流程是pytorch->onnx->tensorflow->tensorflow lite,本文记录一下踩坑的过程。
这一步比较简单,使用pytorch自带接口就行。不过有一点需要注意的,就是opset版本,可能会影响后续的转换。
- os.environ['CUDA_VISIBLE_DEVICES']='0'
- model_path = 'model.pth'
- model = architecture.IMDN_RTC(upscale=2).cuda()
- model_dict = utils.load_state_dict(model_path)
- model.load_state_dict(model_dict, strict=True)
-
- model.eval()
- for k, v in model.named_parameters():
- v.requires_grad = False
-
- dummy_input = torch.rand(1, 3, 224, 224).cuda()
- input_names = ["input"]
- #output_names = ["output1", "output2", "output3"]
- output_names = ["output"]
- #使用pytorch的onnx模块来进行转换
- #opset 10转换后,使用onnxruntime运行,在pixelshuffle处会出错
- torch.onnx.export(model, dummy_input, "model.onnx", opset_version=11, verbose=True,
- input_names=input_names, output_names=output_names,
- dynamic_axes={'input': [0, 2, 3], 'output': [0, 2, 3]})
-
- session = onnxruntime.InferenceSession("model.onnx")
- input_name = session.get_inputs()[0].name
- #output_name = session.get_outputs()[0].name
- output_names = [s.name for s in session.get_outputs()]
- input_shape = session.get_inputs()[0].shape
-
- img = cv2.imread('babyx2.bmp')[:,:,::-1]
- img = np.transpose(img, (2, 0, 1)) / 255.
- img = torch.from_numpy(img).unsqueeze(0).float()
- res = session.run(output_names, {input_name: img.cpu().numpy()})
- tmp = res[0]
- tmp = np.clip(tmp[0], 0, 1)
- img = np.array(tmp*255, dtype=np.uint8)
- img = np.transpose(img, (1, 2, 0))[:,:,::-1]
- cv2.imwrite('tmp.jpg', img)
torch.onnx.export后,就得到了onnx模型,后面的代码是使用onnxruntime测试转换后的onnx模型。建议每一步转换后,都测试一下转换后模型的结果,确保每一步都是正确的。
需要安装onnx-tensorflow进行转换。
- from onnx_tf.backend import prepare
- import onnx
- import tensorflow as tf
- if __name__ == '__main__':
- onnx_model = onnx.load("model.onnx") # load onnx model
- tf_rep = prepare(onnx_model) # prepare tf representation
- tf_rep.export_graph("model.tf") # export the model
-
- img = cv2.imread('babyx2.bmp')[:,:,::-1]
- img = np.transpose(img, (2, 0, 1)) / 255.
- img = torch.from_numpy(img).unsqueeze(0).float()
- input = img.numpy()
- if 0:
- output = tf_rep.run(input) # run the loaded model
- res = output.output[0]
- res = np.clip(res, 0, 1)
- im = np.array(res*255, dtype=np.uint8)
- im1 = np.transpose(im, (1, 2, 0))[:,:,::-1]
- cv2.imwrite('tfres.jpg', im1)
- else:
- saved_model = tf.saved_model.load("model.tf")
- detect_fn = saved_model.signatures["serving_default"]
- output = detect_fn(tf.constant(input))
- tmp = np.array(output['output'])[0]
- res = np.clip(tmp, 0, 1)
- im = np.array(res*255, dtype=np.uint8)
- im1 = np.transpose(im, (1, 2, 0))[:,:,::-1]
- cv2.imwrite('savedmodelres.jpg', im1)
转换部分就是前三行代码,后面是对TensorFlow模型的测试,确保转换结果没有问题。
没想到这一步是比较坑的,换了几个TensorFlow版本,最终使用tf2.5,转换成功了,参考issue。
- import tensorflow as tf
-
- if __name__ == '__main__':
- # Convert the model
- converter = tf.lite.TFLiteConverter.from_saved_model('model.tf') # path to the SavedModel directory
- converter.target_spec.supported_ops = [
- tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
- tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
- ]
- tflite_model = converter.convert()
-
- # Save the model.
- with open('model.tflite', 'wb') as f:
- f.write(tflite_model)
-
- # test tflite model
- interpreter = tf.lite.Interpreter(model_path='model.tflite')
- #my_signature = interpreter.get_signature_runner()
- img = cv2.imread('babyx2.bmp')[:,:,::-1]
- img = np.transpose(img, (2, 0, 1)) / 255.
- img = img[np.newaxis, :]
- #output = my_signature(tf.constant(img))
- print()
-
- interpreter.resize_tensor_input(0, [1, 3, 256, 256])
- interpreter.allocate_tensors()
-
- # Get input and output tensors.
- input_details = interpreter.get_input_details()
- output_details = interpreter.get_output_details()
-
- # Test the model on random input data.
- #input_shape = input_details[0]['shape_signature']
- #input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
- input_data = img.astype(np.float32)
- interpreter.set_tensor(input_details[0]['index'], input_data)
-
- interpreter.invoke()
-
- # The function `get_tensor()` returns a copy of the tensor data.
- # Use `tensor()` in order to get a pointer to the tensor.
- output_data = interpreter.get_tensor(output_details[0]['index'])
- res = np.clip(output_data[0], 0, 1)
- im = np.array(res*255, dtype=np.uint8)
- im1 = np.transpose(im, (1, 2, 0))[:,:,::-1]
- cv2.imwrite('tfliteres.jpg', im1)
这里需要注意一点,converter.target_spec.supported_ops这个需要加上,不然有些op在TensorFlow lite中不支持,转换不成功。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。