当前位置:   article > 正文

Transformers 源码解析(二百九十一)_autoquantizationconfig

autoquantizationconfig

.\quantizers\__init__.py

# 导入自动量化相关模块
from .auto import AutoHfQuantizer, AutoQuantizationConfig
# 导入基础量化器模块
from .base import HfQuantizer
  • 1
  • 2
  • 3
  • 4

.\safetensors_conversion.py

import json  # 导入json模块,用于处理JSON格式数据
import uuid  # 导入uuid模块,用于生成唯一标识符
from typing import Optional  # 导入Optional类型,用于可选的类型声明

import requests  # 导入requests模块,用于发送HTTP请求
from huggingface_hub import Discussion, HfApi, get_repo_discussions  # 导入huggingface_hub相关函数和类

from .utils import cached_file, logging  # 从当前包中导入cached_file和logging模块

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器对象


def previous_pr(api: HfApi, model_id: str, pr_title: str, token: str) -> Optional["Discussion"]:
    # 获取主提交的commit_id
    main_commit = api.list_repo_commits(model_id, token=token)[0].commit_id
    # 遍历当前模型repo中的所有讨论
    for discussion in get_repo_discussions(repo_id=model_id, token=token):
        # 判断讨论是否为打开的PR并且标题为pr_title
        if discussion.title == pr_title and discussion.status == "open" and discussion.is_pull_request:
            # 获取与讨论相关的提交信息
            commits = api.list_repo_commits(model_id, revision=discussion.git_reference, token=token)

            # 检查主提交是否与PR的第二个提交相同
            if main_commit == commits[1].commit_id:
                return discussion  # 如果条件符合,返回此讨论对象
    return None  # 如果未找到符合条件的讨论,返回None


def spawn_conversion(token: str, private: bool, model_id: str):
    logger.info("Attempting to convert .bin model on the fly to safetensors.")

    safetensors_convert_space_url = "https://safetensors-convert.hf.space"
    sse_url = f"{safetensors_convert_space_url}/queue/join"
    sse_data_url = f"{safetensors_convert_space_url}/queue/data"

    # 指定fn_index以指示使用Space的run方法
    hash_data = {"fn_index": 1, "session_hash": str(uuid.uuid4())}

    def start(_sse_connection, payload):
        # 迭代SSE连接的每一行数据
        for line in _sse_connection.iter_lines():
            line = line.decode()
            if line.startswith("data:"):
                resp = json.loads(line[5:])  # 解析收到的JSON数据
                logger.debug(f"Safetensors conversion status: {resp['msg']}")
                # 处理不同的转换状态
                if resp["msg"] == "queue_full":
                    raise ValueError("Queue is full! Please try again.")
                elif resp["msg"] == "send_data":
                    event_id = resp["event_id"]
                    # 发送数据到sse_data_url
                    response = requests.post(
                        sse_data_url,
                        stream=True,
                        params=hash_data,
                        json={"event_id": event_id, **payload, **hash_data},
                    )
                    response.raise_for_status()  # 检查响应状态
                elif resp["msg"] == "process_completed":
                    return  # 如果转换完成,结束函数

    with requests.get(sse_url, stream=True, params=hash_data) as sse_connection:
        data = {"data": [model_id, private, token]}
        try:
            logger.debug("Spawning safetensors automatic conversion.")
            start(sse_connection, data)  # 调用start函数开始转换
        except Exception as e:
            logger.warning(f"Error during conversion: {repr(e)}")  # 处理转换过程中的异常情况


def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs):
    private = api.model_info(model_id).private  # 获取模型信息中的private字段值

    logger.info("Attempting to create safetensors variant")
    pr_title = "Adding `safetensors` variant of this model"
    token = kwargs.get("token")

    # 这段代码查找当前repo中是否有关于safetensors的已打开的PR
    # 调用函数 `previous_pr`,获取先前创建的 pull request 对象
    pr = previous_pr(api, model_id, pr_title, token=token)

    # 如果 pr 为 None 或者(不是私有且 pr 的作者不是 "SFConvertBot"),则执行以下操作:
    if pr is None or (not private and pr.author != "SFConvertBot"):
        # 调用函数 `spawn_conversion`,启动转换过程
        spawn_conversion(token, private, model_id)
        # 再次获取先前创建的 pull request 对象
        pr = previous_pr(api, model_id, pr_title, token=token)
    else:
        # 记录日志,指示安全张量的 pull request 已存在
        logger.info("Safetensors PR exists")

    # 构建 SHA 引用,格式为 "refs/pr/{pr.num}"
    sha = f"refs/pr/{pr.num}"

    # 返回 SHA 引用
    return sha
