当前位置:   article > 正文

Gemma生态又添大将——3B VLM的paligemma_paligemma-3b-mix-224

paligemma-3b-mix-224

通过 IO 2024 大会可以充分看出 Google 正在加大 AI 领域的投资和研发。而目前主流的是闭源的 GPT 生态和开源的 Llama 生态。而在今年年初 DeepMind 就发布了 CodeGemma 和 RecurrentGemma,其中 RecurrentGemma 使用了最新研发的 Griffin 框架 (尚在实验性),有意替代Transformers。而此次,Paligemma 也是为了填补了 Gemma 在 VLM 的空缺。

介绍

PaliGemma 是一个多功能、轻量级的视觉语言模型(VLM),它受 PaLI-3 的启发,基于 SigLIP 视觉模型和 Gemma 语言模型等开放组件。它将图像和文本作为输入,并生成文本作为输出,支持多种语言。它的设计目的是在图像和短视频字幕、视觉问题解答、文本阅读、对象检测和对象分割等各种视觉语言任务中实现同类领先的微调性能。

Transformers PaliGemma 3B 使用 448*448(224*224) 输入图像和 512 (256)标记输入/输出文本序列,在混合下游学术数据集上进行微调。模型采用 float32、bfloat16 和 float16 格式,仅供研究使用。

相关文档:

PaliGemma 在以下混合数据集上进行了预训练:

WebLI: WebLI(网络语言图像)是一个网络规模的多语言图像-文本数据集,由公共网络构建而成。通过对 WebLI 的广泛拆分,可获得多种模型功能,如视觉语义理解、对象定位、视觉定位文本理解、多语言性等。

CC3M-35L: 从网页中策划的英文图像-alt_文本对(Sharma 等人,2018 年)。我们使用谷歌云翻译 API 翻译成另外 34 种语言。

VQ²A-CC3M-35L/VQG-CC3M-35L: VQ2A-CC3M 的子集(Changpinyo 等人,2022a),使用 Google Cloud Translation API 翻译成与 CC3M-35L 相同的 34 种语言。

OpenImages: 检测和对象感知问答(Piergiovanni 等人,2022 年)由 OpenImages 数据集上的手工规则生成。

WIT: 从维基百科收集的图片和文本(Srinivasan 等人,2021 年)。

代码

GPU

from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image
import requests
import torch

model_id = "google/paligemma-3b-mix-224"
device = "cuda:0"
dtype = torch.bfloat16

url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
image = Image.open(requests.get(url, stream=True).raw)

model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=dtype,
    device_map=device,
    revision="bfloat16",
).eval()
processor = AutoProcessor.from_pretrained(model_id)

# Instruct the model to create a caption in Spanish
prompt = "caption es"
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
input_len = model_inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
    generation = generation[0][input_len:]
    decoded = processor.decode(generation, skip_special_tokens=True)
    print(decoded)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30

我在RTX2070试过,运行正常。

注意:Flash Attention2 由于某些原因会发生“/.local/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEi”的错误,可以执行重装代码

pip install flash_attn -U --force-reinstall
  • 1
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小桥流水78/article/detail/916992
推荐阅读
相关标签
  

闽ICP备14008679号