当前位置:   article > 正文

pytorch Dataset, DataLoader产生自定义的训练数据_pytorch 自定义 dataloader nii格式

pytorch 自定义 dataloader nii格式

pytorch Dataset, DataLoader产生自定义的训练数据


目录

pytorch Dataset, DataLoader产生自定义的训练数据

1. torch.utils.data.Dataset

2. torch.utils.data.DataLoader

3. 使用Dataset, DataLoader产生自定义训练数据

3.1 自定义Dataset

3.2 DataLoader产生批训练数据

3.3 附件:image_processing.py

3.4 完整的代码


1. torch.utils.data.Dataset

  datasets这是一个pytorch定义的dataset的源码集合。下面是一个自定义Datasets的基本框架,初始化放在__init__()中,其中__getitem__()和__len__()两个方法是必须重写的。__getitem__()返回训练数据,如图片和label,而__len__()返回数据长度。

  1. class CustomDataset(data.Dataset):#需要继承data.Dataset
  2. def __init__(self):
  3. # TODO
  4. # 1. Initialize file path or list of file names.
  5. pass
  6. def __getitem__(self, index):
  7. # TODO
  8. # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
  9. # 2. Preprocess the data (e.g. torchvision.Transform).
  10. # 3. Return a data pair (e.g. image and label).
  11. #这里需要注意的是,第一步:read one data,是一个data
  12. pass
  13. def __len__(self):
  14. # You should change 0 to the total size of your dataset.
  15. return 0

2. torch.utils.data.DataLoader

DataLoader(object)可用参数:

  1. dataset(Dataset): 传入的数据集
  2. batch_size(int, optional): 每个batch有多少个样本
  3. shuffle(bool, optional): 在每个epoch开始的时候,对数据进行重新排序
  4. sampler(Sampler, optional): 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False
  5. batch_sampler(Sampler, optional): 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)
  6. num_workers (int, optional): 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)
  7. collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数
  8. pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.
  9. drop_last (bool, optional):如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了。 如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。
  10. timeout(numeric, optional):如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0
  11. worker_init_fn (callable, optional): 每个worker初始化函数 If not None, this will be called on eachworker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None) 

3. 使用Dataset, DataLoader产生自定义训练数据

假设TXT文件保存了数据的图片和label,格式如下:第一列是图片的名字,第二列是label

  1. 0.jpg 0
  2. 1.jpg 1
  3. 2.jpg 2
  4. 3.jpg 3
  5. 4.jpg 4
  6. 5.jpg 5
  7. 6.jpg 6
  8. 7.jpg 7
  9. 8.jpg 8
  10. 9.jpg 9

也可以是多标签的数据,如:

  1. 0.jpg 0 10
  2. 1.jpg 1 11
  3. 2.jpg 2 12
  4. 3.jpg 3 13
  5. 4.jpg 4 14
  6. 5.jpg 5 15
  7. 6.jpg 6 16
  8. 7.jpg 7 17
  9. 8.jpg 8 18
  10. 9.jpg 9 19

图库十张原始图片放在./dataset/images目录下,然后我们就可以自定义一个Dataset解析这些数据并读取图片,再使用DataLoader类产生batch的训练数据


3.1 自定义Dataset

首先先自定义一个TorchDataset类,用于读取图片数据,产生标签:

