赞
踩
- def __init__(self, model, tokenizer, chat_prompt_template=None, run_prompt_template=None, additional_tools=None):
- self.model = model
- self.tokenizer = tokenizer
- super().__init__(
- chat_prompt_template=chat_prompt_template,
- run_prompt_template=run_prompt_template,
- additional_tools=additional_tools,
- )
-
- @classmethod
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
- """
- Convenience method to build a `LocalAgent` from a pretrained checkpoint.
- Args:
- pretrained_model_name_or_path (`str` or `os.PathLike`):
- The name of a repo on the Hub or a local path to a folder containing both model and tokenizer.
- kwargs:
- Keyword arguments passed along to [`~PreTrainedModel.from_pretrained`].
- Example:
- ```py
- import torch
- from transformers import LocalAgent
- agent = LocalAgent.from_pretrained("bigcode/starcoder", device_map="auto", torch_dtype=torch.bfloat16)
- agent.run("Draw me a picture of rivers and lakes.")
- ```
- """
- model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs)
- tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
- return cls(model, tokenizer)
-
- @property
- def _model_device(self):
- if hasattr(self.model, "hf_device_map"):
- return list(self.model.hf_device_map.values())[0]
- for param in self.mode.parameters():
- return param.device
-
- def generate_one(self, prompt, stop):
- encoded_inputs = self.tokenizer(prompt, return_tensors="pt").to(self._model_device)
- src_len = encoded_inputs["input_ids"].shape[1]
- stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(stop, self.tokenizer)])
- outputs = self.model.generate(
- encoded_inputs["input_ids"], max_new_tokens=200, stopping_criteria=stopping_criteria
- )
-
- result = self.tokenizer.decode(outputs[0].tolist()[src_len:])
- # Inference API returns the stop sequence
- for stop_seq in stop:
- if result.endswith(stop_seq):
- result = result[: -len(stop_seq)]
- return result
-
-
- class StopSequenceCriteria(StoppingCriteria):
- """
- This class can be used to stop generation whenever a sequence of tokens is encountered.
- Args:
- stop_sequences (`str` or `List[str]`):
- The sequence (or list of sequences) on which to stop execution.
- tokenizer:
- The tokenizer used to decode the model outputs.
- """
-
- def __init__(self, stop_sequences, tokenizer):
- if isinstance(stop_sequences, str):
- stop_sequences = [stop_sequences]
- self.stop_sequences = stop_sequences
- self.tokenizer = tokenizer
-
- def __call__(self, input_ids, scores, **kwargs) -> bool:
- decoded_output = self.tokenizer.decode(input_ids.tolist()[0])
- return any(decoded_output.endswith(stop_sequence) for stop_sequence in self.stop_sequences)

这段代码主要定义了两个类:一个是LocalAgent,另一个是StopSequenceCriteria。下面我会逐行解释代码。
StopSequenceCriteria
类的定义。def __init__(self, model, tokenizer, chat_prompt_template=None, run_prompt_template=None, additional_tools=None):
这是LocalAgent类的构造函数,接受五个参数。model和tokenizer是用于生成文本的模型和分词器。chat_prompt_template、run_prompt_template和additional_tools这三个参数是可选的,分别代表聊天提示模板、运行提示模板和附加工具。
self.model = model
和 self.tokenizer = tokenizer
这两行将输入的model和tokenizer保存为LocalAgent对象的属性。
super().__init__(chat_prompt_template=chat_prompt_template, run_prompt_template=run_prompt_template, additional_tools=additional_tools)
调用父类的构造函数,传入chat_prompt_template、run_prompt_template和additional_tools。
@classmethod
这是一个修饰器,表示下面的方法是类方法。
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
这是一个类方法,用于从预训练模型中创建LocalAgent对象。pretrained_model_name_or_path是预训练模型的名称或者路径,**kwargs是其他的关键字参数。
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs)
和 tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
从预训练模型中加载模型和分词器。
return cls(model, tokenizer)
使用加载的模型和分词器创建LocalAgent对象,并返回。
@property
是一个修饰器,表示下面的方法是一个属性。
def _model_device(self):
这个方法返回模型的设备。
if hasattr(self.model, "hf_device_map"):
判断模型是否有hf_device_map属性。
return list(self.model.hf_device_map.values())[0]
如果有hf_device_map属性,返回第一个设备。
for param in self.mode.parameters():
如果没有hf_device_map属性,遍历模型的参数。
return param.device
返回第一个参数的设备。
def generate_one(self, prompt, stop):
这个方法用于根据给定的提示生成一段文本。prompt是提示,stop是停止标志。
encoded_inputs = self.tokenizer(prompt, return_tensors="pt").to(self._model_device)
对提示进行编码,并将编码结果移动到模型的设备上。
src_len = encoded_inputs["input_ids"].shape[1]
获取输入的长度。
stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(stop, self.tokenizer)])
创建停止条件,当生成的文本包含停止标志时停止生成。
outputs = self.model.generate(encoded_inputs["input_ids"], max_new_tokens=200, stopping_criteria=stopping_criteria)
生成文本。
result = self.tokenizer.decode(outputs[0].tolist()[src_len:])
解码生成的文本。
for stop_seq in20.
for stop_seq in stop:` 遍历每一个停止序列。
if result.endswith(stop_seq):
检查生成的结果是否以停止序列结尾。
result = result[: -len(stop_seq)]
如果是,则将这个停止序列从结果中去掉。
return result
返回生成的结果。
class StopSequenceCriteria(StoppingCriteria):
定义一个名为StopSequenceCriteria的类,它继承自StoppingCriteria。这个类用于在生成过程中遇到特定序列时停止生成。
def __init__(self, stop_sequences, tokenizer):
这是StopSequenceCriteria的构造函数,接受两个参数:停止序列和分词器。
if isinstance(stop_sequences, str):
如果stop_sequences是字符串,那么将其转化为列表。
stop_sequences = [stop_sequences]
self.stop_sequences = stop_sequences
和 self.tokenizer = tokenizer
将输入的停止序列和分词器保存为StopSequenceCriteria对象的属性。
def __call__(self, input_ids, scores, **kwargs) -> bool:
定义了该类的调用方法,输入参数为输入的id、得分以及其他关键字参数,返回值是布尔值。
decoded_output = self.tokenizer.decode(input_ids.tolist()[0])
将输入的id解码为文本。
return any(decoded_output.endswith(stop_sequence) for stop_sequence in self.stop_sequences)
如果解码出的文本以任何一个停止序列结尾,那么返回True,否则返回False。
对最后两个函数有更深入的理解:
generate_one
是 LocalAgent
类的一个实例方法。这个方法通过使用实例的 tokenizer
对输入的提示 prompt
进行编码,然后调用模型的 generate
方法生成新的文本。生成的文本通过使用 StopSequenceCriteria
停止条件进行控制,如果生成的文本满足停止条件(即包含某个特定序列),则停止生成新的文本。生成的新文本通过 tokenizer
解码成字符串。如果解码出的结果以任何一个停止序列结尾,那么该停止序列将被去掉。最后,方法返回生成的结果。
StopSequenceCriteria
是一个继承自 StoppingCriteria
的子类。它重写了父类的 __call__
方法。在这个方法中,输入的 id 通过使用实例的 tokenizer
解码为字符串,然后检查解码出的字符串是否以任何一个停止序列结尾,如果是,则返回 True
,否则返回 False
。这个返回值会被用来决定是否需要停止生成新的文本。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。