当前位置:   article > 正文

基于llava-v1.5-7b大模型的图片理解_clip-vit-large-patch14

clip-vit-large-patch14

多模态模型:llava-v1.5-7b,主要用于图片理解,本文的使用场景为获取图片标题和图片内容描述。

1、查看服务器配置

nvidia-smi
  • 1

CUDA版本= 12.2,8张24g RTX 3090显卡。

2、环境配置

2.1、克隆此存储库并导航到 LLaVA 文件夹

git clone https://github.com/haotian-liu/LLaVA.git
cd LLaVA
  • 1
  • 2

2.2、安装包

conda create -n llava python=3.10 -y
conda activate llava
pip install --upgrade pip  # enable PEP 660 support
pip install -e .
  • 1
  • 2
  • 3
  • 4

2.3、为培训案例安装附加包

pip install -e ".[train]"
pip install flash-attn --no-build-isolation
  • 1
  • 2

2.4、升级到最新的代码库

git pull
pip install -e .

# if you see some import errors when you upgrade, please try running the command below (without #)
# pip install flash-attn --no-build-isolation --no-cache-dir
  • 1
  • 2
  • 3
  • 4
  • 5

3、下载模型

3.1、llava-v1.5-7b下载

https://huggingface.co/liuhaotian/llava-v1.5-7b/tree/main
在这里插入图片描述
特别注意:LLaVa大模型其中会用到clip-vit-large-patch14-336模型!!!
所以还要下载clip-vit-large-patch14-336模型。并修改llava-v1.5-7b中的config.json。

3.2、clip-vit-large-patch14-336下载

https://huggingface.co/openai/clip-vit-large-patch14-336/tree/main
在这里插入图片描述

4、API服务

参考代码:https://github.com/haotian-liu/LLaVA/tree/main/llava/serve/cli.py

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
import torch
from PIL import Image
import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer
from pydantic import BaseModel
import uvicorn
from fastapi import FastAPI
# Model
disable_torch_init()
model_path = "/models/llava-v1.6-34b"
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path, None, model_name, False, False, device="cuda")

def load_image(image_file):
    if image_file.startswith('http://') or image_file.startswith('https://'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image


def model_infer(image_file, inp):
    """
    模型推断
    """
    if "llama-2" in model_name.lower():
        conv_mode = "llava_llama_2"
    elif "mistral" in model_name.lower():
        conv_mode = "mistral_instruct"
    elif "v1.6-34b" in model_name.lower():
        conv_mode = "chatml_direct"
    elif "v1" in model_name.lower():
        conv_mode = "llava_v1"
    elif "mpt" in model_name.lower():
        conv_mode = "mpt"
    else:
        conv_mode = "llava_v0"

    conv = conv_templates[conv_mode].copy()
    if "mpt" in model_name.lower():
        roles = ('user', 'assistant')
    else:
        roles = conv.roles

    image = load_image(image_file)
    image_size = image.size
    # Similar operation in model_worker.py
    image_tensor = process_images([image], image_processor, model.config)
    if type(image_tensor) is list:
        image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
    else:
        image_tensor = image_tensor.to(model.device, dtype=torch.float16)

    if image is not None:
        # first message
        if model.config.mm_use_im_start_end:
            inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
        else:
            inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
        conv.append_message(conv.roles[0], inp)
        image = None
    else:
        # later messages
        conv.append_message(conv.roles[0], inp)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    keywords = [stop_str]
    streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            image_sizes=[image_size],
            do_sample=True,
            temperature=0.1,
            max_new_tokens=1024,
            streamer=streamer,
            use_cache=True)

    outputs = tokenizer.decode(output_ids[0, 1:-1]).strip()
    conv.messages[-1][-1] = outputs
    print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
    return outputs


app = FastAPI()


class ImageInput(BaseModel):
    url: str
    ocr_result: str


@app.get('/')
def home():
    return 'hello world'


@app.post('/img_desc')
def image_desc(image_input: ImageInput):
    title_string = "请为这张图片生成一个中文标题。" if not image_input.ocr_result else \
        f'这张图片中的文字为"{image_input.ocr_result}"。请为这张图片生成一个中文标题。'

    title_output = model_infer(image_input.url, title_string)

    desc_string = "请详细描述这张图片中的内容。" if not image_input.ocr_result else \
        f'这张图片中的文字为"{image_input.ocr_result}"。请详细描述这张图片中的内容。'

    desc_output = model_infer(image_input.url, desc_string)
    return {"url": image_input.url, "title": title_output, "desc": desc_output}


if __name__ == '__main__':
    uvicorn.run(app, host="0.0.0.0", port=50075)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127

5、接口展示

)
理解能力还是相对差点,后期会结合paddleocr来完善,或者加载llava-v1.6-34b来做测试。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/482504
推荐阅读
相关标签
  

闽ICP备14008679号