注意初始化函数:

  1. import torch
  2. from torch.autograd import Variable
  3. from torchvision import transforms
  4. from torch.utils.data import Dataset, DataLoader
  5. import numpy as np
  6. from utils import image_processing
  7. import os
  8. class TorchDataset(Dataset):
  9. def __init__(self, filename, image_dir, resize_height=256, resize_width=256, repeat=1):
  10. '''
  11. :param filename: 数据文件TXT:格式:imge_name.jpg label1_id labe2_id
  12. :param image_dir: 图片路径:image_dir+imge_name.jpg构成图片的完整路径
  13. :param resize_height 为None时,不进行缩放
  14. :param resize_width 为None时,不进行缩放,
  15. PS:当参数resize_height或resize_width其中一个为None时,可实现等比例缩放
  16. :param repeat: 所有样本数据重复次数,默认循环一次,当repeat为None时,表示无限循环<sys.maxsize
  17. '''
  18. self.image_label_list = self.read_file(filename)
  19. self.image_dir = image_dir
  20. self.len = len(self.image_label_list)
  21. self.repeat = repeat
  22. self.resize_height = resize_height
  23. self.resize_width = resize_width
  24. # 相关预处理的初始化
  25. '''class torchvision.transforms.ToTensor'''
  26. # 把shape=(H,W,C)的像素值范围为[0, 255]的PIL.Image或者numpy.ndarray数据
  27. # 转换成shape=(C,H,W)的像素数据,并且被归一化到[0.0, 1.0]的torch.FloatTensor类型。
  28. self.toTensor = transforms.ToTensor()
  29. '''class torchvision.transforms.Normalize(mean, std)
  30. 此转换类作用于torch. * Tensor,给定均值(R, G, B) 和标准差(R, G, B),
  31. 用公式channel = (channel - mean) / std进行规范化。
  32. '''
  33. # self.normalize=transforms.Normalize()
  34. def __getitem__(self, i):
  35. index = i % self.len
  36. # print("i={},index={}".format(i, index))
  37. image_name, label = self.image_label_list[index]
  38. image_path = os.path.join(self.image_dir, image_name)
  39. img = self.load_data(image_path, self.resize_height, self.resize_width, normalization=False)
  40. img = self.data_preproccess(img)
  41. label=np.array(label)
  42. return img, label
  43. def __len__(self):
  44. if self.repeat == None:
  45. data_len = 10000000
  46. else:
  47. data_len = len(self.image_label_list) * self.repeat
  48. return data_len
  49. def read_file(self, filename):
  50. image_label_list = []
  51. with open(filename, 'r') as f:
  52. lines = f.readlines()
  53. for line in lines:
  54. # rstrip:用来去除结尾字符、空白符(包括\n、\r、\t、' ',即:换行、回车、制表符、空格)
  55. content = line.rstrip().split(' ')
  56. name = content[0]
  57. labels = []
  58. for value in content[1:]:
  59. labels.append(int(value))
  60. image_label_list.append((name, labels))
  61. return image_label_list
  62. def load_data(self, path, resize_height, resize_width, normalization):
  63. '''
  64. 加载数据
  65. :param path:
  66. :param resize_height:
  67. :param resize_width:
  68. :param normalization: 是否归一化
  69. :return:
  70. '''
  71. image = image_processing.read_image(path, resize_height, resize_width, normalization)
  72. return image
  73. def data_preproccess(self, data):
  74. '''
  75. 数据预处理
  76. :param data:
  77. :return:
  78. '''
  79. data = self.toTensor(data)
  80. return data

3.2 DataLoader产生批训练数据

  1. if __name__=='__main__':
  2. train_filename="../dataset/train.txt"
  3. # test_filename="../dataset/test.txt"
  4. image_dir='../dataset/images'
  5. epoch_num=2 #总样本循环次数
  6. batch_size=7 #训练时的一组数据的大小
  7. train_data_nums=10
  8. max_iterate=int((train_data_nums+batch_size-1)/batch_size*epoch_num) #总迭代次数
  9. train_data = TorchDataset(filename=train_filename, image_dir=image_dir,repeat=1)
  10. # test_data = TorchDataset(filename=test_filename, image_dir=image_dir,repeat=1)
  11. train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False)
  12. # test_loader = DataLoader(dataset=test_data, batch_size=batch_size,shuffle=False)
  13. # [1]使用epoch方法迭代,TorchDataset的参数repeat=1
  14. for epoch in range(epoch_num):
  15. for batch_image, batch_label in train_loader:
  16. image=batch_image[0,:]
  17. image=image.numpy()#image=np.array(image)
  18. image = image.transpose(1, 2, 0) # 通道由[c,h,w]->[h,w,c]
  19. image_processing.cv_show_image("image",image)
  20. print("batch_image.shape:{},batch_label:{}".format(batch_image.shape,batch_label))
  21. # batch_x, batch_y = Variable(batch_x), Variable(batch_y)

上面的迭代代码是通过两个for实现,其中参数epoch_num表示总样本循环次数,比如epoch_num=2,那就是所有样本循环迭代2次。但这会出现一个问题,当样本总数train_data_nums与batch_size不能整取时,最后一个batch会少于规定batch_size的大小,比如这里样本总数train_data_nums=10,batch_size=7,第一次迭代会产生7个样本,第二次迭代会因为样本不足,只能产生3个样本。

