好的,让我们来一场关于 Python AST (Abstract Syntax Tree) ast
模块的深度讲座,主题是编写自定义代码转换器。
各位同学,欢迎来到“代码炼金术”课堂!今天我们要学习的是如何将你的 Python 代码变成橡皮泥,想捏成什么形状就捏成什么形状!而我们使用的工具,就是 Python 的 ast
模块。
第一章:什么是 AST?为什么我们要关心它?
想象一下,你写了一段 Python 代码,计算机是怎么理解它的呢?不是直接“嗖”的一下就运行了,而是要经过一个“翻译”的过程。这个“翻译”的第一步,就是把你的代码变成一棵“抽象语法树”,也就是 AST。
AST 就像是代码的骨架,它用一种树状结构,清晰地表达了代码的语法结构。 举个例子, 1 + 2 * 3
这行代码,对应的 AST 可能是这样的(简化版):
+
/
1 *
/
2 3
看到了吗?加法是树的根,乘法是加法的右子树。AST 清楚地表达了运算的优先级。
为什么要关心 AST 呢?因为它给了我们一个机会,在代码运行之前,对代码进行“动手术”。我们可以分析 AST,修改 AST,甚至生成全新的 AST。这就是代码转换器的核心思想。
想象一下,你可以:
- 自动优化代码: 比如,把
x + 0
变成x
,或者把if True:
里的代码直接拿出来。 - 静态代码分析: 检查代码风格,发现潜在的 bug。
- 代码生成: 把一种语言的代码转换成另一种语言的代码(比如,把 Python 转换成 JavaScript)。
- 自定义代码规范: 强制团队遵循统一的代码风格。
第二章:ast
模块:我们的魔法工具箱
Python 的 ast
模块,就是我们用来操作 AST 的魔法工具箱。它提供了以下几个核心功能:
ast.parse(source)
: 把 Python 代码字符串转换成 AST 对象。ast.dump(node)
: 把 AST 对象转换成字符串,方便我们查看 AST 的结构。ast.NodeVisitor
: 一个基类,用于遍历 AST,并执行自定义的操作。ast.NodeTransformer
: 一个基类,用于修改 AST。ast.unparse(node)
: (Python 3.9+) 把 AST 对象转换回 Python 代码字符串。
让我们先来玩个简单的例子,看看 ast.parse
和 ast.dump
的威力:
import ast
code = "x = 1 + 2 * 3"
tree = ast.parse(code)
print(ast.dump(tree, indent=4))
运行这段代码,你会看到类似这样的输出:
Module(
body=[
Assign(
targets=[
Name(id='x', ctx=Store())
],
value=BinOp(
left=Constant(value=1),
op=Add(),
right=BinOp(
left=Constant(value=2),
op=Mult(),
right=Constant(value=3))))
],
type_ignores=[])
这就是 x = 1 + 2 * 3
这行代码的 AST 表示。 虽然看起来有点吓人,但仔细观察,你会发现它确实反映了代码的结构。Assign
表示赋值,Name
表示变量名,BinOp
表示二元运算,Constant
表示常量。
第三章:ast.NodeVisitor
:AST 探险家
ast.NodeVisitor
就像一个 AST 探险家,它可以带我们遍历 AST 的每一个节点。 我们需要做的,就是继承 ast.NodeVisitor
类,并重写一些 visit_XXX
方法,来处理特定类型的节点。
举个例子,我们想统计一段代码中,所有变量的个数:
import ast
class VariableCounter(ast.NodeVisitor):
def __init__(self):
self.count = 0
def visit_Name(self, node):
self.count += 1
# 继续遍历子节点
self.generic_visit(node)
code = "x = 1 + y * z"
tree = ast.parse(code)
counter = VariableCounter()
counter.visit(tree)
print(f"变量的个数:{counter.count}")
在这个例子中:
- 我们定义了一个
VariableCounter
类,继承自ast.NodeVisitor
。 - 我们重写了
visit_Name
方法,这个方法会在遍历到Name
节点时被调用。 - 在
visit_Name
方法中,我们将计数器count
加 1。 self.generic_visit(node)
保证了会继续遍历node
的子节点。
visit_XXX
方法的命名规则很简单: visit_
加上 AST 节点的类型名(比如,Name
,Assign
,BinOp
)。
第四章:ast.NodeTransformer
:AST 整容师
ast.NodeTransformer
比 ast.NodeVisitor
更强大,它可以修改 AST。 我们需要做的,也是继承 ast.NodeTransformer
类,并重写一些 visit_XXX
方法。 但是,visit_XXX
方法需要返回一个新的 AST 节点,或者 None
(表示删除该节点)。
举个例子,我们想把代码中所有的 x + 0
替换成 x
:
import ast
class RemoveAddZero(ast.NodeTransformer):
def visit_BinOp(self, node):
if isinstance(node.op, ast.Add) and
isinstance(node.right, ast.Constant) and
node.right.value == 0:
return node.left
return node
code = "y = x + 0"
tree = ast.parse(code)
transformer = RemoveAddZero()
new_tree = transformer.visit(tree)
new_code = ast.unparse(new_tree)
print(f"转换后的代码:{new_code}")
在这个例子中:
- 我们定义了一个
RemoveAddZero
类,继承自ast.NodeTransformer
。 - 我们重写了
visit_BinOp
方法,这个方法会在遍历到BinOp
节点时被调用。 - 在
visit_BinOp
方法中,我们检查是否是x + 0
的形式。 - 如果是,我们就返回
node.left
,也就是x
。 - 如果不是,我们就返回
node
,表示不修改该节点。 - 最后,我们使用
ast.unparse
把修改后的 AST 转换回代码。
注意:ast.NodeTransformer
会递归地遍历 AST,所以我们只需要处理最顶层的节点,它会自动处理子节点。
第五章:实战演练:代码自动优化器
现在,让我们来做一个更复杂的例子:一个简单的代码自动优化器。 它可以:
- 把
x + 0
变成x
。 - 把
x * 1
变成x
。 - 把
if True:
里的代码直接拿出来。
import ast
class Optimizer(ast.NodeTransformer):
def visit_BinOp(self, node):
# x + 0 -> x
if isinstance(node.op, ast.Add) and
isinstance(node.right, ast.Constant) and
node.right.value == 0:
return node.left
# x * 1 -> x
if isinstance(node.op, ast.Mult) and
isinstance(node.right, ast.Constant) and
node.right.value == 1:
return node.left
return node
def visit_If(self, node):
# if True: -> 把里面的代码拿出来
if isinstance(node.test, ast.Constant) and node.test.value == True:
return ast.Module(body=node.body, type_ignores=[]) # 变成module
return node
code = """
x = 1 + 0
y = z * 1
if True:
print("Hello, world!")
"""
tree = ast.parse(code)
optimizer = Optimizer()
new_tree = optimizer.visit(tree)
new_code = ast.unparse(new_tree)
print(f"优化后的代码:n{new_code}")
运行这段代码,你会看到:
优化后的代码:
x = 1
y = z
print('Hello, world!')
第六章:一些重要的 AST 节点类型
为了编写更强大的代码转换器,我们需要了解一些重要的 AST 节点类型。 下面是一些常用的节点类型:
节点类型 | 描述 | 示例 |
---|---|---|
Module |
AST 的根节点,表示一个模块(一个 Python 文件)。 | ast.parse("x = 1") 返回的 AST 的根节点就是 Module 。 |
Assign |
赋值语句。 | x = 1 -> Assign(targets=[Name(id='x', ctx=Store())], value=Constant(value=1)) |
Name |
变量名。 | x -> Name(id='x', ctx=Load()) (ctx 表示变量的使用场景,Load 表示读取,Store 表示赋值。) |
Constant |
常量。 | 1 -> Constant(value=1) |
BinOp |
二元运算(比如,加法、减法、乘法、除法)。 | x + 1 -> BinOp(left=Name(id='x', ctx=Load()), op=Add(), right=Constant(value=1)) |
UnaryOp |
一元运算(比如,正号、负号、取反)。 | -x -> UnaryOp(op=USub(), operand=Name(id='x', ctx=Load())) |
Call |
函数调用。 | print("Hello") -> Call(func=Name(id='print', ctx=Load()), args=[Constant(value='Hello')], keywords=[]) |
If |
if 语句。 |
if x > 0: -> If(test=Compare(left=Name(id='x', ctx=Load()), ops=[Gt()], comparators=[Constant(value=0)]), body=[...], orelse=[...]) |
For |
for 循环。 |
for i in range(10): -> For(target=Name(id='i', ctx=Store()), iter=Call(func=Name(id='range', ctx=Load()), args=[Constant(value=10)], keywords=[]), body=[...], orelse=[...]) |
FunctionDef |
函数定义。 | def foo(x): -> FunctionDef(name='foo', args=arguments(args=[arg(arg='x', annotation=None, type_comment=None)], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[...], decorator_list=[], returns=None, type_comment=None) |
Return |
return 语句。 |
return x -> Return(value=Name(id='x', ctx=Load())) |
List |
列表。 | [1, 2, 3] -> List(elts=[Constant(value=1), Constant(value=2), Constant(value=3)], ctx=Load()) |
Dict |
字典。 | {'a': 1, 'b': 2} -> Dict(keys=[Constant(value='a'), Constant(value='b')], values=[Constant(value=1), Constant(value=2)]) |
Attribute |
属性访问 (比如 object.attribute ) |
obj.attr -> Attribute(value=Name(id='obj', ctx=Load()), attr='attr', ctx=Load()) |
Expr |
表达式语句 (单独一个表达式,比如函数调用,它不是赋值的一部分)。 | print("hello") -> Expr(value=Call(func=Name(id='print', ctx=Load()), args=[Constant(value='hello')], keywords=[])) |
第七章:进阶技巧:代码生成
除了修改现有的 AST,我们还可以生成全新的 AST。 这需要我们手动创建 AST 节点,并将它们组合成一棵完整的 AST。
举个例子,我们想生成代码 x = 1 + 2
:
import ast
# 创建 AST 节点
x = ast.Name(id='x', ctx=ast.Store())
one = ast.Constant(value=1)
two = ast.Constant(value=2)
add = ast.BinOp(left=one, op=ast.Add(), right=two)
assign = ast.Assign(targets=[x], value=add)
module = ast.Module(body=[assign], type_ignores=[])
# 转换成代码
code = ast.unparse(module)
print(f"生成的代码:{code}")
第八章:一些注意事项
- AST 的结构可能很复杂。 可以使用
ast.dump
来查看 AST 的结构,或者使用一些 AST 可视化工具。 - 修改 AST 时要小心。 如果修改不当,可能会导致代码无法运行,或者产生意想不到的结果。
ast.unparse
的输出可能和原始代码略有不同。 比如,空格、换行等可能会发生变化。ast
模块的功能有限。 它只能处理 Python 的语法结构,无法处理语义信息(比如,变量的类型)。
第九章:总结
今天我们学习了 Python AST ast
模块的基本用法,包括:
- 什么是 AST,以及为什么要关心它。
ast
模块的核心功能:ast.parse
,ast.dump
,ast.NodeVisitor
,ast.NodeTransformer
,ast.unparse
。- 如何编写自定义的代码转换器,实现代码自动优化等功能。
- 一些重要的 AST 节点类型。
- 如何生成全新的 AST。
希望今天的课程能帮助大家打开代码炼金术的大门! 记住,代码不仅仅是死的文本,它也是可以被我们随意操纵的“橡皮泥”。 只要掌握了 AST,你就可以创造出无限的可能性!
作业:
- 编写一个代码转换器,将代码中的所有
print
语句替换成logging.info
语句。 - 编写一个代码分析器,检查代码中是否存在未使用的变量。
下课!