当前位置:   article > 正文

TFLite文件解析及格式转换_.tflite

.tflite

        随着深度学习越来越流行,工业生产不光在PC端应用场景丰富,在移动端也越来越凸显出深度学习的重要性及应用价值。由于嵌入式平台受存储、指令集限制,需要提供更小的网络模型,并且某些DSP平台不支持float指令。tensorflow提供TOCO转换工具能够自动生成量化为U8的TFLite文件。本文将介绍如何解析tflite的网络结构以及权重信息。

一、tflite文件格式

        Tflite文件由tensorflow提供的TOCO工具生成的轻量级模型,存储格式是flatbuffer,它是google开源的一种二进制序列化格式,同功能的像protobuf。对flatbuffer可小结为三点。

1.内容分为vtable区和数据区,vtable区保存着变量的偏移值,数据区保存着变量值;

2.要解析变量a,是在vtable区组合一层层的offset偏移量计算出总偏移,然后以总偏移到数据区中定位从而获取变量a的值。

3.一个叫schema的文本文件定义了要进行序列化和反序列化的数据结构。

具体定义的结构可以参考tensorflow源码中的schema文件:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs

二、tflite解析

         由于工作需要,本文使用了google flatbuffer开源工具flatc,flatc可以实现tflite格式到jason文件的自动转换。

flatbuffer源码:https://github.com/google/flatbuffers 

安装命令: cmake -G "Unix Makefiles" //生成MakeFile

                   make //生成flatc

                   make install //安裝flatc

安装完成后,从tensorflow源码中copy 结构文件schema.fbs到flatbuffer根目录,执行#./flatc -t schema.fbs -- mobilenet_v1_1.0_224_quant.tflite,生成对应的json文件。Json文件结构如下图所示:

operator_codes: 以列表的形式存储该网络结构用的layer种类;

subgraphs: 为每一层的具体信息具体包括:

            1)tensors.包含input、weight、bias的shape信息、量化参数以及在buffer数据区的offset值;

            2)inputs: 整个网络的输入对应的tensors索引;

            3)outputs: 整个网络的输出对应的tensors索引;

            4)operators:网络结构所需要的相关参数;

buffers: 存放weight、bias等权重信息。

三、网络结构及权重提取

       使用python的json包可以很方便的读取tflite生成的json文件。关于解析过程有几点说明:

        1.flatc转换的json文件不是标准的key-value格式,需要稍作转换给索引key加上双引号具体代码如下:

  1. # -*- coding: UTF-8 -*-
  2. import os
  3. pathIn='xxx.json'
  4. pathDst='xxx_new.json'
  5. f = open(pathIn) # 返回一个文件对象
  6. line = f.readline() # 调用文件的 readline()方法
  7. fout = open(pathDst,'w')
  8. while line:
  9. #print(line)
  10. #print(len(line))
  11. dstline='aaa'
  12. if line.find(':')!=-1:
  13. quoteIdx2=line.find(':')
  14. #print("line has :, and index =%d" %quoteIdx2)
  15. linenew=line[:quoteIdx2] + '"' + line[quoteIdx2:]
  16. quoteIdx1=linenew.rfind(' ',0, quoteIdx2)
  17. #print("quoteIdx1 %d" %quoteIdx1)
  18. dstline=linenew[:quoteIdx1+1] + '"' + linenew[quoteIdx1+1:]
  19. #print(dstline)
  20. fout.write(dstline+os.linesep)
  21. else:
  22. dstline=line
  23. fout.write(line)
  24. #print("No")
  25. #print dstline
  26. line = f.readline()
  27. f.close()
  28. fout.close()

        2.由于量化后的bias为int32的类型,而flatc将bias数据按照uint8的格式进行了转换,这里需要对json文件的bias再转换回int32类型,相当于json中bias区域四个字节转换为一个int32。详细讨论参考tensorflow github链接:https://github.com/tensorflow/tensorflow/issues/22279

        解析部分代码分为两个部分包括网络结构以及权重解析,方法相似。网络结构参数解析,部分代码如下:

  1. from __future__ import division
  2. import json
  3. def write_blob_info(p_file, inputs, input_shape):
  4. p_file.write(str(inputs) + ', ')
  5. p_file.write(str(3) + ', ')
  6. p_file.write(str(input_shape[3]) + ', ')
  7. p_file.write(str(input_shape[1]) + ', ')
  8. p_file.write(str(input_shape[2]) + ', ')
  9. with open("mobilenet_v1_1.0_224_quant.json",'r') as f:
  10. load_dict = json.load(f)
  11. param_file=open("mobilenet_v1_1.0_224_quant.proto",'w')
  12. tensors = load_dict["subgraphs"][0]["tensors"]
  13. operators = load_dict["subgraphs"][0]["operators"]
  14. inputs = load_dict["subgraphs"][0]["inputs"]
  15. input_shape = tensors[inputs[0]]["shape"]
  16. param_file.write(str(len(operators) + 1) + ',\n')
  17. write_blob_info(param_file, \
  18. inputs[0], \
  19. input_shape)
  20. param_file.write('\n')
  21. for layer in operators:
  22. layer_name = layer["builtin_options_type"]
  23. operators_inputs = layer["inputs"]
  24. input_len = len(operators_inputs)
  25. builtin_options = layer["builtin_options"]
  26. if layer_name == "Conv2DOptions": #conv_2d, depthwiseconv_2d
  27. input_shape = tensors[operators_inputs[0]]["shape"]
  28. kernel_shape = tensors[operators_inputs[1]]["shape"]
  29. bias_shape = tensors[operators_inputs[2]]["shape"]
  30. kernel_H = kernel_shape[1]
  31. kernel_W = kernel_shape[2]
  32. param_file.write(str(kernel_H) + ', ')
  33. param_file.write(str(kernel_W) + ', ')
  34. stride_H = builtin_options["stride_h"]
  35. stride_W = builtin_options["stride_w"]
  36. param_file.write(str(stride_H) + ', ')
  37. param_file.write(str(stride_W) + ', ')
  38. dilation_W = builtin_options["dilation_w_factor"]
  39. dilation_H = builtin_options["dilation_h_factor"]
  40. param_file.write(str(dilation_H) + ', ')
  41. param_file.write(str(dilation_W) + ', ')
  42. bias_term = 1
  43. if input_len < 3 or bias_shape[0] == 0:
  44. bias_term = 0
  45. param_file.write(str(bias_term) + ', ')
  46. bottom_zero_point = tensors[operators_inputs[0]]["quantization"]["zero_point"][0]
  47. param_file.write(str(bottom_zero_point) + ', ')
  48. write_blob_info(param_file, \
  49. operators_inputs[0], \
  50. input_shape)
  51. #output_blob
  52. operators_outputs = layer["outputs"]
  53. output_shape = tensors[operators_outputs[0]]["shape"]
  54. write_blob_info(param_file, \
  55. operators_outputs[0], \
  56. output_shape)
  57. param_file.write('\n')

(水平有限,如有问题及遗漏欢迎补充指出,互相学习。)

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/154319
推荐阅读
相关标签
  

闽ICP备14008679号