我们希望,每次迭代都会产生相同大小的batch数据,因此可以如下迭代:注意本人在构造TorchDataset类时,就已经考虑循环迭代的方法,因此,你现在只需修改repeat为None时,就表示无限循环了,调用方法如下:

  1. '''
  2. 下面两种方式,TorchDataset设置repeat=None可以实现无限循环,退出循环由max_iterate设定
  3. '''
  4. train_data = TorchDataset(filename=train_filename, image_dir=image_dir,repeat=None)
  5. train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False)
  6. # [2]第2种迭代方法
  7. for step, (batch_image, batch_label) in enumerate(train_loader):
  8. image=batch_image[0,:]
  9. image=image.numpy()#image=np.array(image)
  10. image = image.transpose(1, 2, 0) # 通道由[c,h,w]->[h,w,c]
  11. image_processing.cv_show_image("image",image)
  12. print("step:{},batch_image.shape:{},batch_label:{}".format(step,batch_image.shape,batch_label))
  13. # batch_x, batch_y = Variable(batch_x), Variable(batch_y)
  14. if step>=max_iterate:
  15. break
  16. # [3]第3种迭代方法
  17. # for step in range(max_iterate):
  18. # batch_image, batch_label=train_loader.__iter__().__next__()
  19. # image=batch_image[0,:]
  20. # image=image.numpy()#image=np.array(image)
  21. # image = image.transpose(1, 2, 0) # 通道由[c,h,w]->[h,w,c]
  22. # image_processing.cv_show_image("image",image)
  23. # print("batch_image.shape:{},batch_label:{}".format(batch_image.shape,batch_label))
  24. # # batch_x, batch_y = Variable(batch_x), Variable(batch_y)

3.3 附件:image_processing.py

