Python AST `ast` 模块:编写自定义代码转换器

好的,咱们今天来聊聊Python AST ast 模块,以及如何用它来编写自定义代码转换器。这玩意儿听起来很高大上,但其实就像玩乐高积木一样,只要掌握了规则,就能拼出各种奇形怪状的东西。准备好了吗?咱们开始!

开场白:代码的“CT”扫描仪

各位观众,有没有想过,我们写的Python代码,在被Python解释器执行之前,到底经历了什么?它可不是直接就被“duang”的一下运行起来的。实际上,它会被“解剖”成一种叫做抽象语法树(Abstract Syntax Tree,简称AST)的结构。

你可以把AST想象成代码的“CT”扫描仪。它能把代码的每个细节都看得清清楚楚,比如有哪些变量、哪些函数、哪些循环等等。而ast模块,就是Python提供给我们的一个工具,让我们能够访问和操作这个“CT”扫描结果。

第一部分:AST是个什么玩意儿?

1.1 抽象语法树的本质

简单来说,AST是一种树状结构,用来表示代码的语法结构。每个节点代表代码中的一个语法元素,比如变量、运算符、函数调用等等。

举个例子,假设我们有这样一行简单的Python代码:

x = 1 + 2

这行代码对应的AST大概是这个样子(简化版):

Assign(
    targets=[Name(id='x', ctx=Store())],
    value=BinOp(
        left=Constant(value=1),
        op=Add(),
        right=Constant(value=2)
    )
)

看起来有点吓人?别怕,我们来拆解一下:

  • Assign:表示赋值语句。
  • targets:表示赋值的目标,这里是变量x
  • Name:表示一个变量名。id='x'表示变量名是xctx=Store()表示这是一个赋值操作(存储)。
  • value:表示赋的值,这里是一个二元运算。
  • BinOp:表示二元运算。
  • left:表示左边的操作数,这里是常量1
  • op:表示运算符,这里是加号Add()
  • right:表示右边的操作数,这里是常量2
  • Constant:表示一个常量。

可以看到,AST把代码拆解成了非常细粒度的语法元素,并且用树状结构组织起来。

1.2 ast 模块的核心功能

ast模块主要提供了以下几个核心功能:

  • ast.parse(source):将Python代码字符串解析成AST对象。
  • ast.dump(node):将AST对象转换成字符串,方便查看。
  • ast.NodeVisitor:一个基类,用于遍历AST节点。
  • ast.NodeTransformer:一个基类,用于修改AST节点。

第二部分:AST的基本操作:查看和遍历

2.1 初探 ast.parseast.dump

我们先来用ast.parse把代码解析成AST,然后用ast.dump打印出来看看:

import ast

code = "x = 1 + 2"
tree = ast.parse(code)
print(ast.dump(tree, indent=4))

运行结果大概是这样:

Module(
    body=[
        Assign(
            targets=[
                Name(id='x', ctx=Store())
            ],
            value=BinOp(
                left=Constant(value=1),
                op=Add(),
                right=Constant(value=2)
            ),
            type_comment=None
        )
    ],
    type_ignores=[]
)

可以看到,ast.dump把AST的结构以一种嵌套的形式打印了出来。indent=4表示缩进4个空格,让结构更清晰。

2.2 使用 ast.NodeVisitor 遍历AST

ast.NodeVisitor是一个非常有用的类,它可以让我们方便地遍历AST的每个节点。我们需要做的就是继承ast.NodeVisitor,然后重写一些visit_XXX方法,其中XXX是AST节点的类型。

比如,如果我们想遍历AST,并打印出所有的变量名,可以这样做:

import ast

class VariableNameVisitor(ast.NodeVisitor):
    def visit_Name(self, node):
        print(f"发现变量名: {node.id}")

code = "x = 1 + 2ny = x * 3"
tree = ast.parse(code)
visitor = VariableNameVisitor()
visitor.visit(tree)

运行结果:

