赞
踩
- 不是我非让你存list里,因为dataloader到时候会在这里取数据
- 按照人家要求来,不要耍个性,让整list咱就给人家整
- import os
-
- from matplotlib import pyplot as plt
- from torchvision import transforms, models, datasets
- import numpy as np
- import torch
- from PIL import Image
-
-
- def load_annotations(ann_file):
- data_infos = {}
- with open(ann_file) as f:
- samples = [x.strip().split(' ') for x in f.readlines()]
- for filename, gt_label in samples:
- data_infos[filename] = np.array(gt_label, dtype=np.int64)
- return data_infos
-
- img_label =load_annotations('./flower_data/train.txt')
- image_name = list(img_label.keys())
- label = list(img_label.values())
-
- data_dir = './flower_data/'
- train_dir = data_dir + '/train_filelist'
- valid_dir = data_dir + '/val_filelist'
-
- image_path = [os.path.join(train_dir,img) for img in image_name]
-
- from torch.utils.data import Dataset, DataLoader
-
-
- class FlowerDataset(Dataset):
- def __init__(self, root_dir, ann_file, transform=None):
- self.ann_file = ann_file
- self.root_dir = root_dir
- self.img_label = self.load_annotations()
- self.img = [os.path.join(self.root_dir, img) for img in list(self.img_label.keys())]
- self.label = [label for label in list(self.img_label.values())]
- self.transform = transform
-
- def __len__(self):
- return len(self.img)
-
- def __getitem__(self, idx):
- image = Image.open(self.img[idx])
- label = self.label[idx]
- if self.transform:
- image = self.transform(image)
- label = torch.from_numpy(np.array(label))
- return image, label
-
- def load_annotations(self):
- data_infos = {}
- with open(self.ann_file) as f:
- samples = [x.strip().split(' ') for x in f.readlines()]
- for filename, gt_label in samples:
- data_infos[filename] = np.array(gt_label, dtype=np.int64)
- return data_infos
-
-
- data_transforms = {
- 'train':
- transforms.Compose([
- transforms.Resize(64),
- transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
- transforms.CenterCrop(64),#从中心开始裁剪
- transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率
- transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
- transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
- transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
- transforms.ToTensor(),
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差
- ]),
- 'valid':
- transforms.Compose([
- transforms.Resize(64),
- transforms.CenterCrop(64),
- transforms.ToTensor(),
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- ]),
- }
-
- train_dataset = FlowerDataset(root_dir=train_dir, ann_file = './flower_data/train.txt', transform=data_transforms['train'])
- val_dataset = FlowerDataset(root_dir=valid_dir, ann_file = './flower_data/val.txt', transform=data_transforms['valid'])
- train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
- val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)
-
- image, label = next(iter(train_loader))
- sample = image[0].squeeze()
- sample = sample.permute((1, 2, 0)).numpy()
- sample *= [0.229, 0.224, 0.225]
- sample += [0.485, 0.456, 0.406]
- plt.imshow(sample)
- plt.show()
- print('Label is: {}'.format(label[0].numpy()))
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。