当前位置:   article > 正文

Transformers从零到精通教程——Pipeline_transformer模型pipeline

transformer模型pipeline

一、Pipeline

1.查看支持的任务类型

from transformers.pipelines import SUPPORTED_TASKS, get_supported_tasks

print(SUPPORTED_TASKS.items(), get_supported_tasks())
  • 1
  • 2
  • 3

2.Pipeline的创建与使用方式

1.根据任务类型直接创建Pipeline, 默认都是英文的模型

from transformers import pipeline

pipe = pipeline("text-classification")
pipe("very good!")
# [{'label': 'POSITIVE', 'score': 0.9998525381088257}]
  • 1
  • 2
  • 3
  • 4
  • 5

2.指定任务类型,再指定模型,创建基于指定模型的Pipeline

from transformers import pipeline

# https://huggingface.co/models
pipe = pipeline("text-classification", 
                model="uer/roberta-base-finetuned-dianping-chinese")
pipe("我觉得不太行!")
# [{'label': 'negative (stars 1, 2 and 3)', 'score': 0.9735506772994995}]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

3.预先加载模型,再创建Pipeline

from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline

model = AutoModelForSequenceClassification.from_pretrained("uer/roberta-base-finetuned-dianping-chinese")
tokenizer = AutoTokenizer.from_pretrained("uer/roberta-base-finetuned-dianping-chinese")

pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
pipe("你真是个人才!")
# [{'label': 'positive (stars 4 and 5)', 'score': 0.8717765808105469}]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

3.GPU推理

  • 推理默认在 cpu
pipe.model.device
# device(type='cpu')
  • 1
  • 2

%%time # 魔法命令,统计时间

for i in range(100):
    pipe("你真是个人才!")

'''
CPU times: total: 19.4 s
Wall time: 4.94 s
'''
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

import torch
import time
times = []
for i in range(100):
    torch.cuda.synchronize()
    start = time.time()
    pipe("我觉得不太行!")
    torch.cuda.synchronize()
    end = time.time()
    times.append(end - start)
print(sum(times) / 100)
# 0.05427998542785645 CPU
# 0.012370436191558839 GPU
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

pipe = pipeline("text-classification", model="uer/roberta-base-finetuned-dianping-chinese", device=0)
pipe.model.device
# device(type='cuda', index=0)
  • 1
  • 2
  • 3

4.确定Pipeline参数

qa_pipeline = pipeline("question-answering", model="uer/roberta-base-chinese-extractive-qa")
qa_pipeline(question="是谁?", context="是帅哥!")
# {'score': 0.004711466375738382, 'start': 4, 'end': 6, 'answer': '帅哥'}
  • 1
  • 2
  • 3
  • 具体做法是,查看 qa_pipeline的类 QuestionAnsweringPipeline,然后 Ctrl+鼠标左键查看 __call__方法源码

image.png<br />image.png


5.其他Piepeline示例

  • 零样本目标检测
checkpoint = "google/owlvit-base-patch32"
detector = pipeline(model=checkpoint, task="zero-shot-object-detection")
  • 1
  • 2

import requests
from PIL import Image

url = "https://unsplash.com/photos/oj0zeY2Ltk4/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8MTR8fHBpY25pY3xlbnwwfHx8fDE2Nzc0OTE1NDk&force=true&w=640"
im = Image.open(requests.get(url, stream=True).raw)
im
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

image.png

predictions = detector(im,
        candidate_labels = ["hat", "book"])

from PIL import ImageDraw

draw = ImageDraw.Draw(im)

for prediction in predictions:
    box = prediction["box"]
    label = prediction["label"]
    score = prediction["score"]
    xmin, ymin, xmax, ymax = box.values()
    draw.rectangle((xmin, ymin, xmax, ymax), outline="red", width=1)
    draw.text((xmin, ymin), f"{label}: {round(score,2)}", fill="red")

im
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

image.png


6.Pipeline的背后实现

'''
1.处理输入
2.模型输出
3.id2label model.config.id2label
'''
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

model = AutoModelForSequenceClassification.from_pretrained("uer/roberta-base-finetuned-dianping-chinese")
tokenizer = AutoTokenizer.from_pretrained("uer/roberta-base-finetuned-dianping-chinese")

text = "我觉得不行!"
input = tokenizer(text, return_tensors="pt") # 注意return_tensors="pt"
output = model(**input)
logits = torch.softmax(output.logits, dim= -1)


id = torch.argmax(logits).item()

# id2label = {
#     1:"Positive",
#     0:"Negtivate",
# }
model.config.id2label = id2label
print(text, "\n", model.config.id2label.get(id))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/353663
推荐阅读
相关标签
  

闽ICP备14008679号