当前位置:   article > 正文

一起学大模型 - 动手写一写langchain调用本地大模型(2)_langchain调用本地xinference的模型

langchain调用本地xinference的模型


前言

前一篇文章里,from transformers import GPT2LMHeadModel, GPT2Tokenizer 如果模型替换了,就得更改代码,很麻烦,那有没有更简单的方法呢?


一、自动选择

transformers 库中的 AutoTokenizerAutoModel 可以根据配置文件自动选择适当的分词器 和 大模型,而无需明确指定特定的模型分词器 和模型。这使得代码更加通用和简洁。

1. 使用 AutoTokenizer 和AutoModel的示例

假设你已经将模型和分词器下载到本地目录 /path/to/local/model,你可以使用 AutoTokenizerAutoModel 来加载模型和分词器:

from transformers import AutoTokenizer, AutoModel
import openai
import torch
from langchain.llms import BaseLLM

# 定义 BERT 嵌入模型类,使用 AutoTokenizer 和 AutoModel
class BERTEmbedder:
    def __init__(self, model_path='/path/to/local/model'):
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModel.from_pretrained(model_path)

    def embed(self, text):
        inputs = self.tokenizer(text, return_tensors='pt')
        with torch.no_grad():
            outputs = self.model(**inputs)
        return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()

# 定义 GPT-3 生成模型类
class GPT3LLM(BaseLLM):
    def __init__(self, temperature=0.7, max_tokens=150):
        self.temperature = temperature
        self.max_tokens = max_tokens

    def generate(self, prompt):
        response = openai.Completion.create(
            engine="text-davinci-003",
            prompt=prompt,
            temperature=self.temperature,
            max_tokens=self.max_tokens
        )
        return response.choices[0].text.strip()

# 使用 LangChain 的 LLMChain
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.embeddings import Embeddings

# 定义一个模板,将 BERT 的嵌入作为 GPT-3 的输入
prompt_template = PromptTemplate(
    template="User input embedding: {embedding}\nGenerate response:",
    input_variables=["embedding"]
)

# 实现一个嵌入模型类,用于在 LLMChain 中使用
class EmbeddingsWrapper(Embeddings):
    def __init__(self, embedder):
        self.embedder = embedder

    def embed(self, text):
        return self.embedder.embed(text)

# 初始化本地 BERT 嵌入模型和 GPT-3 生成模型
bert_embedder = BERTEmbedder(model_path='/path/to/local/model')
gpt3_llm = GPT3LLM()

# 包装 BERT 嵌入模型
embedding_wrapper = EmbeddingsWrapper(bert_embedder)

# 创建 LLMChain
llm_chain = LLMChain(
    prompt_template=prompt_template,
    llm=gpt3_llm,
    embeddings=embedding_wrapper
)

# 用户输入
user_input = "The quick brown fox jumps over the lazy dog."

# 处理用户输入并生成输出
embedding = bert_embedder.embed(user_input)
output = llm_chain.run({"embedding": embedding})
print(f"Output: {output}")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72

2. 解释

  • AutoTokenizer 和 AutoModel:使用 AutoTokenizer.from_pretrainedAutoModel.from_pretrained,可以自动加载适当的分词器和模型,而不需要明确指定模型类型(例如,BERT、GPT-2等)。
  • model_path:加载本地路径中的模型和分词器。
  • BERTEmbedder 类:利用 AutoTokenizerAutoModel 处理文本并生成嵌入。
  • LLMChain:与之前相同,使用 LangChain 的 LLMChain 来整合 BERT 嵌入和 GPT-3 生成。

通过这种方式,可以更加简洁和通用地加载模型,适应不同的模型配置,而无需修改代码。

二、怎么实现自动选择的呢

AutoTokenizerAutoModeltransformers 库中的自动化工具,它们可以根据模型目录中的配置文件(通常是 config.json 文件)来自动识别并加载适当的模型和分词器。这意味着你在使用它们时,只需提供模型所在的目录路径,它们会根据配置文件确定使用哪个具体的模型类型和分词器。

为了确保 AutoTokenizerAutoModel 能正确加载模型,目录结构应包含必要的文件,例如:


/path/to/local/model/
    ├── config.json
    ├── pytorch_model.bin
    ├── vocab.txt
    └── tokenizer_config.json

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • config.json:模型配置文件,包含模型架构和超参数信息。
  • pytorch_model.bin:预训练的模型权重。
  • vocab.txt:分词器词汇表。
  • tokenizer_config.json:分词器配置文件。

总结

通过这种方式,我们可以使用 AutoTokenizerAutoModel 更加简洁和通用地加载模型,适应不同的模型配置,而无需修改代码。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/黑客灵魂/article/detail/821652
推荐阅读
相关标签
  

闽ICP备14008679号