当前位置:   article > 正文

联邦学习开山代码报错整理_batch_size should be a positive integer value, but

batch_size should be a positive integer value, but got batch_size=0

参考:联邦学习开山之作代码解读与收获


最近学习联邦学习开山代码,运行上面汇总的代码时,遇到了一些警告或报错,遂记录下来。

一、UserWarning警告

        正常运行代码就会在初始发出警告。

报错信息:To copy construct from a tensor...

UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).   y_support=torch.tensor(y_support,dtype=torch.int64)

原因:由于update.py下的datasplit类里torch.tensor函数为深拷贝,不记录历史更新数据,所以有可能导致意想不到的错误,故警告。

改正:将返回值改为下面as_tensor就不会报错,该函数会共享后面对象的历史数据。

return torch.as_tensor(image), torch.as_tensor(label)

二、ValueError错误

        当我将options里的数据集分布默认改成mnist的非独立同分布且非平均分配时,极大概率发生报错。

报错信息:batch_size should be a positive integer value, but got batch_size=0

  0%|          | 0/10 [00:01<?, ?it/s]
Traceback (most recent call last):
  File "F:\1-offical\FLearn\src\federated_main.py", line 90, in <module>
    local_model = LocalUpdate(args=args, dataset=train_dataset,
  File "F:\1-offical\FLearn\src\update.py", line 39, in __init__
    self.trainloader, self.validloader, self.testloader = self.train_val_test(dataset, list(idxs))
  File "F:\1-offical\FLearn\src\update.py", line 54, in train_val_test
    validloader = DataLoader(DatasetSplit(dataset, idxs_val),
  File "C:\Users\lenovo\AppData\Roaming\Python\Python39\site-packages\torch\utils\data\dataloader.py", line 357, in __init__
    batch_sampler = BatchSampler(sampler, batch_size, drop_last)
  File "C:\Users\lenovo\AppData\Roaming\Python\Python39\site-packages\torch\utils\data\sampler.py", line 232, in __init__
    raise ValueError("batch_size should be a positive integer value, "
ValueError: batch_size should be a positive integer value, but got batch_size=0

原因:我的理解是,由于验证集列表太短(仅有5个数据,batch/10=0),导致一个批次的训练都进行不了,这样即产生报错。归根结底是因为用户整个数据集就很短,导致验证集分到的太少。

改正:最直接的方法就是调整update里面的batch_size大小,这样使得每轮都有合适的训练数据。

  1. #parser.add_argument('--local_ep', type=int, default=5,help="本地训练轮次E,默认为10轮")
  2. validloader = DataLoader(DatasetSplit(dataset, idxs_val),
  3. batch_size=int(len(idxs_val)/self.args.local_ep), shuffle=False)
  4. testloader = DataLoader(DatasetSplit(dataset, idxs_test),
  5. batch_size=int(len(idxs_test)/self.args.local_ep), shuffle=False)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/272302
推荐阅读
相关标签
  

闽ICP备14008679号