# 自动转换函数,根据预训练模型名称或路径以及其他缓存文件参数来执行自动转换
def auto_conversion(pretrained_model_name_or_path: str, **cached_file_kwargs):
    # 使用给定的 token 创建 Hugging Face API 的实例
    api = HfApi(token=cached_file_kwargs.get("token"))
    
    # 获取转换 Pull Request 的参考 SHA 值
    sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs)

    # 如果没有找到 SHA 值,则返回 None
    if sha is None:
        return None, None
    
    # 将 SHA 值添加到缓存文件参数中的 revision 键中
    cached_file_kwargs["revision"] = sha
    
    # 从缓存文件参数中删除 _commit_hash 键
    del cached_file_kwargs["_commit_hash"]

    # 这是一个额外的 HEAD 调用,如果能从 PR 描述中推断出分片/非分片,可以删除这个调用
    # 检查指定的模型是否存在分片的 "model.safetensors.index.json" 文件
    sharded = api.file_exists(
        pretrained_model_name_or_path,
        "model.safetensors.index.json",
        revision=sha,
        token=cached_file_kwargs.get("token"),
    )
    
    # 根据是否存在分片文件,选择相应的文件名
    filename = "model.safetensors.index.json" if sharded else "model.safetensors"

    # 缓存解析后的归档文件
    resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
    
    # 返回解析后的归档文件路径、SHA 值和是否分片的标志
    return resolved_archive_file, sha, sharded
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131

.\sagemaker\trainer_sm.py

# 导入警告模块,用于在特定情况下发出警告
import warnings

# 从上级目录中导入 Trainer 类
from ..trainer import Trainer

# 从上级目录中的 utils 模块中导入 logging 工具
from ..utils import logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义 SageMakerTrainer 类,继承自 Trainer 类
class SageMakerTrainer(Trainer):
    def __init__(self, args=None, **kwargs):
        # 发出警告,提示用户 SageMakerTrainer 类将在 Transformers v5 版本中被移除,建议使用 Trainer 类
        warnings.warn(
            "`SageMakerTrainer` is deprecated and will be removed in v5 of Transformers. You can use `Trainer` "
            "instead.",
            FutureWarning,
        )
        # 调用父类 Trainer 的初始化方法,传递参数 args 和其他关键字参数
        super().__init__(args=args, **kwargs)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

.\sagemaker\training_args_sm.py

# 导入必要的模块和库
import importlib.util  # 导入用于动态加载模块的模块
import json  # 导入处理 JSON 数据的模块
import os  # 导入与操作系统交互的模块
import warnings  # 导入用于处理警告的模块
from dataclasses import dataclass, field  # 导入用于创建数据类的装饰器和字段定义

import torch  # 导入 PyTorch 库

from ..training_args import TrainingArguments  # 从上级目录中导入训练参数类
from ..utils import cached_property, is_sagemaker_dp_enabled, logging  # 从上级目录中导入缓存属性装饰器、SageMaker DP 启用状态检查函数和日志模块

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器对象


# TODO: 在 SageMakerTrainer 重构后应移动到 `utils` 模块中


def is_sagemaker_model_parallel_available():
    # 从环境变量中获取 SageMaker 的模型并行参数
    smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}")
    try:
        # 解析 JSON 数据并检查是否包含 "partitions" 字段,模型并行需要此字段
        smp_options = json.loads(smp_options)
        if "partitions" not in smp_options:
            return False
    except json.JSONDecodeError:
        return False

    # 从环境变量中获取 SageMaker 的框架参数
    mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
    try:
        # 解析 JSON 数据并检查是否包含 "sagemaker_mpi_enabled" 字段
        mpi_options = json.loads(mpi_options)
        if not mpi_options.get("sagemaker_mpi_enabled", False):
            return False
    except json.JSONDecodeError:
        return False

    # 最后,检查是否存在 `smdistributed` 模块,以确认 SageMaker 是否支持模型并行
    return importlib.util.find_spec("smdistributed") is not None


# 如果 SageMaker 支持模型并行,则导入相应的模型并行库并进行初始化
if is_sagemaker_model_parallel_available():
    import smdistributed.modelparallel.torch as smp  # 导入 SageMaker 模型并行的 Torch 扩展库

    smp.init()  # 初始化 SageMaker 模型并行


