Python高级技术之:`unittest.mock`的`patch`:如何模拟复杂的外部依赖和`API`调用。

各位观众老爷们,大家好!我是今天的讲师,江湖人称“代码老司机”。今天咱们聊点高级货,关于unittest.mock模块里的patch,看看它怎么帮咱们模拟那些复杂的外部依赖和API调用,让测试变得轻松愉快。

开场白:为什么我们需要模拟?

想象一下,你写了一个函数,这个函数要调用一个外部的API,或者需要连接一个数据库,甚至需要访问一个硬件设备。在测试的时候,你真的想每次都去调用这个API,连接数据库,甚至搬出一台硬件设备吗?

答案当然是:NO!

  • 速度慢: 真实的API调用,数据库连接,IO操作等等,都会消耗大量的时间。
  • 不稳定: 外部依赖可能会宕机,网络可能会不稳定,测试结果也会变得不可预测。
  • 难以控制: 你无法控制API返回什么,数据库里有什么,硬件设备的状态是什么。
  • 环境依赖: 测试环境需要配置好各种依赖,增加了测试的复杂度。

所以,我们需要模拟(Mocking)。模拟就是用假的、可控的替代品来替换真实的依赖,让测试在一个隔离、可预测的环境中运行。

unittest.mock:你的模拟利器

Python的unittest.mock模块就是来帮助我们进行模拟的。它提供了各种工具,包括 Mock, MagicMock, patch 等,今天咱们的重点是 patch

patch:像变魔术一样替换对象

patch 可以像变魔术一样,临时替换掉你代码里的对象,在测试结束后再恢复。它有几种常用的用法:

  1. 装饰器方式 (@patch)
  2. 上下文管理器方式 (with patch(...))
  3. 直接调用方式 (patch(...))

咱们一个一个来细说。

1. 装饰器方式 (@patch):轻量级替换

这种方式最常用,也最简洁。 想象一下,你的代码里有一个函数 get_data_from_api(),它调用了一个外部API api_client.get()

# my_module.py
import requests

def get_data_from_api(url):
    try:
        response = requests.get(url)
        response.raise_for_status()  # 检查状态码
        return response.json()
    except requests.exceptions.RequestException as e:
        print(f"API调用失败: {e}")
        return None

# 一些其他的代码...

现在,我们要测试 get_data_from_api(),但是不想真的去调用API。 使用 @patch,我们可以轻松地替换 requests.get

# test_my_module.py
import unittest
from unittest.mock import patch
import my_module
import json

class TestGetApiData(unittest.TestCase):

    @patch('my_module.requests.get')  # 注意这里的字符串!
    def test_get_data_from_api_success(self, mock_get):
        # 配置 mock_get 的行为
        mock_get.return_value.status_code = 200
        mock_get.return_value.json.return_value = {'data': 'some data'}

        # 调用被测函数
        result = my_module.get_data_from_api('http://example.com/api')

        # 断言
        self.assertEqual(result, {'data': 'some data'})
        mock_get.assert_called_once_with('http://example.com/api')

    @patch('my_module.requests.get')
    def test_get_data_from_api_failure(self, mock_get):
         # 配置 mock_get 的行为,模拟API调用失败
        mock_get.side_effect = requests.exceptions.RequestException("API error")

        # 调用被测函数
        result = my_module.get_data_from_api('http://example.com/api')

        # 断言
        self.assertIsNone(result)
        mock_get.assert_called_once_with('http://example.com/api')

if __name__ == '__main__':
    unittest.main()

