赞
踩
在打包自己处理的数据时有两种方法:
1.写个数据集的类(myDataset ),并继承Dataset
在myDataset 类中实现__len__()和__getitem__两个函数, __len__返回数据集的总长,__getitem__返回每次按照 索引所取的项,即x, y
比如:在处理序列问题时:
__len__返回的是:all_len/seq_len
__getitem__返回的是:一个输入序列,一个输出序列,即:x_seq, y_seq
if index + input_seq_len + prediction_seq_len +1 < all_Data_Len:
train = allData[j:j+input_seq_len,:,:]
...
...
return train, label1, label2, ...
myDataset = TensorDataset(trainDataX,trainDataY)
综上任意一种处理完毕后将处理后的数据集放入DataLoader,就可以在训练的时候直接用了
myloader = DataLoader(dataset=myDataset , batch_size=1, shuffle=False)
训练中:
for i, data in enumerate(train_loader):
torch.autograd为tensor的所有操作自动求导(Variable类是核心),所有Tensor必须转换为Variable
PyTorch数据加载模块一共涉及到Dataset,Sampler,Dataloader三个类
Dataset负责对raw data source封装,将其封装成Python可识别的数据结构,其必须提供提取数据个体的接口。Dataset共有Map-style datasets和Iterable-style datasets两种:
1.1 map-style dataset:实现了__getitem__和__len__接口,表示一个从索引/key到样本数据的map。比如:datasets[10],就表示第10个样本。
1.2 iterable-style dataset:实现了__iter__接口,表示在data samples上的一个Iterable(可迭代对象),这种形式的dataset非常不适合随机存取(代价太高),但非常适合处理流数据。比如:iter(datasets)获得迭代器,然后不断使用next迭代从而实现遍历。
Sampler负责提供一种遍历数据集所有元素索引的方式。
Dataloader负责加载数据,同时支持map-style和iterable-style Dataset,支持单进程/多进程,还可以设置loading order, batch size, pin memory等加载参数。
总结一下步骤:
归纳一下:即Dataloader负责总的调度,命令Sampler定义遍历索引的方式,然后用索引去Dataset中提取元素。于是就实现了对给定数据集的遍历。
所有设计的Dataset类必须继承torch.utils.data.Dataset这个类。
前者很重要,是Dataset及其子类的核心,定义了数据元素提取(即通过索引获取样本,实际代码中常使用[]输入索引)
class Dataset(object):
def __getitem__(self, index):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
具体实践中,我们需要使用Dataset的子类,自己实现的或者现成的。
我们可以来看看PyTorch为我们提供的现成的Dataset子类
下面着重介绍
TensorDataset和IterableDataset.
*CLASS torch.utils.data.TensorDataset(tensors)
包装了Tensor的Dataset子类,map-style dataset
每个样本可以通过tensors第一个维度的索引获取
class TensorDataset(Dataset):
r"""
Arguments:
*tensors (Tensor): tensors that have the same size of the first dimension.
"""
def __init__(self, *tensors):
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors
def __getitem__(self, index):
return tuple(tensor[index] for tensor in self.tensors)
def __len__(self):
return self.tensors[0].size(0)
如上源码:
__init__的形参是*tensors,因此是可以传入多个tensor变量的,但需要保证每个tensor的第一个维度均是一样的。
例子
正确输入:(100*64*64*3,100*32*32*3,100*16*16*31)
错误输入:(100*64*64*3,200*32*32*3,100*16*16*31)
__getitem__提取的就是*tensors中每个张量的第index个样本(因为每个张量第一维度都是一样的)
__len__即*tensors每个张量第一个维度长度
常见用法:*tensors指定我们可以输入多个张量,我们可以同时输入train_data和train_label
dataset = TensorDataset(train_data, train_label)
CLASS torch.utils.data.IterableDataset
内部样本的组织形式是Iterable的所有dataset类都是IterableDataset类的子类,
即:所有iterable-style dataset都是IterableDataset的子类
这种形式的dataset对于处理流数据是非常有用的。
所有这些子类需要实现__iter__方法(而不是__getitem__方法了),需要据此来返回样本的迭代器,从而遍历dataset(实际代码中常使用iter+next来遍历)
关于Python中Iterable和Iterator的介绍见我的另一篇文章:刘昕宸:彻底搞懂Python的__iter__和__next__,Iterable和Iteration
class IterableDataset(Dataset[T_co]):
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
def __add__(self, other: Dataset[T_co]):
return ChainDataset([self, other])
关于多进程的问题:
IterableDataset的某个子类被DataLoader使用时,dataset中的每个item可以通过DataLoader的Iterator迭代获取。
当num_works>0时就是多进程模式,每个工作进程都有一个不同的dataset对象的拷贝,因此我们需要独立安排每一份拷贝该如何处理(后面会有例子),以防止不同的进程会返回重复的元素。(有MPI编程经验的同学应该更能理解!)
可以通过get_worker_info方法,在某一当前进程中调用,获得当前进程信息。这个方法要么在dataset类的__iter__方法中使用,要么在DataLoader的worker_init_fn方法中设置并使用。
举2个例子(来自官网文档):
例1:在dataset类的__iter__方法中使用get_worker_info方法,划分工作空间,获得当前进程id,并根据进程id分配其需要处理的工作空间
class MyIterableDataset(torch.utils.data.IterableDataset): def __init__(self, start, end): super(MyIterableDataset).__init__() assert end > start, "this example code only works with end >= start" self.start = start self.end = end def __iter__(self): worker_info = torch.utils.data.get_worker_info() if worker_info is None: # 单进程:一个进程处理全部样本 iter_start = self.start iter_end = self.end else: # 多进程,在当前进程中 # 划分工作空间 per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) worker_id = worker_info.id iter_start = self.start + worker_id * per_worker iter_end = min(iter_start + per_worker, self.end) return iter(range(iter_start, iter_end))
具体使用:
>>> # 给定样本集range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)
>>> # 单进程加载
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>> # 2个进程加载
>>> # 进程0负责[3, 4]. 进程1负责[5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 5, 4, 6]
>>> # 更多的进程
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
[3, 4, 5, 6]
例2:先来个反例:不手动配置每个进程的工作空间的话,默认每个进程的工作空间是整个dataset,因此每个进程都会遍历一次整个数据集,导致产生重复数据。
>>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... return iter(range(self.start, self.end)) ... >>> # 给定样本集range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # 单进程加载 >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] >>> >>> # 直接多进程加载会产生重复数据 >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 3, 4, 4, 5, 5, 6, 6]
除了上面在MyIterableDataset的__iter__方法中依靠get_work_info分配工作空间,还可以事先定义函数worker_init_fn分配工作空间(分配策略与例1完全一致),再将该函数传给dataloader生效:
>>> def worker_init_fn(worker_id): ... worker_info = torch.utils.data.get_worker_info() ... dataset = worker_info.dataset # the dataset copy in this worker process ... overall_start = dataset.start ... overall_end = dataset.end ... # configure the dataset to only process the split workload ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... dataset.start = overall_start + worker_id * per_worker ... dataset.end = min(dataset.start + per_worker, overall_end) ... >>> # 多进程加载,使用自定义的`worker_init_fn` >>> # 进程0负责[3, 4]. 进程1负责[5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) [3, 5, 4, 6] >>> # 更多的进程 >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn))) [3, 4, 5, 6]
你以为这就完了吗???当然不!!!
贴心的PyTorch小可爱还为我们提供了计算机视觉常用的数据集,并将它们包装成了Dataset!!!
这些数据集都在torchvision.datasets下,共有这么多:
我们以CIFAR-10数据集为例来看一看:
CLASS torchvision.datasets.CIFAR10
使用举例:
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
normalize,
]), download=True),
batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True)
CLASS torch.utils.data.Sampler(data_source: Optional[collections.abc.Sized])
所有Samplers的基类
Sampler的所有子类都需要实现__iter__,用来提供遍历dataset索引的方式。我们获得不同的索引遍历,就能以不同的方式遍历dataset,这就是samplers的目的。
PyTorch为我们提供了几种现成的Sampler子类:
SequentialSampler
RandomSampler
SubsetRandomSampler
WeightedRandomSampler
BatchSampler
DistributedSampler
下面我着重介绍一下SequentialSampler,RandomSampler和BatchSampler
CLASS SequentialSampler(Sampler[int])
SequentialSampler指定总是按照相同的次序,顺序地采样元素
关注方法__iter__,直接range生成顺序的索引,也就是为dataloader提供了顺序遍历dataset的方式。
class SequentialSampler(Sampler[int]):
r"""
Arguments:
data_source (Dataset): dataset to sample from
"""
data_source: Sized
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self) -> int:
return len(self.data_source)
CLASS torch.utils.data.RandomSampler
RandomSampler提供了随机采样元素的方式。
如果replacement==False
,则随机采样整个数据集,即num_samples==len(dataset)。此时sampler提供给dataloader以一种随机的次序遍历dataset.
如果replacement==True
,则从数据集中随机采样num_samples个样本
仅贴出__iter__实现:
@property def num_samples(self) -> int: # dataset size might change at runtime if self._num_samples is None: return len(self.data_source) return self._num_samples def __iter__(self): n = len(self.data_source) if self.generator is None: generator = torch.Generator() generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item())) else: generator = self.generator if self.replacement: for _ in range(self.num_samples // 32): yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() else: yield from torch.randperm(n, generator=self.generator).tolist() def __len__(self): return self.num_samples
CLASS torch.utils.data.BatchSampler
BatchSampler包装另一个sampler(输入参数),用来产生一个mini-batch大小的索引,相当于是为dataloader提供了提取dataset的1个mini-batch样本的索引。
关注__iter__和__len__方法:
class BatchSampler(Sampler[List[int]]): r"""Wraps another sampler to yield a mini-batch of indices. Args: sampler (Sampler or Iterable): Base sampler. Can be any iterable object batch_size (int): Size of mini-batch. drop_last (bool): If ``True``, the sampler will drop the last batch if its size would be less than ``batch_size`` Example: >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]] """ def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None: # Since collections.abc.Iterable does not check for `__getitem__`, which # is one way for an object to be an iterable, we don't do an `isinstance` # check here. if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \ batch_size <= 0: raise ValueError("batch_size should be a positive integer value, " "but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): raise ValueError("drop_last should be a boolean value, but got " "drop_last={}".format(drop_last)) self.sampler = sampler self.batch_size = batch_size self.drop_last = drop_last def __iter__(self): batch = [] for idx in self.sampler: batch.append(idx) if len(batch) == self.batch_size: yield batch batch = [] if len(batch) > 0 and not self.drop_last: yield batch def __len__(self): # Can only be called if self.sampler has __len__ implemented # We cannot enforce this condition, so we turn off typechecking for the # implementation below. # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] if self.drop_last: return len(self.sampler) // self.batch_size # type: ignore else: return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore
铺垫了这么多,终于讲到DataLoader了。
在训练/测试深度学习网络的程序中,我们直接遍历Dataloader来获取数据(data,label等),并将数据feed给网络用于前向传播和反向传播。
代码形如:
for data, label in train_loader:
data, label = data.to(device), label.to(device).squeeze()
opt.zero_grad()
logits = model(data)
loss = criterion(logits, label)
那么在for data, label in train_loader这个过程中究竟发生了什么呢?一起探索!
for循环会调用dataloader iter:
以此获得迭代器来遍历dataset
def __iter__(self) -> '_BaseDataLoaderIter':
# When using a single worker the returned iterator should be
# created everytime to avoid reseting its state
# However, in the case of a multiple workers iterator
# the iterator is only created once in the lifetime of the
# DataLoader object so that workers can be reused
if self.persistent_workers and self.num_workers > 0:
if self._iterator is None:
self._iterator = self._get_iterator()
else:
self._iterator._reset(self)
return self._iterator
else:
return self._get_iterator()
其中调用了self._get_iterator()获得迭代器:
def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
return _MultiProcessingDataLoaderIter(self)
为了简单起见,我们只考虑单进程的代码,那我们看一下_SingleProcessDataLoaderIter的实现:
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
_SingleProcessDataLoaderIter继承自_BaseDataLoaderIter,因此_BaseDataLoaderIter的代码也需要看一下:
class _BaseDataLoaderIter(object): def __init__(self, loader: DataLoader) -> None: self._dataset = loader.dataset self._dataset_kind = loader._dataset_kind self._IterableDataset_len_called = loader._IterableDataset_len_called self._auto_collation = loader._auto_collation self._drop_last = loader.drop_last self._index_sampler = loader._index_sampler self._num_workers = loader.num_workers self._prefetch_factor = loader.prefetch_factor self._pin_memory = loader.pin_memory and torch.cuda.is_available() self._timeout = loader.timeout self._collate_fn = loader.collate_fn self._sampler_iter = iter(self._index_sampler) self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item() self._persistent_workers = loader.persistent_workers self._num_yielded = 0 def __iter__(self) -> '_BaseDataLoaderIter': return self def _reset(self, loader, first_iter=False): self._sampler_iter = iter(self._index_sampler) self._num_yielded = 0 self._IterableDataset_len_called = loader._IterableDataset_len_called def _next_index(self): return next(self._sampler_iter) # may raise StopIteration def _next_data(self): raise NotImplementedError def __next__(self) -> Any: if self._sampler_iter is None: self._reset() data = self._next_data() self._num_yielded += 1 if self._dataset_kind == _DatasetKind.Iterable and \ self._IterableDataset_len_called is not None and \ self._num_yielded > self._IterableDataset_len_called: warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} " "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called, self._num_yielded) if self._num_workers > 0: warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the " "IterableDataset replica at each worker. Please see " "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.") warnings.warn(warn_msg) return data next = __next__ # Python 2 compatibility def __len__(self) -> int: return len(self._index_sampler) def __getstate__(self): # TODO: add limited pickling support for sharing an iterator # across multiple threads for HOGWILD. # Probably the best way to do this is by moving the sample pushing # to a separate thread and then just sharing the data queue # but signalling the end is tricky without a non-blocking API raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
dataloader获得了迭代器之后,我们的for循环需要调用__next__来获得下一个对象,从而实现遍历。
我们看一下_BaseDataLoaderIter的__next__:
def __next__(self) -> Any: if self._sampler_iter is None: self._reset() data = self._next_data() self._num_yielded += 1 if self._dataset_kind == _DatasetKind.Iterable and \ self._IterableDataset_len_called is not None and \ self._num_yielded > self._IterableDataset_len_called: warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} " "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called, self._num_yielded) if self._num_workers > 0: warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the " "IterableDataset replica at each worker. Please see " "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.") warnings.warn(warn_msg) return data
__next__需要调用_next_data,因此我们还需要看一下_SingleProcessDataLoaderIter的_next_data:
_next_data需要_next_index获得索引,并通过索引fetch到对应的样本。
def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
关于_next_index:
_sampler_iter来自_index_sampler,来自loader
def _next_index(self):
return next(self._sampler_iter) # may raise StopIteration
再看dataloader中的_index_sampler,一切就明白了:
@property
def _index_sampler(self):
# The actual sampler used for generating indices for `_DatasetFetcher`
# (see _utils/fetch.py) to read data at each time. This would be
# `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
# We can't change `.sampler` and `.batch_sampler` attributes for BC
# reasons.
if self._auto_collation:
return self.batch_sampler
else:
return self.sampler
总结来说就是dataloader提供了sampler,然后_SingleProcessDataLoaderIter迭代sampler获得索引。
下面我们来看看Fetch:
pytorch在Dataset上又封装了一层Fetcher。
这样做是使得iterable Dataset(对应_IterableDatasetFetcher)和map Dataset(对应_MapDatasetFetcher)在Dataloader内能使用相同的接口fetch,代码更加简洁。
fetcher需要index获取数据元素。
针对map-style fetcher:
关注fetch方法:直接输入索引index,作为map的key,获得对应的样本(即value)
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
注意:这里的index可能不只是一个索引,而是一个batch的索引。
这取决于_auto_collation,_auto_collation的取值在Dataloader中定义:
有batch_sampler,_auto_collation就为True,就优先使用batch_sampler,对应在fetcher中传入的就是一个batch的索引。
@property
def _auto_collation(self):
return self.batch_sampler is not None
针对iterable-style fetcher:
__init__方法内设置了dataset初始的迭代器
fetch方法内获取元素,index其实已经没有多大作用了。
对于batch_sampler(即auto_collation==True):直接使用往后遍历并提取len(possibly_batched_index)个样本(即1个batch的样本)
对于sampler:直接往后遍历并提取1个样本
class _IterableDatasetFetcher(_BaseDatasetFetcher): def __init__(self, dataset, auto_collation, collate_fn, drop_last): super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last) self.dataset_iter = iter(dataset) def fetch(self, possibly_batched_index): if self.auto_collation: data = [] for _ in possibly_batched_index: try: data.append(next(self.dataset_iter)) except StopIteration: break if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)): raise StopIteration else: data = next(self.dataset_iter) return self.collate_fn(data)
另外对python中Iterable,Iterator,iter,__next__等的详细解释,参见我另一篇文章:彻底搞懂Python的__iter__和__next__,Iterable和Iteration
最后,我们通过索引传入fetcher,fetch得到想要的样本!
我们的目标终于实现了!!!
整个过程调用关系总结:
loader.iter–> _get_iterator --> _SingleProcessDataLoaderIter --> _BaseDataLoaderIter --> next --> _next_data–> self._dataset_fetcher.fetch(index) --> _next_index -->_sampler_iter --> loader._index_sampler
但愿这么细致的讲解,能真正搞清楚Dataset,Sampler,DataLoader三者的机理及其运行关系。
总结:
Dataset封装数据集(可通过索引获取元素)
Sampler提供索引次序(可迭代,用于遍历)
DataLoader是一个调度器,迭代DataLoaderIter的过程中,迭代Sampler获得下一索引,并通过该索引使用fetcher(fetcher是对dataset的封装,使得dataloader代码与iterable-style/map-style dataset解耦)获得对应元素。
2.1 实战建议
Dataset
通常使用TensorDataset,或者我们自行实现一个其继承类。
Sampler
我们一般不用管,直接使用DataLoader默认指定的就行:
if sampler is None: # give default samplers
if self._dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
# Cannot statically verify that dataset is Sized
# Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
sampler = RandomSampler(dataset, generator=generator) # type: ignore
else:
sampler = SequentialSampler(dataset)
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
如上来自DataLoader的代码,
如果是iterable-style dataset,默认使用_InfiniteConstantSampler:
其实这个_InfiniteConstantSampler啥也没干,因为我们遍历iterable-style dataset依靠的是迭代器,根本就不需要索引!(上面介绍的_IterableDatasetFetcher已经说明了这一点!)
class _InfiniteConstantSampler(Sampler):
r"""
Arguments:
data_source (Dataset): dataset to sample from
"""
def __init__(self):
super(_InfiniteConstantSampler, self).__init__(None)
def __iter__(self):
while True:
yield None
如果是map-style dataset,有shuffle则默认使用RandomSampler;没有shuffle则默认使用SequentialSampler
batch_sampler就是对上面已经生成的sampler,进一步包装。
DataLoader
直接用就完事了!
2.2 具体例子
这是来自于DGCNN的PyTorch版本官方实现:WangYueFt/dgcnn
DGCNN是非常著名的点云特征学习网络,感兴趣的朋友可以参考我这一篇文章的解读:搞懂DGCNN,这篇就够了!论文及代码完全解析
自己实现Dataset,用于装载ModelNet40数据集:
class ModelNet40(Dataset): def __init__(self, num_points, partition='train'): self.data, self.label = load_data(partition) self.num_points = num_points self.partition = partition def __getitem__(self, item): pointcloud = self.data[item][:self.num_points] label = self.label[item] if self.partition == 'train': pointcloud = translate_pointcloud(pointcloud) np.random.shuffle(pointcloud) return pointcloud, label def __len__(self): return self.data.shape[0]
将ModelNet40装载至DataLoader:
Sampler使用默认的,因为shuffle==True,因此使用的应该是RandomSampler:
train_loader = DataLoader(ModelNet40(partition='train', num_points=args.num_points), num_workers=8,
batch_size=args.batch_size, shuffle=True, drop_last=True)
使用DataLoader:
直接for循环遍历就完事了:使用DataLoader:
直接for循环遍历就完事了:
for data, label in train_loader:
data, label = data.to(device), label.to(device).squeeze()
data = data.permute(0, 2, 1)
batch_size = data.size()[0]
opt.zero_grad()
logits = model(data)
loss = criterion(logits, label)
loss.backward()
opt.step()
preds = logits.max(dim=1)[1]
count += batch_size
train_loss += loss.item() * batch_size
train_true.append(label.cpu().numpy())
train_pred.append(preds.detach().cpu().numpy())
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。