赞
踩
AST学习
AST简介:
AST(Abstract syntac tree)是编译原理中的概念,是对源代码语法结构的一种抽象表示,它以树的形式表现编程语言的语法结构,树上的每个节点都表示源代码中的一种结构。
下面的代码展示了以demo.py中的ast语法,对source_code.py中的内容进行修改,并将修改后的内容转回源代码并写入到target_code.py中,这个过程可以作为客户化定制的内容。
(mmlab中的config机制,采用了另一种方式,并不对config文件的语法进行解析,而是基于base congfig 对个人的config进行merge和替换,得到最终的config,然后通过底层维护的字符串到类的映射拿到config中字符串字段中type的字符串,从而拿到类及其参数)
以语法规则进行解析和更改后,可以生成可执行的python文件(虽然mmlab中的config也是.py文件,但它只是个config而无实际意义)
demo.py
import ast import astor # source_file 是任何一个.py文件的路径 with open("./ast_learning/source_code.py", 'r', encoding='utf-8') as f: source_code = f.read() tree = ast.parse(source_code) import_nodes = [] empty_lines = [] for node in ast.walk(tree): if isinstance(node, ast.ClassDef) and node.name == 'Classification_2d': class_node = node elif isinstance(node, ast.Import) or isinstance(node, ast.ImportFrom): import_nodes.append(node) if isinstance(node, ast.Expr) and isinstance(node.value, ast.Str) and not node.value.s.strip(): empty_lines.append(node.lineno) copied_class_node = ast.copy_location(class_node, ast.ClassDef()) # 替换类节点中的__init__中的内容 for stmt in copied_class_node.body: if isinstance(stmt, ast.FunctionDef) and stmt.name == '__init__': for sub_stmt in stmt.body: # 遍历__init__中的所有操作(super,赋值等) if isinstance(sub_stmt, ast.Assign) and len(sub_stmt.targets) == 1 and isinstance(sub_stmt.targets[0], ast.Attribute) and sub_stmt.targets[0].attr == 'net': sub_stmt.value = ast.parse('models.convnext_large(pretrained=False)').body[0].value # 下面的方式会更改原来的sub_stmt.value 的 type 从_ast.Call object 变为 _ast.Name object 但 也是能用的 # sub_stmt.value = ast.Name(id='models.resnet50(pretrained=False)', ctx=ast.Load(models.resnet50)) if isinstance(sub_stmt, ast.Assign) and len(sub_stmt.targets) == 1 and isinstance(sub_stmt.targets[0], ast.Attribute) and sub_stmt.targets[0].attr == 'loss': sub_stmt.value = ast.parse('nn.CrossEntropyLoss').body[0].value # ast.parse不会改变node的type, # 几种其他方式的mode赋值 # sub_stmt.value = ast.Name(id='nn.L1Loss', ctx=ast.Load()) # 会更改原本的value的type从_ast.Attribute object 变为_ast.Name object code_tree=ast.Module(body=import_nodes+[copied_class_node]) # 四个空格作为每级缩进 copied_code = astor.to_source(code_tree, indent_with=' ' * 4) with open("./ast_learning/target_code.py", 'w') as f: f.write(copied_code)
source_code.py
from PIL import Image import torch from torch.utils.data import Dataset, DataLoader import pytorch_lightning as pl import torchvision.models as models # import 的等级必须是models和nn import torch.nn as nn class Classification_2d(pl.LightningModule): def __init__(self, label_dict={},log_dir=''): super(Classification_2d, self).__init__() self.num_classes = len(label_dict) self.net=models.resnet18(pretrained=True) # resnet 系列 self.fc = nn.Linear(self.net.fc.in_features, self.num_classes) self.net.fc = nn.Identity() self.loss=nn.L1Loss self.label_dict=label_dict self.label_to_name_dict={v:k for k,v in label_dict.items()} self.training_save=True self.log_dir=log_dir
target_code.py运行后的结果
from PIL import Image import torch from torch.utils.data import Dataset, DataLoader import pytorch_lightning as pl import torchvision.models as models import torch.nn as nn class Classification_2d(pl.LightningModule): def __init__(self, label_dict={}, log_dir=''): super(Classification_2d, self).__init__() self.num_classes = len(label_dict) self.net = models.convnext_large(pretrained=False) self.fc = nn.Linear(self.net.fc.in_features, self.num_classes) self.net.fc = nn.Identity() self.loss = nn.CrossEntropyLoss self.label_dict = label_dict self.label_to_name_dict = {v: k for k, v in label_dict.items()} self.training_save = True self.log_dir = log_dir
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。