上面代码,用到image_processing,这是本人封装好的图像处理包,包含读取图片,画图等基本方法:

  1. # -*-coding: utf-8 -*-
  2. """
  3. @Project: IntelligentManufacture
  4. @File : image_processing.py
  5. @Author : panjq
  6. @E-mail : pan_jinquan@163.com
  7. @Date : 2019-02-14 15:34:50
  8. """
  9. import os
  10. import glob
  11. import cv2
  12. import numpy as np
  13. import matplotlib.pyplot as plt
  14. def show_image(title, image):
  15. '''
  16. 调用matplotlib显示RGB图片
  17. :param title: 图像标题
  18. :param image: 图像的数据
  19. :return:
  20. '''
  21. # plt.figure("show_image")
  22. # print(image.dtype)
  23. plt.imshow(image)
  24. plt.axis('on') # 关掉坐标轴为 off
  25. plt.title(title) # 图像题目
  26. plt.show()
  27. def cv_show_image(title, image):
  28. '''
  29. 调用OpenCV显示RGB图片
  30. :param title: 图像标题
  31. :param image: 输入RGB图像
  32. :return:
  33. '''
  34. channels=image.shape[-1]
  35. if channels==3:
  36. image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # 将BGR转为RGB
  37. cv2.imshow(title,image)
  38. cv2.waitKey(0)
  39. def read_image(filename, resize_height=None, resize_width=None, normalization=False):
  40. '''
  41. 读取图片数据,默认返回的是uint8,[0,255]
  42. :param filename:
  43. :param resize_height:
  44. :param resize_width:
  45. :param normalization:是否归一化到[0.,1.0]
  46. :return: 返回的RGB图片数据
  47. '''
  48. bgr_image = cv2.imread(filename)
  49. # bgr_image = cv2.imread(filename,cv2.IMREAD_IGNORE_ORIENTATION|cv2.IMREAD_COLOR)
  50. if bgr_image is None:
  51. print("Warning:不存在:{}", filename)
  52. return None
  53. if len(bgr_image.shape) == 2: # 若是灰度图则转为三通道
  54. print("Warning:gray image", filename)
  55. bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR)
  56. rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB) # 将BGR转为RGB
  57. # show_image(filename,rgb_image)
  58. # rgb_image=Image.open(filename)
  59. rgb_image = resize_image(rgb_image,resize_height,resize_width)
  60. rgb_image = np.asanyarray(rgb_image)
  61. if normalization:
  62. # 不能写成:rgb_image=rgb_image/255
  63. rgb_image = rgb_image / 255.0
  64. # show_image("src resize image",image)
  65. return rgb_image
  66. def fast_read_image_roi(filename, orig_rect, ImreadModes=cv2.IMREAD_COLOR, normalization=False):
  67. '''
  68. 快速读取图片的方法
  69. :param filename: 图片路径
  70. :param orig_rect:原始图片的感兴趣区域rect
  71. :param ImreadModes: IMREAD_UNCHANGED
  72. IMREAD_GRAYSCALE
  73. IMREAD_COLOR
  74. IMREAD_ANYDEPTH
  75. IMREAD_ANYCOLOR
  76. IMREAD_LOAD_GDAL
  77. IMREAD_REDUCED_GRAYSCALE_2
  78. IMREAD_REDUCED_COLOR_2
  79. IMREAD_REDUCED_GRAYSCALE_4
  80. IMREAD_REDUCED_COLOR_4
  81. IMREAD_REDUCED_GRAYSCALE_8
  82. IMREAD_REDUCED_COLOR_8
  83. IMREAD_IGNORE_ORIENTATION
  84. :param normalization: 是否归一化
  85. :return: 返回感兴趣区域ROI
  86. '''
  87. # 当采用IMREAD_REDUCED模式时,对应rect也需要缩放
  88. scale=1
  89. if ImreadModes == cv2.IMREAD_REDUCED_COLOR_2 or ImreadModes == cv2.IMREAD_REDUCED_COLOR_2:
  90. scale=1/2
  91. elif ImreadModes == cv2.IMREAD_REDUCED_GRAYSCALE_4 or ImreadModes == cv2.IMREAD_REDUCED_COLOR_4:
  92. scale=1/4
  93. elif ImreadModes == cv2.IMREAD_REDUCED_GRAYSCALE_8 or ImreadModes == cv2.IMREAD_REDUCED_COLOR_8:
  94. scale=1/8
  95. rect = np.array(orig_rect)*scale
  96. rect = rect.astype(int).tolist()
  97. bgr_image = cv2.imread(filename,flags=ImreadModes)
  98. if bgr_image is None:
  99. print("Warning:不存在:{}", filename)
  100. return None
  101. if len(bgr_image.shape) == 3: #
  102. rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB) # 将BGR转为RGB
  103. else:
  104. rgb_image=bgr_image #若是灰度图
  105. rgb_image = np.asanyarray(rgb_image)
  106. if normalization:
  107. # 不能写成:rgb_image=rgb_image/255
  108. rgb_image = rgb_image / 255.0
  109. roi_image=get_rect_image(rgb_image , rect)
  110. # show_image_rect("src resize image",rgb_image,rect)
  111. # cv_show_image("reROI",roi_image)
  112. return roi_image
  113. def resize_image(image,resize_height, resize_width):
  114. '''
  115. :param image:
  116. :param resize_height:
  117. :param resize_width:
  118. :return:
  119. '''
  120. image_shape=np.shape(image)
  121. height=image_shape[0]
  122. width=image_shape[1]
  123. if (resize_height is None) and (resize_width is None):#错误写法:resize_height and resize_width is None
  124. return image
  125. if resize_height is None:
  126. resize_height=int(height*resize_width/width)
  127. elif resize_width is None:
  128. resize_width=int(width*resize_height/height)
  129. image = cv2.resize(image, dsize=(resize_width, resize_height))
  130. return image
  131. def scale_image(image,scale):
  132. '''
  133. :param image:
  134. :param scale: (scale_w,scale_h)
  135. :return:
  136. '''
  137. image = cv2.resize(image,dsize=None, fx=scale[0],fy=scale[1])
  138. return image
  139. def get_rect_image(image,rect):
  140. '''
  141. :param image:
  142. :param rect: [x,y,w,h]
  143. :return:
  144. '''
  145. x, y, w, h=rect
  146. cut_img = image[y:(y+ h),x:(x+w)]
  147. return cut_img
  148. def scale_rect(orig_rect,orig_shape,dest_shape):
  149. '''
  150. 对图像进行缩放时,对应的rectangle也要进行缩放
  151. :param orig_rect: 原始图像的rect=[x,y,w,h]
  152. :param orig_shape: 原始图像的维度shape=[h,w]
  153. :param dest_shape: 缩放后图像的维度shape=[h,w]
  154. :return: 经过缩放后的rectangle
  155. '''
  156. new_x=int(orig_rect[0]*dest_shape[1]/orig_shape[1])
  157. new_y=int(orig_rect[1]*dest_shape[0]/orig_shape[0])
  158. new_w=int(orig_rect[2]*dest_shape[1]/orig_shape[1])
  159. new_h=int(orig_rect[3]*dest_shape[0]/orig_shape[0])
  160. dest_rect=[new_x,new_y,new_w,new_h]
  161. return dest_rect
  162. def show_image_rect(win_name,image,rect):
  163. '''
  164. :param win_name:
  165. :param image:
  166. :param rect:
  167. :return:
  168. '''
  169. x, y, w, h=rect
  170. point1=(x,y)
  171. point2=(x+w,y+h)
  172. cv2.rectangle(image, point1, point2, (0, 0, 255), thickness=2)
  173. cv_show_image(win_name, image)
  174. def rgb_to_gray(image):
  175. image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
  176. return image
  177. def save_image(image_path, rgb_image,toUINT8=True):
  178. if toUINT8:
  179. rgb_image = np.asanyarray(rgb_image * 255, dtype=np.uint8)
  180. if len(rgb_image.shape) == 2: # 若是灰度图则转为三通道
  181. bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_GRAY2BGR)
  182. else:
  183. bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
  184. cv2.imwrite(image_path, bgr_image)
  185. def combime_save_image(orig_image, dest_image, out_dir,name,prefix):
  186. '''
  187. 命名标准:out_dir/name_prefix.jpg
  188. :param orig_image:
  189. :param dest_image:
  190. :param image_path:
  191. :param out_dir:
  192. :param prefix:
  193. :return:
  194. '''
  195. dest_path = os.path.join(out_dir, name + "_"+prefix+".jpg")
  196. save_image(dest_path, dest_image)
  197. dest_image = np.hstack((orig_image, dest_image))
  198. save_image(os.path.join(out_dir, "{}_src_{}.jpg".format(name,prefix)), dest_image)

