‘Batch Inference’ 优化:利用 RunnableBatch 实现跨模型供应商的并行请求合并
随着人工智能技术,特别是大型语言模型(LLM)的飞速发展,越来越多的企业和开发者开始将LLM集成到他们的应用中。然而,与这些强大模型交互时,效率和成本始终是核心考量。面对高并发、多用户请求的场景,以及需要整合来自不同供应商的模型服务时,如何有效优化推理性能,降低运营成本,并提高系统吞吐量,成为了一个迫切需要解决的问题。
今天,我们将深入探讨一种强大的优化策略:批量推理(Batch Inference),并重点介绍 LangChain 框架中一个专门为此设计的组件——RunnableBatch。我们将从基础概念出发,逐步深入到其工作原理、实际应用场景,特别是如何利用它实现跨模型供应商的并行请求合并,最终提升我们应用的整体性能和可扩展性。
一、批量推理(Batch Inference)的基石:为什么我们需要它?
在分布式系统和微服务架构中,每一次对外部服务的调用都伴随着一定的固定开销:网络握手、协议协商、数据序列化/反序列化、API 鉴权等等。对于LLM调用而言,这些开销是不可忽视的。当应用需要处理大量独立但结构相似的请求时,如果每个请求都单独调用一次LLM API,这些固定开销就会被反复叠加,导致:
- 高延迟 (High Latency):尽管单个请求的推理时间可能很短,但网络和协议开销可能占据主导地位,导致用户体验不佳。
- 低吞吐量 (Low Throughput):服务器资源(如CPU、内存、网络带宽)在处理大量小请求时,频繁的上下文切换和资源争用会降低整体的处理能力。
- 高成本 (High Cost):许多LLM供应商的计费模式可能包含请求次数的考量,即使主要是按 Token 计费,减少请求次数也能摊薄固定成本。
- 资源利用率低下 (Poor Resource Utilization):对于提供模型服务的硬件(如GPU),小批量请求可能无法充分利用其并行计算能力。
批量推理的核心思想是将多个独立的输入数据打包成一个单一的请求,然后发送给模型进行一次性处理。模型处理完整个批次后,再将结果返回。这种方式带来的优势显而易见:
- 分摊固定开销:一次网络传输、一次鉴权、一次模型加载,处理多个请求的数据。
- 提高吞吐量:模型可以更高效地利用底层硬件的并行性,特别是GPU。
- 降低成本:减少API调用次数,可能节省计费费用。
- 更好的资源利用率:模型服务可以一次性处理更多数据,减少空闲时间。
然而,批量推理也并非没有挑战:
- 延迟增加:为了凑齐一个批次,可能需要等待一段时间,这会增加单个请求的端到端延迟。需要在吞吐量和延迟之间找到平衡。
- 异构输入处理:批次中的输入可能长度不同,需要填充(padding)或更复杂的处理。
- 错误处理:批次中某个输入失败时,如何处理整个批次或只处理失败项?
- 供应商API兼容性:并非所有LLM供应商都提供原生批处理API,或者其接口各异。
正是为了应对这些挑战,LangChain 引入了 RunnableBatch。
二、LangChain 与 LCEL:构建可组合的AI应用
在深入 RunnableBatch 之前,我们有必要简要回顾一下 LangChain 及其核心概念 LangChain Expressive Language (LCEL)。
LangChain 是一个用于开发由语言模型驱动的应用程序的框架。它提供了一系列工具、组件和接口,使得开发者能够轻松地构建复杂的AI应用,如问答系统、聊天机器人、数据分析工具等。LCEL 则是 LangChain 中用于构建可组合链(chains)的声明式方式。
LCEL 的核心是 Runnable 接口。任何实现 Runnable 接口的对象都定义了 invoke 方法(用于处理单个输入)和 batch 方法(用于处理多个输入)。这种设计使得 LCEL 链天然支持并行化和批处理。通过 LCEL,我们可以像搭积木一样,将不同的组件(如 LLM、PromptTemplate、OutputParser、自定义函数等)连接起来,形成一个完整的处理流程。
LCEL 链的优势在于:
- 可组合性:所有组件都是
Runnable,可以轻松地进行组合。 - 异步支持:原生支持
async/await,便于构建高性能应用。 - 流式处理:支持数据流式传输,提升用户体验。
- 并行执行:自动识别并优化并行执行的机会。
- 可观测性:易于集成日志、追踪和监控。
RunnableBatch 正是 LCEL 生态系统中的一个重要成员,它利用了 Runnable 接口的 batch 能力,并在此基础上提供了更高级的批处理管理功能。
三、RunnableBatch 的深度解析:工作原理与参数
RunnableBatch 的核心职责是作为代理,将多个针对底层 Runnable 的 invoke 调用聚合成一个 batch 调用。它在内部维护一个队列,收集传入的请求,并在达到特定条件时(例如,队列中的请求数量达到阈值或等待时间超过阈值)触发底层的批处理操作。
3.1 RunnableBatch 的构造函数与核心参数
让我们来看看 RunnableBatch 的主要构造函数签名及其关键参数:
from typing import Any, Callable, List, Optional, Sequence, Union
from langchain_core.runnables import Runnable, RunnableBatch as _RunnableBatch
class RunnableBatch(_RunnableBatch):
def __init__(
self,
bound: Runnable[Sequence[Any], Sequence[Any]], # 强制要求bound runnable支持batch方法
*,
max_batch_size: int = 64,
max_batch_time: float = 0.1, # seconds
wait_until_full: bool = False,
default_response: Optional[Any] = None,
batch_fn: Optional[Callable[[List[Any]], Any]] = None,
):
# ...
-
bound(Runnable[Sequence[Any], Sequence[Any]]):
这是RunnableBatch包装的底层Runnable。关键在于,这个bound的Runnable必须能够处理Sequence[Any]类型的输入并返回Sequence[Any]类型的输出,也就是说它必须支持其自身的batch方法。如果它只支持invoke,那么RunnableBatch就无法将其转换为批处理调用。但是,RunnableBatch提供了batch_fn参数来解决这个问题,我们后面会详细讨论。 -
max_batch_size(int, default=64):
一个批次中可以包含的最大请求数量。当收集到的请求数量达到这个值时,RunnableBatch会立即触发底层的batch调用。 -
max_batch_time(float, default=0.1):
一个批次可以等待的最长时间(秒)。如果在这个时间内没有达到max_batch_size,但时间已到,RunnableBatch也会触发底层的batch调用,即使批次不满。这个参数在平衡延迟和吞吐量之间起着关键作用。 -
wait_until_full(bool, default=False):
如果设置为True,RunnableBatch会一直等待直到批次达到max_batch_size才触发调用,即使max_batch_time已经过期。这通常在对延迟不敏感但对批次效率要求极高的场景中使用。 -
default_response(Optional[Any], default=None):
当底层bound的batch调用中某个子请求发生错误时,RunnableBatch会用这个default_response来填充对应位置的结果,而不是让整个批次失败。这对于构建容错系统非常有用。 -
batch_fn(Optional[Callable[[List[Any]], Any]], default=None):
这是一个非常强大的参数。如果你的boundRunnable没有实现batch方法,或者你希望对批处理的输入/输出进行自定义的预处理/后处理,你可以提供一个batch_fn。这个函数会接收一个List[Any]作为输入(即聚合后的批次),并期望返回一个Any类型的结果(通常也是一个List[Any],对应批次中的每个输入)。RunnableBatch会用这个batch_fn来替代bound.batch()调用。
3.2 RunnableBatch 的内部工作机制
RunnableBatch 在幕后做的工作可以概括为以下几个步骤:
- 请求入队:当一个针对
RunnableBatch实例的invoke或ainvoke调用发生时,它不会立即调用底层boundRunnable的invoke方法。相反,它会将这个请求(及其上下文)放入一个内部的等待队列中。 - 定时器与计数器:
RunnableBatch会启动一个定时器(基于max_batch_time)并维护一个计数器(基于max_batch_size)。 - 批次触发条件:
- 当队列中的请求数量达到
max_batch_size时(且wait_until_full为False)。 - 当
max_batch_time到期时(且wait_until_full为False)。 - 如果
wait_until_full为True,则只在达到max_batch_size时触发。
- 当队列中的请求数量达到
- 执行批处理:一旦触发条件满足,
RunnableBatch会从队列中取出所有等待的请求,将它们的输入聚合成一个列表。- 如果提供了
batch_fn,则调用batch_fn(aggregated_inputs)。 - 否则,调用
bound.batch(aggregated_inputs)。
- 如果提供了
- 结果分发:底层批处理操作完成后,
RunnableBatch会将返回的结果(通常是一个列表)与原始的请求一一对应,然后将每个子请求的结果返回给各自的调用方。如果某个子请求失败,并且default_response已设置,则返回default_response。
这种机制使得 RunnableBatch 能够透明地将多个零散的 invoke 调用转换为高效的 batch 调用,极大地简化了批处理逻辑的实现。
四、实践:单模型供应商的基础批量推理
让我们从最简单的场景开始:对单个LLM供应商的模型进行批量推理。我们将使用 OpenAI 的 ChatOpenAI 作为示例。
首先,确保你已经安装了必要的库并设置了API密钥:
pip install langchain langchain-openai python-dotenv
# .env 文件
# OPENAI_API_KEY="your_openai_api_key"
import os
from dotenv import load_dotenv
load_dotenv() # 加载 .env 文件中的环境变量
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
import time
import asyncio
import random
# 定义一个基础的LLM链
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.7)
prompt = ChatPromptTemplate.from_messages([
("system", "你是一个专业的文案助手。"),
("user", "{text}")
])
output_parser = StrOutputParser()
# 基础链,支持 invoke 和 batch
basic_chain = prompt | llm | output_parser
print("--- 基础链的单次调用 (invoke) ---")
start_time = time.perf_counter()
result_invoke = basic_chain.invoke({"text": "请为我生成一个关于智能家居的短广告语。"})
end_time = time.perf_counter()
print(f"单个 invoke 结果: {result_invoke[:50]}...")
print(f"单个 invoke 耗时: {end_time - start_time:.4f} 秒n")
print("--- 基础链的批量调用 (batch) ---")
texts_for_batch = [
"请为我生成一个关于智能家居的短广告语。",
"请为我生成一个关于环保出行的新闻标题。",
"请为我生成一个关于健康饮食的社交媒体帖子。",
"请为我生成一个关于儿童教育的口号。",
"请为我生成一个关于未来科技趋势的摘要。"
]
inputs_for_batch = [{"text": t} for t in texts_for_batch]
start_time = time.perf_counter()
results_batch = basic_chain.batch(inputs_for_batch)
end_time = time.perf_counter()
print(f"批量 batch 结果 (前20字符): {[r[:20] for r in results_batch]}")
print(f"批量 batch 耗时: {end_time - start_time:.4f} 秒")
print(f"平均每个请求耗时 (batch): {(end_time - start_time) / len(texts_for_batch):.4f} 秒n")
从上面的输出可以看出,batch 调用显著降低了平均每个请求的耗时,体现了批量推理的优势。但是,如果我们的应用逻辑是零散地触发这些请求,而不是一次性收集好一个批次再调用 batch 呢?这就是 RunnableBatch 发挥作用的地方。
4.1 使用 RunnableBatch 包装基础链
现在,我们用 RunnableBatch 来包装 basic_chain。
from langchain_core.runnables import RunnableBatch
# 使用 RunnableBatch 包装我们的基础链
# 设定 max_batch_size 为 3,max_batch_time 为 0.5 秒
batched_chain = RunnableBatch(
bound=basic_chain,
max_batch_size=3,
max_batch_time=0.5,
default_response="抱歉,处理失败。"
)
async def simulate_concurrent_requests(chain_to_test, num_requests=10):
print(f"n--- 模拟 {num_requests} 个并发请求到 {'RunnableBatch' if isinstance(chain_to_test, RunnableBatch) else '原始链'} ---")
start_time_total = time.perf_counter()
async def single_request(i):
text = f"请生成一个关于主题 {i} 的简短描述。"
input_data = {"text": text}
try:
result = await chain_to_test.ainvoke(input_data)
# print(f"请求 {i} 结果: {result[:30]}...")
return f"请求 {i} 成功"
except Exception as e:
# print(f"请求 {i} 失败: {e}")
return f"请求 {i} 失败"
tasks = [single_request(i) for i in range(num_requests)]
results = await asyncio.gather(*tasks)
end_time_total = time.perf_counter()
total_duration = end_time_total - start_time_total
print(f"总耗时: {total_duration:.4f} 秒")
print(f"平均每个请求耗时: {total_duration / num_requests:.4f} 秒")
# print("所有请求结果状态:", results)
# 模拟直接对原始链进行并发调用 (每个都是单独的 invoke)
# 注意:这可能会导致API限速或性能瓶颈
# await simulate_concurrent_requests(basic_chain, num_requests=10)
# 模拟通过 RunnableBatch 进行并发调用
# RunnableBatch 会在内部将这些 invoke 聚合为 batch 调用
asyncio.run(simulate_concurrent_requests(batched_chain, num_requests=10))
# 尝试模拟一个导致 batch_time 超时的场景
batched_chain_small_batch_time = RunnableBatch(
bound=basic_chain,
max_batch_size=10, # 设置一个较大的批次大小,确保不会轻易达到
max_batch_time=0.1, # 设置一个较小的超时时间
default_response="抱歉,处理失败。"
)
asyncio.run(simulate_concurrent_requests(batched_chain_small_batch_time, num_requests=5))
通过 RunnableBatch 包装后,即使我们以 ainvoke 的方式发起多个看似独立的异步请求,RunnableBatch 也会在幕后智能地将它们聚合成批次,然后调用底层 basic_chain 的 batch 方法。这使得我们的应用代码可以保持简单的 invoke 逻辑,而底层的性能优化则由 RunnableBatch 自动完成。
在上面的例子中,当 num_requests=10 时,max_batch_size=3 和 max_batch_time=0.5 意味着 RunnableBatch 会尝试创建 3 个批次(3, 3, 3, 1)。每次批次调用都会分摊固定开销,从而降低平均请求延迟。如果 max_batch_time 很短,即使请求不多,也会尽快发出小批次。
五、跨模型供应商的并行请求合并
现在我们进入更复杂的场景:如何利用 RunnableBatch 实现跨不同模型供应商的并行请求合并。这在需要结合不同模型能力、进行模型A/B测试、或作为多模型路由策略的一部分时非常有用。
假设我们的应用需要从两个不同的模型供应商(例如,OpenAI 和 Anthropic)获取文本生成结果,并且希望这些请求也能被批量处理。
首先,我们需要 Anthropic 的模型。
pip install langchain-anthropic
import os
from dotenv import load_dotenv
load_dotenv() # 确保加载所有API密钥
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic # 导入 Anthropic 模型
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda, RunnableBatch
import time
import asyncio
import random
# --- 模型和链定义 ---
# 1. OpenAI 模型链
llm_openai = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.7)
prompt_openai = ChatPromptTemplate.from_messages([
("system", "你是一个简洁的助手。"),
("user", "{text}")
])
chain_openai = prompt_openai | llm_openai | StrOutputParser()
# 2. Anthropic 模型链
# 确保 ANTHROPIC_API_KEY 已在 .env 中设置
llm_anthropic = ChatAnthropic(model_name="claude-3-haiku-20240307", temperature=0.7)
prompt_anthropic = ChatPromptTemplate.from_messages([
("system", "你是一个富有创意和详细的助手。"),
("user", "{text}")
])
chain_anthropic = prompt_anthropic | llm_anthropic | StrOutputParser()
print("--- 准备跨供应商批量处理 ---")
# --- 使用 RunnableBatch 包装每个供应商的链 ---
# 为 OpenAI 链创建批处理器
batched_openai_chain = RunnableBatch(
bound=chain_openai,
max_batch_size=5,
max_batch_time=0.2,
default_response="OpenAI 响应失败"
)
# 为 Anthropic 链创建批处理器
batched_anthropic_chain = RunnableBatch(
bound=chain_anthropic,
max_batch_size=5,
max_batch_time=0.2,
default_response="Anthropic 响应失败"
)
# --- 组合这些批处理器以实现并行请求合并 ---
# 使用 RunnableParallel 将相同的输入发送到两个批处理器
# 这里的 "并行" 指的是从应用层面看,两个模型供应商的批处理是同时启动的
# 内部每个 batched_xxx_chain 会各自收集请求并进行批处理
combined_batched_chain = RunnableParallel(
openai_result=batched_openai_chain,
anthropic_result=batched_anthropic_chain
)
async def simulate_cross_vendor_requests(num_requests=10):
print(f"n--- 模拟 {num_requests} 个并发请求到跨供应商批处理链 ---")
start_time_total = time.perf_counter()
async def single_request_to_combined(i):
text = f"请为主题 '{chr(65 + i % 26)}' 生成一个简短的创意描述。"
input_data = {"text": text}
try:
# 调用 combined_batched_chain.ainvoke 会同时触发 batched_openai_chain 和 batched_anthropic_chain 的 ainvoke
# 它们各自会把请求放入自己的批处理队列中
result = await combined_batched_chain.ainvoke(input_data)
# print(f"请求 {i} 结果 (OpenAI): {result['openai_result'][:30]}...")
# print(f"请求 {i} 结果 (Anthropic): {result['anthropic_result'][:30]}...")
return f"请求 {i} 成功"
except Exception as e:
# print(f"请求 {i} 失败: {e}")
return f"请求 {i} 失败"
tasks = [single_request_to_combined(i) for i in range(num_requests)]
results = await asyncio.gather(*tasks)
end_time_total = time.perf_counter()
total_duration = end_time_total - start_time_total
print(f"总耗时: {total_duration:.4f} 秒")
print(f"平均每个请求耗时: {total_duration / num_requests:.4f} 秒")
print("所有请求结果状态:", results)
asyncio.run(simulate_cross_vendor_requests(num_requests=10))
# 进一步的例子:如何处理不同的输入路径
print("n--- 模拟路由到特定供应商的批处理链 ---")
# 假设我们有一个路由器,根据输入决定使用哪个模型
def decide_model(input_data: dict) -> str:
if "creative" in input_data["query"].lower():
return "anthropic"
return "openai"
# 定义一个路由器,将请求导向不同的批处理器
router_chain = RunnablePassthrough.assign(
model_choice=RunnableLambda(lambda x: decide_model(x))
) | {
"openai_response": RunnableLambda(lambda x: batched_openai_chain.ainvoke({"text": x["query"]}))
.when(lambda x: x["model_choice"] == "openai"),
"anthropic_response": RunnableLambda(lambda x: batched_anthropic_chain.ainvoke({"text": x["query"]}))
.when(lambda x: x["model_choice"] == "anthropic"),
"original_query": RunnablePassthrough()
}
async def simulate_routed_batched_requests(num_requests=10):
print(f"n--- 模拟 {num_requests} 个路由到供应商的并发请求 ---")
start_time_total = time.perf_counter()
async def single_routed_request(i):
query = f"描述一个普通物体,比如椅子,主题 '{chr(65 + i % 26)}'"
if i % 3 == 0:
query = f"给我一个非常creative的关于科幻主题 '{chr(65 + i % 26)}' 的故事开头。"
input_data = {"query": query}
try:
result = await router_chain.ainvoke(input_data)
# print(f"请求 {i} 路由结果: {result}")
return f"请求 {i} 成功"
except Exception as e:
# print(f"请求 {i} 失败: {e}")
return f"请求 {i} 失败"
tasks = [single_routed_request(i) for i in range(num_requests)]
results = await asyncio.gather(*tasks)
end_time_total = time.perf_counter()
total_duration = end_time_total - start_time_total
print(f"总耗时: {total_duration:.4f} 秒")
print(f"平均每个请求耗时: {total_duration / num_requests:.4f} 秒")
print("所有请求结果状态:", results)
asyncio.run(simulate_routed_batched_requests(num_requests=10))
在这个示例中:
- 我们为 OpenAI 和 Anthropic 各自创建了一个
RunnableBatch实例 (batched_openai_chain和batched_anthropic_chain)。每个实例都独立地管理其各自供应商的批处理队列。 - 通过
RunnableParallel,我们创建了一个combined_batched_chain。当对combined_batched_chain调用ainvoke时,它会同时向batched_openai_chain和batched_anthropic_chain发送请求。 - 关键点:尽管
combined_batched_chain看似并行地调用了两个批处理器,但这两个批处理器本身是独立的。它们各自会收集请求,并在达到自己的max_batch_size或max_batch_time时,向其对应的 LLM 供应商发起一次批处理请求。 - 后续的路由示例展示了如何根据输入动态选择使用哪个批处理器。这使得我们可以在一个统一的接口下,根据业务逻辑将请求智能地分发到不同的模型供应商,并同时享受到批处理带来的性能优势。
这种模式极大地简化了跨供应商模型集成的复杂性。开发者无需手动管理批处理队列、定时器和结果映射,RunnableBatch 会在 LCEL 链中自动处理这些细节。
六、高级用法:自定义批处理函数与错误处理
RunnableBatch 的 batch_fn 参数和 default_response 参数提供了强大的定制和容错能力。
6.1 batch_fn:自定义批处理逻辑
有时,你包装的 Runnable 可能没有原生的 batch 方法,或者你希望在发送批次请求之前/之后进行一些特殊的处理(例如,对所有输入进行统一的格式转换,或者在返回结果时进行聚合)。这时,batch_fn 就派上用场了。
import time
import asyncio
from typing import List, Any
# 模拟一个没有原生 batch 方法的慢速 Runnable
class SlowTextProcessor(Runnable):
def __init__(self, delay_per_item: float = 0.1):
self.delay_per_item = delay_per_item
def invoke(self, input: str, config=None) -> str:
time.sleep(self.delay_per_item) # 模拟处理时间
return f"Processed: {input.upper()}"
# 注意:这里没有实现 batch 方法
# 定义一个自定义的批处理函数
def custom_batch_processor(inputs: List[str]) -> List[str]:
print(f"n[Custom Batch Processor] 正在处理批次,大小: {len(inputs)}")
results = []
# 模拟批处理的并行或优化处理
# 实际应用中,这里可能是调用一个不支持batch但我们希望批量发送的API
# 或者对所有输入进行一些预处理再批量发送到某个服务
for i, item in enumerate(inputs):
time.sleep(0.05) # 模拟一些处理时间,但比单个 invoke 快
results.append(f"Custom Batched Processed: {item.lower()} (Index: {i})")
print("[Custom Batch Processor] 批次处理完成。")
return results
# 使用 RunnableBatch 包装 SlowTextProcessor,并提供 custom_batch_processor 作为 batch_fn
# max_batch_size 设置为 3,max_batch_time 设置为 0.5 秒
custom_batched_processor = RunnableBatch(
bound=SlowTextProcessor(delay_per_item=0.3), # 单个处理很慢
max_batch_size=3,
max_batch_time=0.5,
batch_fn=custom_batch_processor,
default_response="Custom Batch Processor 失败"
)
async def simulate_custom_batch_requests(num_requests=10):
print(f"n--- 模拟 {num_requests} 个并发请求到带自定义 batch_fn 的链 ---")
start_time_total = time.perf_counter()
async def single_request(i):
text = f"item {i}"
try:
result = await custom_batched_processor.ainvoke(text)
# print(f"请求 {i} 结果: {result}")
return f"请求 {i} 成功"
except Exception as e:
# print(f"请求 {i} 失败: {e}")
return f"请求 {i} 失败"
tasks = [single_request(i) for i in range(num_requests)]
results = await asyncio.gather(*tasks)
end_time_total = time.perf_counter()
total_duration = end_time_total - start_time_total
print(f"总耗时: {total_duration:.4f} 秒")
print(f"平均每个请求耗时: {total_duration / num_requests:.4f} 秒")
print("所有请求结果状态:", results)
asyncio.run(simulate_custom_batch_requests(num_requests=10))
在这个例子中,SlowTextProcessor 没有实现 batch 方法。但我们通过 batch_fn=custom_batch_processor 为 RunnableBatch 提供了一个自定义的批处理逻辑。RunnableBatch 会收集请求,然后将它们作为一个列表传递给 custom_batch_processor。custom_batch_processor 可以对这个列表进行任何处理,并返回一个结果列表,RunnableBatch 会将这些结果映射回原始的 ainvoke 调用。这使得 RunnableBatch 的适用范围大大扩展。
6.2 default_response:优雅处理批次内错误
在批处理中,如果批次中的某一个或几个输入导致底层模型失败,我们通常不希望整个批次都失败。default_response 参数就是为此设计的。
import time
import asyncio
from typing import List, Any
from langchain_core.runnables import RunnableBatch, Runnable
# 模拟一个会随机失败的 Runnable
class FailingProcessor(Runnable):
def invoke(self, input: str, config=None) -> str:
if "fail" in input.lower():
raise ValueError(f"故意失败: {input}")
time.sleep(0.1)
return f"Processed: {input}"
def batch(self, inputs: List[str], config=None) -> List[str]:
results = []
for input_item in inputs:
try:
# 模拟批处理中某个项目失败
if "batch_fail" in input_item.lower():
raise ValueError(f"批次内故意失败: {input_item}")
time.sleep(0.05) # 模拟批处理中的单个项目处理时间
results.append(f"Batched Processed: {input_item}")
except Exception as e:
# 在实际的 batch 实现中,你可能需要将错误捕获并返回一个特定的标记
# 或者让 RunnableBatch 的 default_response 处理
# 这里我们让它抛出,看 RunnableBatch 如何处理
results.append(e) # 返回错误对象,让 RunnableBatch 替换
return results
# 使用 RunnableBatch 包装 FailingProcessor
# 设置 default_response
batched_failing_processor = RunnableBatch(
bound=FailingProcessor(),
max_batch_size=5,
max_batch_time=0.5,
default_response="--- 错误已处理 ---" # 当子请求失败时返回此值
)
async def simulate_failing_batch_requests(num_requests=10):
print(f"n--- 模拟 {num_requests} 个并发请求到带 default_response 的链 ---")
start_time_total = time.perf_counter()
async def single_request(i):
text = f"item {i}"
if i % 3 == 0: # 模拟部分请求会失败
text = f"batch_fail item {i}"
input_data = text
try:
result = await batched_failing_processor.ainvoke(input_data)
print(f"请求 {i} 结果: {result}")
return f"请求 {i} 成功" if result != "--- 错误已处理 ---" else f"请求 {i} 失败 (被 default_response 捕获)"
except Exception as e:
print(f"请求 {i} 真的失败了: {e}") # 只有当 default_response 没有捕获到时才会到这里
return f"请求 {i} 真的失败了"
tasks = [single_request(i) for i in range(num_requests)]
results = await asyncio.gather(*tasks)
end_time_total = time.perf_counter()
total_duration = end_time_total - start_time_total
print(f"总耗时: {total_duration:.4f} 秒")
print(f"所有请求结果状态:", results)
asyncio.run(simulate_failing_batch_requests(num_requests=10))
在这个例子中,FailingProcessor 的 batch 方法在处理包含 "batch_fail" 的输入时会抛出异常。RunnableBatch 捕获到这些异常,并用我们定义的 default_response ("— 错误已处理 —") 替换了这些失败的结果,从而保证了批处理的整体流程不会中断,同时调用方也能收到一个明确的错误指示。这对于构建高可用的系统至关重要,因为它可以防止单个问题导致整个批次甚至整个应用程序崩溃。
七、性能考量与权衡
使用 RunnableBatch 进行批量推理确实能带来显著的性能提升,但这种提升并非没有代价,我们需要仔细权衡几个关键参数。
7.1 关键参数的权衡
| 参数 | 描述 | 影响 | 建议与考量 |
|---|---|---|---|
max_batch_size |
一个批次中最大请求数。 | 吞吐量: 越大,潜在吞吐量越高,API固定开销分摊越充分。 延迟: 越大,单个请求等待批次填满的时间可能越长,导致延迟增加。 内存/计算: 越大,单次模型推理的资源消耗越大。 |
根据模型提供商的推荐批次大小、模型本身的计算效率以及系统可用内存来设置。如果请求量高且稳定,可以设置较大值。如果请求量波动大,可能需要设置较小值或配合 max_batch_time。 |
max_batch_time |
一个批次等待的最大时间(秒)。 | 延迟: 越小,单个请求的等待时间越短,延迟越低。 吞吐量: 越小,可能导致批次不满就发出,降低批次效率,减少吞吐量。 API调用频率: 越小,API调用频率可能越高。 |
平衡延迟和吞吐量的关键。对于实时性要求高的应用,设置较小值。对于后台任务,可以设置较大值以最大化批次效率。通常建议从 0.1-0.5 秒开始尝试。 |
wait_until_full |
是否等待批次完全填满才触发。 | 吞吐量: True 时最大化批次效率,可能带来最高吞吐量。 延迟: True 时可能导致无限等待或极高延迟,如果请求量不足以填满批次。 |
仅在确定请求流足够稳定且能快速填满批次,并且对延迟不敏感的场景下使用 True。绝大多数实时应用应设置为 False。 |
default_response |
批次中单个请求失败时的默认响应。 | 容错性: 增强系统容错能力,防止单个失败影响整个批次。 调试: 可能掩盖底层错误,需要配合日志和监控。 |
强烈建议设置,以提高系统的健壮性。返回的 default_response 应该足够清晰,表示该请求失败,以便上层应用进行处理。 |
batch_fn |
自定义批处理函数。 | 灵活性: 允许处理不提供原生 batch 方法的 Runnable 或进行自定义预/后处理。 复杂性: 引入额外的逻辑,需要开发者自行管理批处理的输入输出,并确保其高效性。 |
当底层 Runnable 不支持 batch 或需要特殊处理时使用。确保 batch_fn 本身是高效的,否则会抵消 RunnableBatch 带来的优化。 |
7.2 性能基准测试与监控
要真正理解 RunnableBatch 在您的特定应用场景下的效果,进行实际的基准测试至关重要。
-
定义明确的指标:
- 吞吐量 (Throughput):每秒处理的请求数 (RPS)。
- 平均延迟 (Average Latency):每个请求从发出到收到结果的平均时间。
- P90/P99 延迟 (Percentile Latency):90% 或 99% 的请求所花费的时间。这对于衡量用户体验的稳定性非常重要。
- 成本 (Cost):在不同批处理策略下的API调用成本。
-
模拟真实负载:使用工具(如 Locust、JMeter、k6 或简单的并发脚本)模拟您的应用可能面临的并发请求模式。
-
迭代调整参数:
- 从小批次大小和短超时时间开始。
- 逐渐增加
max_batch_size,观察吞吐量和延迟的变化。 - 调整
max_batch_time,观察它如何平衡批次大小和请求延迟。 - 在每次调整后,重新运行基准测试并记录结果。
-
监控:在生产环境中,集成监控系统(如 Prometheus、Grafana)来跟踪
RunnableBatch实例的实际批次大小、批处理延迟、成功率和错误率。这有助于识别瓶颈并动态调整配置。
7.3 RunnableBatch 的局限性
尽管 RunnableBatch 非常强大,但它并非万能药。它主要适用于以下场景:
- 输入相互独立:批次中的每个输入请求的计算不依赖于批次中其他请求的结果。
- 输出顺序与输入顺序一致:底层
batch方法或batch_fn必须保证返回结果的顺序与输入顺序一致,RunnableBatch才能正确地将结果映射回原始请求。 - 同质或可统一处理的输入:批次中的输入虽然可能内容不同,但结构和处理方式应足够相似,以便于一次性处理。
对于需要复杂依赖关系、动态批次重组或更高级调度策略的场景,可能需要结合队列系统(如 Kafka、RabbitMQ)和更复杂的自定义批处理服务来实现。
八、最佳实践与设计模式
为了最大限度地发挥 RunnableBatch 的优势并构建健壮的系统,请遵循以下最佳实践:
-
按供应商和模型粒度创建
RunnableBatch实例:
为每个不同的 LLM 供应商或模型(如果您有多个模型实例)创建独立的RunnableBatch实例。这有助于隔离配置、错误处理和性能指标。# 错误示例:共享一个 RunnableBatch 实例处理不同的模型 # BAD: all_models_batcher = RunnableBatch(some_generic_chain) # GOOD: openai_batcher = RunnableBatch(openai_specific_chain, max_batch_size=50, max_batch_time=0.1) anthropic_batcher = RunnableBatch(anthropic_specific_chain, max_batch_size=30, max_batch_time=0.2) -
合理设置
max_batch_size和max_batch_time:
根据您的应用场景(实时性要求、预期负载、模型提供商的速率限制)进行调优。对于低延迟应用,max_batch_time应该较小;对于高吞吐量应用,max_batch_size可以较大。 -
利用
default_response进行容错:
始终为RunnableBatch配置default_response,以优雅地处理批次中个别请求的失败,避免整个批次或链条中断。 -
结合 LCEL 的路由能力:
当需要根据输入动态选择模型时,将RunnableBatch与RunnableLambda和when方法结合,构建智能路由器。from langchain_core.runnables import RunnableLambda, RunnableBranch router = RunnableBranch( (lambda x: x["type"] == "creative", creative_model_batcher), (lambda x: x["type"] == "factual", factual_model_batcher), default_model_batcher # 默认处理器 ) -
异步优先:
在与RunnableBatch交互时,尽可能使用ainvoke和abatch方法。RunnableBatch本身是为异步操作设计的,这将确保您的应用程序能够充分利用并发性。 -
监控和日志:
对RunnableBatch实例的运行情况进行监控。记录批次大小、批处理时间、错误率等关键指标。这有助于您理解其性能特点并在生产环境中进行问题排查。 -
预热(Warm-up):
在生产部署后,通过发送少量请求预热RunnableBatch实例。这可以确保内部队列和定时器正常启动,并避免冷启动带来的初始延迟高峰。
九、未来的展望
RunnableBatch 作为 LangChain LCEL 的一部分,为LLM应用的性能优化提供了开箱即用的解决方案。展望未来,我们可以期待:
- 更智能的自适应批处理:根据实时负载和性能指标动态调整
max_batch_size和max_batch_time,进一步优化资源利用。 - 与更广泛的生态系统集成:例如,与消息队列系统(Kafka、RabbitMQ)的更深层集成,以处理跨服务边界的异步批处理流。
- 批处理的透明度与可观测性增强:提供更丰富的钩子和指标,让开发者能够更细致地了解批处理内部的运作。
- 针对特定模型类型的优化:例如,为长上下文或多模态输入提供更专业的批处理策略。
RunnableBatch 极大地简化了在 LangChain 应用中实现高效批量推理的复杂性,它提供了一个统一、声明式的方式来聚合对底层 Runnable 的调用,无论是针对单个模型供应商还是跨多个。通过合理配置和应用,开发者可以显著提升其LLM应用的吞吐量,降低延迟,并优化运营成本,从而构建出更具弹性、高性能和成本效益的AI驱动解决方案。它是 LangChain LCEL 强大可组合性理念的一个绝佳体现,也是构建未来AI应用不可或缺的工具之一。