各位观众老爷们,大家好!我是今天的讲师,江湖人称“代码老司机”。今天咱们聊点高级货,关于unittest.mock
模块里的patch
,看看它怎么帮咱们模拟那些复杂的外部依赖和API调用,让测试变得轻松愉快。
开场白:为什么我们需要模拟?
想象一下,你写了一个函数,这个函数要调用一个外部的API,或者需要连接一个数据库,甚至需要访问一个硬件设备。在测试的时候,你真的想每次都去调用这个API,连接数据库,甚至搬出一台硬件设备吗?
答案当然是:NO!
- 速度慢: 真实的API调用,数据库连接,IO操作等等,都会消耗大量的时间。
- 不稳定: 外部依赖可能会宕机,网络可能会不稳定,测试结果也会变得不可预测。
- 难以控制: 你无法控制API返回什么,数据库里有什么,硬件设备的状态是什么。
- 环境依赖: 测试环境需要配置好各种依赖,增加了测试的复杂度。
所以,我们需要模拟(Mocking)。模拟就是用假的、可控的替代品来替换真实的依赖,让测试在一个隔离、可预测的环境中运行。
unittest.mock
:你的模拟利器
Python的unittest.mock
模块就是来帮助我们进行模拟的。它提供了各种工具,包括 Mock
, MagicMock
, patch
等,今天咱们的重点是 patch
。
patch
:像变魔术一样替换对象
patch
可以像变魔术一样,临时替换掉你代码里的对象,在测试结束后再恢复。它有几种常用的用法:
- 装饰器方式 (
@patch
) - 上下文管理器方式 (
with patch(...)
) - 直接调用方式 (
patch(...)
)
咱们一个一个来细说。
1. 装饰器方式 (@patch
):轻量级替换
这种方式最常用,也最简洁。 想象一下,你的代码里有一个函数 get_data_from_api()
,它调用了一个外部API api_client.get()
。
# my_module.py
import requests
def get_data_from_api(url):
try:
response = requests.get(url)
response.raise_for_status() # 检查状态码
return response.json()
except requests.exceptions.RequestException as e:
print(f"API调用失败: {e}")
return None
# 一些其他的代码...
现在,我们要测试 get_data_from_api()
,但是不想真的去调用API。 使用 @patch
,我们可以轻松地替换 requests.get
。
# test_my_module.py
import unittest
from unittest.mock import patch
import my_module
import json
class TestGetApiData(unittest.TestCase):
@patch('my_module.requests.get') # 注意这里的字符串!
def test_get_data_from_api_success(self, mock_get):
# 配置 mock_get 的行为
mock_get.return_value.status_code = 200
mock_get.return_value.json.return_value = {'data': 'some data'}
# 调用被测函数
result = my_module.get_data_from_api('http://example.com/api')
# 断言
self.assertEqual(result, {'data': 'some data'})
mock_get.assert_called_once_with('http://example.com/api')
@patch('my_module.requests.get')
def test_get_data_from_api_failure(self, mock_get):
# 配置 mock_get 的行为,模拟API调用失败
mock_get.side_effect = requests.exceptions.RequestException("API error")
# 调用被测函数
result = my_module.get_data_from_api('http://example.com/api')
# 断言
self.assertIsNone(result)
mock_get.assert_called_once_with('http://example.com/api')
if __name__ == '__main__':
unittest.main()
代码解读:
@patch('my_module.requests.get')
: 这行代码是关键。它告诉patch
要替换哪个对象。 字符串'my_module.requests.get'
指定了要替换的是my_module
模块里的requests
模块的get
函数。 注意:这里必须是字符串,而且要写清楚完整的路径!def test_get_data_from_api_success(self, mock_get):
:patch
会把 mock 对象作为参数传递给被装饰的函数。 这里的mock_get
就是替换requests.get
的 mock 对象。mock_get.return_value.status_code = 200
: 我们配置了mock_get
的返回值 (return_value
) 的status_code
属性为 200。 这意味着,当我们调用requests.get()
时,它会返回一个Response
对象,这个对象的status_code
属性是 200。mock_get.return_value.json.return_value = {'data': 'some data'}
: 同样,我们配置了Response
对象的json()
方法的返回值。mock_get.assert_called_once_with('http://example.com/api')
: 这是一个断言,用来检查requests.get()
是否被调用,并且参数是否正确。
2. 上下文管理器方式 (with patch(...)
):精确定位
上下文管理器方式和装饰器方式类似,但是它更灵活,可以精确控制 mock 的作用范围。
# my_module.py (保持不变)
import requests
def get_data_from_api(url):
try:
response = requests.get(url)
response.raise_for_status() # 检查状态码
return response.json()
except requests.exceptions.RequestException as e:
print(f"API调用失败: {e}")
return None
# test_my_module.py
import unittest
from unittest.mock import patch
import my_module
import requests
class TestGetApiData(unittest.TestCase):
def test_get_data_from_api_success(self):
with patch('my_module.requests.get') as mock_get:
# 配置 mock_get 的行为
mock_get.return_value.status_code = 200
mock_get.return_value.json.return_value = {'data': 'some data'}
# 调用被测函数
result = my_module.get_data_from_api('http://example.com/api')
# 断言
self.assertEqual(result, {'data': 'some data'})
mock_get.assert_called_once_with('http://example.com/api')
def test_get_data_from_api_failure(self):
with patch('my_module.requests.get') as mock_get:
# 配置 mock_get 的行为,模拟API调用失败
mock_get.side_effect = requests.exceptions.RequestException("API error")
# 调用被测函数
result = my_module.get_data_from_api('http://example.com/api')
# 断言
self.assertIsNone(result)
mock_get.assert_called_once_with('http://example.com/api')
if __name__ == '__main__':
unittest.main()
代码解读:
with patch('my_module.requests.get') as mock_get:
:patch
创建一个上下文管理器,as mock_get
把 mock 对象赋值给mock_get
变量。 在这个with
语句块里,requests.get
会被替换成mock_get
。 当with
语句块结束时,requests.get
会自动恢复成原来的样子。- 其余部分和装饰器方式类似,只是使用 mock 对象的方式略有不同。
3. 直接调用方式 (patch(...)
):手动启动和停止
这种方式最灵活,但是也最麻烦。 你需要手动启动 mock,并在测试结束后手动停止 mock。
# my_module.py (保持不变)
import requests
def get_data_from_api(url):
try:
response = requests.get(url)
response.raise_for_status() # 检查状态码
return response.json()
except requests.exceptions.RequestException as e:
print(f"API调用失败: {e}")
return None
# test_my_module.py
import unittest
from unittest.mock import patch
import my_module
import requests
class TestGetApiData(unittest.TestCase):
def test_get_data_from_api_success(self):
# 启动 mock
mock_get = patch('my_module.requests.get')
mock = mock_get.start()
# 配置 mock 的行为
mock.return_value.status_code = 200
mock.return_value.json.return_value = {'data': 'some data'}
# 调用被测函数
result = my_module.get_data_from_api('http://example.com/api')
# 断言
self.assertEqual(result, {'data': 'some data'})
mock.assert_called_once_with('http://example.com/api')
# 停止 mock
mock_get.stop()
def test_get_data_from_api_failure(self):
# 启动 mock
mock_get = patch('my_module.requests.get')
mock = mock_get.start()
# 配置 mock 的行为,模拟API调用失败
mock.side_effect = requests.exceptions.RequestException("API error")
# 调用被测函数
result = my_module.get_data_from_api('http://example.com/api')
# 断言
self.assertIsNone(result)
mock.assert_called_once_with('http://example.com/api')
# 停止 mock
mock_get.stop()
if __name__ == '__main__':
unittest.main()
代码解读:
mock_get = patch('my_module.requests.get')
: 创建一个patch
对象,但是并没有立即启动 mock。mock = mock_get.start()
: 启动 mock,并把 mock 对象赋值给mock
变量。mock_get.stop()
: 停止 mock,恢复原来的对象。 重要: 必须手动调用stop()
,否则会影响后续的测试!
总结:三种方式的对比
方式 | 优点 | 缺点 | 适用场景 |
---|---|---|---|
装饰器方式 | 简洁易用,代码量少 | 只能替换函数级别的对象 | 简单的替换,只需要替换一个函数,并且作用范围是整个测试函数 |
上下文管理器方式 | 灵活,可以精确控制 mock 的作用范围 | 代码量稍多 | 需要精确控制 mock 的作用范围,或者需要在同一个测试函数里多次替换不同的对象 |
直接调用方式 | 最灵活,可以手动启动和停止 mock | 最麻烦,需要手动管理 mock 的生命周期,容易出错 | 特殊场景,例如需要在测试函数外部启动 mock,或者需要在多个测试函数之间共享 mock 对象。 不推荐新手使用 |
模拟复杂的外部依赖和API调用
现在,让我们来看一些更复杂的例子,看看 patch
如何帮助我们模拟复杂的外部依赖和API调用。
场景1:模拟一个复杂的API客户端
假设你有一个复杂的API客户端,它有很多方法,并且需要进行身份验证。
# api_client.py
import requests
class APIClient:
def __init__(self, api_key):
self.api_key = api_key
self.base_url = "https://api.example.com"
def get_resource(self, resource_id):
url = f"{self.base_url}/resources/{resource_id}"
headers = {'Authorization': f'Bearer {self.api_key}'}
response = requests.get(url, headers=headers)
response.raise_for_status()
return response.json()
def create_resource(self, data):
url = f"{self.base_url}/resources"
headers = {'Authorization': f'Bearer {self.api_key}'}
response = requests.post(url, headers=headers, json=data)
response.raise_for_status()
return response.json()
# 更多的方法...
要测试使用这个 APIClient
的代码,我们可以模拟整个 APIClient
类。
# my_module.py
from api_client import APIClient
def process_resource(api_key, resource_id):
client = APIClient(api_key)
resource = client.get_resource(resource_id)
# 对 resource 进行一些处理
processed_data = {'name': resource['name'].upper()}
return processed_data
# test_my_module.py
import unittest
from unittest.mock import patch, MagicMock
import my_module
from api_client import APIClient
class TestProcessResource(unittest.TestCase):
@patch('my_module.APIClient')
def test_process_resource_success(self, MockAPIClient):
# 配置 MockAPIClient 的行为
mock_client = MagicMock() # 使用 MagicMock 方便模拟各种方法
MockAPIClient.return_value = mock_client
mock_client.get_resource.return_value = {'name': 'test'}
# 调用被测函数
result = my_module.process_resource('fake_api_key', 123)
# 断言
self.assertEqual(result, {'name': 'TEST'})
mock_client.get_resource.assert_called_once_with(123)
MockAPIClient.assert_called_once_with('fake_api_key')
代码解读:
@patch('my_module.APIClient')
: 我们模拟了my_module
模块里的APIClient
类。mock_client = MagicMock()
: 我们创建了一个MagicMock
对象,用来模拟APIClient
的实例。MagicMock
比Mock
更强大,它可以自动创建属性和方法,方便我们模拟复杂的对象。MockAPIClient.return_value = mock_client
: 我们设置了MockAPIClient
的return_value
属性为mock_client
。 这意味着,当我们调用APIClient()
时,它会返回mock_client
。mock_client.get_resource.return_value = {'name': 'test'}
: 我们配置了mock_client
的get_resource()
方法的返回值。MockAPIClient.assert_called_once_with('fake_api_key')
: 我们断言APIClient
的构造函数被正确调用,参数为fake_api_key
。
场景2:模拟一个数据库连接
假设你的代码需要连接数据库,并执行一些查询。
# database.py
import sqlite3
def get_user_name(user_id):
conn = sqlite3.connect('users.db')
cursor = conn.cursor()
cursor.execute("SELECT name FROM users WHERE id = ?", (user_id,))
result = cursor.fetchone()
conn.close()
if result:
return result[0]
else:
return None
我们可以模拟 sqlite3.connect
来避免真实的数据库连接。
# test_database.py
import unittest
from unittest.mock import patch, MagicMock
import database
class TestGetUserName(unittest.TestCase):
@patch('database.sqlite3.connect')
def test_get_user_name_success(self, mock_connect):
# 配置 mock 的行为
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_connect.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = ('John',)
# 调用被测函数
result = database.get_user_name(1)
# 断言
self.assertEqual(result, 'John')
mock_connect.assert_called_once_with('users.db')
mock_cursor.execute.assert_called_once_with("SELECT name FROM users WHERE id = ?", (1,))
mock_conn.close.assert_called_once()
@patch('database.sqlite3.connect')
def test_get_user_name_not_found(self, mock_connect):
# 配置 mock 的行为,模拟用户不存在
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_connect.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = None
# 调用被测函数
result = database.get_user_name(1)
# 断言
self.assertIsNone(result)
进阶技巧:side_effect
的妙用
side_effect
是 Mock
对象的一个非常有用的属性,它可以让你定义一个函数,这个函数会在每次调用 mock 对象时被执行。 这可以让你模拟更复杂的行为,例如:
- 模拟抛出异常
- 模拟根据不同的参数返回不同的值
- 模拟有副作用的操作
例子:模拟根据不同的URL返回不同的数据
import unittest
from unittest.mock import patch
import my_module
def get_data_from_api(url):
# 这个函数已经被修改, 直接返回 requests.get 的结果
import requests
response = requests.get(url)
return response.json()
class TestGetDataFromApi(unittest.TestCase):
@patch('my_module.requests.get')
def test_get_data_from_api_with_side_effect(self, mock_get):
def side_effect(url):
if url == 'http://example.com/api/users':
return MockResponse([{'id': 1, 'name': 'John'}, {'id': 2, 'name': 'Jane'}])
elif url == 'http://example.com/api/products':
return MockResponse([{'id': 101, 'name': 'Product A'}, {'id': 102, 'name': 'Product B'}])
else:
return MockResponse(None) # 或者抛出异常
mock_get.side_effect = side_effect
# 调用被测函数,获取用户数据
users = my_module.get_data_from_api('http://example.com/api/users')
self.assertEqual(users, [{'id': 1, 'name': 'John'}, {'id': 2, 'name': 'Jane'}])
# 调用被测函数,获取产品数据
products = my_module.get_data_from_api('http://example.com/api/products')
self.assertEqual(products, [{'id': 101, 'name': 'Product A'}, {'id': 102, 'name': 'Product B'}])
# 确保 mock_get 被调用了两次
self.assertEqual(mock_get.call_count, 2)
mock_get.assert_any_call('http://example.com/api/users')
mock_get.assert_any_call('http://example.com/api/products')
class MockResponse:
def __init__(self, json_data, status_code=200):
self.json_data = json_data
self.status_code = status_code
def json(self):
return self.json_data
def raise_for_status(self):
if self.status_code >= 400:
raise Exception("Request failed")
if __name__ == '__main__':
unittest.main()
代码解读:
mock_get.side_effect = side_effect
: 我们把side_effect
函数赋值给mock_get
的side_effect
属性。side_effect(url)
:side_effect
函数接收一个参数url
,这个参数就是requests.get()
被调用时传入的 URL。 根据不同的 URL,我们返回不同的MockResponse
对象。- 我们创建了一个
MockResponse
类来模拟requests.Response
对象, 以简化测试代码.
重要提示:patch
的作用域
patch
的作用域是动态作用域,而不是词法作用域。 这意味着,patch
会替换运行时找到的对象,而不是定义时找到的对象。 这可能会导致一些意想不到的结果。
例子:patch
错误的位置
# module_a.py
import module_b
def foo():
return module_b.bar()
# module_b.py
def bar():
return "real bar"
# test_module_a.py
import unittest
from unittest.mock import patch
import module_a
import module_b
class TestModuleA(unittest.TestCase):
@patch('module_b.bar') # 正确的 patch 位置
def test_foo(self, mock_bar):
mock_bar.return_value = "mocked bar"
result = module_a.foo()
self.assertEqual(result, "mocked bar")
@patch('module_a.module_b.bar') # 错误的 patch 位置
def test_foo_wrong_patch(self, mock_bar):
mock_bar.return_value = "mocked bar"
result = module_a.foo()
self.assertEqual(result, "real bar") # 断言失败
在这个例子里,test_foo
使用 @patch('module_b.bar')
,这才是正确的patch
位置。因为foo
函数内部调用的是module_b.bar
。
而test_foo_wrong_patch
使用 @patch('module_a.module_b.bar')
,虽然在 module_a
里看到了 module_b.bar
,但这并不是 bar
真正被调用的地方,所以 patch
并没有生效。
总结:patch
的正确用法
- 永远要
patch
对象被使用的地方,而不是对象被定义的地方。 - 仔细分析你的代码,找到真正调用外部依赖的地方。
- 使用完整的路径来指定要
patch
的对象。
最佳实践:提高测试代码的可读性和可维护性
- 使用有意义的 mock 对象名称。 例如,
mock_get
比mock
更好。 - 把 mock 对象的配置放在测试函数的开头。 这样可以更清楚地了解 mock 对象的行为。
- 使用
assert_called_once_with()
来检查 mock 对象是否被调用,并且参数是否正确。 - 尽量避免使用
side_effect
,除非真的需要模拟复杂的行为。side_effect
会使测试代码更难理解。 - 保持测试代码的简洁和可读性。 测试代码也是代码,也需要维护。
Q&A环节:
-
Q:
Mock
和MagicMock
有什么区别?- A:
MagicMock
是Mock
的子类,它增加了对 magic methods (例如__str__
,__len__
,__iter__
等) 的支持。 如果你需要模拟的对象需要使用 magic methods,那么应该使用MagicMock
。
- A:
-
Q:
patch
可以patch
属性吗?- A: 可以的。 例如,
@patch('my_module.MyClass.my_attribute')
可以替换MyClass
的my_attribute
属性。
- A: 可以的。 例如,
-
Q: 如何
patch
一个全局变量?- A: 和
patch
函数一样, 只是需要指定全局变量的完整路径。 例如,@patch('my_module.GLOBAL_VARIABLE')
。
- A: 和
结束语:
unittest.mock
模块的 patch
是一个强大的工具,它可以帮助我们模拟复杂的外部依赖和API调用,让测试变得更加容易。 但是,patch
也有一些陷阱,需要小心使用。 掌握了 patch
的正确用法,你就可以写出更健壮、更可靠的测试代码。
希望今天的讲座对大家有所帮助! 下次再见!