赞
踩
目录
在大型深度学习模型的上下文中,.safetensors
、.bin
和 .pth
ckpt
文件的用途和区别如下:
.safetensors
文件:
.safetensors
文件,例如 safetensors.torch.load_file()
函数。ckpt
文件:
.bin
文件:
.bin
文件有时用于存储模型权重或其他二进制数据,但并不特指PyTorch的官方标准格式。.bin
扩展名,加载时需要自定义逻辑读取和应用这些权重到模型结构中。.pth
文件:
state_dict
,包含了模型的所有可学习参数,或者整个模型(包括结构和参数)。torch.load()
函数直接加载 .pth
文件,并通过调用 model.load_state_dict()
将加载的字典应用于模型实例。总结起来:
.safetensors
侧重于安全性和效率,适合于那些希望快速部署且对安全有较高要求的场景,尤其在Hugging Face生态中。.ckpt
文件是 PyTorch Lightning 框架采用的模型存储格式,它不仅包含了模型参数,还包括优化器状态以及可能的训练元数据信息,使得用户可以无缝地恢复训练或执行推理。.bin
文件不是标准化的模型保存格式,但在某些情况下可用于存储原始二进制权重数据,加载时需额外处理。.pth
是PyTorch的标准模型保存格式,方便模型的持久化和复用,支持完整模型结构和参数的保存与恢复。- # 用SDXL举例
- import torch
- from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
- from huggingface_hub import hf_hub_download
- from safetensors.torch import load_file
-
- base = "stabilityai/stable-diffusion-xl-base-1.0"
- repo = "ByteDance/SDXL-Lightning"
- ckpt = "/home/bino/svul/models/sdxl/sdxl_lightning_2step_unet.safetensors" # Use the correct ckpt for your step setting!
-
- # Load model.
- unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
- unet.load_state_dict(load_file(ckpt, device="cuda"))
- # unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
- pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
-
- # Ensure sampler uses "trailing" timesteps.
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
-
- # Ensure using the same inference steps as the loaded model and CFG set to 0.
- pipe("A girl smiling", num_inference_steps=4, guidance_scale=0).images[0].save("output.png")
- # 保存模型状态字典
- torch.save(model.state_dict(), "model.pth")
-
- # 加载模型状态字典到已有模型结构中
- model = TheModelClass(*args, **kwargs)
- model.load_state_dict(torch.load("model.pth"))
-
- # 或者保存整个模型,包括结构
- torch.save(model, "model.pth")
-
- # 加载整个模型
- model = torch.load("model.pth", map_location=device)
- import pytorch_lightning as pl
-
- # 定义一个 PyTorch Lightning 训练模块
- class MyLightningModel(pl.LightningModule):
- def __init__(self):
- super().__init__()
- self.linear_layer = nn.Linear(10, 1)
- self.loss_function = nn.MSELoss()
-
- def forward(self, inputs):
- return self.linear_layer(inputs)
-
- def training_step(self, batch, batch_idx):
- features, targets = batch
- predictions = self(features)
- loss = self.loss_function(predictions, targets)
- self.log('train_loss', loss)
- return loss
-
- # 初始化 PyTorch Lightning 模型
- lightning_model = MyLightningModel()
-
- # 配置 ModelCheckpoint 回调以定期保存最佳模型至 .ckpt 文件
- checkpoint_callback = pl.callbacks.ModelCheckpoint(
- monitor='val_loss',
- filename='best-model-{epoch:02d}-{val_loss:.2f}',
- save_top_k=3,
- mode='min'
- )
-
- # 创建训练器并启动模型训练
- trainer = pl.Trainer(
- callbacks=[checkpoint_callback],
- max_epochs=10
- )
- trainer.fit(lightning_model)
-
- # 从 .ckpt 文件加载最优模型权重
- best_model = MyLightningModel.load_from_checkpoint(checkpoint_path='best-model.ckpt')
-
- # 使用加载的 .ckpt 文件中的模型进行预测
- sample_input = torch.randn(1, 10)
- predicted_output = best_model(sample_input)
- print(predicted_output)
在此示例中,我们首先定义了一个 PyTorch Lightning 模块,该模块集成了模型训练的逻辑。然后,我们配置了 ModelCheckpoint 回调函数,在训练过程中按照验证损失自动保存最佳模型至 .ckpt 文件。接着,我们展示了如何加载 .ckpt 文件中的最优模型权重,并利用加载后的模型对随机输入数据进行预测,同样输出预测结果。值得注意的是,由于 .ckpt 文件完整记录了训练状态,它在实际应用中常被用于模型微调和进一步训练。
如果.bin
文件是纯二进制权重文件,加载时需要知道模型结构并且手动将权重加载到对应的层中,例如:
- # 假设已经从.bin文件中读取到了模型权重数据
- weights_data = load_binary_weights("weights.bin")
-
- # 手动初始化模型并加载权重
- model = TheModelClass(*args, **kwargs)
- for name, param in model.named_parameters():
- if name in weights_mapping: # 需要预先知道权重映射关系
- param.data.copy_(weights_data[weights_mapping[name]])
由于 PyTorch Lightning 模型本身就是 PyTorch 模型,因此不存在严格意义上的转换过程。你可以直接通过 LightningModule
中定义的神经网络层来进行保存和加载,就像普通的 PyTorch 模型一样:
- # 假设 model 是一个 PyTorch Lightning 模型实例
- model = MyLightningModel()
-
- # 保存模型权重
- torch.save(model.state_dict(), 'lightning_model.pth')
-
- # 加载到一个新的 PyTorch 模型实例
- new_model = MyLightningModel()
- new_model.load_state_dict(torch.load('lightning_model.pth'))
-
- # 或者加载到一个普通的 PyTorch Module 实例(假设结构一致)
- plain_pytorch_model = MyPlainPytorchModel()
- plain_pytorch_model.load_state_dict(torch.load('lightning_model.pth'))
转换后的模型在stable-diffussion-webui中使用过没有问题,不知道有没有错误,或者没转换成功
- import torch
- import os
- import safetensors
- from typing import Dict, List, Optional, Set, Tuple
- from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file
-
- def ckpt2safetensors():
- loaded = torch.load('v1-5-pruned-emaonly.ckpt')
- if "state_dict" in loaded:
- loaded = loaded["state_dict"]
- safetensors.torch.save_file(loaded, 'v1-5-pruned-emaonly.safetensors')
-
- def st2ckpt():
- # 加载 .safetensors 文件
- data = safetensors.torch.load_file('v1-5-pruned-emaonly.safetensors.bk')
- data["state_dict"] = data
- # 将数据保存为 .ckpt 文件
- torch.save(data, os.path.splitext('v1-5-pruned-emaonly.safetensors')[0] + '.ckpt')
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。