当前位置:   article > 正文

Transformers 源码解析(七十三)

Transformers 源码解析(七十三)

.\models\deberta_v2\tokenization_deberta_v2.py

# coding=utf-8
# Copyright 2020 Microsoft and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tokenization class for model DeBERTa.
"""

import os  # 导入标准库os,用于处理操作系统相关功能
import unicodedata  # 导入unicodedata库,用于Unicode字符数据库的访问
from typing import Any, Dict, List, Optional, Tuple  # 导入类型提示相关的库

import sentencepiece as sp  # 导入sentencepiece库,用于分词模型的处理

from ...tokenization_utils import AddedToken, PreTrainedTokenizer  # 导入自定义模块中的类和函数
from ...utils import logging  # 从自定义模块中导入logging模块

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器对象

# 预定义的词汇文件映射,指定不同预训练模型的SentencePiece模型文件的下载链接
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "microsoft/deberta-v2-xlarge": "https://huggingface.co/microsoft/deberta-v2-xlarge/resolve/main/spm.model",
        "microsoft/deberta-v2-xxlarge": "https://huggingface.co/microsoft/deberta-v2-xxlarge/resolve/main/spm.model",
        "microsoft/deberta-v2-xlarge-mnli": (
            "https://huggingface.co/microsoft/deberta-v2-xlarge-mnli/resolve/main/spm.model"
        ),
        "microsoft/deberta-v2-xxlarge-mnli": (
            "https://huggingface.co/microsoft/deberta-v2-xxlarge-mnli/resolve/main/spm.model"
        ),
    }
}

# 预定义的位置嵌入大小映射,指定不同预训练模型的位置嵌入大小
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "microsoft/deberta-v2-xlarge": 512,
    "microsoft/deberta-v2-xxlarge": 512,
    "microsoft/deberta-v2-xlarge-mnli": 512,
    "microsoft/deberta-v2-xxlarge-mnli": 512,
}

# 预定义的初始化配置映射,指定不同预训练模型的初始化配置
PRETRAINED_INIT_CONFIGURATION = {
    "microsoft/deberta-v2-xlarge": {"do_lower_case": False},
    "microsoft/deberta-v2-xxlarge": {"do_lower_case": False},
    "microsoft/deberta-v2-xlarge-mnli": {"do_lower_case": False},
    "microsoft/deberta-v2-xxlarge-mnli": {"do_lower_case": False},
}

# 词汇文件名称映射,指定模型的SentencePiece模型文件名称
VOCAB_FILES_NAMES = {"vocab_file": "spm.model"}


class DebertaV2Tokenizer(PreTrainedTokenizer):
    r"""
    Constructs a DeBERTa-v2 tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).

    """

    vocab_files_names = VOCAB_FILES_NAMES  # 设置词汇文件名称映射
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP  # 设置预训练模型的词汇文件映射
    pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION  # 设置预训练模型的初始化配置
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES  # 设置模型的最大输入大小

    def __init__(
        self,
        vocab_file,
        do_lower_case=False,
        split_by_punct=False,
        bos_token="[CLS]",
        eos_token="[SEP]",
        unk_token="[UNK]",
        sep_token="[SEP]",
        pad_token="[PAD]",
        cls_token="[CLS]",
        mask_token="[MASK]",
        sp_model_kwargs: Optional[Dict[str, Any]] = None,
        **kwargs,
    ):
        """
        Initialize a DebertaV2Tokenizer with essential parameters.

        Args:
            vocab_file (str): The vocabulary file path.
            do_lower_case (bool): Whether to convert tokens to lowercase.
            split_by_punct (bool): Whether to split tokens by punctuation.
            bos_token (str): Beginning of sequence token.
            eos_token (str): End of sequence token.
            unk_token (str): Token for unknown or unrecognized tokens.
            sep_token (str): Separator token.
            pad_token (str): Token used for padding sequences.
            cls_token (str): Classification token.
            mask_token (str): Mask token for masked language modeling.
            sp_model_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments for SentencePiece model.
            **kwargs: Additional keyword arguments.

        """
        super().__init__(
            bos_token=bos_token,
            eos_token=eos_token,
            unk_token=unk_token,
            sep_token=sep_token,
            pad_token=pad_token,
            cls_token=cls_token,
            mask_token=mask_token,
            **kwargs,
        )
        self.vocab_file = vocab_file  # 设置词汇文件路径
        self.do_lower_case = do_lower_case  # 设置是否将词汇转换为小写
        self.split_by_punct = split_by_punct  # 设置是否按标点符号分割词汇
        self.sp_model_kwargs = sp_model_kwargs if sp_model_kwargs is not None else {}  # 设置SentencePiece模型的额外参数
    ) -> None:
        # 初始化一个空字典作为分词模型的参数,如果没有指定则使用空字典
        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs

        # 检查给定的词汇文件路径是否是一个文件,如果不是则抛出数值错误异常
        if not os.path.isfile(vocab_file):
            raise ValueError(
                f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
                " model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
            )

        # 设置是否小写化文本的标志
        self.do_lower_case = do_lower_case
        # 设置是否通过标点符号分割的标志
        self.split_by_punct = split_by_punct
        # 设置词汇文件路径
        self.vocab_file = vocab_file

        # 使用SPMTokenizer初始化分词器,传入词汇文件路径、None作为模型路径、是否通过标点符号分割的标志、以及分词模型参数字典
        self._tokenizer = SPMTokenizer(
            vocab_file, None, split_by_punct=split_by_punct, sp_model_kwargs=self.sp_model_kwargs
        )

        # 如果unk_token是字符串类型,则创建一个AddedToken对象,标记为特殊且已规范化;否则直接使用unk_token
        unk_token = AddedToken(unk_token, normalized=True, special=True) if isinstance(unk_token, str) else unk_token

        # 调用父类的初始化方法,设置分词器的各种特殊标记以及其他关键字参数
        super().__init__(
            do_lower_case=do_lower_case,
            bos_token=bos_token,
            eos_token=eos_token,
            unk_token=unk_token,
            sep_token=sep_token,
            pad_token=pad_token,
            cls_token=cls_token,
            mask_token=mask_token,
            split_by_punct=split_by_punct,
            sp_model_kwargs=self.sp_model_kwargs,
            **kwargs,
        )

        # 将特殊标记列表赋值给分词器的特殊标记属性
        self._tokenizer.special_tokens = self.all_special_tokens

    @property
    def vocab_size(self):
        # 返回分词器词汇表的大小(词汇表的长度)
        return len(self.vocab)

    @property
    def vocab(self):
        # 返回分词器的词汇表
        return self._tokenizer.vocab

    def get_vocab(self):
        # 获取分词器的完整词汇表,包括额外添加的词汇
        vocab = self.vocab.copy()
        vocab.update(self.get_added_vocab())
        return vocab

    def _tokenize(self, text: str) -> List[str]:
        """Take as input a string and return a list of strings (tokens) for words/sub-words"""
        # 如果设定为小写化,则将输入文本转换为小写
        if self.do_lower_case:
            text = text.lower()
        # 调用分词器的tokenize方法,将文本分词为字符串列表(token列表)
        return self._tokenizer.tokenize(text)

    def _convert_token_to_id(self, token):
        """Converts a token (str) in an id using the vocab."""
        # 使用分词器的spm对象将token(字符串)转换为对应的id(整数)
        return self._tokenizer.spm.PieceToId(token)

    def _convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        # 使用分词器的spm对象将id(整数)转换为对应的token(字符串),如果id超出词汇表大小,则返回unk_token
        return self._tokenizer.spm.IdToPiece(index) if index < self.vocab_size else self.unk_token

    def convert_tokens_to_string(self, tokens):
        """Converts a sequence of tokens (string) in a single string."""
        # 使用分词器的decode方法将token序列(字符串列表)转换为单个字符串
        return self._tokenizer.decode(tokens)
    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
        """
        从序列或序列对中构建模型输入,用于序列分类任务,通过连接和添加特殊标记。DeBERTa 序列的格式如下:

        - 单个序列:[CLS] X [SEP]
        - 序列对:[CLS] A [SEP] B [SEP]

        Args:
            token_ids_0 (`List[int]`):
                要添加特殊标记的 ID 列表。
            token_ids_1 (`List[int]`, *可选*):
                第二个序列的 ID 列表,用于序列对输入。

        Returns:
            `List[int]`: 包含适当特殊标记的输入 ID 列表。
        """

        if token_ids_1 is None:
            # 返回只含有一个序列的特殊标记的输入列表
            return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
        cls = [self.cls_token_id]
        sep = [self.sep_token_id]
        # 返回包含序列对的特殊标记的输入列表
        return cls + token_ids_0 + sep + token_ids_1 + sep

    def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
        """
        从没有添加特殊标记的标记列表中检索序列 ID。当使用 tokenizer 的 `prepare_for_model` 或 `encode_plus` 方法添加特殊标记时调用此方法。

        Args:
            token_ids_0 (`List[int]`):
                ID 列表。
            token_ids_1 (`List[int]`, *可选*):
                第二个序列的 ID 列表,用于序列对输入。
            already_has_special_tokens (`bool`, *可选*, 默认为 `False`):
                标记列表是否已经格式化为模型所需的特殊标记。

        Returns:
            `List[int]`: 一个整数列表,取值为 [0, 1]:1 表示特殊标记,0 表示序列标记。
        """

        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )

        if token_ids_1 is not None:
            # 返回包含序列对特殊标记掩码的列表
            return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
        # 返回只包含单个序列特殊标记掩码的列表
        return [1] + ([0] * len(token_ids_0)) + [1]
    def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
        """
        Create a mask from the two sequences passed to be used in a sequence-pair classification task. A DeBERTa
        sequence pair mask has the following format:

        ```
        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
        | first sequence    | second sequence |
        ```

        If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
        """
        # Define separator and classification token IDs
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]
        
        # If only one sequence is provided
        if token_ids_1 is None:
            # Return token type IDs for single sequence (all 0s)
            return len(cls + token_ids_0 + sep) * [0]
        
        # Return token type IDs for two sequences (0s for first sequence, 1s for second sequence)
        return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]

    def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
        # Extract 'add_prefix_space' from kwargs
        add_prefix_space = kwargs.pop("add_prefix_space", False)
        
        # Add prefix space if required
        if is_split_into_words or add_prefix_space:
            text = " " + text
        
        # Return text and remaining kwargs
        return (text, kwargs)

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # Save vocabulary using the underlying tokenizer's method
        return self._tokenizer.save_pretrained(save_directory, filename_prefix=filename_prefix)
    r"""
    Constructs a tokenizer based on [SentencePiece](https://github.com/google/sentencepiece).

    Args:
        vocab_file (`str`):
            [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
            contains the vocabulary necessary to instantiate a tokenizer.
        sp_model_kwargs (`dict`, *optional*):
            Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
            SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
            to set:

            - `enable_sampling`: Enable subword regularization.
            - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.

              - `nbest_size = {0,1}`: No sampling is performed.
              - `nbest_size > 1`: samples from the nbest_size results.
              - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
                using forward-filtering-and-backward-sampling algorithm.

            - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
              BPE-dropout.
    """

    def __init__(
        self, vocab_file, special_tokens, split_by_punct=False, sp_model_kwargs: Optional[Dict[str, Any]] = None
    ):
        # 是否按标点符号进行分割
        self.split_by_punct = split_by_punct
        # 词汇文件路径
        self.vocab_file = vocab_file
        # SentencePiece 参数,如果未提供则为空字典
        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
        # 使用给定的参数初始化 SentencePieceProcessor 对象
        spm = sp.SentencePieceProcessor(**self.sp_model_kwargs)
        # 检查词汇文件是否存在,不存在则抛出 FileNotFoundError 异常
        if not os.path.exists(vocab_file):
            raise FileNotFoundError(f"{vocab_file} does not exist!")
        # 加载词汇文件到 SentencePieceProcessor 对象
        spm.load(vocab_file)
        # 获取 BPE 词汇表大小
        bpe_vocab_size = spm.GetPieceSize()
        # 构建词汇映射表
        self.vocab = {spm.IdToPiece(i): i for i in range(bpe_vocab_size)}
        # 根据编号获取词汇表
        self.ids_to_tokens = [spm.IdToPiece(i) for i in range(bpe_vocab_size)]
        # 设置特殊标记(未使用的代码段)
        # self.vocab['[PAD]'] = 0
        # self.vocab['[CLS]'] = 1
        # self.vocab['[SEP]'] = 2
        # self.vocab['[UNK]'] = 3

        # 保存 SentencePieceProcessor 对象和特殊标记
        self.spm = spm
        self.special_tokens = special_tokens

    def __getstate__(self):
        # 复制当前对象的状态,但将 spm 属性置为 None
        state = self.__dict__.copy()
        state["spm"] = None
        return state

    def __setstate__(self, d):
        # 恢复对象状态
        self.__dict__ = d

        # 为了向后兼容性
        if not hasattr(self, "sp_model_kwargs"):
            self.sp_model_kwargs = {}

        # 重新初始化 SentencePieceProcessor 对象并加载词汇文件
        self.spm = sp.SentencePieceProcessor(**self.sp_model_kwargs)
        self.spm.Load(self.vocab_file)

    def tokenize(self, text):
        # 使用 SentencePiece 对文本进行分词
        return self._encode_as_pieces(text)

    def convert_ids_to_tokens(self, ids):
        # 将编号转换为对应的标记
        tokens = []
        for i in ids:
            tokens.append(self.ids_to_tokens[i])
        return tokens
    # 解码给定的 token 序列成原始文本。如果 raw_text 为 None,则根据 tokens 进行解码;否则根据 raw_text 进行解码。
    def decode(self, tokens, start=-1, end=-1, raw_text=None):
        if raw_text is None:
            current_sub_tokens = []  # 存储当前正在处理的子 token 序列
            out_string = ""  # 存储最终解码的文本字符串
            prev_is_special = False  # 标记前一个 token 是否为特殊 token
            for token in tokens:
                # 如果 token 是特殊 token,则不使用 sentencepiece 模型解码
                if token in self.special_tokens:
                    if not prev_is_special:
                        out_string += " "  # 如果前一个 token 不是特殊 token,则添加空格分隔
                    out_string += self.spm.decode_pieces(current_sub_tokens) + token  # 解码当前子 token 序列并添加当前 token
                    prev_is_special = True
                    current_sub_tokens = []  # 清空当前子 token 序列,准备处理下一个特殊 token
                else:
                    current_sub_tokens.append(token)  # 将 token 添加到当前子 token 序列中
                    prev_is_special = False
            out_string += self.spm.decode_pieces(current_sub_tokens)  # 解码剩余的子 token 序列并添加到最终文本中
            return out_string.strip()  # 返回去除首尾空格的最终文本
        else:
            words = self.split_to_words(raw_text)  # 根据原始文本分割成单词列表
            word_tokens = [self.tokenize(w) for w in words]  # 对每个单词进行分词得到 token 序列
            token2words = [0] * len(tokens)  # 创建一个与 tokens 等长的列表,用于映射 token 到单词索引
            tid = 0
            for i, w in enumerate(word_tokens):
                for k, t in enumerate(w):
                    token2words[tid] = i  # 将 token 的索引映射到对应的单词索引
                    tid += 1
            word_start = token2words[start]  # 获取起始 token 对应的单词索引
            word_end = token2words[end] if end < len(tokens) else len(words)  # 获取结束 token 对应的单词索引,如果超出 tokens 则取单词列表的末尾
            text = "".join(words[word_start:word_end])  # 根据单词索引拼接原始文本
            return text  # 返回拼接后的文本

    # 添加特殊 token 到 tokenizer 中,如果 token 不存在于特殊 token 列表中,则添加,并更新词汇表和 id 到 token 的映射
    def add_special_token(self, token):
        if token not in self.special_tokens:
            self.special_tokens.append(token)  # 将新的特殊 token 添加到列表中
            if token not in self.vocab:
                self.vocab[token] = len(self.vocab) - 1  # 将新的 token 添加到词汇表中
                self.ids_to_tokens.append(token)  # 更新 id 到 token 的映射
        return self.id(token)  # 返回特殊 token 对应的 id

    # 判断 token 是否为整个单词的一部分。如果 is_bos 为 True,则始终返回 True;否则根据 token 的首字符判断是否为单词的一部分。
    def part_of_whole_word(self, token, is_bos=False):
        logger.warning_once(
            "The `DebertaTokenizer.part_of_whole_word` method is deprecated and will be removed in `transformers==4.35`"
        )
        if is_bos:
            return True
        if (
            len(token) == 1
            and (_is_whitespace(list(token)[0]) or _is_control(list(token)[0]) or _is_punctuation(list(token)[0]))
        ) or token in self.special_tokens:
            return False

        word_start = b"\xe2\x96\x81".decode("utf-8")
        return not token.startswith(word_start)  # 判断 token 是否以词的起始字符开头

    # 返回填充 token
    def pad(self):
        return "[PAD]"

    # 返回文本的开头 token
    def bos(self):
        return "[CLS]"

    # 返回文本的结尾 token
    def eos(self):
        return "[SEP]"

    # 返回未知 token
    def unk(self):
        return "[UNK]"

    # 返回掩码 token
    def mask(self):
        return "[MASK]"

    # 根据 id 返回对应的 token
    def sym(self, id):
        return self.ids_to_tokens[id]

    # 根据 token 返回对应的 id,如果 token 不在词汇表中则返回默认 id 1
    def id(self, sym):
        logger.warning_once(
            "The `DebertaTokenizer.id` method is deprecated and will be removed in `transformers==4.35`"
        )
        return self.vocab[sym] if sym in self.vocab else 1
    # 将输入文本转换为Unicode格式
    def _encode_as_pieces(self, text):
        text = convert_to_unicode(text)
        
        # 如果设置了按标点符号分割,则在文本上运行标点符号分割
        if self.split_by_punct:
            words = self._run_split_on_punc(text)
            # 对每个分割后的单词进行SPM编码,并转换为字符串列表
            pieces = [self.spm.encode(w, out_type=str) for w in words]
            # 展平嵌套列表,将编码后的片段放入一个列表中
            return [p for w in pieces for p in w]
        else:
            # 否则直接对整个文本进行SPM编码
            return self.spm.encode(text, out_type=str)

    # 将文本分割成单词
    def split_to_words(self, text):
        pieces = self._encode_as_pieces(text)
        # 定义用于标记单词开始的特殊字符
        word_start = b"\xe2\x96\x81".decode("utf-8")
        words = []
        offset = 0
        prev_end = 0
        
        # 遍历编码后的片段
        for i, p in enumerate(pieces):
            # 如果片段以单词开始字符开头
            if p.startswith(word_start):
                # 如果当前偏移量大于上一个单词结束的位置
                if offset > prev_end:
                    # 将上一个单词的内容添加到单词列表中
                    words.append(text[prev_end:offset])
                prev_end = offset
                # 移除单词开始字符,获取真正的单词内容
                w = p.replace(word_start, "")
            else:
                w = p
            
            try:
                # 在文本中查找当前单词的起始位置
                s = text.index(w, offset)
                pn = ""
                k = i + 1
                # 查找下一个非空白片段
                while k < len(pieces):
                    pn = pieces[k].replace(word_start, "")
                    if len(pn) > 0:
                        break
                    k += 1
                
                # 如果下一个片段非空且在当前单词范围内,则增加偏移量
                if len(pn) > 0 and pn in text[offset:s]:
                    offset = offset + 1
                else:
                    offset = s + len(w)
            except Exception:
                offset = offset + 1
        
        # 添加最后一个单词到单词列表中
        if prev_end < offset:
            words.append(text[prev_end:offset])
        
        return words

    # 在文本上运行标点符号分割
    def _run_split_on_punc(self, text):
        """Splits punctuation on a piece of text."""
        chars = list(text)
        i = 0
        start_new_word = True
        output = []
        
        # 遍历文本中的每个字符
        while i < len(chars):
            char = chars[i]
            # 如果当前字符是标点符号,则开始一个新单词
            if _is_punctuation(char):
                output.append([char])
                start_new_word = True
            else:
                # 否则将字符添加到当前单词的最后一个片段中
                if start_new_word:
                    output.append([])
                start_new_word = False
                output[-1].append(char)
            i += 1
        
        # 将分割后的列表中的子列表连接成字符串,并返回结果
        return ["".join(x) for x in output]

    # 将当前模型保存到指定路径下
    def save_pretrained(self, path: str, filename_prefix: str = None):
        # 获取保存的文件名
        filename = VOCAB_FILES_NAMES[list(VOCAB_FILES_NAMES.keys())[0]]
        if filename_prefix is not None:
            filename = filename_prefix + "-" + filename
        
        # 拼接保存文件的完整路径
        full_path = os.path.join(path, filename)
        
        # 将序列化后的模型写入文件
        with open(full_path, "wb") as fs:
            fs.write(self.spm.serialized_model_proto())
        
        # 返回保存的文件路径
        return (full_path,)
# 检查字符是否为空白字符
def _is_whitespace(char):
    # 如果字符是空格、制表符、换行符或回车符,则返回 True
    if char == " " or char == "\t" or char == "\n" or char == "\r":
        return True
    # 获取字符的 Unicode 分类
    cat = unicodedata.category(char)
    # 如果分类是 Zs(空格分隔符),则返回 True
    if cat == "Zs":
        return True
    # 否则返回 False
    return False


# 检查字符是否为控制字符
def _is_control(char):
    # 如果字符是制表符、换行符或回车符,则返回 False
    if char == "\t" or char == "\n" or char == "\r":
        return False
    # 获取字符的 Unicode 分类
    cat = unicodedata.category(char)
    # 如果分类以 C 开头(控制字符),则返回 True
    if cat.startswith("C"):
        return True
    # 否则返回 False
    return False


# 检查字符是否为标点符号
def _is_punctuation(char):
    # 获取字符的 Unicode 码点
    cp = ord(char)
    # 检查是否在 ASCII 范围内的非字母/数字字符,认定为标点符号
    if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
        return True
    # 获取字符的 Unicode 分类
    cat = unicodedata.category(char)
    # 如果分类以 P 开头(标点字符),则返回 True
    if cat.startswith("P"):
        return True
    # 否则返回 False
    return False


# 将文本转换为 Unicode 编码(如果尚未)
def convert_to_unicode(text):
    # 如果输入已经是字符串,则直接返回
    if isinstance(text, str):
        return text
    # 如果输入是字节流,则用 UTF-8 解码为字符串并忽略错误
    elif isinstance(text, bytes):
        return text.decode("utf-8", "ignore")
    # 如果输入既不是字符串也不是字节流,则引发异常
    else:
        raise ValueError(f"Unsupported string type: {type(text)}")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345
  • 346
  • 347
  • 348
  • 349
  • 350
  • 351
  • 352
  • 353
  • 354
  • 355
  • 356
  • 357
  • 358
  • 359
  • 360
  • 361
  • 362
  • 363
  • 364
  • 365
  • 366
  • 367
  • 368
  • 369
  • 370
  • 371
  • 372
  • 373
  • 374
  • 375
  • 376
  • 377
  • 378
  • 379
  • 380
  • 381
  • 382
  • 383
  • 384
  • 385
  • 386
  • 387
  • 388
  • 389
  • 390
  • 391
  • 392
  • 393
  • 394
  • 395
  • 396
  • 397
  • 398
  • 399
  • 400
  • 401
  • 402
  • 403
  • 404
  • 405
  • 406
  • 407
  • 408
  • 409
  • 410
  • 411
  • 412
  • 413
  • 414
  • 415
  • 416
  • 417
  • 418
  • 419
  • 420
  • 421
  • 422
  • 423
  • 424
  • 425
  • 426
  • 427
  • 428
  • 429
  • 430
  • 431
  • 432
  • 433
  • 434
  • 435
  • 436
  • 437
  • 438
  • 439
  • 440
  • 441
  • 442
  • 443
  • 444
  • 445
  • 446
  • 447
  • 448
  • 449
  • 450
  • 451
  • 452
  • 453
  • 454
  • 455
  • 456
  • 457
  • 458
  • 459
  • 460
  • 461
  • 462
  • 463
  • 464
  • 465
  • 466
  • 467
  • 468
  • 469
  • 470
  • 471
  • 472
  • 473
  • 474
  • 475
  • 476
  • 477
  • 478
  • 479
  • 480
  • 481
  • 482
  • 483
  • 484
  • 485
  • 486
  • 487
  • 488
  • 489
  • 490
  • 491
  • 492
  • 493
  • 494
  • 495
  • 496
  • 497
  • 498
  • 499
  • 500
  • 501
  • 502
  • 503
  • 504
  • 505
  • 506
  • 507
  • 508
  • 509
  • 510
  • 511
  • 512
  • 513
  • 514
  • 515
  • 516
  • 517
  • 518
  • 519
  • 520
  • 521
  • 522
  • 523
  • 524
  • 525
  • 526
  • 527
  • 528
  • 529
  • 530
  • 531
  • 532
  • 533
  • 534
  • 535
  • 536
  • 537
  • 538
  • 539
  • 540
  • 541
  • 542
  • 543
  • 544
  • 545
  • 546
  • 547
  • 548
  • 549
  • 550
  • 551
  • 552
  • 553
  • 554
  • 555
  • 556
  • 557
  • 558
  • 559
  • 560
  • 561
  • 562
  • 563
  • 564
  • 565
  • 566
  • 567
  • 568
  • 569
  • 570
  • 571
  • 572
  • 573
  • 574
  • 575
  • 576
  • 577
  • 578
  • 579
  • 580
  • 581
  • 582
  • 583
  • 584
  • 585
  • 586
  • 587
  • 588
  • 589
  • 590
  • 591
  • 592
  • 593
  • 594
  • 595
  • 596
  • 597
  • 598
  • 599
  • 600
  • 601
  • 602
  • 603
  • 604
  • 605
  • 606
  • 607
  • 608
  • 609
  • 610
  • 611
  • 612
  • 613
  • 614
  • 615
  • 616
  • 617
  • 618
  • 619
  • 620
  • 621
  • 622
  • 623
  • 624
  • 625
  • 626
  • 627
  • 628
  • 629
  • 630
  • 631
  • 632
  • 633
  • 634

.\models\deberta_v2\tokenization_deberta_v2_fast.py

# coding=utf-8
# Copyright 2020 Microsoft and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Tokenization class for model DeBERTa."""

import os
from shutil import copyfile
from typing import Optional, Tuple

from ...file_utils import is_sentencepiece_available
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging

# 检查是否安装了 SentencePiece 库
if is_sentencepiece_available():
    # 如果安装了,从本地导入 DebertaV2Tokenizer 类
    from .tokenization_deberta_v2 import DebertaV2Tokenizer
else:
    # 如果未安装,将 DebertaV2Tokenizer 设置为 None
    DebertaV2Tokenizer = None

# 获取日志记录器对象
logger = logging.get_logger(__name__)

# 定义词汇文件的名称映射
VOCAB_FILES_NAMES = {"vocab_file": "spm.model", "tokenizer_file": "tokenizer.json"}

# 预训练模型的词汇文件映射
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "microsoft/deberta-v2-xlarge": "https://huggingface.co/microsoft/deberta-v2-xlarge/resolve/main/spm.model",
        "microsoft/deberta-v2-xxlarge": "https://huggingface.co/microsoft/deberta-v2-xxlarge/resolve/main/spm.model",
        "microsoft/deberta-v2-xlarge-mnli": (
            "https://huggingface.co/microsoft/deberta-v2-xlarge-mnli/resolve/main/spm.model"
        ),
        "microsoft/deberta-v2-xxlarge-mnli": (
            "https://huggingface.co/microsoft/deberta-v2-xxlarge-mnli/resolve/main/spm.model"
        ),
    }
}

