赞
踩
torchvision
是 PyTorch 的一个官方库,主要用于处理计算机视觉任务。提供了许多常用的数据集、模型架构、图像转换等功能,使得计算机视觉任务的开发变得更加高效和便捷。以下是对 torchvision
主要功能的详细介绍:
torchvision
提供了许多常用的计算机视觉数据集,如 CIFAR-10、MNIST、ImageNet 等。这些数据集可以直接通过 torchvision.datasets
模块加载。
from torchvision import datasets
from torch.utils.data import DataLoader
# 加载 CIFAR-10 数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True)
# 使用 DataLoader 加载数据
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
torchvision.transforms
模块提供了许多常用的图像转换操作,如裁剪、缩放、旋转、翻转等。这些转换操作可以单独使用,也可以组合使用。
from torchvision import transforms
# 定义转换操作
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 应用转换操作
train_dataset.transform = transform
test_dataset.transform = transform
torchvision.models
模块提供了许多常用的预训练模型,如 ResNet、VGG、AlexNet、DenseNet 等。这些模型可以直接用于迁移学习或作为基准模型。
from torchvision import models
import torch.nn as nn
# 加载预训练的 ResNet-50 模型
model = models.resnet50(pretrained=True)
# 修改最后一层以适应新的分类任务
num_classes = 10
model.fc = nn.Linear(model.fc.in_features, num_classes)
torch.utils.data.DataLoader
是一个实用的数据加载器,可以与 torchvision
提供的数据集一起使用,方便地进行批量加载和数据迭代。
from torch.utils.data import DataLoader
# 使用 DataLoader 加载数据
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# 训练模型
for images, labels in train_loader:
# 训练代码
pass
如果需要使用自定义数据集,可以继承 torch.utils.data.Dataset
类,并实现 __len__
和 __getitem__
方法。
from torch.utils.data import Dataset from PIL import Image import os class CustomDataset(Dataset): def __init__(self, root_dir, transform=None): self.root_dir = root_dir self.transform = transform self.images = os.listdir(root_dir) def __len__(self): return len(self.images) def __getitem__(self, idx): img_path = os.path.join(self.root_dir, self.images[idx]) image = Image.open(img_path) if self.transform: image = self.transform(image) return image # 使用自定义数据集 custom_dataset = CustomDataset(root_dir='path/to/dataset', transform=transform) custom_loader = DataLoader(custom_dataset, batch_size=64, shuffle=True)
torchvision
还提供了一些用于可视化的工具,如 torchvision.utils.make_grid
可以将多个图像拼接成一个网格图像。
import matplotlib.pyplot as plt
from torchvision import utils
# 获取一批图像
images, labels = next(iter(train_loader))
# 将图像拼接成网格
grid = utils.make_grid(images)
# 显示图像
plt.imshow(grid.permute(1, 2, 0))
plt.show()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。