3.4 完整的代码

  1. # -*-coding: utf-8 -*-
  2. """
  3. @Project: pytorch-learning-tutorials
  4. @File : dataset.py
  5. @Author : panjq
  6. @E-mail : pan_jinquan@163.com
  7. @Date : 2019-03-07 18:45:06
  8. """
  9. import torch
  10. from torch.autograd import Variable
  11. from torchvision import transforms
  12. from torch.utils.data import Dataset, DataLoader
  13. import numpy as np
  14. from utils import image_processing
  15. import os
  16. class TorchDataset(Dataset):
  17. def __init__(self, filename, image_dir, resize_height=256, resize_width=256, repeat=1):
  18. '''
  19. :param filename: 数据文件TXT:格式:imge_name.jpg label1_id labe2_id
  20. :param image_dir: 图片路径:image_dir+imge_name.jpg构成图片的完整路径
  21. :param resize_height 为None时,不进行缩放
  22. :param resize_width 为None时,不进行缩放,
  23. PS:当参数resize_height或resize_width其中一个为None时,可实现等比例缩放
  24. :param repeat: 所有样本数据重复次数,默认循环一次,当repeat为None时,表示无限循环<sys.maxsize
  25. '''
  26. self.image_label_list = self.read_file(filename)
  27. self.image_dir = image_dir
  28. self.len = len(self.image_label_list)
  29. self.repeat = repeat
  30. self.resize_height = resize_height
  31. self.resize_width = resize_width
  32. # 相关预处理的初始化
  33. '''class torchvision.transforms.ToTensor'''
  34. # 把shape=(H,W,C)的像素值范围为[0, 255]的PIL.Image或者numpy.ndarray数据
  35. # 转换成shape=(C,H,W)的像素数据,并且被归一化到[0.0, 1.0]的torch.FloatTensor类型。
  36. self.toTensor = transforms.ToTensor()
  37. '''class torchvision.transforms.Normalize(mean, std)
  38. 此转换类作用于torch. * Tensor,给定均值(R, G, B) 和标准差(R, G, B),
  39. 用公式channel = (channel - mean) / std进行规范化。
  40. '''
  41. # self.normalize=transforms.Normalize()
  42. def __getitem__(self, i):
  43. index = i % self.len
  44. # print("i={},index={}".format(i, index))
  45. image_name, label = self.image_label_list[index]
  46. image_path = os.path.join(self.image_dir, image_name)
  47. img = self.load_data(image_path, self.resize_height, self.resize_width, normalization=False)
  48. img = self.data_preproccess(img)
  49. label=np.array(label)
  50. return img, label
  51. def __len__(self):
  52. if self.repeat == None:
  53. data_len = 10000000
  54. else:
  55. data_len = len(self.image_label_list) * self.repeat
  56. return data_len
  57. def read_file(self, filename):
  58. image_label_list = []
  59. with open(filename, 'r') as f:
  60. lines = f.readlines()
  61. for line in lines:
  62. # rstrip:用来去除结尾字符、空白符(包括\n、\r、\t、' ',即:换行、回车、制表符、空格)
  63. content = line.rstrip().split(' ')
  64. name = content[0]
  65. labels = []
  66. for value in content[1:]:
  67. labels.append(int(value))
  68. image_label_list.append((name, labels))
  69. return image_label_list
  70. def load_data(self, path, resize_height, resize_width, normalization):
  71. '''
  72. 加载数据
  73. :param path:
  74. :param resize_height:
  75. :param resize_width:
  76. :param normalization: 是否归一化
  77. :return:
  78. '''
  79. image = image_processing.read_image(path, resize_height, resize_width, normalization)
  80. return image
  81. def data_preproccess(self, data):
  82. '''
  83. 数据预处理
  84. :param data:
  85. :return:
  86. '''
  87. data = self.toTensor(data)
  88. return data
  89. if __name__=='__main__':
  90. train_filename="../dataset/train.txt"
  91. # test_filename="../dataset/test.txt"
  92. image_dir='../dataset/images'
  93. epoch_num=2 #总样本循环次数
  94. batch_size=7 #训练时的一组数据的大小
  95. train_data_nums=10
  96. max_iterate=int((train_data_nums+batch_size-1)/batch_size*epoch_num) #总迭代次数
  97. train_data = TorchDataset(filename=train_filename, image_dir=image_dir,repeat=1)
  98. # test_data = TorchDataset(filename=test_filename, image_dir=image_dir,repeat=1)
  99. train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False)
  100. # test_loader = DataLoader(dataset=test_data, batch_size=batch_size,shuffle=False)
  101. # [1]使用epoch方法迭代,TorchDataset的参数repeat=1
  102. for epoch in range(epoch_num):
  103. for batch_image, batch_label in train_loader:
  104. image=batch_image[0,:]
  105. image=image.numpy()#image=np.array(image)
  106. image = image.transpose(1, 2, 0) # 通道由[c,h,w]->[h,w,c]
  107. image_processing.cv_show_image("image",image)
  108. print("batch_image.shape:{},batch_label:{}".format(batch_image.shape,batch_label))
  109. # batch_x, batch_y = Variable(batch_x), Variable(batch_y)
  110. '''
  111. 下面两种方式,TorchDataset设置repeat=None可以实现无限循环,退出循环由max_iterate设定
  112. '''
  113. train_data = TorchDataset(filename=train_filename, image_dir=image_dir,repeat=None)
  114. train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False)
  115. # [2]第2种迭代方法
  116. for step, (batch_image, batch_label) in enumerate(train_loader):
  117. image=batch_image[0,:]
  118. image=image.numpy()#image=np.array(image)
  119. image = image.transpose(1, 2, 0) # 通道由[c,h,w]->[h,w,c]
  120. image_processing.cv_show_image("image",image)
  121. print("step:{},batch_image.shape:{},batch_label:{}".format(step,batch_image.shape,batch_label))
  122. # batch_x, batch_y = Variable(batch_x), Variable(batch_y)
  123. if step>=max_iterate:
  124. break
  125. # [3]第3种迭代方法
  126. # for step in range(max_iterate):
  127. # batch_image, batch_label=train_loader.__iter__().__next__()
  128. # image=batch_image[0,:]
  129. # image=image.numpy()#image=np.array(image)
  130. # image = image.transpose(1, 2, 0) # 通道由[c,h,w]->[h,w,c]
  131. # image_processing.cv_show_image("image",image)
  132. # print("batch_image.shape:{},batch_label:{}".format(batch_image.shape,batch_label))
  133. # # batch_x, batch_y = Variable(batch_x), Variable(batch_y)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/455101
推荐阅读
相关标签
  

闽ICP备14008679号