# 预训练模型的位置嵌入大小映射
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "microsoft/deberta-v2-xlarge": 512,
    "microsoft/deberta-v2-xxlarge": 512,
    "microsoft/deberta-v2-xlarge-mnli": 512,
    "microsoft/deberta-v2-xxlarge-mnli": 512,
}

# 预训练模型的初始化配置映射
PRETRAINED_INIT_CONFIGURATION = {
    "microsoft/deberta-v2-xlarge": {"do_lower_case": False},
    "microsoft/deberta-v2-xxlarge": {"do_lower_case": False},
    "microsoft/deberta-v2-xlarge-mnli": {"do_lower_case": False},
    "microsoft/deberta-v2-xxlarge-mnli": {"do_lower_case": False},
}


class DebertaV2TokenizerFast(PreTrainedTokenizerFast):
    r"""
    Constructs a DeBERTa-v2 fast tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).

    """

    # 设置词汇文件的名称映射
    vocab_files_names = VOCAB_FILES_NAMES
    # 设置预训练模型的词汇文件映射
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    # 设置预训练模型的初始化配置映射
    pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
    # 设置预训练模型的最大输入大小映射
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    # 设置慢速 tokenizer 的类为 DebertaV2Tokenizer
    slow_tokenizer_class = DebertaV2Tokenizer
    def __init__(
        self,
        vocab_file=None,
        tokenizer_file=None,
        do_lower_case=False,
        split_by_punct=False,
        bos_token="[CLS]",
        eos_token="[SEP]",
        unk_token="[UNK]",
        sep_token="[SEP]",
        pad_token="[PAD]",
        cls_token="[CLS]",
        mask_token="[MASK]",
        **kwargs,
    ) -> None:
        # 调用父类的初始化方法,传入参数进行初始化
        super().__init__(
            vocab_file,
            tokenizer_file=tokenizer_file,
            do_lower_case=do_lower_case,
            bos_token=bos_token,
            eos_token=eos_token,
            unk_token=unk_token,
            sep_token=sep_token,
            pad_token=pad_token,
            cls_token=cls_token,
            mask_token=mask_token,
            split_by_punct=split_by_punct,
            **kwargs,
        )

        # 设置对象属性,保存初始化参数的值
        self.do_lower_case = do_lower_case
        self.split_by_punct = split_by_punct
        self.vocab_file = vocab_file

    @property
    def can_save_slow_tokenizer(self) -> bool:
        # 判断词汇文件是否存在,以确定是否可以保存慢速的分词器
        return os.path.isfile(self.vocab_file) if self.vocab_file else False

    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
        """
        从一个序列或一对序列构建模型输入,用于序列分类任务,通过连接和添加特殊标记。
        DeBERTa 模型的序列格式如下:

        - 单个序列: [CLS] X [SEP]
        - 一对序列: [CLS] A [SEP] B [SEP]

        Args:
            token_ids_0 (`List[int]`):
                要添加特殊标记的 ID 列表。
            token_ids_1 (`List[int]`, *optional*):
                第二个可选的序列 ID 列表,用于序列对。

        Returns:
            `List[int]`: 包含适当特殊标记的输入 ID 列表。
        """

        if token_ids_1 is None:
            # 如果只有一个序列,则返回加上特殊标记的列表
            return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
        # 如果有两个序列,则分别构建包含特殊标记的列表并连接
        cls = [self.cls_token_id]
        sep = [self.sep_token_id]
        return cls + token_ids_0 + sep + token_ids_1 + sep
    def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
        """
        Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.

        Returns:
            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        """

        # If the tokens already have special tokens, delegate to the superclass method
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )

        # If token_ids_1 is provided, create a mask with special tokens for sequence pairs
        if token_ids_1 is not None:
            return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
        # Otherwise, create a mask with special tokens for a single sequence
        return [1] + ([0] * len(token_ids_0)) + [1]


    def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
        """
        Create a mask from the two sequences passed to be used in a sequence-pair classification task. A DeBERTa
        sequence pair mask has the following format:

        ```
        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
        | first sequence    | second sequence |
        ```

        If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
        """

        # Define special tokens for separation and classification
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]

        # If token_ids_1 is None, return a mask with only the first sequence
        if token_ids_1 is None:
            return len(cls + token_ids_0 + sep) * [0]

        # Otherwise, return a mask with special tokens for both sequences
        return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
    # 保存词汇表到指定目录下的文件中,并返回保存的文件路径
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # 如果无法保存慢速分词器的词汇表,则引发数值错误
        if not self.can_save_slow_tokenizer:
            raise ValueError(
                "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
                "tokenizer."
            )

        # 如果保存目录不存在,则记录错误并返回
        if not os.path.isdir(save_directory):
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
            return

        # 构造输出词汇表文件的路径,可以带有前缀
        out_vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )

        # 如果当前词汇表文件路径与输出路径不一致,则复制当前词汇表文件到输出路径
        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
            copyfile(self.vocab_file, out_vocab_file)

        # 返回保存的词汇表文件路径的元组
        return (out_vocab_file,)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238

