当前位置:   article > 正文

【深度学习】Pytorch模型转成Onnx

【深度学习】Pytorch模型转成Onnx

前言

        工作时需要将模型转成onnx使用triton加载,记录将pytorch模型转成onnx的过程。

1.转化步骤

1-1.安装依赖库

  1. pip install onnx
  2. pip install onnxruntime

 1-2.导入模型

        将训练的模型导入

  1. from torch.utils.data import TensorDataset, DataLoader
  2. from transformers import BertTokenizer, BertModel,AdamW
  3. import torch.nn as nn
  4. import torch
  5. import pandas as pd
  6. import json
  7. import re
  8. import requests
  9. import json
  10. import numpy as np
  11. def encoder(max_length,text_list):
  12. #将text_list embedding成bert模型可用的输入形式
  13. #加载分词模型
  14. vocab_path = "/ssd/dongzhenheng/Pretrain_Model/Roberta_Large/"
  15. #tokenizer = RobertaTokenizer.from_pretrained(vocab_path)
  16. tokenizer = BertTokenizer.from_pretrained(vocab_path)
  17. input_dict = tokenizer.encode_plus(
  18. text,
  19. add_special_tokens=True, # 添加'[CLS]'和'[SEP]'
  20. max_length=max_length,
  21. truncation=True, # 截断或填充
  22. padding='max_length', # 填充至最大长度
  23. return_attention_mask=True, # 返回attention_mask
  24. return_token_type_ids=True, # 返回token_type_ids
  25. return_tensors='pt',
  26. )
  27. input_ids = input_dict['input_ids']
  28. token_type_ids = input_dict['token_type_ids']
  29. attention_mask = input_dict['attention_mask']
  30. print(input_ids.dtype,token_type_ids.dtype,attention_mask.dtype)
  31. input_ids = input_ids.to(torch.int32)
  32. token_type_ids = token_type_ids.to(torch.int32)
  33. attention_mask = attention_mask.to(torch.int32)
  34. print(input_ids.dtype,token_type_ids.dtype,attention_mask.dtype)
  35. return input_ids,token_type_ids,attention_mask
  1. class JointBertClassificationModel(nn.Module):
  2. def __init__(self):
  3. super(JointBertClassificationModel, self).__init__()
  4. #加载预训练模型
  5. pretrained_weights = "/ssd/dongzhenheng/Pretrain_Model/Roberta_Large/"
  6. self.bert = BertModel.from_pretrained(pretrained_weights)
  7. #self.bert = ErnieForMaskedLM.from_pretrained(pretrained_weights)
  8. for param in self.bert.parameters():
  9. param.requires_grad = True
  10. self.dropout = nn.Dropout(0.3)
  11. #定义联合分类
  12. self.pri_dense_1 = nn.Linear(1024, 89)
  13. def forward(self, input_ids,token_type_ids,attention_mask):
  14. #得到bert_output
  15. bert_output = self.bert(input_ids=input_ids, token_type_ids= token_type_ids,attention_mask=attention_mask)
  16. #获得预训练模型的输出
  17. bert_cls_hidden_state = bert_output[1]
  18. pri_cls_output_1 = self.pri_dense_1(bert_cls_hidden_state)
  19. return pri_cls_output_1
  20. class FeedBackBertClassificationModel(nn.Module):
  21. def __init__(self):
  22. super(FeedBackBertClassificationModel, self).__init__()
  23. #加载预训练模型
  24. pretrained_weights = "/ssd/dongzhenheng/Pretrain_Model/Roberta_Large/"
  25. self.bert = BertModel.from_pretrained(pretrained_weights)
  26. #self.bert = ErnieForMaskedLM.from_pretrained(pretrained_weights)
  27. for param in self.bert.parameters():
  28. param.requires_grad = True
  29. self.dropout = nn.Dropout(0.3)
  30. self.pri_dense_1 = nn.Linear(1024, 3)
  31. def forward(self, input_ids,token_type_ids,attention_mask):
  32. #得到bert_output
  33. bert_output = self.bert(input_ids=input_ids, token_type_ids=token_type_ids,attention_mask=attention_mask)
  34. #获得预训练模型的输出
  35. bert_cls_hidden_state = bert_output[1]
  36. pri_cls_output_1 = self.pri_dense_1(bert_cls_hidden_state)
  37. #print(pri_cls_output_1.size())
  38. return pri_cls_output_1
  39. FeedBack_classifier_model_path = '/ssd/dongzhenheng/Work/Intelligent_customer_service/feed_back_model_large_1.pkl'
  40. FeedBack_classifier_model = torch.load(FeedBack_classifier_model_path, map_location=torch.device('cpu'))
  41. # 设置模型为评估模式
  42. FeedBack_classifier_model.eval()

1-3 转成onnx格式

  1. # 导出模型
  2. max_len = 100
  3. text = '你好'
  4. input_ids, token_type_ids, attention_mask = encoder(max_len,text)
  5. torch.onnx.export(model = FeedBack_classifier_model, # 模型
  6. args = (input_ids, token_type_ids, attention_mask), # 模型输入
  7. path = "/home/zhenhengdong/WORk/Triton/Bug_Cls/Onnx_model/model_repository/Feedback_classifition_onnx/1/model.onnx", # 输出文件名
  8. export_params=True, # 是否导出参数
  9. opset_version=15, # ONNX版本
  10. verbose=True,
  11. do_constant_folding=True, # 是否执行常量折叠优化
  12. input_names=["input_ids", "token_type_ids", "attention_mask"], # 输入名
  13. output_names=["pri_cls_output"], # 输出名
  14. dynamic_axes={"input_ids": {0: "batch_size"}, "token_type_ids": {0: "batch_size"}, "attention_mask": {0: "batch_size"}, "pri_cls_output": {0: "batch_size"}}) # 动态维度

model :需要导出的pytorch模型
args:模型的输入参数,需要和模型接收到的参数一致。
path:输出的onnx模型的位置和名称。
export_params:输出模型是否可训练。default=True,表示导出trained model,否则untrained。

opset_version :ONNX版本
verbose:是否打印模型转换信息。default=False。
input_names:输入节点名称。default=None。
output_names:输出节点名称。default=None。
do_constant_folding:是否使用常量折叠,默认即可。default=True。
dynamic_axes:模型的输入输出有时是可变的。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/578443
推荐阅读
相关标签
  

闽ICP备14008679号