赞
踩
Flask是一个使用Python编写的轻量级Web应用框架,可扩展性很强,相较于Django框架,灵活度很高,开发成本底。它仅仅实现了Web应用的核心功能,Flask由两个主要依赖组成,提供路由、调试、Web服务器网关接口的Werkzeug 实现的和模板语言依赖的jinja2,其他的一切都可以由第三方库来完成。
在使用Flask之前需要安装一下,安装Flask非常简单只需要在在命令行输入
pip install flask即可
# 导入 Flask 类
from flask import Flask
# 创建了这个类的实例。第一个参数是应用模块或者包的名称。
app = Flask(__name__)
# 使用 route() 装饰器来告诉 Flask 触发函数的 URL
@app.route("/")
def hello():
return "Hello World!"
if __name__ == "__main__":
# 使用 run() 函数来运行本地服务器和我们的应用
app.run()
本文通过使用轻量级的WEB框架Flask来实现Python在服务端的部署CIFAR-10的图像分类。效果如下:
CIFAR-10是一个小型图像分类数据集,数据格式类似于MNIST手写数字数据集,在CIFAR-10数据中图片共有10个类别,分别为airplane、automobile、bird、cat、deer、dog、frog、horse、ship、truck。
对于CIFAR-10分类任务,PyTorch里的torchvision库提供了专门数据处理函数torchvision.datasets.CIFAR10,构建DataLoader代码如下:
import torchvision from torchvision import transforms import torch from config import data_folder, batch_size def create_dataset(data_folder, transform_train=None, transform_test=None): if transform_train is None: transform_train = transforms.Compose( [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize( (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) ) ] ) if transform_test is None: transform_test = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize( (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) ) ] ) trainset = torchvision.datasets.CIFAR10( root=data_folder, train=True, download=True, transform=transform_train ) trainloader = torch.utils.data.DataLoader( trainset, batch_size=batch_size, shuffle=True, num_workers=2 ) testset = torchvision.datasets.CIFAR10( root=data_folder, train=False, download=True, transform=transform_test ) testloader = torch.utils.data.DataLoader( testset, batch_size=batch_size, shuffle=False, num_workers=2 ) return trainloader, testloader
from torch import nn import torch.nn.functional as F # 定义残差块ResBlock class ResBlock(nn.Module): def __init__(self, inchannel, outchannel, stride=1): super(ResBlock, self).__init__() # 这里定义了残差块内连续的2个卷积层 self.left = nn.Sequential( nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False), nn.BatchNorm2d(outchannel), nn.ReLU(inplace=True), nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(outchannel) ) self.shortcut = nn.Sequential() if stride != 1 or inchannel != outchannel: # s
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。