当前位置:   article > 正文

通过FastAPI框架部署Torchvision预训练模型_fastapi 部署pytorch模型

fastapi 部署pytorch模型

通过FastAPI框架部署Torchvision预训练模型

介绍

通过FastAPI框架部署Torchvision.Models预训练模型进行图像识别。

PyTorch框架中有一个非常重要的包:torchvision,它由3个子包组成,分别是:

  • torchvision.datasets
  • torchvision.models
  • torchvision.transforms

其中torchvision.models中包含了很多预训练模型,可以直接使用。由于国内的网络环境,可以通过coggle.club手动下载预训练模型镜像。

通常使用Flask框架为预训练模型创建API服务,但如果想做一个满足高并发的机器学习API服务,异步框架FastAPI是一个不错的选择。

相比Flask,FastAPI框架具有以下几大功能:

  • 异步web框架,支持asyncio
  • 拥有非常高的性能(归功于StarlettePydantic);
  • 通过不同的参数声明实现丰富的功能;
  • 自动类型检查,自动生成交互式文档,自动swagger UI
  • 自带快如闪电的异步服务器Uvicorn
Coding - api_server.py
  1. 加载所需包
import io
import json
from PIL import Image
from torchvision import models
import torchvision.transforms as transforms
from fastapi import FastAPI, File, UploadFile
import uvicorn
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  1. 初始化
app = FastAPI()
# 加载预训练模型
imagenet_class_index = json.load(open('imagenet_class_index.json'))
model = models.densenet121(pretrained=True)
model.eval()
  • 1
  • 2
  • 3
  • 4
  • 5
  1. 自定义函数
# 图片文件读取,输出Image.Image格式
def read_imagefile(file) -> Image.Image:
    image = Image.open(io.BytesIO(file))
    return image

# 图片预处理,torchvision.transforms转换Image格式为torch tensor
def transform_image(image_bytes: Image.Image):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    return my_transforms(image_bytes).unsqueeze(0)

# 定义预测函数,图片预处理->模型预测->预测结果转换
def get_prediction(image_bytes: Image.Image):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  1. 路由/predict
@app.post('/predict')
async def predict(file: UploadFile = File(...)):
    '''
    Parameters
    ----------
    file : UploadFile, optional
        DESCRIPTION. The default is an image file.

    Returns
    -------
    json : Response with list of dicts.
		Each dict contains class_id, class_name

    '''
    extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png")
    if not extension:
        return "Image must be jpg or png format!"

    img_bytes = read_imagefile(await file.read())
    class_id, class_name = get_prediction(image_bytes=img_bytes)
    return {'class_id': class_id, 'class_name': class_name}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  1. 运行uvicorn.run
if __name__ == "__main__":
    app_str = 'api_server:app'
    uvicorn.run(app_str, host='localhost', port=8000, debug=True, reload=True, workers=1)
  • 1
  • 2
  • 3
Coding - api_test.py

在这里插入图片描述

import requests

def test_request(image = 'images/dog.jpg'):
    resp = requests.post("http://localhost:8000/predict",
                         files={"file": open(image,'rb')})
    
    print(resp.json())

if __name__ == '__main__':
	test_request()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
swagger UI界面

http://127.0.0.1:8000/docs在这里插入图片描述
在这里插入图片描述

安装教程

# Clone the repo
$ git clone https://gitee.com/vencen/cv-fast-api.git
# 创建虚拟环境,安装依赖包

$ conda create -n venv python=3.8

#on windows
$ activate venv
#on linux
$ source activate venv

$ pip install -r requirements.txt
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
使用说明
  1. 创建.env文件

    # .env file example
    
    
    • 1
    • 2
  2. 启动虚拟环境

    # 启动虚拟环境
    #on windows
    $ activate venv
    #on linux
    $ source activate venv
    
    • 1
    • 2
    • 3
    • 4
    • 5
  3. 启动api

    # 启动api
    $ uvicorn server:app --reload
    
    • 1
    • 2
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/824965
推荐阅读
相关标签
  

闽ICP备14008679号