@dataclass
class SageMakerTrainingArguments(TrainingArguments):
    mp_parameters: str = field(
        default="",
        metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in SageMakerTrainer"},
    )

    def __post_init__(self):
        super().__post_init__()
        # 发出警告,提示 `SageMakerTrainingArguments` 将在 Transformers v5 中被移除,建议使用 `TrainingArguments` 替代
        warnings.warn(
            "`SageMakerTrainingArguments` is deprecated and will be removed in v5 of Transformers. You can use "
            "`TrainingArguments` instead.",
            FutureWarning,
        )

    @cached_property
    # 设置设备
    def _setup_devices(self) -> "torch.device":
        # 打印日志信息
        logger.info("PyTorch: setting up devices")
        # 检查是否启用了torch分布式,并且本地进程的local_rank为-1
        if torch.distributed.is_available() and torch.distributed.is_initialized() and self.local_rank == -1:
            # 打印警告信息
            logger.warning(
                "torch.distributed process group is initialized, but local_rank == -1. "
                "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
            )
        # 如果禁用了CUDA
        if self.no_cuda:
            # 将设备设置为CPU
            device = torch.device("cpu")
            # GPU数量设为0
            self._n_gpu = 0
        # 如果支持SageMaker模型并行
        elif is_sagemaker_model_parallel_available():
            local_rank = smp.local_rank()
            device = torch.device("cuda", local_rank)
            # GPU数量设为1
            self._n_gpu = 1
        # 如果启用了SageMaker分布式训练
        elif is_sagemaker_dp_enabled():
            # 导入SageMaker分布式训练模块
            import smdistributed.dataparallel.torch.torch_smddp  # noqa: F401
            # 初始化进程组
            torch.distributed.init_process_group(backend="smddp", timeout=self.ddp_timeout_delta)
            self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))
            device = torch.device("cuda", self.local_rank)
            self._n_gpu = 1
        # 如果local_rank为-1
        elif self.local_rank == -1:
            # 如果n_gpu大于1,将使用nn.DataParallel。
            # 如果只想使用指定的GPU子集,可以使用`CUDA_VISIBLE_DEVICES=0`
            # 显式设置CUDA到第一个(索引0)CUDA设备,否则`set_device`会触发缺少设备索引的错误。
            # 索引0考虑了环境中可用的GPU,因此`CUDA_VISIBLE_DEVICES=1,2`与`cuda:0`将使用该环境中的第一个GPU,即GPU#1
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
            # 有时在此之前尚未运行postinit中的行,因此只需检查我们不是默认值。
            self._n_gpu = torch.cuda.device_count()
        else:
            # 在这里,我们将使用torch分布式。
            # 初始化分布式后端,负责同步节点/GPU
            if not torch.distributed.is_initialized():
                torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta)
            device = torch.device("cuda", self.local_rank)
            self._n_gpu = 1

        # 如果设备类型为cuda
        if device.type == "cuda":
            # 设置当前使用的设备
            torch.cuda.set_device(device)

        # 返回设备
        return device

    @property
    # 获取world_size属性
    def world_size(self):
        # 如果支持SageMaker模型并行
        if is_sagemaker_model_parallel_available():
            # 返回并行大小
            return smp.dp_size()

        # 返回基类的world_size
        return super().world_size

    @property
    # 获取place_model_on_device属性
    def place_model_on_device(self):
        # 如果不支持SageMaker模型并行
        return not is_sagemaker_model_parallel_available()

    @property
    # 获取_no_sync_in_gradient_accumulation属性
    def _no_sync_in_gradient_accumulation(self):
        return False
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145

.\sagemaker\__init__.py

# 导入 SageMakerTrainer 类从 trainer_sm 模块中
from .trainer_sm import SageMakerTrainer
# 导入 SageMakerTrainingArguments 和 is_sagemaker_dp_enabled 从 training_args_sm 模块中
from .training_args_sm import SageMakerTrainingArguments, is_sagemaker_dp_enabled
  • 1
  • 2
  • 3
  • 4

.\testing_utils.py

# 导入必要的标准库和第三方库
import collections  # 提供额外的数据容器,如deque(双端队列)
import contextlib  # 提供用于管理上下文的工具
import doctest  # 提供用于运行文档测试的模块
import functools  # 提供函数式编程的工具,如partial函数应用
import importlib  # 提供用于动态加载模块的工具
import inspect  # 提供用于检查源代码的工具
import logging  # 提供用于记录日志消息的功能
import multiprocessing  # 提供用于多进程编程的工具
import os  # 提供与操作系统交互的功能
import re  # 提供支持正则表达式的工具
import shlex  # 提供用于解析和操作命令行字符串的工具
import shutil  # 提供高级文件操作功能的工具
import subprocess  # 提供用于创建子进程的功能
import sys  # 提供与Python解释器交互的功能
import tempfile  # 提供创建临时文件和目录的功能
import time  # 提供时间相关的功能
import unittest  # 提供用于编写和运行单元测试的工具
from collections import defaultdict  # 提供默认字典的功能
from collections.abc import Mapping  # 提供抽象基类,用于检查映射类型
from functools import wraps  # 提供用于创建装饰器的工具
from io import StringIO  # 提供内存中文本I/O的工具
from pathlib import Path  # 提供面向对象的路径操作功能
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union  # 提供类型提示支持
from unittest import mock  # 提供用于模拟测试的工具
from unittest.mock import patch  # 提供用于模拟测试的工具