代码解读:

  • @patch('my_module.requests.get'): 这行代码是关键。它告诉 patch 要替换哪个对象。 字符串 'my_module.requests.get' 指定了要替换的是 my_module 模块里的 requests 模块的 get 函数。 注意:这里必须是字符串,而且要写清楚完整的路径!
  • def test_get_data_from_api_success(self, mock_get):: patch 会把 mock 对象作为参数传递给被装饰的函数。 这里的 mock_get 就是替换 requests.get 的 mock 对象。
  • mock_get.return_value.status_code = 200: 我们配置了 mock_get 的返回值 (return_value) 的 status_code 属性为 200。 这意味着,当我们调用 requests.get() 时,它会返回一个 Response 对象,这个对象的 status_code 属性是 200。
  • mock_get.return_value.json.return_value = {'data': 'some data'}: 同样,我们配置了 Response 对象的 json() 方法的返回值。
  • mock_get.assert_called_once_with('http://example.com/api'): 这是一个断言,用来检查 requests.get() 是否被调用,并且参数是否正确。

2. 上下文管理器方式 (with patch(...)):精确定位

上下文管理器方式和装饰器方式类似,但是它更灵活,可以精确控制 mock 的作用范围。

# my_module.py (保持不变)
import requests

def get_data_from_api(url):
    try:
        response = requests.get(url)
        response.raise_for_status()  # 检查状态码
        return response.json()
    except requests.exceptions.RequestException as e:
        print(f"API调用失败: {e}")
        return None

# test_my_module.py
import unittest
from unittest.mock import patch
import my_module
import requests

class TestGetApiData(unittest.TestCase):

    def test_get_data_from_api_success(self):
        with patch('my_module.requests.get') as mock_get:
            # 配置 mock_get 的行为
            mock_get.return_value.status_code = 200
            mock_get.return_value.json.return_value = {'data': 'some data'}

            # 调用被测函数
            result = my_module.get_data_from_api('http://example.com/api')

            # 断言
            self.assertEqual(result, {'data': 'some data'})
            mock_get.assert_called_once_with('http://example.com/api')

    def test_get_data_from_api_failure(self):
        with patch('my_module.requests.get') as mock_get:
            # 配置 mock_get 的行为,模拟API调用失败
            mock_get.side_effect = requests.exceptions.RequestException("API error")

            # 调用被测函数
            result = my_module.get_data_from_api('http://example.com/api')

            # 断言
            self.assertIsNone(result)
            mock_get.assert_called_once_with('http://example.com/api')

if __name__ == '__main__':
    unittest.main()

代码解读:

  • with patch('my_module.requests.get') as mock_get:: patch 创建一个上下文管理器,as mock_get 把 mock 对象赋值给 mock_get 变量。 在这个 with 语句块里,requests.get 会被替换成 mock_get。 当 with 语句块结束时,requests.get 会自动恢复成原来的样子。
  • 其余部分和装饰器方式类似,只是使用 mock 对象的方式略有不同。

3. 直接调用方式 (patch(...)):手动启动和停止

这种方式最灵活,但是也最麻烦。 你需要手动启动 mock,并在测试结束后手动停止 mock。

# my_module.py (保持不变)
import requests

def get_data_from_api(url):
    try:
        response = requests.get(url)
        response.raise_for_status()  # 检查状态码
        return response.json()
    except requests.exceptions.RequestException as e:
        print(f"API调用失败: {e}")
        return None

# test_my_module.py
import unittest
from unittest.mock import patch
import my_module
import requests

class TestGetApiData(unittest.TestCase):

    def test_get_data_from_api_success(self):
        # 启动 mock
        mock_get = patch('my_module.requests.get')
        mock = mock_get.start()

        # 配置 mock 的行为
        mock.return_value.status_code = 200
        mock.return_value.json.return_value = {'data': 'some data'}

        # 调用被测函数
        result = my_module.get_data_from_api('http://example.com/api')

        # 断言
        self.assertEqual(result, {'data': 'some data'})
        mock.assert_called_once_with('http://example.com/api')

        # 停止 mock
        mock_get.stop()

    def test_get_data_from_api_failure(self):
        # 启动 mock
        mock_get = patch('my_module.requests.get')
        mock = mock_get.start()

        # 配置 mock 的行为,模拟API调用失败
        mock.side_effect = requests.exceptions.RequestException("API error")

        # 调用被测函数
        result = my_module.get_data_from_api('http://example.com/api')

        # 断言
        self.assertIsNone(result)
        mock.assert_called_once_with('http://example.com/api')

        # 停止 mock
        mock_get.stop()