发现变量名: x
发现变量名: x
发现变量名: y

解释一下:

  • 我们定义了一个VariableNameVisitor类,继承自ast.NodeVisitor
  • 我们重写了visit_Name方法。当遍历到Name节点时,这个方法会被调用。
  • visit_Name方法中,我们打印出变量名node.id
  • 我们创建了一个VariableNameVisitor实例,并调用它的visit方法,传入AST对象tree

2.3 更复杂的遍历:统计函数调用次数

我们再来一个稍微复杂一点的例子:统计代码中每个函数的调用次数。

import ast
import collections

class FunctionCallCounter(ast.NodeVisitor):
    def __init__(self):
        self.counts = collections.Counter()

    def visit_Call(self, node):
        if isinstance(node.func, ast.Name):
            self.counts[node.func.id] += 1
        elif isinstance(node.func, ast.Attribute):
            # 处理像 obj.method() 这样的调用
            self.counts[node.func.attr] += 1
        self.generic_visit(node)  # 确保继续遍历子节点

code = """
def foo():
    print("Hello")

foo()
foo()
bar = "World".lower()
"""

tree = ast.parse(code)
counter = FunctionCallCounter()
counter.visit(tree)

for func, count in counter.counts.items():
    print(f"函数 {func} 被调用了 {count} 次")

运行结果:

函数 foo 被调用了 2 次
函数 print 被调用了 1 次
函数 lower 被调用了 1 次

解释一下:

  • 我们定义了一个FunctionCallCounter类,继承自ast.NodeVisitor
  • 我们使用collections.Counter来存储每个函数的调用次数。
  • 我们重写了visit_Call方法。当遍历到Call节点时,这个方法会被调用。
  • visit_Call方法中,我们判断函数调用方式,然后增加对应函数的计数。
  • self.generic_visit(node)这句很重要,它会继续遍历Call节点的子节点,确保我们能找到所有的函数调用。

第三部分:AST的修改:代码转换的魔法

3.1 ast.NodeTransformer 的威力

ast.NodeTransformer是一个更强大的类,它可以让我们修改AST的节点,从而实现代码转换。和ast.NodeVisitor类似,我们需要继承ast.NodeTransformer,然后重写一些visit_XXX方法。但是,visit_XXX方法需要返回一个新的节点,用来替换原来的节点。

3.2 简单的代码转换:将 x + 1 替换成 x + 2

我们先来一个简单的例子:将代码中所有的 x + 1 替换成 x + 2

import ast

class AddOneToTwo(ast.NodeTransformer):
    def visit_BinOp(self, node):
        if isinstance(node.left, ast.Name) and node.left.id == 'x' and 
           isinstance(node.op, ast.Add) and 
           isinstance(node.right, ast.Constant) and node.right.value == 1:
            return ast.BinOp(left=node.left, op=node.op, right=ast.Constant(value=2))
        return self.generic_visit(node)

code = "y = x + 1"
tree = ast.parse(code)
transformer = AddOneToTwo()
new_tree = transformer.visit(tree)
new_code = ast.unparse(new_tree)  # Python 3.9+

print(new_code)

运行结果:

y = x + 2

解释一下:

  • 我们定义了一个AddOneToTwo类,继承自ast.NodeTransformer
  • 我们重写了visit_BinOp方法。当遍历到BinOp节点时,这个方法会被调用。
  • visit_BinOp方法中,我们判断是否是 x + 1 这样的表达式。
  • 如果是,我们创建一个新的BinOp节点,将右边的操作数改成 2,然后返回这个新节点。
  • 如果不是,我们调用self.generic_visit(node),继续遍历子节点。
  • 我们使用ast.unparse将修改后的AST转换回代码字符串。(注意:ast.unparse是Python 3.9+才有的)

3.3 更复杂的代码转换:自动添加日志

我们再来一个更复杂的例子:在每个函数入口处自动添加日志。

import ast