import urllib3  # 提供HTTP客户端的功能

from transformers import logging as transformers_logging  # 导入transformers库中的logging模块

from .integrations import (  # 导入自定义模块中的一系列集成检查函数
    is_clearml_available,
    is_optuna_available,
    is_ray_available,
    is_sigopt_available,
    is_tensorboard_available,
    is_wandb_available,
)
from .integrations.deepspeed import is_deepspeed_available  # 导入自定义模块中的深度加速集成检查函数
from .utils import (  # 导入自定义模块中的一系列实用工具检查函数
    is_accelerate_available,
    is_apex_available,
    is_aqlm_available,
    is_auto_awq_available,
    is_auto_gptq_available,
    is_bitsandbytes_available,
    is_bs4_available,
    is_cv2_available,
    is_cython_available,
    is_decord_available,
    is_detectron2_available,
    is_essentia_available,
    is_faiss_available,
    is_flash_attn_2_available,
    is_flax_available,
    is_fsdp_available,
    is_ftfy_available,
    is_g2p_en_available,
    is_galore_torch_available,
    is_ipex_available,
    is_jieba_available,
    is_jinja_available,
    is_jumanpp_available,
    is_keras_nlp_available,
    is_levenshtein_available,
    is_librosa_available,
    is_natten_available,
    is_nltk_available,
    is_onnx_available,
    is_optimum_available,
    is_pandas_available,
    is_peft_available,
    is_phonemizer_available,
    is_pretty_midi_available,
    is_pyctcdecode_available,
    is_pytesseract_available,
    is_pytest_available,
    is_pytorch_quantization_available,
    is_quanto_available,
    is_rjieba_available,
    is_sacremoses_available,
    is_safetensors_available,
    is_scipy_available,
    is_sentencepiece_available,
    is_seqio_available,
    is_soundfile_availble,
    is_spacy_available,
    is_sudachi_available,
    is_sudachi_projection_available,
    is_tensorflow_probability_available,
    is_tensorflow_text_available,
    is_tf2onnx_available,
    is_tf_available,
    is_timm_available,
    is_tokenizers_available,
    is_torch_available,
)
    # 检查当前设备是否支持 Torch 的 BF16 数据类型
    is_torch_bf16_available_on_device,
    # 检查当前 CPU 是否支持 Torch 的 BF16 数据类型
    is_torch_bf16_cpu_available,
    # 检查当前 GPU 是否支持 Torch 的 BF16 数据类型
    is_torch_bf16_gpu_available,
    # 检查当前设备是否支持 Torch 的 FP16 数据类型
    is_torch_fp16_available_on_device,
    # 检查当前设备是否支持 Torch 的 NeuronCore 加速器
    is_torch_neuroncore_available,
    # 检查当前设备是否支持 Torch 的 NPU 加速器
    is_torch_npu_available,
    # 检查当前设备是否支持 Torch 的 SDPA 加速器
    is_torch_sdpa_available,
    # 检查当前设备是否支持 Torch 的 TensorRT FX 加速器
    is_torch_tensorrt_fx_available,
    # 检查当前设备是否支持 Torch 的 TF32 数据类型
    is_torch_tf32_available,
    # 检查当前设备是否支持 Torch 的 XLA 加速器
    is_torch_xla_available,
    # 检查当前设备是否支持 Torch 的 XPU 加速器
    is_torch_xpu_available,
    # 检查当前环境是否支持 Torch Audio 库
    is_torchaudio_available,
    # 检查当前环境是否支持 TorchDynamo 库
    is_torchdynamo_available,
    # 检查当前环境是否支持 TorchVision 库
    is_torchvision_available,
    # 检查当前环境是否支持 Torch 的 Vision 扩展
    is_vision_available,
    # 将字符串转换为布尔值(支持"true", "false", "yes", "no", "1", "0"等)
    strtobool,
# 如果加速功能可用,则从 accelerate.state 中导入 AcceleratorState 和 PartialState 类
if is_accelerate_available():
    from accelerate.state import AcceleratorState, PartialState