.\models\deberta_v2\__init__.py

# 版权声明和许可证声明,指明代码的版权和许可条件
#
# Licensed under the Apache License, Version 2.0 (the "License");
# 根据 Apache 许可证版本 2.0 授权使用此文件
# you may not use this file except in compliance with the License.
# 除非符合许可证的规定,否则不得使用本文件
# You may obtain a copy of the License at
# 获取许可证的副本地址
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 除非法律要求或书面同意,否则本软件按"原样"提供
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 没有明示或暗示的担保或条件
# See the License for the specific language governing permissions and
# 详见许可证,了解特定语言的授权信息
# limitations under the License.
# 许可证下的限制

from typing import TYPE_CHECKING

# 导入必要的依赖模块和函数
from ...utils import (
    OptionalDependencyNotAvailable,
    _LazyModule,
    is_tf_available,
    is_tokenizers_available,
    is_torch_available,
)

# 定义模块导入结构
_import_structure = {
    "configuration_deberta_v2": ["DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaV2Config", "DebertaV2OnnxConfig"],
    "tokenization_deberta_v2": ["DebertaV2Tokenizer"],
}

# 检查 tokenizers 是否可用,若不可用则引发 OptionalDependencyNotAvailable 异常
try:
    if not is_tokenizers_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,将 tokenization_deberta_v2_fast 模块添加到导入结构中
    _import_structure["tokenization_deberta_v2_fast"] = ["DebertaV2TokenizerFast"]

