赞
踩
本文演示了如何修改Pytorch scripted_model 结构,需求背景
import sys import os import torch max_seq_length=384 input = torch.randint(0, 2, (1, max_seq_length), dtype=torch.long) scripted_model=torch.jit.load(model_path).eval() torch._C._jit_pass_constant_propagation(scripted_model.graph) torch._C._jit_pass_dce(scripted_model.graph) torch._C._jit_pass_inline(scripted_model.graph) with open("{}/graph.txt".format(prefix),"w") as f: f.write(str(scripted_model.graph)) # 修改匹配embedding pattern = """ graph(%input_ids.1,%44,%position_ids,%37,%45,%28): %63 = aten::size(%input_ids.1, %44) %seq_length.1 = prim::NumToTensor(%63) %65 = aten::add(%seq_length.1, %37, %44) %66 = aten::Int(%65) %67 = aten::slice(%position_ids, %45, %45, %28, %44) %input.11 = aten::slice(%67, %44, %45, %66, %44) return (%input.11) """ replacement = """ graph(%input_ids.1,%44,%position_ids,%37,%45,%28): %35 : int = prim::Constant[value=384]() %67 = aten::slice(%position_ids, %45, %45, %28, %44) %input.11 = aten::slice(%67, %44, %45, %35, %44) return (%input.11) """ torch._C._jit_pass_custom_pattern_based_rewrite_graph(pattern, replacement,scripted_model.graph) torch._C._jit_pass_dce(scripted_model.graph) # 替换linear为matmul pattern = """ graph(%input.9, %weight.6, %bias.6): %x.5 = aten::linear(%input.9, %weight.6, %bias.6) return (%x.5) """ replacement = """ graph(%input.7, %weight.6, %bias.6): %120 = aten::t(%weight.6) %45 : int = prim::Constant[value=1]() %output.10 = aten::matmul(%input.7, %120) %122 = aten::add_(%output.10, %bias.6, %45) return (%122) """ torch._C._jit_pass_custom_pattern_based_rewrite_graph(pattern, replacement,scripted_model.graph) torch._C._jit_pass_dce(scripted_model.graph) # 删除掉split pattern = """ graph(%1056,%45,%44,%43): %1057 = aten::split(%1056, %44, %43) %start_logits.1 , %end_logits.1 = prim::ListUnpack(%1057) %1060 = aten::squeeze(%start_logits.1, %43) %1061 = aten::contiguous(%1060, %45) %1062 = aten::squeeze(%end_logits.1, %43) %1063 = aten::contiguous(%1062, %45) %1064 = prim::TupleConstruct(%1061, %1063) %11,%12 = prim::TupleUnpack(%1064) %15 =prim::TupleConstruct(%11, %12) return (%15) """ replacement = """ graph(%1056,%45,%44,%43): return (%1056) """ torch._C._jit_pass_custom_pattern_based_rewrite_graph(pattern, replacement,scripted_model.graph) torch._C._jit_pass_dce(scripted_model.graph) with open("{}/graph_opt.txt".format(prefix),"w") as f: f.write(str(scripted_model.graph)) # 推理测试,确认模型正常 out = scripted_model(input,input,input) for i in out: print(i.shape)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。