class LogAdder(ast.NodeTransformer):
    def visit_FunctionDef(self, node):
        log_statement = ast.Expr(
            value=ast.Call(
                func=ast.Name(id='print', ctx=ast.Load()),
                args=[ast.Constant(value=f"Entering function {node.name}")],
                keywords=[]
            )
        )
        node.body.insert(0, log_statement)  # 在函数体开头插入日志语句
        return self.generic_visit(node)

code = """
def foo(a, b):
    return a + b

def bar():
    pass
"""

tree = ast.parse(code)
transformer = LogAdder()
new_tree = transformer.visit(tree)
new_code = ast.unparse(new_tree)

print(new_code)

运行结果:

def foo(a, b):
    print('Entering function foo')
    return a + b

def bar():
    print('Entering function bar')
    pass

解释一下:

  • 我们定义了一个LogAdder类,继承自ast.NodeTransformer
  • 我们重写了visit_FunctionDef方法。当遍历到FunctionDef节点时,这个方法会被调用。
  • visit_FunctionDef方法中,我们创建一个表示日志语句的AST节点。
  • 我们使用node.body.insert(0, log_statement)在函数体的开头插入日志语句。
  • 我们调用self.generic_visit(node),继续遍历子节点。

第四部分:实战案例:代码风格检查器

我们来做一个简单的代码风格检查器,检查函数名是否符合命名规范(只能包含小写字母和下划线)。

import ast

class FunctionNameChecker(ast.NodeVisitor):
    def __init__(self):
        self.errors = []

    def visit_FunctionDef(self, node):
        if not node.name.islower() or not all(c.isalnum() or c == '_' for c in node.name):
            self.errors.append(f"函数名 '{node.name}' 不符合命名规范")

code = """
def my_function():
    pass

def myFunction():
    pass

def _my_function():
    pass

def my_function123():
    pass

def MyFunction():
    pass
"""

tree = ast.parse(code)
checker = FunctionNameChecker()
checker.visit(tree)

for error in checker.errors:
    print(error)

运行结果:

函数名 'myFunction' 不符合命名规范
函数名 'MyFunction' 不符合命名规范

第五部分:总结与展望

今天我们一起学习了Python AST ast 模块的基本用法,包括AST的结构、ast.parseast.dumpast.NodeVisitorast.NodeTransformer。我们还通过几个实战案例,展示了如何使用AST来实现代码遍历、代码转换和代码检查。

AST的应用场景非常广泛,比如:

  • 代码分析工具:静态代码分析、代码复杂度分析、代码质量评估等。
  • 代码转换工具:代码优化、代码重构、代码混淆等。
  • 代码生成工具:从DSL(领域特定语言)生成Python代码。
  • 自动化测试工具:生成测试用例、修改测试代码等。

希望今天的分享能帮助大家打开AST的大门,探索更多有趣的玩法。记住,代码的世界就像乐高积木,只要掌握了规则,就能创造无限可能!

表格总结

功能 描述
ast.parse(code) 将Python代码字符串解析为AST对象。
ast.dump(node) 将AST对象转换为字符串,便于查看AST结构。
ast.NodeVisitor 用于遍历AST节点的基类。通过继承并重写visit_XXX方法,可以自定义遍历逻辑。
ast.NodeTransformer 用于修改AST节点的基类。通过继承并重写visit_XXX方法,可以替换AST节点,实现代码转换。
ast.unparse(tree) (Python 3.9+)将AST对象转换回Python代码字符串。

一些小提示

  • AST的结构非常复杂,需要耐心学习和理解。
  • 可以使用ast.dump来查看AST的结构,方便调试。
  • ast.NodeVisitorast.NodeTransformer是编写自定义代码转换器的核心。
  • self.generic_visit(node)非常重要,它可以确保遍历到所有的节点。
  • 多看一些开源项目的代码,学习他们是如何使用AST的。

好了,今天的讲座就到这里。希望大家都能成为AST大师,用代码改变世界!

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注