# 检查 TensorFlow 是否可用,若不可用则引发 OptionalDependencyNotAvailable 异常
try:
    if not is_tf_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,将 modeling_tf_deberta_v2 模块添加到导入结构中
    _import_structure["modeling_tf_deberta_v2"] = [
        "TF_DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST",
        "TFDebertaV2ForMaskedLM",
        "TFDebertaV2ForQuestionAnswering",
        "TFDebertaV2ForMultipleChoice",
        "TFDebertaV2ForSequenceClassification",
        "TFDebertaV2ForTokenClassification",
        "TFDebertaV2Model",
        "TFDebertaV2PreTrainedModel",
    ]

# 检查 PyTorch 是否可用,若不可用则引发 OptionalDependencyNotAvailable 异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,将 modeling_deberta_v2 模块添加到导入结构中
    _import_structure["modeling_deberta_v2"] = [
        "DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST",
        "DebertaV2ForMaskedLM",
        "DebertaV2ForMultipleChoice",
        "DebertaV2ForQuestionAnswering",
        "DebertaV2ForSequenceClassification",
        "DebertaV2ForTokenClassification",
        "DebertaV2Model",
        "DebertaV2PreTrainedModel",
    ]

# 如果是类型检查阶段,进行进一步的导入
if TYPE_CHECKING:
    # 导入配置相关的类和变量
    from .configuration_deberta_v2 import (
        DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP,
        DebertaV2Config,
        DebertaV2OnnxConfig,
    )
    # 导入 tokenizers 相关的类
    from .tokenization_deberta_v2 import DebertaV2Tokenizer

    # 检查 tokenizers 是否可用,若不可用则不导入
    try:
        if not is_tokenizers_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果可用,导入 tokenization_deberta_v2_fast 模块
        from .tokenization_deberta_v2_fast import DebertaV2TokenizerFast

    # 检查 TensorFlow 是否可用,若不可用则不导入
    try:
        if not is_tf_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    # 如果前面的条件不满足,则从当前目录下的.modeling_tf_deberta_v2模块中导入以下内容:
    from .modeling_tf_deberta_v2 import (
        TF_DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST,
        TFDebertaV2ForMaskedLM,
        TFDebertaV2ForMultipleChoice,
        TFDebertaV2ForQuestionAnswering,
        TFDebertaV2ForSequenceClassification,
        TFDebertaV2ForTokenClassification,
        TFDebertaV2Model,
        TFDebertaV2PreTrainedModel,
    )

try:
    # 尝试检查是否有torch库可用,如果不可用则引发OptionalDependencyNotAvailable异常
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    # 如果OptionalDependencyNotAvailable异常被引发,什么也不做,直接跳过
    pass
else:
    # 如果上面的try块未引发异常,则从当前目录下的.modeling_deberta_v2模块中导入以下内容:
    from .modeling_deberta_v2 import (
        DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST,
        DebertaV2ForMaskedLM,
        DebertaV2ForMultipleChoice,
        DebertaV2ForQuestionAnswering,
        DebertaV2ForSequenceClassification,
        DebertaV2ForTokenClassification,
        DebertaV2Model,
        DebertaV2PreTrainedModel,
    )
else:
    # 导入 sys 模块,用于操作 Python 解释器的系统功能
    import sys
    
    # 将当前模块注册到 sys.modules 中,使用 _LazyModule 进行延迟加载
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150

.\models\decision_transformer\configuration_decision_transformer.py

# coding=utf-8
# Copyright 2022 The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Decision Transformer model configuration"""

# 导入必要的库和模块
from ...configuration_utils import PretrainedConfig
from ...utils import logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 预训练模型及其配置文件的映射字典
DECISION_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "edbeeching/decision-transformer-gym-hopper-medium": (
        "https://huggingface.co/edbeeching/decision-transformer-gym-hopper-medium/resolve/main/config.json"
    ),
    # 查看所有 DecisionTransformer 模型,请访问 https://huggingface.co/models?filter=decision_transformer
}