# 如果 pytest 可用,则从 _pytest.doctest 中导入以下模块
# Module: 用于表示 Python 模块的类
# _get_checker: 获取 doctest 的检查器
# _get_continue_on_failure: 获取 doctest 的继续失败选项
# _get_runner: 获取 doctest 的运行器
# _is_mocked: 检查是否模拟了对象
# _patch_unwrap_mock_aware: 解除 Mock 对象感知的补丁
# get_optionflags: 获取 doctest 的选项标志
from _pytest.doctest import (
    Module,
    _get_checker,
    _get_continue_on_failure,
    _get_runner,
    _is_mocked,
    _patch_unwrap_mock_aware,
    get_optionflags,
)

# 如果 pytest 不可用,则将 Module 和 DoctestItem 设置为 object 类型
else:
    Module = object
    DoctestItem = object


# 定义了一个小型模型的标识符字符串
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"

# 用于测试自动检测模型类型的标识符
DUMMY_UNKNOWN_IDENTIFIER = "julien-c/dummy-unknown"
DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer"

# 用于测试 Hub 的用户和端点
USER = "__DUMMY_TRANSFORMERS_USER__"
ENDPOINT_STAGING = "https://hub-ci.huggingface.co"

# 仅在受控的 CI 实例中可用,用于测试用的令牌
TOKEN = "hf_94wBhPGp6KrrTH3KDchhKpRxZwd6dmHWLL"


# 从环境变量中解析布尔类型的标志
def parse_flag_from_env(key, default=False):
    try:
        value = os.environ[key]
    except KeyError:
        # 如果 KEY 未设置,则使用默认值 `default`
        _value = default
    else:
        # 如果 KEY 已设置,则尝试将其转换为 True 或 False
        try:
            _value = strtobool(value)
        except ValueError:
            # 如果值不是 `yes` 或 `no`,则抛出异常
            raise ValueError(f"If set, {key} must be yes or no.")
    return _value


# 从环境变量中解析整数类型的值
def parse_int_from_env(key, default=None):
    try:
        value = os.environ[key]
    except KeyError:
        _value = default
    else:
        try:
            _value = int(value)
        except ValueError:
            # 如果值不是整数,则抛出异常
            raise ValueError(f"If set, {key} must be a int.")
    return _value


# 根据环境变量 `RUN_SLOW` 解析是否运行慢速测试的标志
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
# 根据环境变量 `RUN_PT_TF_CROSS_TESTS` 解析是否运行 PyTorch 和 TensorFlow 交叉测试的标志
_run_pt_tf_cross_tests = parse_flag_from_env("RUN_PT_TF_CROSS_TESTS", default=True)
# 根据环境变量 `RUN_PT_FLAX_CROSS_TESTS` 解析是否运行 PyTorch 和 Flax 交叉测试的标志
_run_pt_flax_cross_tests = parse_flag_from_env("RUN_PT_FLAX_CROSS_TESTS", default=True)
# 根据环境变量 `RUN_CUSTOM_TOKENIZERS` 解析是否运行自定义分词器测试的标志
_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
# 根据环境变量 `HUGGINGFACE_CO_STAGING` 解析是否运行在 Hugging Face CO 预发布环境中的标志
_run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False)
# 根据环境变量 `TF_GPU_MEMORY_LIMIT` 解析 TensorFlow GPU 内存限制的值
_tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None)
# 根据环境变量 `RUN_PIPELINE_TESTS` 解析是否运行管道测试的标志
_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True)
# 根据环境变量 `RUN_TOOL_TESTS` 解析是否运行工具测试的标志
_run_tool_tests = parse_flag_from_env("RUN_TOOL_TESTS", default=False)
# 根据环境变量 `RUN_THIRD_PARTY_DEVICE_TESTS` 解析是否运行第三方设备测试的标志
_run_third_party_device_tests = parse_flag_from_env("RUN_THIRD_PARTY_DEVICE_TESTS", default=False)