if __name__ == '__main__':
    unittest.main()

代码解读:

  • mock_get = patch('my_module.requests.get'): 创建一个 patch 对象,但是并没有立即启动 mock。
  • mock = mock_get.start(): 启动 mock,并把 mock 对象赋值给 mock 变量。
  • mock_get.stop(): 停止 mock,恢复原来的对象。 重要: 必须手动调用 stop(),否则会影响后续的测试!

总结:三种方式的对比

方式 优点 缺点 适用场景
装饰器方式 简洁易用,代码量少 只能替换函数级别的对象 简单的替换,只需要替换一个函数,并且作用范围是整个测试函数
上下文管理器方式 灵活,可以精确控制 mock 的作用范围 代码量稍多 需要精确控制 mock 的作用范围,或者需要在同一个测试函数里多次替换不同的对象
直接调用方式 最灵活,可以手动启动和停止 mock 最麻烦,需要手动管理 mock 的生命周期,容易出错 特殊场景,例如需要在测试函数外部启动 mock,或者需要在多个测试函数之间共享 mock 对象。 不推荐新手使用

模拟复杂的外部依赖和API调用

现在,让我们来看一些更复杂的例子,看看 patch 如何帮助我们模拟复杂的外部依赖和API调用。

场景1:模拟一个复杂的API客户端

假设你有一个复杂的API客户端,它有很多方法,并且需要进行身份验证。

# api_client.py
import requests

class APIClient:
    def __init__(self, api_key):
        self.api_key = api_key
        self.base_url = "https://api.example.com"

    def get_resource(self, resource_id):
        url = f"{self.base_url}/resources/{resource_id}"
        headers = {'Authorization': f'Bearer {self.api_key}'}
        response = requests.get(url, headers=headers)
        response.raise_for_status()
        return response.json()

    def create_resource(self, data):
        url = f"{self.base_url}/resources"
        headers = {'Authorization': f'Bearer {self.api_key}'}
        response = requests.post(url, headers=headers, json=data)
        response.raise_for_status()
        return response.json()

    # 更多的方法...

要测试使用这个 APIClient 的代码,我们可以模拟整个 APIClient 类。

# my_module.py
from api_client import APIClient

def process_resource(api_key, resource_id):
    client = APIClient(api_key)
    resource = client.get_resource(resource_id)
    # 对 resource 进行一些处理
    processed_data = {'name': resource['name'].upper()}
    return processed_data
# test_my_module.py
import unittest
from unittest.mock import patch, MagicMock
import my_module
from api_client import APIClient

class TestProcessResource(unittest.TestCase):

    @patch('my_module.APIClient')
    def test_process_resource_success(self, MockAPIClient):
        # 配置 MockAPIClient 的行为
        mock_client = MagicMock() # 使用 MagicMock 方便模拟各种方法
        MockAPIClient.return_value = mock_client
        mock_client.get_resource.return_value = {'name': 'test'}

        # 调用被测函数
        result = my_module.process_resource('fake_api_key', 123)

        # 断言
        self.assertEqual(result, {'name': 'TEST'})
        mock_client.get_resource.assert_called_once_with(123)
        MockAPIClient.assert_called_once_with('fake_api_key')

代码解读:

  • @patch('my_module.APIClient'): 我们模拟了 my_module 模块里的 APIClient 类。
  • mock_client = MagicMock(): 我们创建了一个 MagicMock 对象,用来模拟 APIClient 的实例。 MagicMockMock 更强大,它可以自动创建属性和方法,方便我们模拟复杂的对象。
  • MockAPIClient.return_value = mock_client: 我们设置了 MockAPIClientreturn_value 属性为 mock_client。 这意味着,当我们调用 APIClient() 时,它会返回 mock_client
  • mock_client.get_resource.return_value = {'name': 'test'}: 我们配置了 mock_clientget_resource() 方法的返回值。
  • MockAPIClient.assert_called_once_with('fake_api_key'): 我们断言APIClient 的构造函数被正确调用,参数为 fake_api_key