# DecisionTransformerConfig 类,用于存储 DecisionTransformer 模型的配置信息
class DecisionTransformerConfig(PretrainedConfig):
    """
    This is the configuration class to store the configuration of a [`DecisionTransformerModel`]. It is used to
    instantiate a Decision Transformer model according to the specified arguments, defining the model architecture.
    Instantiating a configuration with the defaults will yield a similar configuration to that of the standard
    DecisionTransformer architecture. Many of the config options are used to instatiate the GPT2 model that is used as
    part of the architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Example:

    ```
    >>> from transformers import DecisionTransformerConfig, DecisionTransformerModel

    >>> # Initializing a DecisionTransformer configuration
    >>> configuration = DecisionTransformerConfig()

    >>> # Initializing a model (with random weights) from the configuration
    >>> model = DecisionTransformerModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
    """

    # 模型类型
    model_type = "decision_transformer"
    # 推理时忽略的键列表
    keys_to_ignore_at_inference = ["past_key_values"]
    # 属性映射字典,用于调整配置
    attribute_map = {
        "max_position_embeddings": "n_positions",
        "num_attention_heads": "n_head",
        "num_hidden_layers": "n_layer",
    }
    # 初始化函数,用于设置模型的各种参数和配置
    def __init__(
        self,
        state_dim=17,  # 状态维度,默认为17
        act_dim=4,  # 动作维度,默认为4
        hidden_size=128,  # 隐藏层大小,默认为128
        max_ep_len=4096,  # 最大的 episode 长度,默认为4096
        action_tanh=True,  # 是否对动作进行 tanh 处理,默认为True
        vocab_size=1,  # 词汇表大小,默认为1
        n_positions=1024,  # 序列位置编码的最大长度,默认为1024
        n_layer=3,  # Transformer 层数,默认为3
        n_head=1,  # 自注意力机制中的头数,默认为1
        n_inner=None,  # FeedForward 层中间层的维度,默认为None
        activation_function="relu",  # 激活函数,默认为 relu
        resid_pdrop=0.1,  # 残差连接中的 dropout 概率,默认为0.1
        embd_pdrop=0.1,  # Embedding 层的 dropout 概率,默认为0.1
        attn_pdrop=0.1,  # 注意力机制中的 dropout 概率,默认为0.1
        layer_norm_epsilon=1e-5,  # Layer Normalization 中的 epsilon,默认为1e-5
        initializer_range=0.02,  # 参数初始化范围,默认为0.02
        scale_attn_weights=True,  # 是否对注意力权重进行缩放,默认为True
        use_cache=True,  # 是否使用缓存,默认为True
        bos_token_id=50256,  # 起始 token 的 id,默认为50256
        eos_token_id=50256,  # 结束 token 的 id,默认为50256
        scale_attn_by_inverse_layer_idx=False,  # 是否根据逆层索引缩放注意力,默认为False
        reorder_and_upcast_attn=False,  # 是否重新排序并提升注意力,默认为False
        **kwargs,
    ):
        self.state_dim = state_dim  # 初始化模型的状态维度
        self.act_dim = act_dim  # 初始化模型的动作维度
        self.hidden_size = hidden_size  # 初始化模型的隐藏层大小
        self.max_ep_len = max_ep_len  # 初始化模型的最大 episode 长度
        self.action_tanh = action_tanh  # 初始化模型的动作是否经过 tanh 处理
        self.vocab_size = vocab_size  # 初始化模型的词汇表大小
        self.n_positions = n_positions  # 初始化模型的序列位置编码的最大长度
        self.n_layer = n_layer  # 初始化模型的 Transformer 层数
        self.n_head = n_head  # 初始化模型的自注意力机制中的头数
        self.n_inner = n_inner  # 初始化模型的 FeedForward 层中间层的维度
        self.activation_function = activation_function  # 初始化模型的激活函数
        self.resid_pdrop = resid_pdrop  # 初始化模型的残差连接中的 dropout 概率
        self.embd_pdrop = embd_pdrop  # 初始化模型的 Embedding 层的 dropout 概率
        self.attn_pdrop = attn_pdrop  # 初始化模型的注意力机制中的 dropout 概率
        self.layer_norm_epsilon = layer_norm_epsilon  # 初始化模型的 Layer Normalization 中的 epsilon
        self.initializer_range = initializer_range  # 初始化模型的参数初始化范围
        self.scale_attn_weights = scale_attn_weights  # 初始化模型的是否对注意力权重进行缩放
        self.use_cache = use_cache  # 初始化模型的是否使用缓存
        self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx  # 初始化模型的是否根据逆层索引缩放注意力
        self.reorder_and_upcast_attn = reorder_and_upcast_attn  # 初始化模型的是否重新排序并提升注意力

        self.bos_token_id = bos_token_id  # 初始化模型的起始 token 的 id
        self.eos_token_id = eos_token_id  # 初始化模型的结束 token 的 id

        super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)  # 调用父类的初始化函数,并传递参数
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122

.\models\decision_transformer\modeling_decision_transformer.py

# coding=utf-8
# Copyright 2022 The HuggingFace Team The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch DecisionTransformer model."""

import math
import os
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.cuda.amp import autocast

# 导入激活函数映射表
from ...activations import ACT2FN
# 导入模型输出的基类,包含过去注意力和交叉注意力
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
# 导入预训练模型的基类
from ...modeling_utils import PreTrainedModel
# 导入与PyTorch相关的实用工具
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
# 导入通用的模型输出类型
from ...utils import (
    ModelOutput,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
# 导入决策Transformer的配置文件类
from .configuration_decision_transformer import DecisionTransformerConfig

# 获取日志记录器
logger = logging.get_logger(__name__)

# 用于文档的检查点名称
_CHECKPOINT_FOR_DOC = "edbeeching/decision-transformer-gym-hopper-medium"
# 用于文档的配置文件名称
_CONFIG_FOR_DOC = "DecisionTransformerConfig"

# 决策Transformer预训练模型的存档列表
DECISION_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "edbeeching/decision-transformer-gym-hopper-medium",
    # 可以查看所有决策Transformer模型的列表
    # https://huggingface.co/models?filter=decision_transformer
]


