当前位置:   article > 正文

超越sora,最强文生视频CogVideo模型落地分享

cogvideo

CogVideo是智谱AI开发的一款基于深度学习的文本到视频生成模型,它能够根据文本描述自动生成3D环境的视频内容。

作为CogVideoX系列中的第一个模型,CogVideoX-2B拥有20亿参数,与智谱AI的视频生成产品“清影”同源。

CogVideoX-2B融合了多项前沿技术,包括三维变分自编码器(3D VAE)、端到端视频理解模型和专家Transformer技术,这些技术使得模型在视频生成领域处于领先地位。

该模型支持英语提示词,单GPU推理时显存消耗约为18GB(使用SAT技术)或23.9GB(使用diffusers)。

模型的微调显存消耗为42GB,提示词长度上限为226个Tokens,能够生成长度为6秒、每秒8帧、分辨率为720*480的视频。

github项目地址:https://github.com/THUDM/CogVideo。

一、环境安装

1、python环境

建议安装python版本在3.10以上。

2、pip库安装

pip install torch==2.4.0+cu118 torchvision==0.19.0+cu118 torchaudio==2.4.0 --extra-index-url https://download.pytorch.org/whl/cu118

pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

3、CogVideoX-2b Diffusers模型下载

git lfs install

git clone https://huggingface.co/THUDM/CogVideoX-2b

4、CogVideoX-2b SAT模型下载

wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1

mv 'index.html?dl=1' vae.zip

unzip vae.zip

wget https://cloud.tsinghua.edu.cn/f/556a3e1329e74f1bac45/?dl=1

mv 'index.html?dl=1' transformer.zip

unzip transformer.zip

、功能测试

1、命令行运行测试

(1)python代码调用测试

  1. import argparse
  2. import tempfile
  3. from typing import Union, List
  4. import PIL.Image
  5. import imageio
  6. import numpy as np
  7. import torch
  8. from diffusers import CogVideoXPipeline
  9. def export_to_video_imageio(
  10. video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
  11. ) -> str:
  12. """
  13. Export the video frames to a video file using imageio library to avoid the "green screen" issue.
  14. """
  15. if output_video_path is None:
  16. output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
  17. if isinstance(video_frames[0], PIL.Image.Image):
  18. video_frames = [np.array(frame) for frame in video_frames]
  19. with imageio.get_writer(output_video_path, fps=fps) as writer:
  20. for frame in video_frames:
  21. writer.append_data(frame)
  22. return output_video_path
  23. def generate_video(
  24. prompt: str,
  25. model_path: str,
  26. output_path: str = "./output.mp4",
  27. num_inference_steps: int = 50,
  28. guidance_scale: float = 6.0,
  29. num_videos_per_prompt: int = 1,
  30. device: str = "cuda",
  31. dtype: torch.dtype = torch.float16,
  32. ):
  33. """
  34. Generates a video based on the given prompt and saves it to the specified path.
  35. Parameters:
  36. - prompt (str): The description of the video to be generated.
  37. - model_path (str): The path of the pre-trained model to be used.
  38. - output_path (str): The path where the generated video will be saved.
  39. - num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
  40. - guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
  41. - num_videos_per_prompt (int): Number of videos to generate per prompt.
  42. - device (str): The device to use for computation (e.g., "cuda" or "cpu").
  43. - dtype (torch.dtype): The data type for computation (default is torch.float16).
  44. """
  45. try:
  46. # Load pre-trained model
  47. pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
  48. except Exception as e:
  49. raise RuntimeError(f"Failed to load model from {model_path}: {e}")
  50. print(f"Model loaded successfully from {model_path}")
  51. try:
  52. # Encode the prompt to get embeddings
  53. prompt_embeds, _ = pipe.encode_prompt(
  54. prompt=prompt,
  55. num_videos_per_prompt=num_videos_per_prompt,
  56. device=device,
  57. dtype=dtype,
  58. )
  59. except Exception as e:
  60. raise RuntimeError(f"Failed to encode prompt: {e}")
  61. print(f"Prompt encoded successfully: {prompt}")
  62. try:
  63. # Generate video frames
  64. video = pipe(
  65. num_inference_steps=num_inference_steps,
  66. guidance_scale=guidance_scale,
  67. prompt_embeds=prompt_embeds,
  68. negative_prompt_embeds=torch.zeros_like(prompt_embeds), # Not Supported negative prompt
  69. ).frames[0]
  70. except Exception as e:
  71. raise RuntimeError(f"Failed to generate video: {e}")
  72. print("Video generated successfully")
  73. try:
  74. # Export frames to video file
  75. export_to_video_imageio(video, output_path, fps=8)
  76. except Exception as e:
  77. raise RuntimeError(f"Failed to export video: {e}")
  78. print(f"Video saved successfully at {output_path}")
  79. if __name__ == "__main__":
  80. parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
  81. parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
  82. parser.add_argument(
  83. "--model_path", type=str, default="THUDM/CogVideoX-2b", help="The path of the pre-trained model to be used"
  84. )
  85. parser.add_argument(
  86. "--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
  87. )
  88. parser.add_argument(
  89. "--num_inference_steps", type=int, default=50, help="Number of steps for the inference process"
  90. )
  91. parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
  92. parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
  93. parser.add_argument(
  94. "--device", type=str, default="cuda", help="The device to use for computation (e.g., 'cuda' or 'cpu')"
  95. )
  96. parser.add_argument(
  97. "--dtype", type=str, default="float16", help="The data type for computation (e.g., 'float16' or 'float32')"
  98. )
  99. args = parser.parse_args()
  100. # Convert dtype argument to torch.dtype.
  101. dtype = torch.float16 if args.dtype == "float16" else torch.float32
  102. generate_video(
  103. prompt=args.prompt,
  104. model_path=args.model_path,
  105. output_path=args.output_path,
  106. num_inference_steps=args.num_inference_steps,
  107. guidance_scale=args.guidance_scale,
  108. num_videos_per_prompt=args.num_videos_per_prompt,
  109. device=args.device,
  110. dtype=dtype,
  111. )

未完......

更多详细的内容欢迎关注:杰哥新技术

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/酷酷是懒虫/article/detail/1010316
推荐阅读
相关标签
  

闽ICP备14008679号