当前位置:   article > 正文

解决Stable Diffusion TensorRT转换模型报错cpu and cuda:0! (when checking argument for argume_when checking argument for argument grid in method

when checking argument for argument grid in method wrapper__grid_sampler_2d

记录Stable Diffusion webUI TensorRT插件使用过程的报错:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

拷贝下面的代码覆盖extensions\stable-diffusion-webui-tensorrt里的export_onnx.py文件,将模型和相关的张量移动到GPU,即可解决。

  1. import os
  2. from modules import sd_hijack, sd_unet
  3. from modules import shared, devices
  4. import torch
  5. def export_current_unet_to_onnx(filename, opset_version=17):
  6. if torch.cuda.is_available():
  7. print("CUDA is available")
  8. else:
  9. print("CUDA is not available")
  10. device = 'cuda' if torch.cuda.is_available() else 'cpu' # 根据CUDA是否可用选择设备
  11. shared.sd_model.model.diffusion_model.to(device)
  12. x = torch.randn(1, 4, 16, 16).to(devices.device, devices.dtype)
  13. timesteps = torch.zeros((1,)).to(devices.device, devices.dtype) + 500
  14. context = torch.randn(1, 77, 768).to(devices.device, devices.dtype)
  15. x = x.to(device)
  16. timesteps = timesteps.to(device)
  17. context = context.to(device)
  18. print(x.device, timesteps.device, context.device)
  19. def disable_checkpoint(self):
  20. if getattr(self, 'use_checkpoint', False) == True:
  21. self.use_checkpoint = False
  22. if getattr(self, 'checkpoint', False) == True:
  23. self.checkpoint = False
  24. shared.sd_model.model.diffusion_model.apply(disable_checkpoint)
  25. sd_unet.apply_unet("None")
  26. sd_hijack.model_hijack.apply_optimizations('None')
  27. os.makedirs(os.path.dirname(filename), exist_ok=True)
  28. with devices.autocast():
  29. torch.onnx.export(
  30. shared.sd_model.model.diffusion_model,
  31. (x, timesteps, context),
  32. filename,
  33. export_params=True,
  34. opset_version=opset_version,
  35. do_constant_folding=True,
  36. input_names=['x', 'timesteps', 'context'],
  37. output_names=['output'],
  38. dynamic_axes={
  39. 'x': {0: 'batch_size', 2: 'height', 3: 'width'},
  40. 'timesteps': {0: 'batch_size'},
  41. 'context': {0: 'batch_size', 1: 'sequence_length'},
  42. 'output': {0: 'batch_size'},
  43. },
  44. )
  45. sd_hijack.model_hijack.apply_optimizations()
  46. sd_unet.apply_unet()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/89304
推荐阅读
  

闽ICP备14008679号