# 函数装饰器,用于标记 PT+TF 交叉测试
def is_pt_tf_cross_test(test_case):
    """
    Decorator marking a test as a test that control interactions between PyTorch and TensorFlow.

    PT+TF tests are skipped by default and we can run only them by setting RUN_PT_TF_CROSS_TESTS environment variable
    to a truthy value and selecting the is_pt_tf_cross_test pytest mark.

    """
    # 如果未设置环境变量 `RUN_PT_TF_CROSS_TESTS` 或者当前环境中没有安装 PyTorch 或 TensorFlow,
    # 则跳过 PT+TF 测试
    if not _run_pt_tf_cross_tests or not is_torch_available() or not is_tf_available():
        return unittest.skip("test is PT+TF test")(test_case)
    else:
        # 尝试导入 pytest 模块,避免在主库中硬编码依赖 pytest
        try:
            import pytest  
        # 如果导入失败,返回原始的 test_case
        except ImportError:
            return test_case
        # 如果导入成功,应用 pytest.mark.is_pt_tf_cross_test() 装饰器到 test_case 上
        else:
            return pytest.mark.is_pt_tf_cross_test()(test_case)
# 标记一个测试用例为控制 PyTorch 和 Flax 交互的测试的装饰器

PT+FLAX 测试默认情况下会被跳过,只有当设置了环境变量 RUN_PT_FLAX_CROSS_TESTS 为真值并且选择了 is_pt_flax_cross_test pytest 标记时才会运行。

def is_pt_flax_cross_test(test_case):
    if not _run_pt_flax_cross_tests or not is_torch_available() or not is_flax_available():
        # 如果不满足运行条件(未设置环境变量或者没有可用的 PyTorch 或 Flax),则跳过测试
        return unittest.skip("test is PT+FLAX test")(test_case)
    else:
        try:
            import pytest  # 我们不需要在主库中强制依赖 pytest
        except ImportError:
            return test_case
        else:
            # 使用 pytest 的 is_pt_flax_cross_test 标记来标记测试用例
            return pytest.mark.is_pt_flax_cross_test()(test_case)


# 标记一个测试用例为在 staging 环境下运行的测试的装饰器

这些测试将在 huggingface.co 的 staging 环境下运行,而不是真实的模型中心。

def is_staging_test(test_case):
    if not _run_staging:
        # 如果不运行 staging 测试,则跳过测试
        return unittest.skip("test is staging test")(test_case)
    else:
        try:
            import pytest  # 我们不需要在主库中强制依赖 pytest
        except ImportError:
            return test_case
        else:
            # 使用 pytest 的 is_staging_test 标记来标记测试用例
            return pytest.mark.is_staging_test()(test_case)


# 标记一个测试用例为 pipeline 测试的装饰器

如果未将 RUN_PIPELINE_TESTS 设置为真值,则这些测试将被跳过。

def is_pipeline_test(test_case):
    if not _run_pipeline_tests:
        # 如果不运行 pipeline 测试,则跳过测试
        return unittest.skip("test is pipeline test")(test_case)
    else:
        try:
            import pytest  # 我们不需要在主库中强制依赖 pytest
        except ImportError:
            return test_case
        else:
            # 使用 pytest 的 is_pipeline_test 标记来标记测试用例
            return pytest.mark.is_pipeline_test()(test_case)


# 标记一个测试用例为工具测试的装饰器

如果未将 RUN_TOOL_TESTS 设置为真值,则这些测试将被跳过。

def is_tool_test(test_case):
    if not _run_tool_tests:
        # 如果不运行工具测试,则跳过测试
        return unittest.skip("test is a tool test")(test_case)
    else:
        try:
            import pytest  # 我们不需要在主库中强制依赖 pytest
        except ImportError:
            return test_case
        else:
            # 使用 pytest 的 is_tool_test 标记来标记测试用例
            return pytest.mark.is_tool_test()(test_case)


# 标记一个测试用例为慢速测试的装饰器

慢速测试默认情况下会被跳过。设置 RUN_SLOW 环境变量为真值以运行这些测试。

def slow(test_case):
    return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)


# 标记一个测试用例为太慢测试的装饰器

太慢的测试在修复过程中会被跳过。不应将任何测试标记为 "tooslow",因为这些测试不会被 CI 测试。

def tooslow(test_case):
    return unittest.skip("test is too slow")(test_case)


# 标记一个测试用例为自定义分词器测试的装饰器
    """
    自定义分词器需要额外的依赖项,默认情况下会被跳过。将环境变量 RUN_CUSTOM_TOKENIZERS
    设置为真值,以便运行它们。
    """
    # 返回一个装饰器,根据 _run_custom_tokenizers 的真假决定是否跳过测试用例
    return unittest.skipUnless(_run_custom_tokenizers, "test of custom tokenizers")(test_case)
# 装饰器,用于标记需要 BeautifulSoup4 的测试用例。在未安装 BeautifulSoup4 时跳过这些测试。
def require_bs4(test_case):
    return unittest.skipUnless(is_bs4_available(), "test requires BeautifulSoup4")(test_case)