# 从transformers.models.gpt2.modeling_gpt2.load_tf_weights_in_gpt2中复制而来
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
    """Load tf checkpoints in a pytorch model"""
    try:
        import re

        import tensorflow as tf
    except ImportError:
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
        raise
    # 获取TF检查点的绝对路径
    tf_path = os.path.abspath(gpt2_checkpoint_path)
    logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
    # 从TF模型加载权重
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
        logger.info(f"Loading TF weight {name} with shape {shape}")
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
        arrays.append(array.squeeze())
    # 对于每个名字和数组的组合,执行以下操作
    for name, array in zip(names, arrays):
        # 去掉名字中的"model/"前缀
        name = name[6:]  # skip "model/"
        # 使用斜杠分割名字
        name = name.split("/")
        # 初始化指针为模型本身
        pointer = model
        # 遍历名字中的每个部分
        for m_name in name:
            # 如果名字匹配字母+数字的模式
            if re.fullmatch(r"[A-Za-z]+\d+", m_name):
                # 按数字分割名字
                scope_names = re.split(r"(\d+)", m_name)
            else:
                # 否则将整个名字作为列表中的一个元素
                scope_names = [m_name]
            # 根据第一个部分选择不同的属性
            if scope_names[0] == "w" or scope_names[0] == "g":
                pointer = getattr(pointer, "weight")
            elif scope_names[0] == "b":
                pointer = getattr(pointer, "bias")
            elif scope_names[0] == "wpe" or scope_names[0] == "wte":
                # 处理"wpe"或"wte"的情况
                pointer = getattr(pointer, scope_names[0])
                pointer = getattr(pointer, "weight")
            else:
                # 根据名字的第一个部分选择属性
                pointer = getattr(pointer, scope_names[0])
            # 如果名字有第二个部分,则选择对应索引的元素
            if len(scope_names) >= 2:
                num = int(scope_names[1])
                pointer = pointer[num]
        try:
            # 检查指针的形状是否与数组的形状匹配
            if pointer.shape != array.shape:
                raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
        except ValueError as e:
            # 如果形状不匹配,将详细信息添加到异常中并重新抛出
            e.args += (pointer.shape, array.shape)
            raise
        # 记录初始化操作信息
        logger.info(f"Initialize PyTorch weight {name}")
        # 将数组转换为PyTorch张量,并赋值给指针的数据属性
        pointer.data = torch.from_numpy(array)
    # 返回处理后的模型
    return model
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Attention with GPT2->DecisionTransformerGPT2
class DecisionTransformerGPT2Attention(nn.Module):
    def __init__(self, config, is_cross_attention=False, layer_idx=None):
        super().__init__()

        # 初始化注意事项
        max_positions = config.max_position_embeddings
        # 注册缓冲区并生成一个下三角形状的布尔张量作为注意力偏置
        self.register_buffer(
            "bias",
            torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
                1, 1, max_positions, max_positions
            ),
            persistent=False,
        )
        # 注册缓冲区并设置掩码偏置
        self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)

        self.embed_dim = config.hidden_size  # 嵌入维度大小
        self.num_heads = config.num_attention_heads  # 注意力头的数量
        self.head_dim = self.embed_dim // self.num_heads  # 每个注意力头的维度
        self.split_size = self.embed_dim  # 分割后的大小
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads})."
            )

        self.scale_attn_weights = config.scale_attn_weights  # 注意力权重缩放
        self.is_cross_attention = is_cross_attention  # 是否是交叉注意力

        # 层级注意力权重缩放、重排序和向上转换
        self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
        self.layer_idx = layer_idx
        self.reorder_and_upcast_attn = config.reorder_and_upcast_attn

        if self.is_cross_attention:
            # 如果是交叉注意力,创建交叉注意力层和查询注意力层
            self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
            self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
        else:
            # 如果不是交叉注意力,创建常规注意力层
            self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
        self.c_proj = Conv1D(self.embed_dim, self.embed_dim)  # 创建投影层

        self.attn_dropout = nn.Dropout(config.attn_pdrop)  # 注意力丢弃率
        self.resid_dropout = nn.Dropout(config.resid_pdrop)  # 残差丢弃率

        self.pruned_heads = set()  # 初始化被修剪的注意力头集合

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
        index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])

        # 对 conv1d 层进行修剪
        self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
        self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)

        # 更新超参数
        self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
        self.num_heads = self.num_heads - len(heads)
        self.pruned_heads = self.pruned_heads.union(heads)
    # 计算注意力权重,通过查询和键的矩阵乘法得到
    attn_weights = torch.matmul(query, key.transpose(-1, -2))

    # 如果设置了缩放注意力权重标志,则对注意力权重进行缩放
    if self.scale_attn_weights:
        attn_weights = attn_weights / torch.full(
            [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
        )

    # 如果设置了按逆层索引缩放注意力权重,则对注意力权重进行额外缩放
    if self.scale_attn_by_inverse_layer_idx:
        attn_weights = attn_weights / float(self.layer_idx + 1)

    # 如果不是交叉注意力,实现因果屏蔽
    if not self.is_cross_attention:
        # 获取查询和键的长度
        query_length, key_length = query.size(-2), key.size(-2)
        # 生成因果屏蔽掩码
        causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
        # 设定掩码的值为极小值,用于屏蔽不需要的位置
        mask_value = torch.finfo(attn_weights.dtype).min
        # 创建与注意力权重相同类型和设备的掩码张量
        mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
        # 将因果屏蔽应用于注意力权重
        attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)

    # 如果存在注意力掩码,则将其应用于注意力权重
    if attention_mask is not None:
        attn_weights = attn_weights + attention_mask

    # 对注意力权重进行 softmax 归一化
    attn_weights = nn.functional.softmax(attn_weights, dim=-1)

    # 将注意力权重的数据类型降回到 value 张量的数据类型(如果使用了混合精度)
    attn_weights = attn_weights.type(value.dtype)

    # 应用注意力 dropout
    attn_weights = self.attn_dropout(attn_weights)

    # 如果需要,对注意力权重应用头部掩码
    if head_mask is not None:
        attn_weights = attn_weights * head_mask

    # 计算最终的注意力输出
    attn_output = torch.matmul(attn_weights, value)

    # 返回注意力输出和注意力权重
    return attn_output, attn_weights
    # 将 query, key, value 和 attention_mask(如果存在)按照指定的方式进行上转型和重新排序,并计算注意力权重
    def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
        # 获取 query 的维度信息:batch size, num_heads, query sequence length, key dimension
        bsz, num_heads, q_seq_len, dk = query.size()
        # 获取 key 的维度信息:batch size, num_heads, key sequence length, key dimension
        _, _, k_seq_len, _ = key.size()

        # 为 `baddbmm` 预先分配注意力权重张量
        attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)

        # 计算注意力权重的缩放因子
        scale_factor = 1.0
        if self.scale_attn_weights:
            scale_factor /= float(value.size(-1)) ** 0.5

        if self.scale_attn_by_inverse_layer_idx:
            scale_factor /= float(self.layer_idx + 1)

        # 关闭自动混合精度并上转型和重新排序 (Scale K by 1 / root(dk))
        with autocast(enabled=False):
            # 将 query 转换为形状为 (-1, q_seq_len, dk) 的张量
            q = query.reshape(-1, q_seq_len, dk)
            # 将 key 转置并重塑为形状为 (-1, dk, k_seq_len) 的张量
            k = key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
            # 使用 `torch.baddbmm` 计算加权和,注意力权重使用缩放因子进行缩放
            attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
            # 将注意力权重张量重塑为形状为 (bsz, num_heads, q_seq_len, k_seq_len)
            attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)

        # 如果不是跨注意力(cross-attention),实现因果掩码
        if not self.is_cross_attention:
            # 获取 query 和 key 的长度
            query_length, key_length = query.size(-2), key.size(-2)
            # 创建因果掩码,限制只能看到过去的信息
            causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
            # 计算用于掩码的最小值,确保张量的类型和设备一致
            mask_value = torch.finfo(attn_weights.dtype).min
            mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
            # 根据因果掩码应用掩码操作
            attn_weights = torch.where(causal_mask, attn_weights, mask_value)

        # 如果存在注意力掩码,则应用该掩码
        if attention_mask is not None:
            attn_weights = attn_weights + attention_mask

        # 对注意力权重张量进行 softmax 操作,以获得归一化的注意力分布
        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        # 将注意力权重张量转换回 value 张量的数据类型(如果需要)
        if attn_weights.dtype != torch.float32:
            raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
        attn_weights = attn_weights.type(value.dtype)

        # 应用注意力 dropout 操作
        attn_weights = self.attn_dropout(attn_weights)

        # 如果需要,对注意力权重应用头部掩码
        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        # 计算最终的注意力输出,通过注意力权重与 value 的乘积得到
        attn_output = torch.matmul(attn_weights, value)

        # 返回注意力输出和注意力权重张量
        return attn_output, attn_weights

    # 将张量按照给定的方式进行分割为多个头部
    def _split_heads(self, tensor, num_heads, attn_head_size):
        """
        Splits hidden_size dim into attn_head_size and num_heads
        """
        # 计算新的张量形状,将 hidden_size 维度分割为 num_heads 和 attn_head_size
        new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
        # 重新调整张量形状,并交换维度以符合注意力头部的分割要求
        tensor = tensor.view(new_shape)
        tensor = tensor.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)
        return tensor
    def _merge_heads(self, tensor, num_heads, attn_head_size):
        """
        Merges attn_head_size dim and num_attn_heads dim into hidden_size
        """
        # 交换张量的维度顺序,将注意力头和头数的维度合并到隐藏层维度中
        tensor = tensor.permute(0, 2, 1, 3).contiguous()
        # 计算新的形状,将注意力头和头数维度合并成新的隐藏层维度
        new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
        # 重新视图张量以适应新形状
        return tensor.view(new_shape)

    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
        if encoder_hidden_states is not None:
            if not hasattr(self, "q_attn"):
                # 如果作为跨注意力使用,则必须定义权重 `q_attn`,否则引发错误
                raise ValueError(
                    "If class is used as cross attention, the weights `q_attn` have to be defined. "
                    "Please make sure to instantiate class with `DecisionTransformerGPT2Attention(..., is_cross_attention=True)`."
                )

            # 使用 self.q_attn 处理隐藏状态以生成查询张量
            query = self.q_attn(hidden_states)
            # 使用 self.c_attn 处理编码器隐藏状态以生成键和值张量,并按指定维度分割
            key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
            # 使用编码器的注意力掩码
            attention_mask = encoder_attention_mask
        else:
            # 使用 self.c_attn 处理隐藏状态以生成查询、键和值张量,并按指定维度分割
            query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

        # 将查询张量按头数和头维度分割
        query = self._split_heads(query, self.num_heads, self.head_dim)
        # 将键张量按头数和头维度分割
        key = self._split_heads(key, self.num_heads, self.head_dim)
        # 将值张量按头数和头维度分割
        value = self._split_heads(value, self.num_heads, self.head_dim)

        # 如果存在过去的层状态,将过去的键和值与当前的键和值拼接在一起
        if layer_past is not None:
            past_key, past_value = layer_past
            key = torch.cat((past_key, key), dim=-2)
            value = torch.cat((past_value, value), dim=-2)

        # 如果使用缓存,保存当前的键和值
        if use_cache is True:
            present = (key, value)
        else:
            present = None

        # 如果需要重新排序和向上转型的注意力机制
        if self.reorder_and_upcast_attn:
            # 使用特定方法处理注意力机制,得到注意力输出和注意力权重
            attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
        else:
            # 使用默认的注意力方法处理注意力机制,得到注意力输出和注意力权重
            attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

        # 将注意力输出按头数和头维度合并成隐藏层维度
        attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
        # 使用投影层处理注意力输出
        attn_output = self.c_proj(attn_output)
        # 应用残差连接和dropout到注意力输出
        attn_output = self.resid_dropout(attn_output)

        # 输出包括注意力输出和可能的 present
        outputs = (attn_output, present)
        # 如果需要输出注意力权重,也将其加入到输出中
        if output_attentions:
            outputs += (attn_weights,)

        # 返回最终的输出
        return outputs  # a, present, (attentions)
# 从transformers.models.gpt2.modeling_gpt2.GPT2MLP复制代码,将GPT2改为DecisionTransformerGPT2
class DecisionTransformerGPT2MLP(nn.Module):
    def __init__(self, intermediate_size, config):
        super().__init__()
        embed_dim = config.hidden_size
        # 定义一个一维卷积层,输入维度为embed_dim,输出维度为intermediate_size
        self.c_fc = Conv1D(intermediate_size, embed_dim)
        # 定义另一个一维卷积层,输入维度为intermediate_size,输出维度为embed_dim
        self.c_proj = Conv1D(embed_dim, intermediate_size)
        # 设置激活函数为配置中指定的激活函数类型对应的函数
        self.act = ACT2FN[config.activation_function]
        # 设置dropout层,丢弃概率为config.resid_pdrop
        self.dropout = nn.Dropout(config.resid_pdrop)

    def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
        # 应用第一个卷积层
        hidden_states = self.c_fc(hidden_states)
        # 应用激活函数
        hidden_states = self.act(hidden_states)
        # 应用第二个卷积层
        hidden_states = self.c_proj(hidden_states)
        # 应用dropout层
        hidden_states = self.dropout(hidden_states)
        return hidden_states


