赞
踩
通过FastAPI框架部署Torchvision.Models预训练模型进行图像识别。
PyTorch框架中有一个非常重要的包:torchvision,它由3个子包组成,分别是:
其中torchvision.models
中包含了很多预训练模型,可以直接使用。由于国内的网络环境,可以通过coggle.club手动下载预训练模型镜像。
通常使用Flask
框架为预训练模型创建API服务,但如果想做一个满足高并发的机器学习API服务,异步框架FastAPI
是一个不错的选择。
相比Flask,FastAPI框架具有以下几大功能:
asyncio
;Starlette
和Pydantic
);swagger UI
;Uvicorn
;import io
import json
from PIL import Image
from torchvision import models
import torchvision.transforms as transforms
from fastapi import FastAPI, File, UploadFile
import uvicorn
app = FastAPI()
# 加载预训练模型
imagenet_class_index = json.load(open('imagenet_class_index.json'))
model = models.densenet121(pretrained=True)
model.eval()
# 图片文件读取,输出Image.Image格式 def read_imagefile(file) -> Image.Image: image = Image.open(io.BytesIO(file)) return image # 图片预处理,torchvision.transforms转换Image格式为torch tensor def transform_image(image_bytes: Image.Image): my_transforms = transforms.Compose([transforms.Resize(255), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) return my_transforms(image_bytes).unsqueeze(0) # 定义预测函数,图片预处理->模型预测->预测结果转换 def get_prediction(image_bytes: Image.Image): tensor = transform_image(image_bytes=image_bytes) outputs = model.forward(tensor) _, y_hat = outputs.max(1) predicted_idx = str(y_hat.item()) return imagenet_class_index[predicted_idx]
/predict
@app.post('/predict') async def predict(file: UploadFile = File(...)): ''' Parameters ---------- file : UploadFile, optional DESCRIPTION. The default is an image file. Returns ------- json : Response with list of dicts. Each dict contains class_id, class_name ''' extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png") if not extension: return "Image must be jpg or png format!" img_bytes = read_imagefile(await file.read()) class_id, class_name = get_prediction(image_bytes=img_bytes) return {'class_id': class_id, 'class_name': class_name}
uvicorn.run
if __name__ == "__main__":
app_str = 'api_server:app'
uvicorn.run(app_str, host='localhost', port=8000, debug=True, reload=True, workers=1)
import requests
def test_request(image = 'images/dog.jpg'):
resp = requests.post("http://localhost:8000/predict",
files={"file": open(image,'rb')})
print(resp.json())
if __name__ == '__main__':
test_request()
swagger UI
界面http://127.0.0.1:8000/docs
# Clone the repo
$ git clone https://gitee.com/vencen/cv-fast-api.git
# 创建虚拟环境,安装依赖包
$ conda create -n venv python=3.8
#on windows
$ activate venv
#on linux
$ source activate venv
$ pip install -r requirements.txt
创建.env
文件
# .env file example
启动虚拟环境
# 启动虚拟环境
#on windows
$ activate venv
#on linux
$ source activate venv
启动api
# 启动api
$ uvicorn server:app --reload
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。