当前位置:   article > 正文

如何修改Pytorch scripted_model 模型

如何修改Pytorch scripted_model 模型

如何修改Pytorch scripted_model 模型

本文演示了如何修改Pytorch scripted_model 结构,需求背景

  • 某些AI加速卡的推理软件栈会对模型做图优化,一些模型的图匹配策略不完善,导致编译失败
  • 方案一是等待厂家解决,方案二是自己修改图结构,向厂家支持的结构靠拢

源码

import sys
import os
import torch

max_seq_length=384
input = torch.randint(0, 2, (1, max_seq_length), dtype=torch.long)
scripted_model=torch.jit.load(model_path).eval()

torch._C._jit_pass_constant_propagation(scripted_model.graph)
torch._C._jit_pass_dce(scripted_model.graph)
torch._C._jit_pass_inline(scripted_model.graph)

with open("{}/graph.txt".format(prefix),"w") as f:
    f.write(str(scripted_model.graph))

# 修改匹配embedding
pattern = """
    graph(%input_ids.1,%44,%position_ids,%37,%45,%28):
        %63  = aten::size(%input_ids.1, %44) 
        %seq_length.1  = prim::NumToTensor(%63)
        %65  = aten::add(%seq_length.1, %37, %44) 
        %66  = aten::Int(%65)
        %67  = aten::slice(%position_ids, %45, %45, %28, %44) 
        %input.11  = aten::slice(%67, %44, %45, %66, %44)
        return (%input.11)
"""
replacement = """
    graph(%input_ids.1,%44,%position_ids,%37,%45,%28):
        %35 : int = prim::Constant[value=384]()
        %67  = aten::slice(%position_ids, %45, %45, %28, %44) 
        %input.11  = aten::slice(%67, %44, %45, %35, %44)
        return (%input.11)
"""

torch._C._jit_pass_custom_pattern_based_rewrite_graph(pattern, replacement,scripted_model.graph)
torch._C._jit_pass_dce(scripted_model.graph)

# 替换linear为matmul
pattern = """
    graph(%input.9, %weight.6, %bias.6):
        %x.5 = aten::linear(%input.9, %weight.6, %bias.6)
        return (%x.5)
"""
replacement = """
    graph(%input.7, %weight.6, %bias.6):
        %120  = aten::t(%weight.6)
        %45 : int = prim::Constant[value=1]()
        %output.10  = aten::matmul(%input.7, %120) 
        %122  = aten::add_(%output.10, %bias.6, %45)
        return (%122)
"""

torch._C._jit_pass_custom_pattern_based_rewrite_graph(pattern, replacement,scripted_model.graph)
torch._C._jit_pass_dce(scripted_model.graph)

# 删除掉split
pattern = """
    graph(%1056,%45,%44,%43):
        %1057 = aten::split(%1056, %44, %43)
        %start_logits.1 , %end_logits.1  = prim::ListUnpack(%1057)
        %1060 = aten::squeeze(%start_logits.1, %43)
        %1061 = aten::contiguous(%1060, %45)
        %1062 = aten::squeeze(%end_logits.1, %43)
        %1063 = aten::contiguous(%1062, %45)
        %1064 = prim::TupleConstruct(%1061, %1063)
        %11,%12  = prim::TupleUnpack(%1064)
        %15 =prim::TupleConstruct(%11, %12)
        return (%15)
"""
replacement = """
    graph(%1056,%45,%44,%43):
        return (%1056)
"""

torch._C._jit_pass_custom_pattern_based_rewrite_graph(pattern, replacement,scripted_model.graph)
torch._C._jit_pass_dce(scripted_model.graph)

with open("{}/graph_opt.txt".format(prefix),"w") as f:
    f.write(str(scripted_model.graph))

# 推理测试,确认模型正常
out = scripted_model(input,input,input)
for i in out:
    print(i.shape)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/149432
推荐阅读
相关标签
  

闽ICP备14008679号