Python AST `ast` 模块:编写自定义代码转换器

好的,让我们来一场关于 Python AST (Abstract Syntax Tree) ast 模块的深度讲座,主题是编写自定义代码转换器。

各位同学,欢迎来到“代码炼金术”课堂!今天我们要学习的是如何将你的 Python 代码变成橡皮泥,想捏成什么形状就捏成什么形状!而我们使用的工具,就是 Python 的 ast 模块。

第一章:什么是 AST?为什么我们要关心它?

想象一下,你写了一段 Python 代码,计算机是怎么理解它的呢?不是直接“嗖”的一下就运行了,而是要经过一个“翻译”的过程。这个“翻译”的第一步,就是把你的代码变成一棵“抽象语法树”,也就是 AST。

AST 就像是代码的骨架,它用一种树状结构,清晰地表达了代码的语法结构。 举个例子, 1 + 2 * 3 这行代码,对应的 AST 可能是这样的(简化版):

      +
     / 
    1   *
       / 
      2   3

看到了吗?加法是树的根,乘法是加法的右子树。AST 清楚地表达了运算的优先级。

为什么要关心 AST 呢?因为它给了我们一个机会,在代码运行之前,对代码进行“动手术”。我们可以分析 AST,修改 AST,甚至生成全新的 AST。这就是代码转换器的核心思想。

想象一下,你可以:

  • 自动优化代码: 比如,把 x + 0 变成 x,或者把 if True: 里的代码直接拿出来。
  • 静态代码分析: 检查代码风格,发现潜在的 bug。
  • 代码生成: 把一种语言的代码转换成另一种语言的代码(比如,把 Python 转换成 JavaScript)。
  • 自定义代码规范: 强制团队遵循统一的代码风格。

第二章:ast 模块:我们的魔法工具箱

Python 的 ast 模块,就是我们用来操作 AST 的魔法工具箱。它提供了以下几个核心功能:

  • ast.parse(source) 把 Python 代码字符串转换成 AST 对象。
  • ast.dump(node) 把 AST 对象转换成字符串,方便我们查看 AST 的结构。
  • ast.NodeVisitor 一个基类,用于遍历 AST,并执行自定义的操作。
  • ast.NodeTransformer 一个基类,用于修改 AST。
  • ast.unparse(node) (Python 3.9+) 把 AST 对象转换回 Python 代码字符串。

让我们先来玩个简单的例子,看看 ast.parseast.dump 的威力:

import ast

code = "x = 1 + 2 * 3"
tree = ast.parse(code)
print(ast.dump(tree, indent=4))

运行这段代码,你会看到类似这样的输出:

Module(
    body=[
        Assign(
            targets=[
                Name(id='x', ctx=Store())
            ],
            value=BinOp(
                left=Constant(value=1),
                op=Add(),
                right=BinOp(
                    left=Constant(value=2),
                    op=Mult(),
                    right=Constant(value=3))))
    ],
    type_ignores=[])

这就是 x = 1 + 2 * 3 这行代码的 AST 表示。 虽然看起来有点吓人,但仔细观察,你会发现它确实反映了代码的结构。Assign 表示赋值,Name 表示变量名,BinOp 表示二元运算,Constant 表示常量。

第三章:ast.NodeVisitor:AST 探险家

ast.NodeVisitor 就像一个 AST 探险家,它可以带我们遍历 AST 的每一个节点。 我们需要做的,就是继承 ast.NodeVisitor 类,并重写一些 visit_XXX 方法,来处理特定类型的节点。

举个例子,我们想统计一段代码中,所有变量的个数:

import ast

class VariableCounter(ast.NodeVisitor):
    def __init__(self):
        self.count = 0

    def visit_Name(self, node):
        self.count += 1
        # 继续遍历子节点
        self.generic_visit(node)

code = "x = 1 + y * z"
tree = ast.parse(code)
counter = VariableCounter()
counter.visit(tree)
print(f"变量的个数:{counter.count}")