# 从transformers.models.gpt2.modeling_gpt2.GPT2Block复制代码,将GPT2改为DecisionTransformerGPT2
class DecisionTransformerGPT2Block(nn.Module):
    def __init__(self, config, layer_idx=None):
        super().__init__()
        hidden_size = config.hidden_size
        # 初始化LayerNorm层,输入维度为hidden_size,eps为config.layer_norm_epsilon
        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        # 初始化DecisionTransformerGPT2Attention层
        self.attn = DecisionTransformerGPT2Attention(config, layer_idx=layer_idx)
        # 初始化LayerNorm层,输入维度为hidden_size,eps为config.layer_norm_epsilon
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)

        if config.add_cross_attention:
            # 如果配置中指定添加跨注意力层,则初始化DecisionTransformerGPT2Attention层作为跨注意力层
            self.crossattention = DecisionTransformerGPT2Attention(
                config, is_cross_attention=True, layer_idx=layer_idx
            )
            # 初始化LayerNorm层,输入维度为hidden_size,eps为config.layer_norm_epsilon
            self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)

        # 初始化DecisionTransformerGPT2MLP层
        self.mlp = DecisionTransformerGPT2MLP(inner_dim, config)

    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
        # 定义函数的返回类型,可以返回包含 torch.Tensor 的元组或者包含可选元组的 Union
        residual = hidden_states
        # 应用 LayerNormalization,归一化隐藏状态
        hidden_states = self.ln_1(hidden_states)
        # 使用 self.attn 处理注意力机制
        attn_outputs = self.attn(
            hidden_states,
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )
        # 提取注意力输出的第一个元素,即注意力的输出
        attn_output = attn_outputs[0]  # output_attn: a, present, (attentions)
        # 提取除了第一个元素外的所有输出,作为其他输出
        outputs = attn_outputs[1:]
        # 残差连接,将注意力输出与原始隐藏状态相加
        hidden_states = attn_output + residual

        if encoder_hidden_states is not None:
            # 如果传入了 encoder_hidden_states,则进行交叉注意力处理
            if not hasattr(self, "crossattention"):
                # 如果模型未配置交叉注意力层,则引发错误
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
                    "cross-attention layers by setting `config.add_cross_attention=True`"
                )
            residual = hidden_states
            # 应用交叉注意力层前的 LayerNormalization
            hidden_states = self.ln_cross_attn(hidden_states)
            # 使用 self.crossattention 进行交叉注意力计算
            cross_attn_outputs = self.crossattention(
                hidden_states,
                attention_mask=attention_mask,
                head_mask=head_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                output_attentions=output_attentions,
            )
            # 提取交叉注意力输出的第一个元素
            attn_output = cross_attn_outputs[0]
            # 残差连接,将交叉注意力输出与之前的隐藏状态相加
            hidden_states = residual + attn_output
            # 将交叉注意力输出的其他部分添加到已有的 outputs 中,如果输出了注意力权重
            outputs = outputs + cross_attn_outputs[2:]  # add cross attentions if we output attention weights

        residual = hidden_states
        # 应用 LayerNormalization
        hidden_states = self.ln_2(hidden_states)
        # 应用 MLP(Feed Forward)层
        feed_forward_hidden_states = self.mlp(hidden_states)
        # 残差连接,将 MLP 层的输出与原始隐藏状态相加
        hidden_states = residual + feed_forward_hidden_states

        if use_cache:
            # 如果需要缓存,则将隐藏状态和其他输出组成一个元组返回
            outputs = (hidden_states,) + outputs
        else:
            # 否则,只返回隐藏状态和除第一个元素外的其他输出
            outputs = (hidden_states,) + outputs[1:]

        return outputs  # 返回隐藏状态、present、(attentions, cross_attentions)
class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # 使用 DecisionTransformerConfig 作为配置类
    config_class = DecisionTransformerConfig
    # 使用 load_tf_weights_in_gpt2 函数加载 TensorFlow 权重
    load_tf_weights = load_tf_weights_in_gpt2
    # 基础模型前缀
    base_model_prefix = "transformer"
    # 可并行化处理
    is_parallelizable = True
    # 支持梯度检查点
    supports_gradient_checkpointing = True

    def __init__(self, *inputs, **kwargs):
        super().__init__(*inputs, **kwargs)

    def _init_weights(self, module):
        """Initialize the weights."""
        if isinstance(module, (nn.Linear, Conv1D)):
            # 初始化线性层和一维卷积层的权重
            # 与 TF 版本略有不同,TF 版本使用截断正态分布进行初始化
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            # 初始化嵌入层的权重
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            # 初始化 LayerNorm 层的偏置和权重
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

        # 根据 OpenAI GPT-2 论文的方案重新初始化选定的权重:
        #   > 修改的初始化方法考虑到了模型深度中残差路径的累积。在初始化时,通过因子 1/√N 缩放残差层的权重,
        #   > 其中 N 是残差层数量。
        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
        #
        # 参考 (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
        for name, p in module.named_parameters():
            if "c_proj" in name and "weight" in name:
                # 特殊的缩放初始化 --> 每个 Transformer 块中有 2 个 Layer Norm
                p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))


