当前位置:   article > 正文

bert-NER 转化成 onnx 模型

bert-NER 转化成 onnx 模型

保存模型

加载模型

from transformers import AutoTokenizer, AutoModel, AutoConfig

NER_MODEL_PATH = './save_model'
ner_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_PATH)
ner_config = AutoConfig.from_pretrained(NER_MODEL_PATH)
ner_model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_PATH)
ner_model.eval()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

测试ner效果

在这里插入图片描述

测试速度

在这里插入图片描述

导出到onnx

# !pip install onnx onnxruntime -i https://pypi.tuna.tsinghua.edu.cn/simple/

# 导出 onnx 模型
import onnxruntime
from itertools import chain
from transformers.onnx.features import FeaturesManager

config = ner_config
tokenizer = ner_tokenizer
model = ner_model
output_onnx_path = "bert-ner.onnx"

onnx_config = FeaturesManager._SUPPORTED_MODEL_TYPE['bert']['sequence-classification'](config)
dummy_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework='pt')

torch.onnx.export(
    model,
    (dummy_inputs,),
    f=output_onnx_path,
    input_names=list(onnx_config.inputs.keys()),
    output_names=list(onnx_config.outputs.keys()),
    dynamic_axes={
        name: axes for name, axes in chain(onnx_config.inputs.items(), onnx_config.outputs.items())
    },
    do_constant_folding=True,
    use_external_data_format=onnx_config.use_external_data_format(model.num_parameters()),
    enable_onnx_checker=True,
    opset_version=onnx_config.default_onnx_opset,
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

加载ONNX模型

自定义pipeline

from onnxruntime import SessionOptions, GraphOptimizationLevel, InferenceSession

class PipeLineOnnx:
    def __init__(self, tokenizer, onnx_path, config):
        self.tokenizer = tokenizer
        self.config = config  # label2id, id2label
        options = SessionOptions() # initialize session options
        options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
        # 设置线程数
#         options.intra_op_num_threads = 4
        # 这里的路径传上一节保存的onnx模型地址
        self.session = InferenceSession(
            onnx_path, sess_options=options, providers=["CPUExecutionProvider"]
        )
        # disable session.run() fallback mechanism, it prevents for a reset of the execution provider
        self.session.disable_fallback() 

    def __call__(self, text):
        inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt')
        ids = inputs["input_ids"]
        inputs_offset = self.tokenizer.encode_plus(text, return_offsets_mapping=True).offset_mapping
        inputs_detach = {k: v.detach().cpu().numpy() for k, v in inputs.items()}

        # 运行 ONNX 模型
        # 这里的logits要有export的时候output_names相对应

        output = self.session.run(output_names=['logits'], input_feed=inputs_detach)[0]
        logits = torch.tensor(output)

        num_labels = len(self.config.label2id)
        active_logits = logits.view(-1, num_labels) # shape (batch_size * seq_len, num_labels)
        softmax = torch.softmax(active_logits, axis=1)
        scores = torch.max(softmax, axis=1).values.cpu().detach().numpy()
        flattened_predictions = torch.argmax(active_logits, axis=1) # shape (batch_size*seq_len,) - predictions at the token level

        tokens = self.tokenizer.convert_ids_to_tokens(ids.squeeze().tolist())
        token_predictions = [self.config.id2label[i] for i in flattened_predictions.cpu().numpy()]
        wp_preds = list(zip(tokens, token_predictions)) # list of tuples. Each tuple = (wordpiece, prediction)

        ner_result = [{"index": idx, "word":i,"entity":j, "start": k[0], "end": k[1], "score": s} for idx, (i,j,k,s) in enumerate(zip(tokens, token_predictions, inputs_offset, scores)) if j != 'O']
        return post_process(ner_result)
        

def allow_merge(a, b):
    a_flag, a_type = a.split('-')
    b_flag, b_type = b.split('-')
    if b_flag == 'B' or a_flag == 'E':
        return False
    if a_type != b_type:
        return False
    if (a_flag, b_flag) in [
        ("B", "I"),
        ("B", "E"),
        ("I", "I"),
        ("I", "E")
    ]:
        return True
    return False

def divide_entities(ner_results):
    divided_entities = []
    current_entity = []

    for item in sorted(ner_results, key=lambda x: x['index']):
        if not current_entity:
            current_entity.append(item)
        elif allow_merge(current_entity[-1]['entity'], item['entity']):
            current_entity.append(item)
        else:
            divided_entities.append(current_entity)
            current_entity = [item]
    divided_entities.append(current_entity)
    return divided_entities

def merge_entities(same_entities):
    def avg(scores):
        return sum(scores)/len(scores)
    return {
        'entity': same_entities[0]['entity'].split("-")[1],
        'score': avg([e['score'] for e in same_entities]),
        'word': ''.join(e['word'].replace('##', '') for e in same_entities),
        'start': same_entities[0]['start'],
        'end': same_entities[-1]['end']
    }

def post_process(ner_results):
    return [merge_entities(i) for i in divide_entities(ner_results)]

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88

加载模型

from transformers import AutoTokenizer, AutoConfig

NER_MODEL_PATH = './save_model'
ner_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_PATH)
ner_config = AutoConfig.from_pretrained(NER_MODEL_PATH)

pipe2 = PipeLineOnnx(ner_tokenizer, "bert-ner.onnx", config=ner_config)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

测试效果

在这里插入图片描述

测试速度

在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/565365
推荐阅读
相关标签
  

闽ICP备14008679号