# 装饰器,用于标记需要 GaLore 的测试用例。在未安装 GaLore 时跳过这些测试。
def require_galore_torch(test_case):
    return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)


# 装饰器,用于标记需要 OpenCV 的测试用例。在未安装 OpenCV 时跳过这些测试。
def require_cv2(test_case):
    return unittest.skipUnless(is_cv2_available(), "test requires OpenCV")(test_case)


# 装饰器,用于标记需要 Levenshtein 的测试用例。在未安装 Levenshtein 时跳过这些测试。
def require_levenshtein(test_case):
    return unittest.skipUnless(is_levenshtein_available(), "test requires Levenshtein")(test_case)


# 装饰器,用于标记需要 NLTK 的测试用例。在未安装 NLTK 时跳过这些测试。
def require_nltk(test_case):
    return unittest.skipUnless(is_nltk_available(), "test requires NLTK")(test_case)


# 装饰器,用于标记需要 accelerate 的测试用例。在未安装 accelerate 时跳过这些测试。
def require_accelerate(test_case):
    return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)


# 装饰器,用于标记需要 fsdp 的测试用例。在未安装 fsdp 或版本不符合要求时跳过这些测试。
def require_fsdp(test_case, min_version: str = "1.12.0"):
    return unittest.skipUnless(is_fsdp_available(min_version), f"test requires torch version >= {min_version}")(test_case)


# 装饰器,用于标记需要 g2p_en 的测试用例。在未安装 SentencePiece 时跳过这些测试。
def require_g2p_en(test_case):
    return unittest.skipUnless(is_g2p_en_available(), "test requires g2p_en")(test_case)


# 装饰器,用于标记需要 safetensors 的测试用例。在未安装 safetensors 时跳过这些测试。
def require_safetensors(test_case):
    return unittest.skipUnless(is_safetensors_available(), "test requires safetensors")(test_case)


# 装饰器,用于标记需要 rjieba 的测试用例。在未安装 rjieba 时跳过这些测试。
def require_rjieba(test_case):
    return unittest.skipUnless(is_rjieba_available(), "test requires rjieba")(test_case)


# 装饰器,用于标记需要 jieba 的测试用例。在未安装 jieba 时跳过这些测试。
def require_jieba(test_case):
    return unittest.skipUnless(is_jieba_available(), "test requires jieba")(test_case)


# 装饰器,用于标记需要 jinja 的测试用例。在此处仅声明函数,实际装饰逻辑未提供。
def require_jinja(test_case):
    # Placeholder for decorator marking tests requiring Jinja
    pass
    # 使用装饰器标记一个需要 jinja 的测试用例。如果 jinja 没有安装,则跳过这些测试。
    """
    使用 unittest.skipUnless 函数来动态地装饰测试用例,只有在 jinja 可用时才运行该测试用例。
    如果 is_jinja_available() 函数返回 True,则装饰器返回一个可用于跳过测试的装饰器函数,否则返回 None。
    """
    return unittest.skipUnless(is_jinja_available(), "test requires jinja")(test_case)
# 根据条件判断是否加载 tf2onnx
def require_tf2onnx(test_case):
    return unittest.skipUnless(is_tf2onnx_available(), "test requires tf2onnx")(test_case)


# 根据条件判断是否加载 ONNX
def require_onnx(test_case):
    return unittest.skipUnless(is_onnx_available(), "test requires ONNX")(test_case)


# 根据条件判断是否加载 Timm
def require_timm(test_case):
    """
    Decorator marking a test that requires Timm.

    These tests are skipped when Timm isn't installed.
    """
    return unittest.skipUnless(is_timm_available(), "test requires Timm")(test_case)


# 根据条件判断是否加载 NATTEN
def require_natten(test_case):
    """
    Decorator marking a test that requires NATTEN.

    These tests are skipped when NATTEN isn't installed.
    """
    return unittest.skipUnless(is_natten_available(), "test requires natten")(test_case)


# 根据条件判断是否加载 PyTorch
def require_torch(test_case):
    """
    Decorator marking a test that requires PyTorch.

    These tests are skipped when PyTorch isn't installed.
    """
    return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)


# 根据条件判断是否加载 Flash Attention
def require_flash_attn(test_case):
    """
    Decorator marking a test that requires Flash Attention.

    These tests are skipped when Flash Attention isn't installed.
    """
    return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case)


# 根据条件判断是否加载 PyTorch's SDPA
def require_torch_sdpa(test_case):
    """
    Decorator marking a test that requires PyTorch's SDPA.

    These tests are skipped when requirements are not met (torch version).
    """
    return unittest.skipUnless(is_torch_sdpa_available(), "test requires PyTorch SDPA")(test_case)


