当前位置:   article > 正文

超详细!DALL · E 文生图模型实践指南_dalle

dalle

最近需要用到 DALL·E的推断功能,在现有开源代码基础上发现还有几个问题需要注意,谨以此篇博客记录之。

我用的源码主要是 https://github.com/borisdayma/dalle-mini 仓库中的Inference pipeline.ipynb 文件。

在这里插入图片描述

运行环境:Ubuntu服务器

⚠️注意:本博客仅涉及 DALL · E 推断,不涉及训练过程。



一、环境配置

建议使用anaconda新建一个dalle环境,然后在该环境中进行相关配置,避免与环境中的其他库产生版本冲突。

使用下述命令新建名为dalle的环境:

conda create -n dalle python==3.8.0
  • 1

在终端分别运行下述命令,安装所需的python库:

# 安装 dalle运行需要的依赖库(注意版本只能是0.3.25)# Required only for colab environments + GPU
pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# 安装 dalle特定的库
pip install dalle-mini
# 安装 VQGAN
pip install -q git+https://github.com/patil-suraj/vqgan-jax.git
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

PS:如果由于网络连接问题无法通过pip命令下载VQGAN,就采取Plan-B:将仓库 https://github.com/patil-suraj/vqgan-jax 下载到服务器并解压,然后使用cd命令将当前目录到对应的仓库下载路径下,在终端运行python setup.py install安装VQGAN即可。


二、模型下载

由于网络连接问题,我采取「事先把模型下载到本地」的策略对模型进行直接调用,首先要明确的一点是,本项目中使用DALL · E 对图像进行编码,使用VQGAN对图像进行解码,所以我们需要分别下载DALL · E 和 VQGAN 两个模型。

DALL · E 模型下载地址:
mini版本:https://huggingface.co/dalle-mini/dalle-mini/tree/main
mega版本:https://huggingface.co/dalle-mini/dalle-mega/tree/main

VQGAN 模型下载地址:
https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/tree/main

下载完毕后,将模型部署到服务器,注意保存路径。


三、程序转换

相较于ipynb文件,我个人更加喜欢操作py文件,所以对于给定的ipynb文件,首先使用命令jupyter nbconvert --to script Inference pipeline.ipynb 将其转为同名py文件,该文件的主要内容如下(不含CLIP排序部分),其中模型路径 DALLE_MODEL和VQGAN_REPO 已改为本地路径(就是第二步中两个模型的保存路径),可以看到文件的注释也比较详细。

# dalle-mini
DALLE_MODEL = "/newdata/SD/dalle-mini/dalle-mini"
DALLE_COMMIT_ID = None
# VQGAN model
VQGAN_REPO = "/newdata/SD/dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"

import jax
import jax.numpy as jnp

# check how many devices are available
jax.local_device_count()

# Load models & tokenizer
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
# Load dalle-mini
model, params = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False)
# Load VQGAN
vqgan, vqgan_params = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False)

# Model parameters are replicated on each device for faster inference.
from flax.jax_utils import replicate
params = replicate(params)
vqgan_params = replicate(vqgan_params)

# Model functions are compiled and parallelized to take advantage of multiple devices.
from functools import partial

# model inference
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
def p_generate(
    tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale
):
    return model.generate(
        **tokenized_prompt,
        prng_key=key,
        params=params,
        top_k=top_k,
        top_p=top_p,
        temperature=temperature,
        condition_scale=condition_scale,
    )

# decode image
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
    return vqgan.decode_code(indices, params=params)

# Keys are passed to the model on each device to generate unique inference per device.
import random

# create a random key
seed = random.randint(0, 2**32 - 1)
key = jax.random.PRNGKey(seed)

# ## 
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/481199
推荐阅读
相关标签