赞
踩
定义自己的数据集类需要继承torch.utils.data中的Dataset类
主要实现两个方法,即__len__和__getitem__
- from torch.utils.data import Dataset
- class VOCDataSet(Dataset):
- #初始化
- def __init__(self):
- pass
-
- #返回数的长度
- def __len__(self):
- pass
-
- #返回样本和标签
- def __getitem__(self, idx):
- pass
以具体的例子进行演示
- from torch.utils.data import Dataset
- import os
- import torch
- import json
- from PIL import Image
- from lxml import etree
-
- class VOCDataSet(Dataset):
- """读取解析PASCAL VOC2007/2012数据集"""
-
- def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):
- assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
- # 增加容错能力
- if "VOCdevkit" in voc_root:
- self.root = os.path.join(voc_root, f"VOC{year}")
- else:
- self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
- self.img_root = os.path.join(self.root, "JPEGImages")
- self.annotations_root = os.path.join(self.root, "Annotations")
-
- # read train.txt or val.txt file
- txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
- assert os.path.exists(txt_path), "not found {} file.".format(txt_name)
-
- with open(txt_path) as read:
- xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
- for line in read.readlines() if len(line.strip()) > 0]
-
- self.xml_list = []
- # check file
- for xml_path in xml_list:
- if os.path.exists(xml_path) is False:
- print(f"Warning: not found '{xml_path}', skip this annotation file.")
- continue
-
- # check for targets
- with open(xml_path) as fid:
- xml_str = fid.read()
- xml = etree.fromstring(xml_str.encode("utf-8"))
-
- data = self.parse_xml_to_dict(xml)["annotation"]
- if "object" not in data:
- print(f"INFO: no objects in {xml_path}, skip this annotation file.")
- continue
-
- self.xml_list.append(xml_path)
-
- assert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path)
-
- # read class_indict
- json_file = './pascal_voc_classes.json'
- assert os.path.exists(json_file), "{} file not exist.".format(json_file)
- with open(json_file, 'r') as f:
- self.class_dict = json.load(f)
-
- self.transforms = transforms
-
- def __len__(self):
- return len(self.xml_list)
-
- def __getitem__(self, idx):
- # read xml
- xml_path = self.xml_list[idx]
- with open(xml_path) as fid:
- xml_str = fid.read()
- xml = etree.fromstring(xml_str.encode("utf-8"))
- data = self.parse_xml_to_dict(xml)["annotation"]
- img_path = os.path.join(self.img_root, data["filename"])
- image = Image.open(img_path)
- if image.format != "JPEG":
- raise ValueError("Image '{}' format not JPEG".format(img_path))
-
- boxes = []
- labels = []
- iscrowd = []
- assert "object" in data, "{} lack of object information.".format(xml_path)
- for obj in data["object"]:
- xmin = float(obj["bndbox"]["xmin"])
- xmax = float(obj["bndbox"]["xmax"])
- ymin = float(obj["bndbox"]["ymin"])
- ymax = float(obj["bndbox"]["ymax"])
-
- # 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
- if xmax <= xmin or ymax <= ymin:
- print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
- continue
-
- boxes.append([xmin, ymin, xmax, ymax])
- labels.append(self.class_dict[obj["name"]])
- if "difficult" in obj:
- iscrowd.append(int(obj["difficult"]))
- else:
- iscrowd.append(0)
-
- # convert everything into a torch.Tensor
- boxes = torch.as_tensor(boxes, dtype=torch.float32)
- labels = torch.as_tensor(labels, dtype=torch.int64)
- iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
- image_id = torch.tensor([idx])
- area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
-
- target = {}
- target["boxes"] = boxes
- target["labels"] = labels
- target["image_id"] = image_id
- target["area"] = area
- target["iscrowd"] = iscrowd
-
- if self.transforms is not None:
- image, target = self.transforms(image, target)
-
- return image, target
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。