当前位置:   article > 正文

ast在python架构中的使用_ast import

ast import

AST学习
AST简介:
AST(Abstract syntac tree)是编译原理中的概念,是对源代码语法结构的一种抽象表示,它以树的形式表现编程语言的语法结构,树上的每个节点都表示源代码中的一种结构。

下面的代码展示了以demo.py中的ast语法,对source_code.py中的内容进行修改,并将修改后的内容转回源代码并写入到target_code.py中,这个过程可以作为客户化定制的内容。
(mmlab中的config机制,采用了另一种方式,并不对config文件的语法进行解析,而是基于base congfig 对个人的config进行merge和替换,得到最终的config,然后通过底层维护的字符串到类的映射拿到config中字符串字段中type的字符串,从而拿到类及其参数)
以语法规则进行解析和更改后,可以生成可执行的python文件(虽然mmlab中的config也是.py文件,但它只是个config而无实际意义)
在这里插入图片描述

demo.py

import ast
import astor

# source_file 是任何一个.py文件的路径

with open("./ast_learning/source_code.py", 'r', encoding='utf-8') as f:
	source_code = f.read()
tree = ast.parse(source_code)

import_nodes = []
empty_lines = []
for node in ast.walk(tree):
	if isinstance(node, ast.ClassDef) and node.name == 'Classification_2d':
		class_node = node
	elif isinstance(node, ast.Import) or isinstance(node, ast.ImportFrom):
		import_nodes.append(node)
	if isinstance(node, ast.Expr) and isinstance(node.value, ast.Str) and not node.value.s.strip():
		empty_lines.append(node.lineno)

copied_class_node = ast.copy_location(class_node, ast.ClassDef())
# 替换类节点中的__init__中的内容
for stmt in copied_class_node.body:
	if isinstance(stmt, ast.FunctionDef) and stmt.name == '__init__':
		for sub_stmt in stmt.body:
		# 遍历__init__中的所有操作(super,赋值等)
			if isinstance(sub_stmt, ast.Assign) and len(sub_stmt.targets) == 1 and isinstance(sub_stmt.targets[0], ast.Attribute) and sub_stmt.targets[0].attr == 'net':
				sub_stmt.value = ast.parse('models.convnext_large(pretrained=False)').body[0].value
				# 下面的方式会更改原来的sub_stmt.value 的 type 从_ast.Call object 变为 _ast.Name object 但 也是能用的
				# sub_stmt.value = ast.Name(id='models.resnet50(pretrained=False)', ctx=ast.Load(models.resnet50))
			if isinstance(sub_stmt, ast.Assign) and len(sub_stmt.targets) == 1 and isinstance(sub_stmt.targets[0], ast.Attribute) and sub_stmt.targets[0].attr == 'loss':
				sub_stmt.value = ast.parse('nn.CrossEntropyLoss').body[0].value
				# ast.parse不会改变node的type,
				# 几种其他方式的mode赋值
				# sub_stmt.value = ast.Name(id='nn.L1Loss', ctx=ast.Load()) # 会更改原本的value的type从_ast.Attribute object 变为_ast.Name object

code_tree=ast.Module(body=import_nodes+[copied_class_node])
# 四个空格作为每级缩进
copied_code = astor.to_source(code_tree, indent_with=' ' * 4)
with open("./ast_learning/target_code.py", 'w') as f:
	f.write(copied_code)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40

source_code.py

from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import torchvision.models as models
# import 的等级必须是models和nn
import torch.nn as nn

class Classification_2d(pl.LightningModule):
    def __init__(self, label_dict={},log_dir=''):
        super(Classification_2d, self).__init__()
        self.num_classes = len(label_dict)
        self.net=models.resnet18(pretrained=True)
        # resnet 系列
        self.fc = nn.Linear(self.net.fc.in_features, self.num_classes)
        self.net.fc = nn.Identity()
        
        self.loss=nn.L1Loss
        self.label_dict=label_dict
        self.label_to_name_dict={v:k for k,v in label_dict.items()}


        self.training_save=True
        self.log_dir=log_dir
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

target_code.py运行后的结果

from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import torchvision.models as models
import torch.nn as nn


class Classification_2d(pl.LightningModule):

    def __init__(self, label_dict={}, log_dir=''):
        super(Classification_2d, self).__init__()
        self.num_classes = len(label_dict)
        self.net = models.convnext_large(pretrained=False)
        self.fc = nn.Linear(self.net.fc.in_features, self.num_classes)
        self.net.fc = nn.Identity()
        self.loss = nn.CrossEntropyLoss
        self.label_dict = label_dict
        self.label_to_name_dict = {v: k for k, v in label_dict.items()}
        self.training_save = True
        self.log_dir = log_dir

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号