赞
踩
本文以bert为例,对比了添加Lora模块前后的网络结构图
说明:
可参考的点:
tee ./config.json <<-'EOF' { "architectures": [ "BertForMaskedLM" ], "attention_probs_dropout_prob": 0.1, "directionality": "bidi", "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-12, "max_position_embeddings": 512, "model_type": "bert", "num_attention_heads": 12, "num_hidden_layers": 1, "pad_token_id": 0, "pooler_fc_size": 768, "pooler_num_attention_heads": 12, "pooler_num_fc_layers": 3, "pooler_size_per_head": 128, "pooler_type": "first_token_transform", "type_vocab_size": 2, "vocab_size": 21128 } EOF
tee bert_lora.py <<-'EOF' import time import os import torch import torchvision.models as models import torch.nn as nn import torch.nn.init as init import time import numpy as np from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType from torchviz import make_dot from torch.utils.tensorboard import SummaryWriter from torch._functorch.partitioners import draw_graph def onnx_infer_shape(onnx_path): import onnx onnx_model = onnx.load_model(onnx_path) new_onnx= onnx.shape_inference.infer_shapes(onnx_model) onnx.save_model(new_onnx, onnx_path) def get_model(): torch.manual_seed(1) from transformers import AutoModelForMaskedLM,BertConfig config=BertConfig.from_pretrained("./config.json") model = AutoModelForMaskedLM.from_config(config) return model,config def my_compiler(fx_module: torch.fx.GraphModule, _): draw_graph(fx_module, f"bert.{time.time()}.svg") return fx_module.forward if __name__ == "__main__": model,config=get_model() model.eval() input_tokens=torch.randint(0,config.vocab_size,(1,128)) # 一.原始模型 # 1.onnx可视化 torch.onnx.export(model,input_tokens, "bert_base.onnx", export_params=False, opset_version=11, do_constant_folding=True) onnx_infer_shape("bert_base.onnx") # 2.torchviz图 output = model(input_tokens) logits = output.logits viz = make_dot(logits, params=dict(model.named_parameters())) viz.render("bert_base", view=False) # 3.torch.fx可视化 compiled_model = torch.compile(model, backend=my_compiler) output = compiled_model(input_tokens) # 4.tensorboard可视化 writer = SummaryWriter('./runs') writer.add_graph(model, input_to_model = input_tokens,use_strict_trace=False) writer.close() # 二.Lora模型 peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=True, r=8, lora_alpha=32, target_modules=['intermediate.dense'], lora_dropout=0.1, ) lora_model = get_peft_model(model, peft_config) lora_model.eval() torch.onnx.export(lora_model,input_tokens, "bert_base_lora_inference_mode.onnx", export_params=False, opset_version=11, do_constant_folding=True) onnx_infer_shape("bert_base_lora_inference_mode.onnx") compiled_model = torch.compile(lora_model, backend=my_compiler) output = compiled_model(input_tokens) writer = SummaryWriter('./runs_lora') writer.add_graph(lora_model, input_to_model = input_tokens,use_strict_trace=False) writer.close() EOF # 安装依赖 apt install graphviz -y pip install torchviz pip install pydot # 运行测试程序 python bert_lora.py
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。