赞
踩
.\quantizers\__init__.py
# 导入自动量化相关模块
from .auto import AutoHfQuantizer, AutoQuantizationConfig
# 导入基础量化器模块
from .base import HfQuantizer
.\safetensors_conversion.py
import json # 导入json模块,用于处理JSON格式数据 import uuid # 导入uuid模块,用于生成唯一标识符 from typing import Optional # 导入Optional类型,用于可选的类型声明 import requests # 导入requests模块,用于发送HTTP请求 from huggingface_hub import Discussion, HfApi, get_repo_discussions # 导入huggingface_hub相关函数和类 from .utils import cached_file, logging # 从当前包中导入cached_file和logging模块 logger = logging.get_logger(__name__) # 获取当前模块的日志记录器对象 def previous_pr(api: HfApi, model_id: str, pr_title: str, token: str) -> Optional["Discussion"]: # 获取主提交的commit_id main_commit = api.list_repo_commits(model_id, token=token)[0].commit_id # 遍历当前模型repo中的所有讨论 for discussion in get_repo_discussions(repo_id=model_id, token=token): # 判断讨论是否为打开的PR并且标题为pr_title if discussion.title == pr_title and discussion.status == "open" and discussion.is_pull_request: # 获取与讨论相关的提交信息 commits = api.list_repo_commits(model_id, revision=discussion.git_reference, token=token) # 检查主提交是否与PR的第二个提交相同 if main_commit == commits[1].commit_id: return discussion # 如果条件符合,返回此讨论对象 return None # 如果未找到符合条件的讨论,返回None def spawn_conversion(token: str, private: bool, model_id: str): logger.info("Attempting to convert .bin model on the fly to safetensors.") safetensors_convert_space_url = "https://safetensors-convert.hf.space" sse_url = f"{safetensors_convert_space_url}/queue/join" sse_data_url = f"{safetensors_convert_space_url}/queue/data" # 指定fn_index以指示使用Space的run方法 hash_data = {"fn_index": 1, "session_hash": str(uuid.uuid4())} def start(_sse_connection, payload): # 迭代SSE连接的每一行数据 for line in _sse_connection.iter_lines(): line = line.decode() if line.startswith("data:"): resp = json.loads(line[5:]) # 解析收到的JSON数据 logger.debug(f"Safetensors conversion status: {resp['msg']}") # 处理不同的转换状态 if resp["msg"] == "queue_full": raise ValueError("Queue is full! Please try again.") elif resp["msg"] == "send_data": event_id = resp["event_id"] # 发送数据到sse_data_url response = requests.post( sse_data_url, stream=True, params=hash_data, json={"event_id": event_id, **payload, **hash_data}, ) response.raise_for_status() # 检查响应状态 elif resp["msg"] == "process_completed": return # 如果转换完成,结束函数 with requests.get(sse_url, stream=True, params=hash_data) as sse_connection: data = {"data": [model_id, private, token]} try: logger.debug("Spawning safetensors automatic conversion.") start(sse_connection, data) # 调用start函数开始转换 except Exception as e: logger.warning(f"Error during conversion: {repr(e)}") # 处理转换过程中的异常情况 def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs): private = api.model_info(model_id).private # 获取模型信息中的private字段值 logger.info("Attempting to create safetensors variant") pr_title = "Adding `safetensors` variant of this model" token = kwargs.get("token") # 这段代码查找当前repo中是否有关于safetensors的已打开的PR # 调用函数 `previous_pr`,获取先前创建的 pull request 对象 pr = previous_pr(api, model_id, pr_title, token=token) # 如果 pr 为 None 或者(不是私有且 pr 的作者不是 "SFConvertBot"),则执行以下操作: if pr is None or (not private and pr.author != "SFConvertBot"): # 调用函数 `spawn_conversion`,启动转换过程 spawn_conversion(token, private, model_id) # 再次获取先前创建的 pull request 对象 pr = previous_pr(api, model_id, pr_title, token=token) else: # 记录日志,指示安全张量的 pull request 已存在 logger.info("Safetensors PR exists") # 构建 SHA 引用,格式为 "refs/pr/{pr.num}" sha = f"refs/pr/{pr.num}" # 返回 SHA 引用 return sha # 自动转换函数,根据预训练模型名称或路径以及其他缓存文件参数来执行自动转换 def auto_conversion(pretrained_model_name_or_path: str, **cached_file_kwargs): # 使用给定的 token 创建 Hugging Face API 的实例 api = HfApi(token=cached_file_kwargs.get("token")) # 获取转换 Pull Request 的参考 SHA 值 sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs) # 如果没有找到 SHA 值,则返回 None if sha is None: return None, None # 将 SHA 值添加到缓存文件参数中的 revision 键中 cached_file_kwargs["revision"] = sha # 从缓存文件参数中删除 _commit_hash 键 del cached_file_kwargs["_commit_hash"] # 这是一个额外的 HEAD 调用,如果能从 PR 描述中推断出分片/非分片,可以删除这个调用 # 检查指定的模型是否存在分片的 "model.safetensors.index.json" 文件 sharded = api.file_exists( pretrained_model_name_or_path, "model.safetensors.index.json", revision=sha, token=cached_file_kwargs.get("token"), ) # 根据是否存在分片文件,选择相应的文件名 filename = "model.safetensors.index.json" if sharded else "model.safetensors" # 缓存解析后的归档文件 resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) # 返回解析后的归档文件路径、SHA 值和是否分片的标志 return resolved_archive_file, sha, sharded
.\sagemaker\trainer_sm.py
# 导入警告模块,用于在特定情况下发出警告 import warnings # 从上级目录中导入 Trainer 类 from ..trainer import Trainer # 从上级目录中的 utils 模块中导入 logging 工具 from ..utils import logging # 获取当前模块的日志记录器 logger = logging.get_logger(__name__) # 定义 SageMakerTrainer 类,继承自 Trainer 类 class SageMakerTrainer(Trainer): def __init__(self, args=None, **kwargs): # 发出警告,提示用户 SageMakerTrainer 类将在 Transformers v5 版本中被移除,建议使用 Trainer 类 warnings.warn( "`SageMakerTrainer` is deprecated and will be removed in v5 of Transformers. You can use `Trainer` " "instead.", FutureWarning, ) # 调用父类 Trainer 的初始化方法,传递参数 args 和其他关键字参数 super().__init__(args=args, **kwargs)
.\sagemaker\training_args_sm.py
# 导入必要的模块和库 import importlib.util # 导入用于动态加载模块的模块 import json # 导入处理 JSON 数据的模块 import os # 导入与操作系统交互的模块 import warnings # 导入用于处理警告的模块 from dataclasses import dataclass, field # 导入用于创建数据类的装饰器和字段定义 import torch # 导入 PyTorch 库 from ..training_args import TrainingArguments # 从上级目录中导入训练参数类 from ..utils import cached_property, is_sagemaker_dp_enabled, logging # 从上级目录中导入缓存属性装饰器、SageMaker DP 启用状态检查函数和日志模块 logger = logging.get_logger(__name__) # 获取当前模块的日志记录器对象 # TODO: 在 SageMakerTrainer 重构后应移动到 `utils` 模块中 def is_sagemaker_model_parallel_available(): # 从环境变量中获取 SageMaker 的模型并行参数 smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}") try: # 解析 JSON 数据并检查是否包含 "partitions" 字段,模型并行需要此字段 smp_options = json.loads(smp_options) if "partitions" not in smp_options: return False except json.JSONDecodeError: return False # 从环境变量中获取 SageMaker 的框架参数 mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}") try: # 解析 JSON 数据并检查是否包含 "sagemaker_mpi_enabled" 字段 mpi_options = json.loads(mpi_options) if not mpi_options.get("sagemaker_mpi_enabled", False): return False except json.JSONDecodeError: return False # 最后,检查是否存在 `smdistributed` 模块,以确认 SageMaker 是否支持模型并行 return importlib.util.find_spec("smdistributed") is not None # 如果 SageMaker 支持模型并行,则导入相应的模型并行库并进行初始化 if is_sagemaker_model_parallel_available(): import smdistributed.modelparallel.torch as smp # 导入 SageMaker 模型并行的 Torch 扩展库 smp.init() # 初始化 SageMaker 模型并行 @dataclass class SageMakerTrainingArguments(TrainingArguments): mp_parameters: str = field( default="", metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in SageMakerTrainer"}, ) def __post_init__(self): super().__post_init__() # 发出警告,提示 `SageMakerTrainingArguments` 将在 Transformers v5 中被移除,建议使用 `TrainingArguments` 替代 warnings.warn( "`SageMakerTrainingArguments` is deprecated and will be removed in v5 of Transformers. You can use " "`TrainingArguments` instead.", FutureWarning, ) @cached_property # 设置设备 def _setup_devices(self) -> "torch.device": # 打印日志信息 logger.info("PyTorch: setting up devices") # 检查是否启用了torch分布式,并且本地进程的local_rank为-1 if torch.distributed.is_available() and torch.distributed.is_initialized() and self.local_rank == -1: # 打印警告信息 logger.warning( "torch.distributed process group is initialized, but local_rank == -1. " "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch" ) # 如果禁用了CUDA if self.no_cuda: # 将设备设置为CPU device = torch.device("cpu") # GPU数量设为0 self._n_gpu = 0 # 如果支持SageMaker模型并行 elif is_sagemaker_model_parallel_available(): local_rank = smp.local_rank() device = torch.device("cuda", local_rank) # GPU数量设为1 self._n_gpu = 1 # 如果启用了SageMaker分布式训练 elif is_sagemaker_dp_enabled(): # 导入SageMaker分布式训练模块 import smdistributed.dataparallel.torch.torch_smddp # noqa: F401 # 初始化进程组 torch.distributed.init_process_group(backend="smddp", timeout=self.ddp_timeout_delta) self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK")) device = torch.device("cuda", self.local_rank) self._n_gpu = 1 # 如果local_rank为-1 elif self.local_rank == -1: # 如果n_gpu大于1,将使用nn.DataParallel。 # 如果只想使用指定的GPU子集,可以使用`CUDA_VISIBLE_DEVICES=0` # 显式设置CUDA到第一个(索引0)CUDA设备,否则`set_device`会触发缺少设备索引的错误。 # 索引0考虑了环境中可用的GPU,因此`CUDA_VISIBLE_DEVICES=1,2`与`cuda:0`将使用该环境中的第一个GPU,即GPU#1 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 有时在此之前尚未运行postinit中的行,因此只需检查我们不是默认值。 self._n_gpu = torch.cuda.device_count() else: # 在这里,我们将使用torch分布式。 # 初始化分布式后端,负责同步节点/GPU if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta) device = torch.device("cuda", self.local_rank) self._n_gpu = 1 # 如果设备类型为cuda if device.type == "cuda": # 设置当前使用的设备 torch.cuda.set_device(device) # 返回设备 return device @property # 获取world_size属性 def world_size(self): # 如果支持SageMaker模型并行 if is_sagemaker_model_parallel_available(): # 返回并行大小 return smp.dp_size() # 返回基类的world_size return super().world_size @property # 获取place_model_on_device属性 def place_model_on_device(self): # 如果不支持SageMaker模型并行 return not is_sagemaker_model_parallel_available() @property # 获取_no_sync_in_gradient_accumulation属性 def _no_sync_in_gradient_accumulation(self): return False
.\sagemaker\__init__.py
# 导入 SageMakerTrainer 类从 trainer_sm 模块中
from .trainer_sm import SageMakerTrainer
# 导入 SageMakerTrainingArguments 和 is_sagemaker_dp_enabled 从 training_args_sm 模块中
from .training_args_sm import SageMakerTrainingArguments, is_sagemaker_dp_enabled
.\testing_utils.py
# 导入必要的标准库和第三方库
import collections # 提供额外的数据容器,如deque(双端队列)
import contextlib # 提供用于管理上下文的工具
import doctest # 提供用于运行文档测试的模块
import functools # 提供函数式编程的工具,如partial函数应用
import importlib # 提供用于动态加载模块的工具
import inspect # 提供用于检查源代码的工具
import logging # 提供用于记录日志消息的功能
import multiprocessing # 提供用于多进程编程的工具
import os # 提供与操作系统交互的功能
import re # 提供支持正则表达式的工具
import shlex # 提供用于解析和操作命令行字符串的工具
import shutil # 提供高级文件操作功能的工具
import subprocess # 提供用于创建子进程的功能
import sys # 提供与Python解释器交互的功能
import tempfile # 提供创建临时文件和目录的功能
import time # 提供时间相关的功能
import unittest # 提供用于编写和运行单元测试的工具
from collections import defaultdict # 提供默认字典的功能
from collections.abc import Mapping # 提供抽象基类,用于检查映射类型
from functools import wraps # 提供用于创建装饰器的工具
from io import StringIO # 提供内存中文本I/O的工具
from pathlib import Path # 提供面向对象的路径操作功能
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union # 提供类型提示支持
from unittest import mock # 提供用于模拟测试的工具
from unittest.mock import patch # 提供用于模拟测试的工具
import urllib3 # 提供HTTP客户端的功能
from transformers import logging as transformers_logging # 导入transformers库中的logging模块
from .integrations import ( # 导入自定义模块中的一系列集成检查函数
is_clearml_available,
is_optuna_available,
is_ray_available,
is_sigopt_available,
is_tensorboard_available,
is_wandb_available,
)
from .integrations.deepspeed import is_deepspeed_available # 导入自定义模块中的深度加速集成检查函数
from .utils import ( # 导入自定义模块中的一系列实用工具检查函数
is_accelerate_available,
is_apex_available,
is_aqlm_available,
is_auto_awq_available,
is_auto_gptq_available,
is_bitsandbytes_available,
is_bs4_available,
is_cv2_available,
is_cython_available,
is_decord_available,
is_detectron2_available,
is_essentia_available,
is_faiss_available,
is_flash_attn_2_available,
is_flax_available,
is_fsdp_available,
is_ftfy_available,
is_g2p_en_available,
is_galore_torch_available,
is_ipex_available,
is_jieba_available,
is_jinja_available,
is_jumanpp_available,
is_keras_nlp_available,
is_levenshtein_available,
is_librosa_available,
is_natten_available,
is_nltk_available,
is_onnx_available,
is_optimum_available,
is_pandas_available,
is_peft_available,
is_phonemizer_available,
is_pretty_midi_available,
is_pyctcdecode_available,
is_pytesseract_available,
is_pytest_available,
is_pytorch_quantization_available,
is_quanto_available,
is_rjieba_available,
is_sacremoses_available,
is_safetensors_available,
is_scipy_available,
is_sentencepiece_available,
is_seqio_available,
is_soundfile_availble,
is_spacy_available,
is_sudachi_available,
is_sudachi_projection_available,
is_tensorflow_probability_available,
is_tensorflow_text_available,
is_tf2onnx_available,
is_tf_available,
is_timm_available,
is_tokenizers_available,
is_torch_available,
)
# 检查当前设备是否支持 Torch 的 BF16 数据类型
is_torch_bf16_available_on_device,
# 检查当前 CPU 是否支持 Torch 的 BF16 数据类型
is_torch_bf16_cpu_available,
# 检查当前 GPU 是否支持 Torch 的 BF16 数据类型
is_torch_bf16_gpu_available,
# 检查当前设备是否支持 Torch 的 FP16 数据类型
is_torch_fp16_available_on_device,
# 检查当前设备是否支持 Torch 的 NeuronCore 加速器
is_torch_neuroncore_available,
# 检查当前设备是否支持 Torch 的 NPU 加速器
is_torch_npu_available,
# 检查当前设备是否支持 Torch 的 SDPA 加速器
is_torch_sdpa_available,
# 检查当前设备是否支持 Torch 的 TensorRT FX 加速器
is_torch_tensorrt_fx_available,
# 检查当前设备是否支持 Torch 的 TF32 数据类型
is_torch_tf32_available,
# 检查当前设备是否支持 Torch 的 XLA 加速器
is_torch_xla_available,
# 检查当前设备是否支持 Torch 的 XPU 加速器
is_torch_xpu_available,
# 检查当前环境是否支持 Torch Audio 库
is_torchaudio_available,
# 检查当前环境是否支持 TorchDynamo 库
is_torchdynamo_available,
# 检查当前环境是否支持 TorchVision 库
is_torchvision_available,
# 检查当前环境是否支持 Torch 的 Vision 扩展
is_vision_available,
# 将字符串转换为布尔值(支持"true", "false", "yes", "no", "1", "0"等)
strtobool,
# 如果加速功能可用,则从 accelerate.state 中导入 AcceleratorState 和 PartialState 类
if is_accelerate_available():
from accelerate.state import AcceleratorState, PartialState
# 如果 pytest 可用,则从 _pytest.doctest 中导入以下模块
# Module: 用于表示 Python 模块的类
# _get_checker: 获取 doctest 的检查器
# _get_continue_on_failure: 获取 doctest 的继续失败选项
# _get_runner: 获取 doctest 的运行器
# _is_mocked: 检查是否模拟了对象
# _patch_unwrap_mock_aware: 解除 Mock 对象感知的补丁
# get_optionflags: 获取 doctest 的选项标志
from _pytest.doctest import (
Module,
_get_checker,
_get_continue_on_failure,
_get_runner,
_is_mocked,
_patch_unwrap_mock_aware,
get_optionflags,
)
# 如果 pytest 不可用,则将 Module 和 DoctestItem 设置为 object 类型
else:
Module = object
DoctestItem = object
# 定义了一个小型模型的标识符字符串
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
# 用于测试自动检测模型类型的标识符
DUMMY_UNKNOWN_IDENTIFIER = "julien-c/dummy-unknown"
DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer"
# 用于测试 Hub 的用户和端点
USER = "__DUMMY_TRANSFORMERS_USER__"
ENDPOINT_STAGING = "https://hub-ci.huggingface.co"
# 仅在受控的 CI 实例中可用,用于测试用的令牌
TOKEN = "hf_94wBhPGp6KrrTH3KDchhKpRxZwd6dmHWLL"
# 从环境变量中解析布尔类型的标志
def parse_flag_from_env(key, default=False):
try:
value = os.environ[key]
except KeyError:
# 如果 KEY 未设置,则使用默认值 `default`
_value = default
else:
# 如果 KEY 已设置,则尝试将其转换为 True 或 False
try:
_value = strtobool(value)
except ValueError:
# 如果值不是 `yes` 或 `no`,则抛出异常
raise ValueError(f"If set, {key} must be yes or no.")
return _value
# 从环境变量中解析整数类型的值
def parse_int_from_env(key, default=None):
try:
value = os.environ[key]
except KeyError:
_value = default
else:
try:
_value = int(value)
except ValueError:
# 如果值不是整数,则抛出异常
raise ValueError(f"If set, {key} must be a int.")
return _value
# 根据环境变量 `RUN_SLOW` 解析是否运行慢速测试的标志
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
# 根据环境变量 `RUN_PT_TF_CROSS_TESTS` 解析是否运行 PyTorch 和 TensorFlow 交叉测试的标志
_run_pt_tf_cross_tests = parse_flag_from_env("RUN_PT_TF_CROSS_TESTS", default=True)
# 根据环境变量 `RUN_PT_FLAX_CROSS_TESTS` 解析是否运行 PyTorch 和 Flax 交叉测试的标志
_run_pt_flax_cross_tests = parse_flag_from_env("RUN_PT_FLAX_CROSS_TESTS", default=True)
# 根据环境变量 `RUN_CUSTOM_TOKENIZERS` 解析是否运行自定义分词器测试的标志
_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
# 根据环境变量 `HUGGINGFACE_CO_STAGING` 解析是否运行在 Hugging Face CO 预发布环境中的标志
_run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False)
# 根据环境变量 `TF_GPU_MEMORY_LIMIT` 解析 TensorFlow GPU 内存限制的值
_tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None)
# 根据环境变量 `RUN_PIPELINE_TESTS` 解析是否运行管道测试的标志
_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True)
# 根据环境变量 `RUN_TOOL_TESTS` 解析是否运行工具测试的标志
_run_tool_tests = parse_flag_from_env("RUN_TOOL_TESTS", default=False)
# 根据环境变量 `RUN_THIRD_PARTY_DEVICE_TESTS` 解析是否运行第三方设备测试的标志
_run_third_party_device_tests = parse_flag_from_env("RUN_THIRD_PARTY_DEVICE_TESTS", default=False)
# 函数装饰器,用于标记 PT+TF 交叉测试
def is_pt_tf_cross_test(test_case):
"""
Decorator marking a test as a test that control interactions between PyTorch and TensorFlow.
PT+TF tests are skipped by default and we can run only them by setting RUN_PT_TF_CROSS_TESTS environment variable
to a truthy value and selecting the is_pt_tf_cross_test pytest mark.
"""
# 如果未设置环境变量 `RUN_PT_TF_CROSS_TESTS` 或者当前环境中没有安装 PyTorch 或 TensorFlow,
# 则跳过 PT+TF 测试
if not _run_pt_tf_cross_tests or not is_torch_available() or not is_tf_available():
return unittest.skip("test is PT+TF test")(test_case)
else:
# 尝试导入 pytest 模块,避免在主库中硬编码依赖 pytest
try:
import pytest
# 如果导入失败,返回原始的 test_case
except ImportError:
return test_case
# 如果导入成功,应用 pytest.mark.is_pt_tf_cross_test() 装饰器到 test_case 上
else:
return pytest.mark.is_pt_tf_cross_test()(test_case)
# 标记一个测试用例为控制 PyTorch 和 Flax 交互的测试的装饰器
PT+FLAX 测试默认情况下会被跳过,只有当设置了环境变量 RUN_PT_FLAX_CROSS_TESTS 为真值并且选择了 is_pt_flax_cross_test pytest 标记时才会运行。
def is_pt_flax_cross_test(test_case):
if not _run_pt_flax_cross_tests or not is_torch_available() or not is_flax_available():
# 如果不满足运行条件(未设置环境变量或者没有可用的 PyTorch 或 Flax),则跳过测试
return unittest.skip("test is PT+FLAX test")(test_case)
else:
try:
import pytest # 我们不需要在主库中强制依赖 pytest
except ImportError:
return test_case
else:
# 使用 pytest 的 is_pt_flax_cross_test 标记来标记测试用例
return pytest.mark.is_pt_flax_cross_test()(test_case)
# 标记一个测试用例为在 staging 环境下运行的测试的装饰器
这些测试将在 huggingface.co 的 staging 环境下运行,而不是真实的模型中心。
def is_staging_test(test_case):
if not _run_staging:
# 如果不运行 staging 测试,则跳过测试
return unittest.skip("test is staging test")(test_case)
else:
try:
import pytest # 我们不需要在主库中强制依赖 pytest
except ImportError:
return test_case
else:
# 使用 pytest 的 is_staging_test 标记来标记测试用例
return pytest.mark.is_staging_test()(test_case)
# 标记一个测试用例为 pipeline 测试的装饰器
如果未将 RUN_PIPELINE_TESTS 设置为真值,则这些测试将被跳过。
def is_pipeline_test(test_case):
if not _run_pipeline_tests:
# 如果不运行 pipeline 测试,则跳过测试
return unittest.skip("test is pipeline test")(test_case)
else:
try:
import pytest # 我们不需要在主库中强制依赖 pytest
except ImportError:
return test_case
else:
# 使用 pytest 的 is_pipeline_test 标记来标记测试用例
return pytest.mark.is_pipeline_test()(test_case)
# 标记一个测试用例为工具测试的装饰器
如果未将 RUN_TOOL_TESTS 设置为真值,则这些测试将被跳过。
def is_tool_test(test_case):
if not _run_tool_tests:
# 如果不运行工具测试,则跳过测试
return unittest.skip("test is a tool test")(test_case)
else:
try:
import pytest # 我们不需要在主库中强制依赖 pytest
except ImportError:
return test_case
else:
# 使用 pytest 的 is_tool_test 标记来标记测试用例
return pytest.mark.is_tool_test()(test_case)
# 标记一个测试用例为慢速测试的装饰器
慢速测试默认情况下会被跳过。设置 RUN_SLOW 环境变量为真值以运行这些测试。
def slow(test_case):
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
# 标记一个测试用例为太慢测试的装饰器
太慢的测试在修复过程中会被跳过。不应将任何测试标记为 "tooslow",因为这些测试不会被 CI 测试。
def tooslow(test_case):
return unittest.skip("test is too slow")(test_case)
# 标记一个测试用例为自定义分词器测试的装饰器
"""
自定义分词器需要额外的依赖项,默认情况下会被跳过。将环境变量 RUN_CUSTOM_TOKENIZERS
设置为真值,以便运行它们。
"""
# 返回一个装饰器,根据 _run_custom_tokenizers 的真假决定是否跳过测试用例
return unittest.skipUnless(_run_custom_tokenizers, "test of custom tokenizers")(test_case)
# 装饰器,用于标记需要 BeautifulSoup4 的测试用例。在未安装 BeautifulSoup4 时跳过这些测试。
def require_bs4(test_case):
return unittest.skipUnless(is_bs4_available(), "test requires BeautifulSoup4")(test_case)
# 装饰器,用于标记需要 GaLore 的测试用例。在未安装 GaLore 时跳过这些测试。
def require_galore_torch(test_case):
return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)
# 装饰器,用于标记需要 OpenCV 的测试用例。在未安装 OpenCV 时跳过这些测试。
def require_cv2(test_case):
return unittest.skipUnless(is_cv2_available(), "test requires OpenCV")(test_case)
# 装饰器,用于标记需要 Levenshtein 的测试用例。在未安装 Levenshtein 时跳过这些测试。
def require_levenshtein(test_case):
return unittest.skipUnless(is_levenshtein_available(), "test requires Levenshtein")(test_case)
# 装饰器,用于标记需要 NLTK 的测试用例。在未安装 NLTK 时跳过这些测试。
def require_nltk(test_case):
return unittest.skipUnless(is_nltk_available(), "test requires NLTK")(test_case)
# 装饰器,用于标记需要 accelerate 的测试用例。在未安装 accelerate 时跳过这些测试。
def require_accelerate(test_case):
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
# 装饰器,用于标记需要 fsdp 的测试用例。在未安装 fsdp 或版本不符合要求时跳过这些测试。
def require_fsdp(test_case, min_version: str = "1.12.0"):
return unittest.skipUnless(is_fsdp_available(min_version), f"test requires torch version >= {min_version}")(test_case)
# 装饰器,用于标记需要 g2p_en 的测试用例。在未安装 SentencePiece 时跳过这些测试。
def require_g2p_en(test_case):
return unittest.skipUnless(is_g2p_en_available(), "test requires g2p_en")(test_case)
# 装饰器,用于标记需要 safetensors 的测试用例。在未安装 safetensors 时跳过这些测试。
def require_safetensors(test_case):
return unittest.skipUnless(is_safetensors_available(), "test requires safetensors")(test_case)
# 装饰器,用于标记需要 rjieba 的测试用例。在未安装 rjieba 时跳过这些测试。
def require_rjieba(test_case):
return unittest.skipUnless(is_rjieba_available(), "test requires rjieba")(test_case)
# 装饰器,用于标记需要 jieba 的测试用例。在未安装 jieba 时跳过这些测试。
def require_jieba(test_case):
return unittest.skipUnless(is_jieba_available(), "test requires jieba")(test_case)
# 装饰器,用于标记需要 jinja 的测试用例。在此处仅声明函数,实际装饰逻辑未提供。
def require_jinja(test_case):
# Placeholder for decorator marking tests requiring Jinja
pass
# 使用装饰器标记一个需要 jinja 的测试用例。如果 jinja 没有安装,则跳过这些测试。
"""
使用 unittest.skipUnless 函数来动态地装饰测试用例,只有在 jinja 可用时才运行该测试用例。
如果 is_jinja_available() 函数返回 True,则装饰器返回一个可用于跳过测试的装饰器函数,否则返回 None。
"""
return unittest.skipUnless(is_jinja_available(), "test requires jinja")(test_case)
# 根据条件判断是否加载 tf2onnx
def require_tf2onnx(test_case):
return unittest.skipUnless(is_tf2onnx_available(), "test requires tf2onnx")(test_case)
# 根据条件判断是否加载 ONNX
def require_onnx(test_case):
return unittest.skipUnless(is_onnx_available(), "test requires ONNX")(test_case)
# 根据条件判断是否加载 Timm
def require_timm(test_case):
"""
Decorator marking a test that requires Timm.
These tests are skipped when Timm isn't installed.
"""
return unittest.skipUnless(is_timm_available(), "test requires Timm")(test_case)
# 根据条件判断是否加载 NATTEN
def require_natten(test_case):
"""
Decorator marking a test that requires NATTEN.
These tests are skipped when NATTEN isn't installed.
"""
return unittest.skipUnless(is_natten_available(), "test requires natten")(test_case)
# 根据条件判断是否加载 PyTorch
def require_torch(test_case):
"""
Decorator marking a test that requires PyTorch.
These tests are skipped when PyTorch isn't installed.
"""
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
# 根据条件判断是否加载 Flash Attention
def require_flash_attn(test_case):
"""
Decorator marking a test that requires Flash Attention.
These tests are skipped when Flash Attention isn't installed.
"""
return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case)
# 根据条件判断是否加载 PyTorch's SDPA
def require_torch_sdpa(test_case):
"""
Decorator marking a test that requires PyTorch's SDPA.
These tests are skipped when requirements are not met (torch version).
"""
return unittest.skipUnless(is_torch_sdpa_available(), "test requires PyTorch SDPA")(test_case)
# 根据条件判断是否加载 HF token
def require_read_token(fn):
"""
A decorator that loads the HF token for tests that require to load gated models.
"""
token = os.getenv("HF_HUB_READ_TOKEN")
@wraps(fn)
def _inner(*args, **kwargs):
with patch("huggingface_hub.utils._headers.get_token", return_value=token):
return fn(*args, **kwargs)
return _inner
# 根据条件判断是否加载 PEFT
def require_peft(test_case):
"""
Decorator marking a test that requires PEFT.
These tests are skipped when PEFT isn't installed.
"""
return unittest.skipUnless(is_peft_available(), "test requires PEFT")(test_case)
# 根据条件判断是否加载 Torchvision
def require_torchvision(test_case):
"""
Decorator marking a test that requires Torchvision.
These tests are skipped when Torchvision isn't installed.
"""
return unittest.skipUnless(is_torchvision_available(), "test requires Torchvision")(test_case)
# 根据条件判断是否加载 PyTorch 或 TensorFlow
def require_torch_or_tf(test_case):
"""
Decorator marking a test that requires PyTorch or TensorFlow.
These tests are skipped when neither PyTorch nor TensorFlow is installed.
"""
return unittest.skipUnless(is_torch_available() or is_tf_available(), "test requires PyTorch or TensorFlow")(
test_case
)
# 根据条件判断是否加载 Intel Extension for PyTorch
def require_intel_extension_for_pytorch(test_case):
"""
Decorator marking a test that requires Intel Extension for PyTorch.
"""
# 注释部分未提供
pass
# 当未安装Intel Extension for PyTorch或者其版本与当前PyTorch版本不匹配时,跳过这些测试。
"""
返回一个装饰器,用于根据条件跳过测试。
装饰器检查是否可用Intel Extension for PyTorch(IPEX)。
如果不可用或版本不匹配,则跳过测试,并提供相应的提示信息。
参考链接:https://github.com/intel/intel-extension-for-pytorch
"""
return unittest.skipUnless(
is_ipex_available(),
"test requires Intel Extension for PyTorch to be installed and match current PyTorch version, see"
" https://github.com/intel/intel-extension-for-pytorch",
)(test_case)
# 装饰器,用于标记一个测试需要 TensorFlow probability
def require_tensorflow_probability(test_case):
# 返回一个装饰器,其功能是当 TensorFlow probability 未安装时跳过测试
return unittest.skipUnless(is_tensorflow_probability_available(), "test requires TensorFlow probability")(
test_case
)
# 装饰器,用于标记一个测试需要 torchaudio
def require_torchaudio(test_case):
# 返回一个装饰器,其功能是当 torchaudio 未安装时跳过测试
return unittest.skipUnless(is_torchaudio_available(), "test requires torchaudio")(test_case)
# 装饰器,用于标记一个测试需要 TensorFlow
def require_tf(test_case):
# 返回一个装饰器,其功能是当 TensorFlow 未安装时跳过测试
return unittest.skipUnless(is_tf_available(), "test requires TensorFlow")(test_case)
# 装饰器,用于标记一个测试需要 JAX & Flax
def require_flax(test_case):
# 返回一个装饰器,其功能是当 JAX 或 Flax 未安装时跳过测试
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
# 装饰器,用于标记一个测试需要 SentencePiece
def require_sentencepiece(test_case):
# 返回一个装饰器,其功能是当 SentencePiece 未安装时跳过测试
return unittest.skipUnless(is_sentencepiece_available(), "test requires SentencePiece")(test_case)
# 装饰器,用于标记一个测试需要 Sacremoses
def require_sacremoses(test_case):
# 返回一个装饰器,其功能是当 Sacremoses 未安装时跳过测试
return unittest.skipUnless(is_sacremoses_available(), "test requires Sacremoses")(test_case)
# 装饰器,用于标记一个测试需要 Seqio
def require_seqio(test_case):
# 返回一个装饰器,其功能是当 Seqio 未安装时跳过测试
return unittest.skipUnless(is_seqio_available(), "test requires Seqio")(test_case)
# 装饰器,用于标记一个测试需要 Scipy
def require_scipy(test_case):
# 返回一个装饰器,其功能是当 Scipy 未安装时跳过测试
return unittest.skipUnless(is_scipy_available(), "test requires Scipy")(test_case)
# 装饰器,用于标记一个测试需要 声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/908516
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。