Streaming Middleware:FastAPI 中 LangChain 流式输出的 WebSocket 封装
引言:流式输出与现代 Web 应用的需求
在现代 Web 应用,尤其是涉及人工智能和大型语言模型(LLM)的场景中,用户体验已成为设计的核心。传统的“请求-响应”模式在处理耗时操作时,会导致用户界面长时间卡顿,或者在等待整个响应完成之前无法获取任何信息,这极大地损害了用户体验。想象一下,向一个 LLM 提问,然后等待几十秒甚至几分钟才能看到完整的答案,这无疑是令人沮丧的。
为了解决这一问题,流式输出(Streaming Output)应运而生。流式输出允许服务器在生成响应的同时,逐步将数据发送给客户端。这意味着客户端可以在接收到第一个可用数据块时立即开始处理和显示,从而实现实时反馈和更流畅的用户体验。对于 LLM 应用而言,这意味着用户可以“看着”模型逐字逐句地生成答案,就像与人类对话一样。
在实现流式输出时,我们通常会遇到几种技术:
- Server-Sent Events (SSE):一种基于 HTTP 的单向流协议,服务器可以持续向客户端发送事件。它简单易用,但只能单向通信,且在某些浏览器或代理中可能存在限制。
- HTTP Long Polling:客户端发起请求后,服务器会保持连接打开,直到有新数据可用或超时,然后发送响应并关闭连接。客户端收到响应后立即发起新的请求。效率相对较低,且实现复杂。
- WebSockets:一种全双工、持久化的通信协议,允许服务器和客户端之间进行双向实时通信。它提供了低延迟、高效率的连接,非常适合需要频繁、双向数据交换的场景,如聊天应用、实时协作工具以及本文将重点探讨的 LLM 流式输出。
本文将深入探讨如何在 FastAPI 框架中,利用 WebSocket 协议,构建一个兼容 LangChain 流式输出的“流式中间件”层。这里的“流式中间件”并非指 ASGI 规范中的传统中间件(如 CORSMiddleware),而是指一个位于 LangChain 核心逻辑与客户端 WebSocket 接口之间的适配层,负责捕获 LangChain 的流式事件并将其格式化后通过 WebSocket 发送出去。
理解 LangChain 的流式机制
LangChain 是一个强大的框架,用于开发由语言模型驱动的应用程序。它支持多种组件的链式(Chain)和代理(Agent)模式,并且内置了对流式输出的良好支持。
stream() 方法与 BaseCallbackHandler
LangChain 的许多组件(如 LLM、Runnable 等)都提供了 stream() 方法。这个方法不会一次性返回所有结果,而是会异步地逐个生成“块”(chunks)。这些块可以是文本片段、工具调用信息、或整个链的中间步骤。
为了捕获这些流式块并进行处理,LangChain 引入了回调(Callbacks)机制。所有回调都继承自 BaseCallbackHandler 类。通过实现 BaseCallbackHandler 中的特定方法,我们可以在 LangChain 运行的不同阶段插入自定义逻辑,包括在流式输出过程中接收到新块时。
常用的回调方法包括:
on_llm_start/on_llm_end:LLM 调用开始/结束时。on_llm_new_token:LLM 生成新 token 时(仅限 LLM 自身流式输出)。on_chain_start/on_chain_end:链执行开始/结束时。on_agent_action/on_agent_finish:Agent 执行动作/完成时。on_retriever_start/on_retriever_end:检索器调用开始/结束时。on_chat_model_start/on_chat_model_end:聊天模型调用开始/结束时。on_tool_start/on_tool_end:工具调用开始/结束时。on_tool_error:工具调用出错时。on_llm_error:LLM 调用出错时。on_chain_error:链执行出错时。
对于流式输出,on_llm_new_token 和 LangChain 表达式语言 (LCEL) 中的 stream() 方法直接返回的 AIMessageChunk 或 ToolCallChunk 等对象是关键。当我们使用 chain.stream() 时,它会产生一系列的 chunk 对象,这些对象可能是 AIMessageChunk(包含文本内容)、ToolCallChunk(包含工具调用信息)等。我们的任务就是捕获这些 chunk 并将其发送给客户端。
FastAPI 的 WebSocket 编程基础
FastAPI 是一个现代、快速(高性能)的 Web 框架,用于使用 Python 3.7+ 构建 API。它基于 Starlette 和 Pydantic,提供了对异步编程的强大支持,包括 WebSockets。
核心概念
WebSocket对象:在 FastAPI 的 WebSocket 路由处理函数中,你会接收到一个WebSocket对象。这个对象代表了与客户端建立的 WebSocket 连接。websocket.accept():在开始通信之前,服务器必须通过调用await websocket.accept()来接受客户端的 WebSocket 连接请求。websocket.receive_text()/websocket.receive_json():用于接收客户端发送的文本或 JSON 数据。websocket.send_text()/websocket.send_json():用于向客户端发送文本或 JSON 数据。websocket.close():用于关闭 WebSocket 连接。通常在处理函数结束时或发生错误时调用。WebSocketDisconnect异常:当客户端断开连接时,FastAPI 会抛出WebSocketDisconnect异常。我们需要捕获这个异常以进行清理。
基本 WebSocket 路由
一个最简单的 FastAPI WebSocket 路由如下所示:
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
app = FastAPI()
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Message text was: {data}")
except WebSocketDisconnect:
print("Client disconnected")
except Exception as e:
print(f"An error occurred: {e}")
在这个例子中,一旦客户端连接到 /ws,服务器会接受连接,然后进入一个无限循环,不断接收客户端的文本消息并原样发回。当客户端断开连接时,WebSocketDisconnect 异常会被捕获,循环终止。
构建 LangChain 流式输出的 WebSocket 封装
现在,我们将把 LangChain 的流式机制与 FastAPI 的 WebSocket 功能结合起来。核心思想是创建一个自定义的 LangChain 回调处理器,它能够将接收到的 LangChain 块通过 WebSocket 发送给客户端。
1. 定义 WebSocket 输出数据格式
在 LangChain 的流式输出中,我们可能会收到不同类型的块:文本、工具调用、错误信息等。为了让客户端能够清晰地理解和处理这些信息,我们需要定义一个统一的、结构化的数据格式。JSON 是一个理想的选择。
我们可以定义一个通用的 WebSocket 消息结构,包含 type 字段来指示消息类型,以及 data 字段来承载实际内容。
| 字段 | 类型 | 描述 | 示例值 |
|---|---|---|---|
type |
str |
消息类型(e.g., llm_new_token, tool_call, error, end) |
llm_new_token |
data |
dict |
实际的消息内容,格式取决于 type |
{"content": "Hello"} |
status |
str |
请求状态(e.g., success, error) |
success |
detail |
str |
错误或额外信息 | An unexpected error occurred. |
id |
str |
(可选)用于客户端关联响应到特定请求的 ID | req_12345 |
event |
str |
(可选)LangChain内部事件名称,如on_llm_new_token |
on_llm_new_token |
例如,一个 LLM 新 token 消息可能看起来像这样:
{
"type": "llm_new_token",
"data": {
"content": "世界"
},
"status": "success",
"event": "on_llm_new_token"
}
一个工具调用块可能:
{
"type": "tool_call",
"data": {
"name": "search_tool",
"args": {
"query": "FastAPI latest version"
},
"id": "tool_call_..."
},
"status": "success",
"event": "on_tool_start"
}
一个错误消息可能:
{
"type": "error",
"data": {},
"status": "error",
"detail": "Failed to connect to LLM provider.",
"event": "on_llm_error"
}
2. 创建自定义 WebSocketCallbackHandler
这个回调处理器将继承自 BaseCallbackHandler,并且会持有 WebSocket 实例,以便在接收到 LangChain 事件时发送消息。
from typing import Any, Dict, List, Optional, Union
from uuid import UUID
from fastapi import WebSocket
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import (
AIMessage,
BaseMessage,
FunctionMessage,
HumanMessage,
ToolMessage,
)
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
from langchain_core.runnables import RunnableConfig
class WebSocketCallbackHandler(BaseCallbackHandler):
"""
一个自定义的 LangChain 回调处理器,用于通过 WebSocket 发送流式数据。
"""
def __init__(self, websocket: WebSocket, conversation_id: str = None):
self.websocket = websocket
self.conversation_id = conversation_id
async def _send_websocket_message(self, message_type: str, data: Dict[str, Any], status: str = "success", detail: str = None, event: str = None):
"""
封装发送 WebSocket 消息的逻辑。
"""
payload = {
"type": message_type,
"data": data,
"status": status,
"event": event,
}
if self.conversation_id:
payload["conversation_id"] = self.conversation_id
if detail:
payload["detail"] = detail
try:
await self.websocket.send_json(payload)
except Exception as e:
print(f"Error sending message over WebSocket: {e}")
# 可以在这里处理 WebSocket 连接中断的情况,例如记录日志或尝试关闭连接
async def on_llm_new_token(self, token: str, *, chunk: Optional[GenerationChunk] = None, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> None:
"""
在 LLM 生成新 token 时调用。
"""
# print(f"on_llm_new_token: {token}") # 调试用
await self._send_websocket_message(
"llm_new_token",
{"content": token},
event="on_llm_new_token"
)
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
"""
在 LLM 调用开始时调用。
"""
await self._send_websocket_message(
"llm_start",
{"prompts": prompts, "serialized": serialized},
event="on_llm_start"
)
async def on_llm_end(self, response: LLMResult, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> None:
"""
在 LLM 调用结束时调用。
"""
# 我们可以选择发送最终的 LLM 响应,但对于流式输出,通常主要关注 token。
# 这里的 response 可能包含完整的输出,我们可以提取出来。
final_output = response.generations[0][0].text if response.generations else ""
await self._send_websocket_message(
"llm_end",
{"final_output": final_output},
event="on_llm_end"
)
async def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> None:
"""
在 LLM 调用出错时调用。
"""
await self._send_websocket_message(
"error",
{},
status="error",
detail=f"LLM Error: {str(error)}",
event="on_llm_error"
)
# 针对 LangChain Expression Language (LCEL) 的流式输出
# LCEL 的 stream() 方法直接返回 AIMessageChunk, ToolCallChunk 等
# 这些不会触发 on_llm_new_token,而是在 Runnable.stream() 中直接迭代。
# 因此,我们需要在 WebSocket 路由中处理这些 chunk。
# 如果需要更细粒度的控制,可以实现 on_chain_* 或 on_tool_* 等方法
# 例如,捕捉工具调用的开始和结束
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
"""
在工具调用开始时调用。
"""
await self._send_websocket_message(
"tool_start",
{"name": serialized.get("name"), "input": input_str},
event="on_tool_start"
)
async def on_tool_end(self, output: Any, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> None:
"""
在工具调用结束时调用。
"""
await self._send_websocket_message(
"tool_end",
{"output": str(output)}, # 确保 output 是可序列化的
event="on_tool_end"
)
async def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> None:
"""
在工具调用出错时调用。
"""
await self._send_websocket_message(
"error",
{},
status="error",
detail=f"Tool Error: {str(error)}",
event="on_tool_error"
)
async def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
"""
在链执行开始时调用。
"""
await self._send_websocket_message(
"chain_start",
{"name": serialized.get("lc_kwargs", {}).get("name") or serialized.get("name"), "inputs": inputs},
event="on_chain_start"
)
async def on_chain_end(self, outputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> None:
"""
在链执行结束时调用。
"""
await self._send_websocket_message(
"chain_end",
{"outputs": outputs},
event="on_chain_end"
)
async def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> None:
"""
在链执行出错时调用。
"""
await self._send_websocket_message(
"error",
{},
status="error",
detail=f"Chain Error: {str(error)}",
event="on_chain_error"
)
# 针对 LangChain Expression Language (LCEL) 的 stream() 方法直接返回的 chunk
# 这里的 on_llm_new_token 适用于传统的 LLM.stream()
# 对于 LCEL 的 Runnable.stream(),它会返回 AIMessageChunk, ToolCallChunk 等。
# 这些需要我们直接在 WebSocket 路由中迭代处理。
# 因此,我们还需要一个方法来发送这些直接的 chunk。
async def send_chunk(self, chunk: Any, event_name: str = "chunk"):
"""
发送 LangChain stream() 方法直接返回的 chunk。
"""
# LangChain_core 的 chunk 对象是 Pydantic 模型,可以直接转换为字典
chunk_dict = {}
if hasattr(chunk, 'dict'): # Pydantic v1
chunk_dict = chunk.dict()
elif hasattr(chunk, 'model_dump'): # Pydantic v2
chunk_dict = chunk.model_dump()
else:
# 兼容非 Pydantic chunk,尽量转换为字符串
chunk_dict = {"content": str(chunk)}
await self._send_websocket_message(
"lc_chunk",
{"chunk_type": chunk.__class__.__name__, "content": chunk_dict},
event=event_name
)
重要说明:LangChain 的 stream() 方法,特别是对于 LCEL 构建的 Runnable 对象,它会直接生成 AIMessageChunk、ToolCallChunk 等对象。这些对象并不会触发 on_llm_new_token 回调(on_llm_new_token 主要针对底层 LLM 的 token 生成)。因此,我们的 WebSocketCallbackHandler 主要是为了捕获更高层次的事件(如 on_tool_start, on_chain_end),而对于 chain.stream() 直接产生的 chunk,我们需要在 FastAPI 路由中显式地迭代并使用 send_chunk 方法发送。
3. FastAPI WebSocket 路由集成
现在,我们将在 FastAPI 路由中结合 LangChain 链的执行和自定义回调处理器。
import os
import asyncio
from uuid import uuid4
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, status
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI # 假设使用 OpenAI
# 导入上面定义的 WebSocketCallbackHandler
# from .callbacks import WebSocketCallbackHandler
# 设置 OpenAI API 密钥
# 可以通过环境变量设置,或直接在这里赋值
# os.environ["OPENAI_API_KEY"] = "YOUR_OPENAI_API_KEY"
app = FastAPI()
# 示例 LangChain 组件
# 定义一个简单的 LangChain 链
prompt = ChatPromptTemplate.from_messages([
("system", "你是一个乐于助人的AI助手。"),
("user", "{question}")
])
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.7, streaming=True)
output_parser = StrOutputParser()
# 创建一个 LCEL chain
simple_chain = {"question": RunnablePassthrough()} | prompt | llm | output_parser
# 定义 WebSocket 消息模型 (可选,但推荐用于清晰性)
class WebSocketRequest(BaseModel):
conversation_id: str = Field(default_factory=lambda: str(uuid4()))
question: str
@app.websocket("/ws/chat")
async def websocket_chat_endpoint(websocket: WebSocket):
await websocket.accept()
conversation_id = None
try:
# 接收客户端的初始化消息,通常包含问题和会话ID
initial_message = await websocket.receive_json()
request = WebSocketRequest(**initial_message)
conversation_id = request.conversation_id
user_question = request.question
print(f"[{conversation_id}] Received question: {user_question}")
# 创建自定义回调处理器
handler = WebSocketCallbackHandler(websocket, conversation_id=conversation_id)
# 发送开始消息
await handler._send_websocket_message(
"start",
{"message": "开始处理请求...", "question": user_question},
event="request_start"
)
# 定义 LangChain 链的配置,将自定义回调处理器传递进去
config = {"callbacks": [handler]}
# 执行 LangChain 链,并流式迭代结果
# 注意:对于 LCEL 的 .stream() 方法,它直接返回 AIMessageChunk/ToolCallChunk 等
# on_llm_new_token 仅在底层 LLM 流式生成时触发(如 ChatOpenAI 的 streaming=True)。
# 对于 LCEL 链,我们会接收到 AIMessageChunk,其中包含了 token。
# 我们需要在循环中解析这些 chunk。
full_response_content = ""
async for chunk in simple_chain.stream(user_question, config=config):
# LangChain LCEL 的 stream() 直接返回的是 output_parser 后的结果
# 对于 StrOutputParser,chunk 就是字符串
# 如果 output_parser 没有,chunk 可能是 AIMessageChunk
# 检查 chunk 的类型并发送
if isinstance(chunk, str):
# 如果是 StrOutputParser,直接发送文本
await handler._send_websocket_message(
"llm_new_token",
{"content": chunk},
event="on_llm_new_token_from_parser"
)
full_response_content += chunk
elif hasattr(chunk, 'content') and isinstance(chunk.content, str): # AIMessageChunk
await handler._send_websocket_message(
"llm_new_token",
{"content": chunk.content},
event="on_llm_new_token_from_chunk"
)
full_response_content += chunk.content
else:
# 处理其他类型的 chunk,例如 ToolCallChunk
await handler.send_chunk(chunk) # 使用我们自定义的 send_chunk 方法
# 发送结束消息
await handler._send_websocket_message(
"end",
{"message": "请求处理完成", "full_response": full_response_content},
event="request_end"
)
except WebSocketDisconnect:
print(f"[{conversation_id}] Client disconnected.")
except Exception as e:
print(f"[{conversation_id}] An error occurred: {e}")
# 发送错误消息给客户端
if websocket.client_state == status.WS_CONNECTED:
await handler._send_websocket_message(
"error",
{},
status="error",
detail=f"Server error: {str(e)}",
event="server_error"
)
# 确保关闭连接
await websocket.close(code=status.WS_1011_INTERNAL_ERROR) # 内部错误
代码解释:
@app.websocket("/ws/chat"):定义了一个 WebSocket 路由。await websocket.accept():接受客户端的连接。initial_message = await websocket.receive_json():我们假定客户端会先发送一个包含question和conversation_id的 JSON 消息来启动对话。这有助于在服务器端管理会话状态。WebSocketCallbackHandler实例化:传入websocket对象和conversation_id。config = {"callbacks": [handler]}:这是将自定义回调处理器注入 LangChain 链的关键。LangChain 在执行stream()方法时,会自动调用注册的回调处理器。async for chunk in simple_chain.stream(user_question, config=config)::异步迭代 LangChain 链的流式输出。这里的chunk是 LangChain 生成的每个小片段。if isinstance(chunk, str): ... elif hasattr(chunk, 'content'): ... else: handler.send_chunk(chunk):这里是核心逻辑。因为 LangChain LCEL 的stream()方法可能返回不同类型的对象(取决于链的结构和output_parser),我们需要检查chunk的类型。- 如果链的末端是
StrOutputParser,chunk将直接是字符串。 - 如果链的末端是
ChatOpenAI等 LLM,且没有output_parser,则chunk可能是AIMessageChunk,它有一个content属性。 - 对于其他更复杂的 LangChain chunk 类型(如
ToolCallChunk),我们使用handler.send_chunk(chunk)来发送其 Pydantic 字典表示。
- 如果链的末端是
- 错误处理:
try...except WebSocketDisconnect用于处理客户端断开连接,而except Exception用于捕获服务器端处理过程中的其他错误,并尝试通过 WebSocket 将错误信息发送给客户端。
4. 客户端示例(JavaScript)
虽然本文主要关注 FastAPI 服务端,但为了完整性,这里提供一个简单的 JavaScript 客户端示例,展示如何连接到 WebSocket 并处理流式消息。
<!DOCTYPE html>
<html>
<head>
<title>FastAPI LangChain Streaming Chat</title>
<style>
body { font-family: sans-serif; margin: 20px; }
#chat-window { border: 1px solid #ccc; padding: 10px; height: 300px; overflow-y: scroll; margin-bottom: 10px; }
.message { margin-bottom: 5px; }
.user-message { color: blue; }
.ai-message { color: green; }
.system-message { color: gray; font-size: 0.8em; }
.error-message { color: red; font-weight: bold; }
</style>
</head>
<body>
<h1>FastAPI LangChain Streaming Chat</h1>
<div id="chat-window"></div>
<input type="text" id="question-input" placeholder="输入你的问题..." style="width: 80%;">
<button id="send-button">发送</button>
<button id="connect-button">连接 WebSocket</button>
<button id="disconnect-button" disabled>断开 WebSocket</button>
<script>
let ws;
let conversationId = null;
const chatWindow = document.getElementById('chat-window');
const questionInput = document.getElementById('question-input');
const sendButton = document.getElementById('send-button');
const connectButton = document.getElementById('connect-button');
const disconnectButton = document.getElementById('disconnect-button');
function appendMessage(sender, message, type = 'text') {
const msgDiv = document.createElement('div');
msgDiv.classList.add('message');
if (sender === 'user') {
msgDiv.classList.add('user-message');
msgDiv.innerHTML = `<strong>你:</strong> ${message}`;
} else if (sender === 'ai') {
msgDiv.classList.add('ai-message');
msgDiv.innerHTML = `<strong>AI:</strong> ${message}`;
} else if (sender === 'system') {
msgDiv.classList.add('system-message');
msgDiv.innerHTML = `<em>系统:</em> ${message}`;
} else if (sender === 'error') {
msgDiv.classList.add('error-message');
msgDiv.innerHTML = `<strong>错误:</strong> ${message}`;
}
chatWindow.appendChild(msgDiv);
chatWindow.scrollTop = chatWindow.scrollHeight; // 滚动到底部
}
function connectWebSocket() {
if (ws && ws.readyState === WebSocket.OPEN) {
appendMessage('system', 'WebSocket 已连接。');
return;
}
ws = new WebSocket("ws://localhost:8000/ws/chat"); // 替换为你的 FastAPI 地址
ws.onopen = (event) => {
appendMessage('system', 'WebSocket 连接成功!');
connectButton.disabled = true;
disconnectButton.disabled = false;
sendButton.disabled = false;
conversationId = crypto.randomUUID(); // 生成新的会话ID
};
ws.onmessage = (event) => {
const msg = JSON.parse(event.data);
console.log("Received:", msg);
if (msg.type === "llm_new_token") {
// 模拟AI逐字输出
const lastAiMessage = chatWindow.querySelector('.ai-message:last-child span.content');
if (lastAiMessage && msg.event === 'on_llm_new_token_from_parser') { // 假设这是持续的AI文本
lastAiMessage.textContent += msg.data.content;
} else {
// 如果是新消息或者不同类型的AI chunk,则创建新行
const newMsgDiv = document.createElement('div');
newMsgDiv.classList.add('message', 'ai-message');
newMsgDiv.innerHTML = `<strong>AI:</strong> <span class="content">${msg.data.content}</span>`;
chatWindow.appendChild(newMsgDiv);
}
chatWindow.scrollTop = chatWindow.scrollHeight;
} else if (msg.type === "start") {
appendMessage('system', `会话开始 (ID: ${msg.conversation_id || 'N/A'})。问题: "${msg.data.question}"`);
// 清除之前的AI消息,准备接收新消息
const lastAiMessage = chatWindow.querySelector('.ai-message:last-child');
if (lastAiMessage) {
lastAiMessage.remove();
}
} else if (msg.type === "end") {
appendMessage('system', `会话结束。完整响应: "${msg.data.full_response}"`);
} else if (msg.type === "error") {
appendMessage('error', `错误: ${msg.detail}`);
} else if (msg.type === "tool_start") {
appendMessage('system', `调用工具: ${msg.data.name} (输入: ${msg.data.input})`);
} else if (msg.type === "tool_end") {
appendMessage('system', `工具 ${msg.data.name} 完成。输出: ${msg.data.output}`);
} else {
appendMessage('system', `未知消息类型 (${msg.type}): ${JSON.stringify(msg.data)}`);
}
};
ws.onclose = (event) => {
appendMessage('system', `WebSocket 连接关闭。Code: ${event.code}, Reason: ${event.reason}`);
connectButton.disabled = false;
disconnectButton.disabled = true;
sendButton.disabled = true;
ws = null;
conversationId = null;
};
ws.onerror = (error) => {
appendMessage('error', `WebSocket 错误: ${error.message}`);
};
}
function sendMessage() {
const question = questionInput.value.trim();
if (question && ws && ws.readyState === WebSocket.OPEN) {
appendMessage('user', question);
const messagePayload = {
conversation_id: conversationId,
question: question
};
ws.send(JSON.stringify(messagePayload));
questionInput.value = ''; // 清空输入框
} else if (!ws || ws.readyState !== WebSocket.OPEN) {
appendMessage('system', 'WebSocket 未连接,请先连接。');
}
}
connectButton.addEventListener('click', connectWebSocket);
disconnectButton.addEventListener('click', () => {
if (ws) ws.close();
});
sendButton.addEventListener('click', sendMessage);
questionInput.addEventListener('keypress', (e) => {
if (e.key === 'Enter') {
sendMessage();
}
});
// 页面加载时尝试自动连接
// connectWebSocket();
sendButton.disabled = true; // 默认禁用发送按钮
disconnectButton.disabled = true;
</script>
</body>
</html>
客户端代码解释:
connectWebSocket():建立与 FastAPI 服务端的 WebSocket 连接。ws.onmessage:这是核心处理逻辑。当收到服务器发送的 JSON 消息时,解析它并根据type字段更新聊天界面。llm_new_token类型消息的data.content包含 LLM 生成的文本片段。客户端可以将其追加到当前 AI 消息中,实现逐字显示效果。start和end消息用于显示会话的开始和结束。error消息显示服务器端产生的错误。
sendMessage():将用户输入的问题封装成 JSON 格式(包含conversation_id和question),并通过 WebSocket 发送给服务器。
部署与运行
- 安装依赖:
pip install fastapi uvicorn "langchain_openai>=0.1.0" "langchain_core>=0.1.0" pydantic - 保存代码:将 FastAPI 代码保存为
main.py。 - 设置环境变量:确保
OPENAI_API_KEY环境变量已设置,或在代码中直接赋值。export OPENAI_API_KEY="your_openai_api_key_here" - 运行 FastAPI 应用:
uvicorn main:app --reload - 打开客户端:在浏览器中打开保存的
index.html文件。点击“连接 WebSocket”,然后在输入框中输入问题,点击“发送”即可看到流式输出。
进阶考量与最佳实践
1. 错误处理与连接管理
- 健壮的错误消息:服务器端应捕获 LangChain 内部可能抛出的所有异常(如 LLM API 错误、工具执行错误),并通过 WebSocket 发送结构化的错误消息给客户端,以便客户端能够显示友好的错误提示。
- WebSocket 状态码:在
websocket.close()时,使用适当的 WebSocket 状态码(如status.WS_1011_INTERNAL_ERROR表示服务器内部错误,status.WS_1000_NORMAL_CLOSURE表示正常关闭)可以帮助客户端理解连接关闭的原因。 - 心跳机制 (Ping/Pong):对于长时间连接,为了防止中间网络设备断开不活跃的连接,可以实现心跳机制。FastAPI/Starlette 的 WebSocket 实现通常会处理底层的 Ping/Pong 帧,但在应用层也可以定期发送自定义心跳消息。
2. 可扩展性与并发
- FastAPI 的异步优势:FastAPI 和 LangChain 的异步特性(
async/await)使得在单个进程中处理大量并发 WebSocket 连接成为可能,而不会阻塞事件循环。 - 多进程部署:使用 Gunicorn 等 ASGI 服务器配合 Uvicorn worker 可以充分利用多核 CPU,进一步提升并发处理能力。
gunicorn main:app --workers 4 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:8000 - 会话管理:对于更复杂的应用,可能需要将会话状态(如聊天历史)存储在外部存储(如 Redis、数据库)中,而不是仅依赖内存,以便在多个服务器实例之间共享或在服务器重启后恢复。
3. 安全性
- 身份验证与授权:WebSocket 连接也需要身份验证和授权。可以在 WebSocket 连接建立时,通过查询参数、HTTP 头(在连接升级阶段)或第一次消息传递时进行用户身份验证。例如,在 FastAPI 中使用
Depends依赖注入来验证 JWT token。 - 输入验证:对客户端发送的任何输入(如
question)进行严格验证,防止注入攻击或其他恶意输入。 - 限流:为 WebSocket 连接设置速率限制,防止滥用和拒绝服务攻击。
4. 更复杂 LangChain 链的适配
- Agent 输出:如果 LangChain 链包含 Agent,它们可能会产生更复杂的输出,包括工具思考过程、中间步骤、工具调用参数等。自定义回调处理器需要实现更多
on_agent_*和on_tool_*方法来捕获这些信息,并将其结构化发送给客户端。 - Runnable.stream() 的多样性:如前所述,
Runnable.stream()会根据链的结构返回不同类型的chunk。在处理这些chunk时,需要根据其类型(AIMessageChunk,ToolCallChunk,FunctionCallChunk等)进行相应的序列化和发送。 StreamingStdOutCallbackHandler的启发:LangChain 自带的StreamingStdOutCallbackHandler是一个很好的参考,它展示了如何捕获不同类型的流式事件并进行处理。我们的WebSocketCallbackHandler基本上就是将其输出重定向到 WebSocket。
5. 用户体验优化
- 前端重组:客户端在接收到
llm_new_token类型的消息时,需要将这些文本片段逐步追加到 UI 元素中,而不是每次都创建一个新元素。 - 错误提示:当服务器发送错误消息时,客户端应以醒目的方式显示错误,并提供用户友好的建议。
- 加载指示:在等待第一个 token 到来之前,显示加载动画或文本,告知用户请求正在处理中。
总结
通过本文的讲解与示例,我们深入理解了如何在 FastAPI 中构建一个强大的流式中间件,以支持 LangChain 的流式输出并通过 WebSocket 实时传输给客户端。核心在于创建自定义的 WebSocketCallbackHandler 来桥接 LangChain 的内部事件与 WebSocket 协议,并通过 FastAPI 的异步 WebSocket 路由高效地处理并发连接和数据传输。
这种架构不仅极大地提升了 LLM 应用的用户体验,提供了实时的反馈,也为构建高性能、可扩展的 AI 驱动型 Web 应用奠定了坚实基础。