# 根据条件判断是否加载 HF token
def require_read_token(fn):
    """
    A decorator that loads the HF token for tests that require to load gated models.
    """
    token = os.getenv("HF_HUB_READ_TOKEN")

    @wraps(fn)
    def _inner(*args, **kwargs):
        with patch("huggingface_hub.utils._headers.get_token", return_value=token):
            return fn(*args, **kwargs)

    return _inner


# 根据条件判断是否加载 PEFT
def require_peft(test_case):
    """
    Decorator marking a test that requires PEFT.

    These tests are skipped when PEFT isn't installed.
    """
    return unittest.skipUnless(is_peft_available(), "test requires PEFT")(test_case)


# 根据条件判断是否加载 Torchvision
def require_torchvision(test_case):
    """
    Decorator marking a test that requires Torchvision.

    These tests are skipped when Torchvision isn't installed.
    """
    return unittest.skipUnless(is_torchvision_available(), "test requires Torchvision")(test_case)


# 根据条件判断是否加载 PyTorch 或 TensorFlow
def require_torch_or_tf(test_case):
    """
    Decorator marking a test that requires PyTorch or TensorFlow.

    These tests are skipped when neither PyTorch nor TensorFlow is installed.
    """
    return unittest.skipUnless(is_torch_available() or is_tf_available(), "test requires PyTorch or TensorFlow")(
        test_case
    )


# 根据条件判断是否加载 Intel Extension for PyTorch
def require_intel_extension_for_pytorch(test_case):
    """
    Decorator marking a test that requires Intel Extension for PyTorch.
    """
    # 注释部分未提供
    pass
    # 当未安装Intel Extension for PyTorch或者其版本与当前PyTorch版本不匹配时,跳过这些测试。
    """
    返回一个装饰器,用于根据条件跳过测试。
    装饰器检查是否可用Intel Extension for PyTorch(IPEX)。
    如果不可用或版本不匹配,则跳过测试,并提供相应的提示信息。
    参考链接:https://github.com/intel/intel-extension-for-pytorch
    """
    return unittest.skipUnless(
        is_ipex_available(),
        "test requires Intel Extension for PyTorch to be installed and match current PyTorch version, see"
        " https://github.com/intel/intel-extension-for-pytorch",
    )(test_case)
# 装饰器,用于标记一个测试需要 TensorFlow probability
def require_tensorflow_probability(test_case):
    # 返回一个装饰器,其功能是当 TensorFlow probability 未安装时跳过测试
    return unittest.skipUnless(is_tensorflow_probability_available(), "test requires TensorFlow probability")(
        test_case
    )


# 装饰器,用于标记一个测试需要 torchaudio
def require_torchaudio(test_case):
    # 返回一个装饰器,其功能是当 torchaudio 未安装时跳过测试
    return unittest.skipUnless(is_torchaudio_available(), "test requires torchaudio")(test_case)


# 装饰器,用于标记一个测试需要 TensorFlow
def require_tf(test_case):
    # 返回一个装饰器,其功能是当 TensorFlow 未安装时跳过测试
    return unittest.skipUnless(is_tf_available(), "test requires TensorFlow")(test_case)


# 装饰器,用于标记一个测试需要 JAX & Flax
def require_flax(test_case):
    # 返回一个装饰器,其功能是当 JAX 或 Flax 未安装时跳过测试
    return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)


# 装饰器,用于标记一个测试需要 SentencePiece
def require_sentencepiece(test_case):
    # 返回一个装饰器,其功能是当 SentencePiece 未安装时跳过测试
    return unittest.skipUnless(is_sentencepiece_available(), "test requires SentencePiece")(test_case)


# 装饰器,用于标记一个测试需要 Sacremoses
def require_sacremoses(test_case):
    # 返回一个装饰器,其功能是当 Sacremoses 未安装时跳过测试
    return unittest.skipUnless(is_sacremoses_available(), "test requires Sacremoses")(test_case)


# 装饰器,用于标记一个测试需要 Seqio
def require_seqio(test_case):
    # 返回一个装饰器,其功能是当 Seqio 未安装时跳过测试
    return unittest.skipUnless(is_seqio_available(), "test requires Seqio")(test_case)


# 装饰器,用于标记一个测试需要 Scipy
def require_scipy(test_case):
    # 返回一个装饰器,其功能是当 Scipy 未安装时跳过测试
    return unittest.skipUnless(is_scipy_available(), "test requires Scipy")(test_case)


# 装饰器,用于标记一个测试需要 
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/908516
推荐阅读
相关标签