场景2:模拟一个数据库连接

假设你的代码需要连接数据库,并执行一些查询。

# database.py
import sqlite3

def get_user_name(user_id):
    conn = sqlite3.connect('users.db')
    cursor = conn.cursor()
    cursor.execute("SELECT name FROM users WHERE id = ?", (user_id,))
    result = cursor.fetchone()
    conn.close()
    if result:
        return result[0]
    else:
        return None

我们可以模拟 sqlite3.connect 来避免真实的数据库连接。

# test_database.py
import unittest
from unittest.mock import patch, MagicMock
import database

class TestGetUserName(unittest.TestCase):

    @patch('database.sqlite3.connect')
    def test_get_user_name_success(self, mock_connect):
        # 配置 mock 的行为
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_connect.return_value = mock_conn
        mock_conn.cursor.return_value = mock_cursor
        mock_cursor.fetchone.return_value = ('John',)

        # 调用被测函数
        result = database.get_user_name(1)

        # 断言
        self.assertEqual(result, 'John')
        mock_connect.assert_called_once_with('users.db')
        mock_cursor.execute.assert_called_once_with("SELECT name FROM users WHERE id = ?", (1,))
        mock_conn.close.assert_called_once()

    @patch('database.sqlite3.connect')
    def test_get_user_name_not_found(self, mock_connect):
        # 配置 mock 的行为,模拟用户不存在
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_connect.return_value = mock_conn
        mock_conn.cursor.return_value = mock_cursor
        mock_cursor.fetchone.return_value = None

        # 调用被测函数
        result = database.get_user_name(1)

        # 断言
        self.assertIsNone(result)

进阶技巧:side_effect 的妙用

side_effectMock 对象的一个非常有用的属性,它可以让你定义一个函数,这个函数会在每次调用 mock 对象时被执行。 这可以让你模拟更复杂的行为,例如:

  • 模拟抛出异常
  • 模拟根据不同的参数返回不同的值
  • 模拟有副作用的操作

例子:模拟根据不同的URL返回不同的数据

import unittest
from unittest.mock import patch
import my_module

def get_data_from_api(url):
    # 这个函数已经被修改, 直接返回 requests.get 的结果
    import requests
    response = requests.get(url)
    return response.json()
class TestGetDataFromApi(unittest.TestCase):

    @patch('my_module.requests.get')
    def test_get_data_from_api_with_side_effect(self, mock_get):
        def side_effect(url):
            if url == 'http://example.com/api/users':
                return MockResponse([{'id': 1, 'name': 'John'}, {'id': 2, 'name': 'Jane'}])
            elif url == 'http://example.com/api/products':
                return MockResponse([{'id': 101, 'name': 'Product A'}, {'id': 102, 'name': 'Product B'}])
            else:
                return MockResponse(None) # 或者抛出异常

        mock_get.side_effect = side_effect

        # 调用被测函数,获取用户数据
        users = my_module.get_data_from_api('http://example.com/api/users')
        self.assertEqual(users, [{'id': 1, 'name': 'John'}, {'id': 2, 'name': 'Jane'}])

        # 调用被测函数,获取产品数据
        products = my_module.get_data_from_api('http://example.com/api/products')
        self.assertEqual(products, [{'id': 101, 'name': 'Product A'}, {'id': 102, 'name': 'Product B'}])

        # 确保 mock_get 被调用了两次
        self.assertEqual(mock_get.call_count, 2)
        mock_get.assert_any_call('http://example.com/api/users')
        mock_get.assert_any_call('http://example.com/api/products')

