赞
踩
在深度学习中,模型的保存和加载是非常重要的环节。不同的格式有不同的特点和适用场景。本文将为新手朋友们介绍几种常见的模型格式,包括它们的简介、保存方式、加载方式、优缺点以及应用场景。
简介:PyTorch 的默认保存格式,灵活支持保存整个模型、模型的权重和优化器状态。
保存方式:
import torch
torch.save(model.state_dict(), 'model.pth')
加载方式:
model.load_state_dict(torch.load('model.pth'))
model.eval()
部署代码:
from flask import Flask, request, jsonify app = Flask(__name__) @app.route('/predict', methods=['POST']) def predict(): data = request.json text = data['text'] inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predictions = torch.argmax(logits, dim=-1) return jsonify({'prediction': predictions.item()}) if __name__ == '__main__': from transformers import BertTokenizer, BertForSequenceClassification tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = BertForSequenceClassification.from_pretrained('bert-base-uncased') model.load_state_dict(torch.load('model.pth')) model.eval() app.run(host='0.0.0.0', port=5000)
优点:
.pth
和 .pt
是 PyTorch 的原生格式。缺点:
应用场景:
简介:TensorFlow 和 Keras 的保存格式,支持保存模型的权重、架构和优化器状态。
保存方式:
model.save('model.h5')
加载方式:
from tensorflow.keras.models import load_model
model = load_model('model.h5')
部署代码:
from flask import Flask, request, jsonify import tensorflow as tf app = Flask(__name__) @app.route('/predict', methods=['POST']) def predict(): data = request.json text = data['text'] inputs = tokenizer(text, return_tensors='tf', truncation=True, padding=True, max_length=512) outputs = model(inputs) logits = outputs.logits predictions = tf.argmax(logits, axis=-1) return jsonify({'prediction': int(predictions.numpy()[0])}) if __name__ == '__main__': from transformers import BertTokenizer, TFBertForSequenceClassification tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased') model.load_weights('model.h5') app.run(host='0.0.0.0', port=5000)
优点:
缺点:
.h5
文件包含了完整的模型架构和权重。应用场景:
简介:开放格式,旨在实现不同深度学习框架之间的互操作性。
保存方式:
import torch.onnx
torch.onnx.export(model, dummy_input, 'model.onnx')
加载方式:
import onnx
import onnxruntime as ort
onnx_model = onnx.load('model.onnx')
ort_session = ort.InferenceSession('model.onnx')
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
outputs = ort_session.run(None, {ort_session.get_inputs()[0].name: to_numpy(dummy_input)})
部署代码:
from flask import Flask, request, jsonify import onnxruntime as ort import numpy as np app = Flask(__name__) @app.route('/predict', methods=['POST']) def predict(): data = request.json text = data['text'] inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512) ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(inputs['input_ids'])} ort_outs = ort_session.run(None, ort_inputs) predictions = np.argmax(ort_outs[0], axis=1) return jsonify({'prediction': int(predictions[0])}) if __name__ == '__main__': import onnx from transformers import BertTokenizer tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') ort_session = ort.InferenceSession('model.onnx') app.run(host='0.0.0.0', port=5000)
优点:
缺点:
应用场景:
简介:专门为移动和嵌入式设备设计的轻量级模型格式。
保存方式:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
加载方式:
import tensorflow as tf
interpreter = tf.lite.Interpreter(model_path='model.tflite')
interpreter.allocate_tensors()
部署代码:
TensorFlow Lite 模型主要用于移动设备和嵌入式设备,下面是一个简化的示例,展示如何在 Python 环境中进行推理:
import tensorflow as tf import numpy as np # 加载模型 interpreter = tf.lite.Interpreter(model_path='model.tflite') interpreter.allocate_tensors() # 获取模型输入和输出的详细信息 input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() # 准备输入数据 input_data = np.array([...], dtype=np.float32) # 根据模型输入需求准备数据 # 设置模型输入 interpreter.set_tensor(input_details[0]['index'], input_data) # 推理 interpreter.invoke() # 获取输出数据 output_data = interpreter.get_tensor(output_details[0]['index']) print(output_data)
优点:
缺点:
应用场景:
简介:苹果公司为 iOS 和 macOS 设备提供的模型格式。
保存方式:
import coremltools as ct
coreml_model = ct.convert(model)
coreml_model.save('model.mlmodel')
加载方式:
在 iOS/macOS 应用中使用 CoreML 框架加载。
部署代码:
CoreML 模型主要用于 iOS 和 macOS 应用开发,下面是一个简化的示例,展示如何在 Swift 中使用 CoreML 模型进行推理:
import CoreML
import Foundation
// 加载模型
let model = try! MyCoreMLModel(configuration: MLModelConfiguration())
// 准备输入数据
let input = MyCoreMLModelInput(text: "your input text")
// 获取模型预测结果
let prediction = try! model.prediction(input: input)
print(prediction.label)
优点:
缺点:
应用场景:
简介:百度开发的深度学习框架 PaddlePaddle 的保存格式。
保存方式:
import paddle
paddle.save(model.state_dict(), 'model.pdparams')
加载方式:
model.set_state_dict(paddle.load('model.pdparams'))
部署代码:
from flask import Flask, request, jsonify app = Flask(__name__) @app.route('/predict', methods=['POST']) def predict(): data = request.json text = data['text'] inputs = tokenizer(text, return_tensors='pd', truncation=True, padding=True, max_length=512) with paddle.no_grad(): outputs = model(**inputs) logits = outputs.logits predictions = paddle.argmax(logits, axis=-1) return jsonify({'prediction': predictions.item()}) if __name__ == '__main__': from paddlenlp.transformers import BertTokenizer, BertForSequenceClassification tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = BertForSequenceClassification.from_pretrained('bert-base-uncased') model.set_state_dict(paddle.load('model.pdparams')) model.eval() app.run(host='0.0.0.0', port=5000)
优点:
缺点:
应用场景:
简介:一种用于存储大型数据集的文件格式,Keras 默认支持这种格式。
保存方式:
model.save('model.h5')
加载方式:
from tensorflow.keras.models import load_model
model = load_model('model.h5')
部署代码:
与 TensorFlow/Keras 的 .h5
部署代码相同,参考 TensorFlow/Keras 部分的部署代码。
优点:
缺点:
应用场景:
简介:一种新型的格式,旨在提高模型保存和加载的安全性和速度。
保存方式:
from safetensors.torch import save_file
save_file(model.state_dict(), 'model.safetensors')
加载方式:
from safetensors.torch import load_file
state_dict = load_file('model.safetensors')
model.load_state_dict(state_dict)
部署代码:
与 PyTorch 部署代码类似,可使用 Flask 或其他框架创建 API 服务。
优点:
safetensors
格式不允许在加载模型时执行任意代码。缺点:
safetensors
库才能使用这种格式。应用场景:
下面是各深度学习模型保存和加载格式的汇总表,包括格式、简介、优点、缺点和应用场景:
格式 | 简介 | 优点 | 缺点 | 应用场景 |
---|---|---|---|---|
PyTorch (.pth, .pt) | PyTorch 的默认保存格式,支持保存整个模型、权重和优化器状态 | 高度灵活,支持复杂的模型和训练过程。与 PyTorch 框架紧密集成。 | 只能在 PyTorch 环境中加载和使用,限制了跨平台和跨框架的兼容性。 | 研究和开发环境,频繁保存和加载模型的场景 |
TensorFlow/Keras (.h5, SavedModel) | TensorFlow 和 Keras 的保存格式,支持保存模型的权重、架构和优化器状态 | 适用于 TensorFlow 和 Keras 环境,支持多种部署方式(如 TensorFlow Serving)。 | 模型文件较大,可能影响加载速度。 | 生产环境中的模型部署,与 TensorFlow 生态系统集成的应用 |
ONNX | 开放格式,实现不同深度学习框架之间的互操作性 | 跨平台兼容,支持多种深度学习框架。统一格式,简化了在不同框架之间转换模型的复杂性。 | 需要额外的工具链来转换和部署模型。 | 跨平台模型部署,在不同框架之间转换模型 |
TensorFlow Lite | 专为移动和嵌入式设备设计的轻量级模型格式 | 轻量级,适合资源受限的设备。快速加载和推理。 | 支持的操作有限,可能需要调整模型架构以适应 TensorFlow Lite 的限制。 | 移动设备上的应用,物联网和嵌入式设备 |
CoreML | 苹果公司为 iOS 和 macOS 设备提供的模型格式 | 与苹果生态系统深度集成,在 iOS 和 macOS 设备上运行非常高效。易于部署,适合苹果开发者。 | 仅限于苹果设备,无法在其他平台上运行。 | iOS 应用开发,macOS 应用开发 |
PaddlePaddle (.pdparams) | 百度开发的深度学习框架 PaddlePaddle 的保存格式 | 与 PaddlePaddle 框架集成,适用于百度生态系统。优化的中国市场支持。 | 只能在 PaddlePaddle 环境中加载和使用,限制了在其他深度学习框架中的兼容性。 | 使用百度深度学习工具的项目,在中国市场的应用 |
HDF5 (.h5) | 一种用于存储大型数据集的文件格式,Keras 默认支持这种格式 | 方便存储和管理大型数据集,HDF5 格式擅长处理大规模数据,并支持压缩和并行 I/O 操作。与 Keras 深度集成。 | 模型文件较大,包含了完整的模型架构和权重,加载速度可能较慢。 | Keras 环境下的模型存储和加载,需要保存大型模型的场景 |
SafeTensors | 一种新型格式,提高模型保存和加载的安全性和速度 | 安全性高,消除潜在执行代码风险。加载速度快,特别适用于大型模型。 | 需要额外的库支持,必须安装 safetensors 库才能使用这种格式。 | 需要高安全性和快速加载的环境,大型模型的存储和部署 |
希望这张表格能够帮助新手朋友们更好地理解不同格式的特点,并根据自己的需求选择合适的格式来保存和部署模型。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。