在这个例子中:

  • 我们定义了一个 VariableCounter 类,继承自 ast.NodeVisitor
  • 我们重写了 visit_Name 方法,这个方法会在遍历到 Name 节点时被调用。
  • visit_Name 方法中,我们将计数器 count 加 1。
  • self.generic_visit(node) 保证了会继续遍历 node 的子节点。

visit_XXX 方法的命名规则很简单: visit_ 加上 AST 节点的类型名(比如,NameAssignBinOp)。

第四章:ast.NodeTransformer:AST 整容师

ast.NodeTransformerast.NodeVisitor 更强大,它可以修改 AST。 我们需要做的,也是继承 ast.NodeTransformer 类,并重写一些 visit_XXX 方法。 但是,visit_XXX 方法需要返回一个新的 AST 节点,或者 None(表示删除该节点)。

举个例子,我们想把代码中所有的 x + 0 替换成 x

import ast

class RemoveAddZero(ast.NodeTransformer):
    def visit_BinOp(self, node):
        if isinstance(node.op, ast.Add) and 
           isinstance(node.right, ast.Constant) and 
           node.right.value == 0:
            return node.left
        return node

code = "y = x + 0"
tree = ast.parse(code)
transformer = RemoveAddZero()
new_tree = transformer.visit(tree)
new_code = ast.unparse(new_tree)
print(f"转换后的代码:{new_code}")

在这个例子中:

  • 我们定义了一个 RemoveAddZero 类,继承自 ast.NodeTransformer
  • 我们重写了 visit_BinOp 方法,这个方法会在遍历到 BinOp 节点时被调用。
  • visit_BinOp 方法中,我们检查是否是 x + 0 的形式。
  • 如果是,我们就返回 node.left,也就是 x
  • 如果不是,我们就返回 node,表示不修改该节点。
  • 最后,我们使用 ast.unparse 把修改后的 AST 转换回代码。

注意:ast.NodeTransformer 会递归地遍历 AST,所以我们只需要处理最顶层的节点,它会自动处理子节点。

第五章:实战演练:代码自动优化器

现在,让我们来做一个更复杂的例子:一个简单的代码自动优化器。 它可以:

  • x + 0 变成 x
  • x * 1 变成 x
  • if True: 里的代码直接拿出来。
import ast

class Optimizer(ast.NodeTransformer):
    def visit_BinOp(self, node):
        # x + 0 -> x
        if isinstance(node.op, ast.Add) and 
           isinstance(node.right, ast.Constant) and 
           node.right.value == 0:
            return node.left
        # x * 1 -> x
        if isinstance(node.op, ast.Mult) and 
           isinstance(node.right, ast.Constant) and 
           node.right.value == 1:
            return node.left
        return node

    def visit_If(self, node):
        # if True: -> 把里面的代码拿出来
        if isinstance(node.test, ast.Constant) and node.test.value == True:
            return ast.Module(body=node.body, type_ignores=[]) # 变成module
        return node

code = """
x = 1 + 0
y = z * 1
if True:
    print("Hello, world!")
"""
tree = ast.parse(code)
optimizer = Optimizer()
new_tree = optimizer.visit(tree)
new_code = ast.unparse(new_tree)
print(f"优化后的代码:n{new_code}")

运行这段代码,你会看到:

优化后的代码:
x = 1
y = z
print('Hello, world!')

第六章:一些重要的 AST 节点类型

为了编写更强大的代码转换器,我们需要了解一些重要的 AST 节点类型。 下面是一些常用的节点类型:

