当前位置:   article > 正文

Pytorch学习-利用Dataset类定义自己的数据集

Pytorch学习-利用Dataset类定义自己的数据集

定义自己的数据集类需要继承torch.utils.data中的Dataset类

主要实现两个方法,即__len__和__getitem__

  1. from torch.utils.data import Dataset
  2. class VOCDataSet(Dataset):
  3. #初始化
  4. def __init__(self):
  5. pass
  6. #返回数的长度
  7. def __len__(self):
  8. pass
  9. #返回样本和标签
  10. def __getitem__(self, idx):
  11. pass

以具体的例子进行演示

  1. from torch.utils.data import Dataset
  2. import os
  3. import torch
  4. import json
  5. from PIL import Image
  6. from lxml import etree
  7. class VOCDataSet(Dataset):
  8. """读取解析PASCAL VOC2007/2012数据集"""
  9. def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):
  10. assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
  11. # 增加容错能力
  12. if "VOCdevkit" in voc_root:
  13. self.root = os.path.join(voc_root, f"VOC{year}")
  14. else:
  15. self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
  16. self.img_root = os.path.join(self.root, "JPEGImages")
  17. self.annotations_root = os.path.join(self.root, "Annotations")
  18. # read train.txt or val.txt file
  19. txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
  20. assert os.path.exists(txt_path), "not found {} file.".format(txt_name)
  21. with open(txt_path) as read:
  22. xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
  23. for line in read.readlines() if len(line.strip()) > 0]
  24. self.xml_list = []
  25. # check file
  26. for xml_path in xml_list:
  27. if os.path.exists(xml_path) is False:
  28. print(f"Warning: not found '{xml_path}', skip this annotation file.")
  29. continue
  30. # check for targets
  31. with open(xml_path) as fid:
  32. xml_str = fid.read()
  33. xml = etree.fromstring(xml_str.encode("utf-8"))
  34. data = self.parse_xml_to_dict(xml)["annotation"]
  35. if "object" not in data:
  36. print(f"INFO: no objects in {xml_path}, skip this annotation file.")
  37. continue
  38. self.xml_list.append(xml_path)
  39. assert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path)
  40. # read class_indict
  41. json_file = './pascal_voc_classes.json'
  42. assert os.path.exists(json_file), "{} file not exist.".format(json_file)
  43. with open(json_file, 'r') as f:
  44. self.class_dict = json.load(f)
  45. self.transforms = transforms
  46. def __len__(self):
  47. return len(self.xml_list)
  48. def __getitem__(self, idx):
  49. # read xml
  50. xml_path = self.xml_list[idx]
  51. with open(xml_path) as fid:
  52. xml_str = fid.read()
  53. xml = etree.fromstring(xml_str.encode("utf-8"))
  54. data = self.parse_xml_to_dict(xml)["annotation"]
  55. img_path = os.path.join(self.img_root, data["filename"])
  56. image = Image.open(img_path)
  57. if image.format != "JPEG":
  58. raise ValueError("Image '{}' format not JPEG".format(img_path))
  59. boxes = []
  60. labels = []
  61. iscrowd = []
  62. assert "object" in data, "{} lack of object information.".format(xml_path)
  63. for obj in data["object"]:
  64. xmin = float(obj["bndbox"]["xmin"])
  65. xmax = float(obj["bndbox"]["xmax"])
  66. ymin = float(obj["bndbox"]["ymin"])
  67. ymax = float(obj["bndbox"]["ymax"])
  68. # 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
  69. if xmax <= xmin or ymax <= ymin:
  70. print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
  71. continue
  72. boxes.append([xmin, ymin, xmax, ymax])
  73. labels.append(self.class_dict[obj["name"]])
  74. if "difficult" in obj:
  75. iscrowd.append(int(obj["difficult"]))
  76. else:
  77. iscrowd.append(0)
  78. # convert everything into a torch.Tensor
  79. boxes = torch.as_tensor(boxes, dtype=torch.float32)
  80. labels = torch.as_tensor(labels, dtype=torch.int64)
  81. iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
  82. image_id = torch.tensor([idx])
  83. area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
  84. target = {}
  85. target["boxes"] = boxes
  86. target["labels"] = labels
  87. target["image_id"] = image_id
  88. target["area"] = area
  89. target["iscrowd"] = iscrowd
  90. if self.transforms is not None:
  91. image, target = self.transforms(image, target)
  92. return image, target

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/606224
推荐阅读
相关标签
  

闽ICP备14008679号