当前位置:   article > 正文

Flask深度学习模型服务端部署_flask 部署

flask 部署

1、Flask框架简介

Flask是一个使用Python编写的轻量级Web应用框架,可扩展性很强,相较于Django框架,灵活度很高,开发成本底。它仅仅实现了Web应用的核心功能,Flask由两个主要依赖组成,提供路由、调试、Web服务器网关接口的Werkzeug 实现的和模板语言依赖的jinja2,其他的一切都可以由第三方库来完成。

2、Flask框架安装

在使用Flask之前需要安装一下,安装Flask非常简单只需要在在命令行输入

pip install flask即可

3、Flask实现 Hello World案例

# 导入 Flask 类
from flask import Flask
# 创建了这个类的实例。第一个参数是应用模块或者包的名称。
app = Flask(__name__)

# 使用 route() 装饰器来告诉 Flask 触发函数的 URL
@app.route("/")
def hello():
    return "Hello World!"
 
if __name__ == "__main__":
    # 使用 run() 函数来运行本地服务器和我们的应用
    app.run()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

4、Flask深度学习模型部署

本文通过使用轻量级的WEB框架Flask来实现Python在服务端的部署CIFAR-10的图像分类。效果如下:

CIFAR-10是一个小型图像分类数据集,数据格式类似于MNIST手写数字数据集,在CIFAR-10数据中图片共有10个类别,分别为airplane、automobile、bird、cat、deer、dog、frog、horse、ship、truck。

4.1 数据加载

对于CIFAR-10分类任务,PyTorch里的torchvision库提供了专门数据处理函数torchvision.datasets.CIFAR10,构建DataLoader代码如下:

import torchvision
from torchvision import transforms
import torch
from config import data_folder, batch_size


def create_dataset(data_folder, transform_train=None, transform_test=None):
    if transform_train is None:
        transform_train = transforms.Compose(
            [
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.4914, 0.4822, 0.4465),
                    (0.2023, 0.1994, 0.2010)
                )
            ]
        )
    if transform_test is None:
        transform_test = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.4914, 0.4822, 0.4465),
                    (0.2023, 0.1994, 0.2010)
                )
            ]
        )

    trainset = torchvision.datasets.CIFAR10(
        root=data_folder, train=True, download=True, transform=transform_train
    )

    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=batch_size, shuffle=True, num_workers=2
    )

    testset = torchvision.datasets.CIFAR10(
        root=data_folder, train=False, download=True, transform=transform_test
    )

    testloader = torch.utils.data.DataLoader(
        testset, batch_size=batch_size, shuffle=False, num_workers=2
    )

    return trainloader, testloader
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47

4.2 构建模型resent18实现分类

from torch import nn
import torch.nn.functional as F


# 定义残差块ResBlock
class ResBlock(nn.Module):
    def __init__(self, inchannel, outchannel, stride=1):
        super(ResBlock, self).__init__()
        # 这里定义了残差块内连续的2个卷积层
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(outchannel)
        )
        self.shortcut = nn.Sequential()
        if stride != 1 or inchannel != outchannel:
            # s
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/604589
推荐阅读
相关标签
  

闽ICP备14008679号