节点类型 描述 示例
Module AST 的根节点,表示一个模块(一个 Python 文件)。 ast.parse("x = 1") 返回的 AST 的根节点就是 Module
Assign 赋值语句。 x = 1 -> Assign(targets=[Name(id='x', ctx=Store())], value=Constant(value=1))
Name 变量名。 x -> Name(id='x', ctx=Load())ctx 表示变量的使用场景,Load 表示读取,Store 表示赋值。)
Constant 常量。 1 -> Constant(value=1)
BinOp 二元运算(比如,加法、减法、乘法、除法)。 x + 1 -> BinOp(left=Name(id='x', ctx=Load()), op=Add(), right=Constant(value=1))
UnaryOp 一元运算(比如,正号、负号、取反)。 -x -> UnaryOp(op=USub(), operand=Name(id='x', ctx=Load()))
Call 函数调用。 print("Hello") -> Call(func=Name(id='print', ctx=Load()), args=[Constant(value='Hello')], keywords=[])
If if 语句。 if x > 0: -> If(test=Compare(left=Name(id='x', ctx=Load()), ops=[Gt()], comparators=[Constant(value=0)]), body=[...], orelse=[...])
For for 循环。 for i in range(10): -> For(target=Name(id='i', ctx=Store()), iter=Call(func=Name(id='range', ctx=Load()), args=[Constant(value=10)], keywords=[]), body=[...], orelse=[...])
FunctionDef 函数定义。 def foo(x): -> FunctionDef(name='foo', args=arguments(args=[arg(arg='x', annotation=None, type_comment=None)], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[...], decorator_list=[], returns=None, type_comment=None)
Return return 语句。 return x -> Return(value=Name(id='x', ctx=Load()))
List 列表。 [1, 2, 3] -> List(elts=[Constant(value=1), Constant(value=2), Constant(value=3)], ctx=Load())
Dict 字典。 {'a': 1, 'b': 2} -> Dict(keys=[Constant(value='a'), Constant(value='b')], values=[Constant(value=1), Constant(value=2)])
Attribute 属性访问 (比如 object.attribute) obj.attr -> Attribute(value=Name(id='obj', ctx=Load()), attr='attr', ctx=Load())
Expr 表达式语句 (单独一个表达式,比如函数调用,它不是赋值的一部分)。 print("hello") -> Expr(value=Call(func=Name(id='print', ctx=Load()), args=[Constant(value='hello')], keywords=[]))

第七章:进阶技巧:代码生成

除了修改现有的 AST,我们还可以生成全新的 AST。 这需要我们手动创建 AST 节点,并将它们组合成一棵完整的 AST。

举个例子,我们想生成代码 x = 1 + 2

import ast

# 创建 AST 节点
x = ast.Name(id='x', ctx=ast.Store())
one = ast.Constant(value=1)
two = ast.Constant(value=2)
add = ast.BinOp(left=one, op=ast.Add(), right=two)
assign = ast.Assign(targets=[x], value=add)
module = ast.Module(body=[assign], type_ignores=[])

# 转换成代码
code = ast.unparse(module)
print(f"生成的代码:{code}")

第八章:一些注意事项

  • AST 的结构可能很复杂。 可以使用 ast.dump 来查看 AST 的结构,或者使用一些 AST 可视化工具。
  • 修改 AST 时要小心。 如果修改不当,可能会导致代码无法运行,或者产生意想不到的结果。
  • ast.unparse 的输出可能和原始代码略有不同。 比如,空格、换行等可能会发生变化。
  • ast 模块的功能有限。 它只能处理 Python 的语法结构,无法处理语义信息(比如,变量的类型)。

第九章:总结

今天我们学习了 Python AST ast 模块的基本用法,包括:

  • 什么是 AST,以及为什么要关心它。
  • ast 模块的核心功能:ast.parseast.dumpast.NodeVisitorast.NodeTransformerast.unparse
  • 如何编写自定义的代码转换器,实现代码自动优化等功能。
  • 一些重要的 AST 节点类型。
  • 如何生成全新的 AST。

希望今天的课程能帮助大家打开代码炼金术的大门! 记住,代码不仅仅是死的文本,它也是可以被我们随意操纵的“橡皮泥”。 只要掌握了 AST,你就可以创造出无限的可能性!

作业:

  1. 编写一个代码转换器,将代码中的所有 print 语句替换成 logging.info 语句。
  2. 编写一个代码分析器,检查代码中是否存在未使用的变量。

下课!

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注