class MockResponse:
    def __init__(self, json_data, status_code=200):
        self.json_data = json_data
        self.status_code = status_code

    def json(self):
        return self.json_data

    def raise_for_status(self):
        if self.status_code >= 400:
            raise Exception("Request failed")

if __name__ == '__main__':
    unittest.main()

代码解读:

  • mock_get.side_effect = side_effect: 我们把 side_effect 函数赋值给 mock_getside_effect 属性。
  • side_effect(url): side_effect 函数接收一个参数 url,这个参数就是 requests.get() 被调用时传入的 URL。 根据不同的 URL,我们返回不同的 MockResponse 对象。
  • 我们创建了一个 MockResponse 类来模拟 requests.Response 对象, 以简化测试代码.

重要提示:patch 的作用域

patch 的作用域是动态作用域,而不是词法作用域。 这意味着,patch 会替换运行时找到的对象,而不是定义时找到的对象。 这可能会导致一些意想不到的结果。

例子:patch 错误的位置

# module_a.py
import module_b

def foo():
    return module_b.bar()

# module_b.py
def bar():
    return "real bar"

# test_module_a.py
import unittest
from unittest.mock import patch
import module_a
import module_b

class TestModuleA(unittest.TestCase):

    @patch('module_b.bar') # 正确的 patch 位置
    def test_foo(self, mock_bar):
        mock_bar.return_value = "mocked bar"
        result = module_a.foo()
        self.assertEqual(result, "mocked bar")

    @patch('module_a.module_b.bar') # 错误的 patch 位置
    def test_foo_wrong_patch(self, mock_bar):
        mock_bar.return_value = "mocked bar"
        result = module_a.foo()
        self.assertEqual(result, "real bar") # 断言失败

在这个例子里,test_foo 使用 @patch('module_b.bar'),这才是正确的patch位置。因为foo函数内部调用的是module_b.bar

test_foo_wrong_patch 使用 @patch('module_a.module_b.bar'),虽然在 module_a 里看到了 module_b.bar,但这并不是 bar 真正被调用的地方,所以 patch 并没有生效。

总结:patch 的正确用法

  • 永远要 patch 对象被使用的地方,而不是对象被定义的地方。
  • 仔细分析你的代码,找到真正调用外部依赖的地方。
  • 使用完整的路径来指定要 patch 的对象。

最佳实践:提高测试代码的可读性和可维护性

  • 使用有意义的 mock 对象名称。 例如,mock_getmock 更好。
  • 把 mock 对象的配置放在测试函数的开头。 这样可以更清楚地了解 mock 对象的行为。
  • 使用 assert_called_once_with() 来检查 mock 对象是否被调用,并且参数是否正确。
  • 尽量避免使用 side_effect,除非真的需要模拟复杂的行为。 side_effect 会使测试代码更难理解。
  • 保持测试代码的简洁和可读性。 测试代码也是代码,也需要维护。

Q&A环节:

  • Q: MockMagicMock 有什么区别?

    • A: MagicMockMock 的子类,它增加了对 magic methods (例如 __str__, __len__, __iter__ 等) 的支持。 如果你需要模拟的对象需要使用 magic methods,那么应该使用 MagicMock
  • Q: patch 可以 patch 属性吗?

    • A: 可以的。 例如,@patch('my_module.MyClass.my_attribute') 可以替换 MyClassmy_attribute 属性。
  • Q: 如何 patch 一个全局变量?

    • A: 和 patch 函数一样, 只是需要指定全局变量的完整路径。 例如,@patch('my_module.GLOBAL_VARIABLE')

结束语:

unittest.mock 模块的 patch 是一个强大的工具,它可以帮助我们模拟复杂的外部依赖和API调用,让测试变得更加容易。 但是,patch 也有一些陷阱,需要小心使用。 掌握了 patch 的正确用法,你就可以写出更健壮、更可靠的测试代码。

希望今天的讲座对大家有所帮助! 下次再见!

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注