class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.embed_dim = config.hidden_size

        # 词嵌入层和位置编码层的初始化
        self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)

        # Dropout 层的初始化
        self.drop = nn.Dropout(config.embd_pdrop)

        # Transformer 块的初始化
        self.h = nn.ModuleList(
            [DecisionTransformerGPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
        )

        # 最终的 LayerNorm 层的初始化
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

        # 模型并行
        self.model_parallel = False
        self.device_map = None
        self.gradient_checkpointing = False

        # 初始化权重并应用最终处理
        self.post_init()
    # 获取输入的词嵌入(词向量)矩阵
    def get_input_embeddings(self):
        return self.wte

    # 设置输入的词嵌入(词向量)矩阵为新的嵌入矩阵
    def set_input_embeddings(self, new_embeddings):
        self.wte = new_embeddings

    # 从transformers库中GPT2Model类的forward方法复制而来
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
# 为决策变换器模型输出定义一个数据类,继承自模型输出基类
@dataclass
class DecisionTransformerOutput(ModelOutput):
    """
    Base class for model's outputs that also contains a pooling of the last hidden states.
    
    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        state_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, state_dim)`):
            Environment state predictions
        action_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, action_dim)`):
            Model action predictions
        return_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, 1)`):
            Predicted returns for each state
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """
    
    # 环境状态预测
    state_preds: torch.FloatTensor = None
    # 模型动作预测
    action_preds: torch.FloatTensor = None
    # 对每个状态的预测返回
    return_preds: torch.FloatTensor = None
    # 模型隐藏状态
    hidden_states: torch.FloatTensor = None
    # 注意力权重
    attentions: torch.FloatTensor = None
    # 最后一层隐藏状态
    last_hidden_state: torch.FloatTensor = None


# 决策变换器预训练模型的抽象类,处理权重初始化、预训练模型下载和加载的简单接口
class DecisionTransformerPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """
    
    # 决策变换器配置类
    config_class = DecisionTransformerConfig
    # 基础模型前缀
    base_model_prefix = "decision_transformer"
    # 主输入名称
    main_input_name = "states"
    # 是否支持梯度检查点
    supports_gradient_checkpointing = False
    def _init_weights(self, module):
        """Initialize the weights"""
        # 如果是线性层
        if isinstance(module, nn.Linear):
            # 使用正态分布初始化权重,均值为0,标准差为配置中的初始化范围
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # 如果有偏置项,将其初始化为0
            if module.bias is not None:
                module.bias.data.zero_()
        # 如果是嵌入层
        elif isinstance(module, nn.Embedding):
            # 使用正态分布初始化权重,均值为0,标准差为配置中的初始化范围
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # 如果定义了填充索引,将填充索引位置的权重初始化为0
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        # 如果是层归一化层
        elif isinstance(module, nn.LayerNorm):
            # 将偏置项初始化为0
            module.bias.data.zero_()
            # 将权重初始化为1
            module.weight.data.fill_(1.0)
# 决策变换器模型的文档字符串,描述了这是一个 PyTorch 的子类模块,可作为常规的 PyTorch 模块使用。建议参考 PyTorch 文档以获取有关通用用法和行为的详细信息。
DECISION_TRANSFORMER_START_DOCSTRING = r"""
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`~DecisionTransformerConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

# 决策变换器模型的输入文档字符串,描述了模型的输入参数及其形状。
DECISION_TRANSFORMER_INPUTS_DOCSTRING = r"""
    Args:
        states (`torch.FloatTensor` of shape `(batch_size, episode_length, state_dim)`):
            The states for each step in the trajectory
        actions (`torch.FloatTensor` of shape `(batch_size, episode_length, act_dim)`):
            The actions taken by the "expert" policy for the current state, these are masked for auto regressive
            prediction
        rewards (`torch.FloatTensor` of shape `(batch_size, episode_length, 1)`):
            The rewards for each state, action
        returns_to_go (`torch.FloatTensor` of shape `(batch_size, episode_length, 1)`):
            The returns for each state in the trajectory
        timesteps (`torch.LongTensor` of shape `(batch_size, episode_length)`):
            The timestep for each step in the trajectory
        attention_mask (`torch.FloatTensor` of shape `(batch_size, episode_length)`):
            Masking, used to mask the actions when performing autoregressive prediction
"""

# 通过装饰器 @add_start_docstrings 将决策变换器模型的文档字符串和起始描述串联接起来,用以说明决策变换器模型的作用和功能。
@add_start_docstrings("The Decision Transformer Model", DECISION_TRANSFORMER_START_DOCSTRING)
class DecisionTransformerModel(DecisionTransformerPreTrainedModel):
    """
    The model builds upon the GPT2 architecture to perform autoregressive prediction of actions in an offline RL
    setting. Refer to the paper for more details: https://arxiv.org/abs/2106.01345
    """
    # 初始化函数,接受一个配置对象作为参数
    def __init__(self, config):
        # 调用父类的初始化方法,传入配置对象
        super().__init__(config)
        # 将配置对象保存在实例中
        self.config = config
        # 设置隐藏层大小为配置对象中指定的隐藏层大小
        self.hidden_size = config.hidden_size

        # 创建一个 DecisionTransformerGPT2Model 实例作为编码器
        # 注意:与 Huggingface 默认版本唯一的区别是移除了位置嵌入(因为我们将自己添加)
        self.encoder = DecisionTransformerGPT2Model(config)

        # 创建嵌入层,用于不同类型的输入
        self.embed_timestep = nn.Embedding(config.max_ep_len, config.hidden_size)
        self.embed_return = torch.nn.Linear(1, config.hidden_size)
        self.embed_state = torch.nn.Linear(config.state_dim, config.hidden_size)
        self.embed_action = torch.nn.Linear(config.act_dim, config.hidden_size)

        # LayerNorm 层,用于标准化隐藏层表示
        self.embed_ln = nn.LayerNorm(config.hidden_size)

        # 不预测状态或回报值(根据论文设定)
        
        # 线性层,用于预测状态
        self.predict_state = torch.nn.Linear(config.hidden_size, config.state_dim)
        # 序列模块,用于预测动作
        self.predict_action = nn.Sequential(
            *([nn.Linear(config.hidden_size, config.act_dim)] + ([nn.Tanh()] if config.action_tanh else []))
        )
        # 线性层,用于预测回报值
        self.predict_return = torch.nn.Linear(config.hidden_size, 1)

        # 初始化权重并应用最终处理
        self.post_init()

    # 前向传播函数,接受多个输入参数并返回一个输出
    @add_start_docstrings_to_model_forward(DECISION_TRANSFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=DecisionTransformerOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        states: Optional[torch.FloatTensor] = None,
        actions: Optional[torch.FloatTensor] = None,
        rewards: Optional[torch.FloatTensor] = None,
        returns_to_go: Optional[torch.FloatTensor] = None,
        timesteps: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        output_hidden_states: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        return_dict: Optional[bool] = None,
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345
  • 346
  • 347
  • 348
  • 349
  • 350
  • 351
  • 352
  • 353
  • 354
  • 355
  • 356
  • 357
  • 358
  • 359
  • 360
  • 361
  • 362
  • 363
  • 364
  • 365
  • 366
  • 367
  • 368
  • 369
  • 370
  • 371
  • 372
  • 373
  • 374
  • 375
  • 376
  • 377
  • 378
  • 379
  • 380
  • 381
  • 382
  • 383
  • 384
  • 385
  • 386
  • 387
  • 388
  • 389
  • 390
  • 391
  • 392
  • 393
  • 394
  • 395
  • 396
  • 397
  • 398
  • 399
  • 400
  • 401
  • 402
  • 403
  • 404
  • 405
  • 406
  • 407
  • 408
  • 409
  • 410
  • 411
  • 412
  • 413
  • 414
  • 415
  • 416
  • 417
  • 418
  • 419
  • 420
  • 421
  • 422
  • 423
  • 424
  • 425
  • 426
  • 427
  • 428
  • 429
  • 430
  • 431
  • 432
  • 433
  • 434
  • 435
  • 436
  • 437
  • 438
  • 439
  • 440
  • 441
  • 442
  • 443
  • 444
  • 445
  • 446
  • 447
  • 448
  • 449
  • 450
  • 451
  • 452
  • 453
  • 454
  • 455
  • 456
  • 457
  • 458
  • 459
  • 460
  • 461
  • 462
  • 463
  • 464
  • 465
  • 466
  • 467
  • 468
  • 469
  • 470
  • 471
  • 472
  • 473
  • 474
  • 475
  • 476
  • 477
  • 478
  • 479
  • 480
  • 481
  • 482
  • 483
  • 484
  • 485
  • 486
  • 487
  • 488
  • 489
  • 490
  • 491
  • 492
  • 493
  • 494
  • 495
  • 496
  • 497
  • 498
  • 499
  • 500
  • 501
  • 502
  • 503
  • 504
  • 505
  • 506
  • 507
  • 508
  • 509
  • 510
  • 511
  • 512
  • 513
  • 514
  • 515
  • 516
  • 517
  • 518
  • 519
  • 520
  • 521
  • 522
  • 523
  • 524
  • 525
  • 526
  • 527
  • 528
  • 529
  • 530
  • 531
  • 532
  • 533
  • 534
  • 535
  • 536
  • 537
  • 538
  • 539
  • 540
  • 541
  • 542
  • 543
  • 544
  • 545
  • 546
  • 547
  • 548
  • 549
  • 550
  • 551
  • 552
  • 553
  • 554
  • 555
  • 556
  • 557
  • 558
  • 559
  • 560
  • 561
  • 562
  • 563
  • 564
  • 565
  • 566
  • 567
  • 568
  • 569
  • 570
  • 571
  • 572
  • 573
  • 574
  • 575
  • 576
  • 577
  • 578
  • 579
  • 580
  • 581
  • 582
  • 583
  • 584
  • 585
  • 586
  • 587
  • 588
  • 589
  • 590
  • 591
  • 592
  • 593
  • 594
  • 595
  • 596
  • 597
  • 598
  • 599
  • 600
  • 601
  • 602
  • 603
  • 604
  • 605
  • 606
  • 607
  • 608
  • 609
  • 610
  • 611
  • 612
  • 613
  • 614
  • 615
  • 616
  • 617
  • 618
  • 619
  • 620
  • 621
  • 622
  • 623
  • 624
  • 625
  • 626
  • 627
  • 628
  • 629
  • 630
  • 631
  • 632
  • 633
  • 634
  • 635
  • 636
  • 637
  • 638
  • 639
  • 640
  • 641
  • 642
  • 643
  • 644
  • 645
  • 646
  • 647
  • 648
  • 649
  • 650
  • 651
  • 652
  • 653
  • 654
  • 655
  • 656
  • 657
  • 658
  • 659
  • 660
  • 661
  • 662
  • 663
  • 664
  • 665
  • 666
  • 667
  • 668
  • 669
  • 670
  • 671
  • 672
  • 673
  • 674
  • 675
  • 676
  • 677
  • 678
  • 679
  • 680
  • 681
  • 682
  • 683
  • 684
  • 685
  • 686
  • 687
  • 688
  • 689
  • 690
  • 691
  • 692
  • 693
  • 694
  • 695
  • 696
  • 697
  • 698
  • 699
  • 700
  • 701
  • 702
  • 703
  • 704
  • 705
  • 706
  • 707
  • 708
  • 709
  • 710
  • 711
  • 712
  • 713
  • 714
  • 715
  • 716
  • 717
  • 718
  • 719
  • 720
  • 721
  • 722
  • 723
  • 724
  • 725
  • 726
  • 727
  • 728
  • 729
  • 730
  • 731
  • 732
  • 733
  • 734
  • 735
  • 736
  • 737
  • 738
  • 739
  • 740
  • 741
  • 742
  • 743
  • 744
  • 745
  • 746
  • 747
  • 748
  • 749
  • 750
  • 751
  • 752
  • 753
  • 754
  • 755
  • 756
  • 757
  • 758
  • 759
  • 760
  • 761
  • 762
  • 763
  • 764
  • 765
  • 766
  • 767
  • 768
  • 769
  • 770
  • 771
  • 772
  • 773
  • 774
  • 775
  • 776
  • 777
  • 778
  • 779
  • 780
  • 781
  • 782
  • 783
  • 784
  • 785
  • 786
  • 787
  • 788
  • 789
  • 790
  • 791
  • 792
  • 793
  • 794
  • 795

.\models\decision_transformer\__init__.py

# 版权声明和许可证信息,说明此文件的版权归HuggingFace团队所有,并遵循Apache License 2.0许可
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 引入TYPE_CHECKING用于静态类型检查
from typing import TYPE_CHECKING

# 从utils模块导入OptionalDependencyNotAvailable、_LazyModule和is_torch_available函数
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available

# 定义模块的导入结构,包括configuration_decision_transformer模块的部分内容
_import_structure = {
    "configuration_decision_transformer": [
        "DECISION_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
        "DecisionTransformerConfig",
    ],
}

# 检查是否torch可用,如果不可用则抛出OptionalDependencyNotAvailable异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果torch可用,则扩展_import_structure添加modeling_decision_transformer模块的内容
    _import_structure["modeling_decision_transformer"] = [
        "DECISION_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
        "DecisionTransformerGPT2Model",
        "DecisionTransformerGPT2PreTrainedModel",
        "DecisionTransformerModel",
        "DecisionTransformerPreTrainedModel",
    ]

# 如果正在进行类型检查
if TYPE_CHECKING:
    # 从configuration_decision_transformer模块导入特定内容,包括DECISION_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP和DecisionTransformerConfig
    from .configuration_decision_transformer import (
        DECISION_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
        DecisionTransformerConfig,
    )

    # 再次检查torch是否可用,如果不可用则跳过
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 从modeling_decision_transformer模块导入特定内容,包括DECISION_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST和多个DecisionTransformer类
        from .modeling_decision_transformer import (
            DECISION_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
            DecisionTransformerGPT2Model,
            DecisionTransformerGPT2PreTrainedModel,
            DecisionTransformerModel,
            DecisionTransformerPreTrainedModel,
        )

# 如果不是在进行类型检查
else:
    # 导入sys模块
    import sys

    # 将当前模块设置为_LazyModule,使用_LazyModule延迟加载模块内容
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/天景科技苑/article/detail/865940
推荐阅读
相关标签
  

闽ICP备14008679号