当前位置:   article > 正文

diffusion训练提示Sizes of tensors must match except in dimension 1. Expected size 20 but got size 21 for_return self._call_impl(*args, **kwargs)

return self._call_impl(*args, **kwargs)
class TrainSet(Dataset):
    def __init__(self, classes):
        reader = pd.read_csv("D:/新建文件夹/train_data42.csv")
        data = reader['data'].tolist()
        label = reader['label'].tolist()
        datas = []
        for item in data:
            item = list(map(float, item.split(",")))
            datas.append(np.array(item).reshape(42, 1))
        datas = np.stack(datas)[np.array(label) == classes]
        self.datas = torch.from_numpy(datas).transpose(-1, -2).half()

    def __getitem__(self, index):
        return self.datas[index]

    def __len__(self):
        return len(self.datas)

if __name__ == '__main__':
    for i in range(10):
        # dataset = TrainSet(i)
        dim = 32
        batch_size = 32
        seq_length = 42
        channels = 1
        epoches = 100
        model = Unet1D(
            dim = dim,
            dim_mults = (1, 2, 4, 8),
            channels = channels
        )

        diffusion = GaussianDiffusion1D(
            model,
            seq_length = seq_length,
            timesteps = 1000,
            objective = 'pred_v'
        )

        dataset = TrainSet(i)

        trainer = Trainer1D(
            diffusion.to(device),
            dataset = dataset,
            train_batch_size = batch_size,
            train_lr = 8e-5,
            train_num_steps = epoches,         # total training steps
            gradient_accumulate_every = 2,    # gradient accumulation steps
            ema_decay = 0.995,                # exponential moving average decay
            amp = True,                       # turn on mixed precision
        )
        print(trainer.device)
        trainer.train()
        trainer.save(f"classes_{i}")


        # after a lot of training
        sampled_seqs = []
        for j in range(1024):
            sampled_seq = diffusion.sample(batch_size = 32)
            sampled_seqs.append(sampled_seq)
        # print(sampled_seq.shape) # (4, 32, 128)
        data = torch.cat(sampled_seqs).cpu().numpy()
        # print(data.shape)
        if not os.path.exists("results_unswnb15"):
            os.mkdir("results_unswnb15")
        # torch.save(data, f"results/data_{i}.pt")

        np.save(f"results_unswnb15\\data_{i}.npy", data)

代码如图,使用diffusion训练时提示

Traceback (most recent call last):
  File "D:\afeu\denoising-diffusion-pytorch-main\main-unswnb15.py", line 78, in <module>
    trainer.train()
  File "D:\afeu\denoising-diffusion-pytorch-main\denoising_diffusion_pytorch\denoising_diffusion_pytorch_1d.py", line 840, in train
    loss = self.model(data)
  File "D:\LenovoSoftstore\venv\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "D:\LenovoSoftstore\venv\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "D:\afeu\denoising-diffusion-pytorch-main\denoising_diffusion_pytorch\denoising_diffusion_pytorch_1d.py", line 710, in forward
    return self.p_losses(img, t, *args, **kwargs)
  File "D:\afeu\denoising-diffusion-pytorch-main\denoising_diffusion_pytorch\denoising_diffusion_pytorch_1d.py", line 686, in p_losses
    model_out = self.model(x, t, x_self_cond)
  File "D:\LenovoSoftstore\venv\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "D:\LenovoSoftstore\venv\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "D:\afeu\denoising-diffusion-pytorch-main\denoising_diffusion_pytorch\denoising_diffusion_pytorch_1d.py", line 370, in forward
    x = torch.cat((x, h.pop()), dim = 1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 20 but got size 21 for tensor number 1 in the list.

哪位大神帮帮忙呢?数据是42*1,10分类。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/149575
推荐阅读
相关标签
  

闽ICP备14008679号