解析 ‘Cost Tracking’:如何在大规模并发下精准计算每个用户、每个会话消耗的 Token 账单?

各位同仁,下午好!

今天我们来深入探讨一个在当前AI大模型时代至关重要的议题:如何在面对大规模并发请求时,精准、可靠地计算每一个用户、每一个会话所消耗的Token账单,也就是我们常说的“Cost Tracking”。这不仅仅是一个技术挑战,更直接关系到产品的商业模式、用户体验以及运营的健康度。作为一名编程专家,我将从架构设计、核心算法到容错机制,为大家剖析这一复杂问题。

高并发下LLM Token账单的精准计算:Cost Tracking 深度解析

引言:挑战与机遇并存

随着大型语言模型(LLM)能力的飞速发展和应用场景的日益广泛,无论是开发者平台、SaaS产品还是企业内部应用,都面临着一个核心问题:如何计量和管理用户对LLM资源的消耗。Token作为LLM交互的基本单位,其消耗量直接决定了成本。想象一下,一个拥有数百万用户的平台,每秒处理成千上万个来自不同用户、不同会话的LLM请求,其中包含复杂的流式响应、不同模型的计费策略以及潜在的网络波动和系统故障。在这种高并发、高复杂度的环境下,要做到Token账单的“精准”和“实时”,绝非易事。

今天的讲座,我们的目标是构建一个健壮、可扩展、高可用的Token计费系统,它能够:

  1. 精准计量: 确保每个请求的输入和输出Token都被准确地统计,不遗漏,不重复。
  2. 用户/会话归属: 明确每个Token消耗属于哪个用户、哪个会话,便于成本归因和账单生成。
  3. 高并发支持: 在海量请求下依然稳定运行,保持低延迟。
  4. 容错性: 应对网络中断、服务宕机等异常情况,不丢失数据。
  5. 可审计性: 提供详细的日志和数据,支持对账和争议解决。

Part 1: 基础概念与挑战

在深入技术细节之前,我们先回顾一下Token计费的基础,并明确我们即将面对的挑战。

1.1 Tokenisation 回顾

什么是Token?
Token是LLM处理文本的基本单位。它通常比一个单词小,比如“unbelievable”可能会被拆分成“un”、“believe”、“able”。不同的模型使用不同的分词器(tokenizer),因此相同的文本在不同模型下可能会产生不同数量的Token。

Token的计算方式:
通常,Token分为输入Token输出Token

  • 输入Token: 用户提供给模型的prompt、历史对话上下文等。
  • 输出Token: 模型生成并返回给用户的响应内容。

定价策略:
LLM提供商通常会为不同的模型、不同的Token类型(输入/输出)设置不同的价格。

表1:示例模型定价结构

模型名称 输入 Token 价格 (USD / 1K Token) 输出 Token 价格 (USD / 1K Token)
GPT-4o $0.005 $0.015
GPT-4 $0.03 $0.06
GPT-3.5 Turbo $0.0005 $0.0015
Claude 3 Sonnet $0.003 $0.015

1.2 核心挑战

  1. 预估与实际:

    • 用户发送请求前,我们可以预估输入Token数,但实际消耗的输入Token数和模型返回的输出Token数,都必须以LLM API的实际响应为准。预估值仅供参考,不能用于计费。
    • 有些API在请求体中不直接返回Token数,我们需要在应用层自行统计。
  2. 流式输出 (Streaming Output):

    • LLM的流式响应是提升用户体验的关键。但它也带来了计费的复杂性:模型会分块(chunk)返回内容,如何实时、准确地计算这些分块累积的输出Token数?我们不能等到整个响应结束后再计算,因为用户可能提前中断。
  3. 高并发与状态管理:

    • 每秒数千乃至数万的请求涌入,每个请求都需要关联到特定的用户和会话。如何在分布式系统中高效、准确地维护这些状态,并进行Token的累加,同时保证数据一致性?
  4. 容错性与数据一致性:

    • 网络瞬断、LLM服务不稳定、我们自己的服务重启,都可能导致Token数据丢失或重复。如何设计一个健壮的系统来应对这些问题,并确保最终账单的准确无误?

Part 2: 核心架构设计

为了应对上述挑战,我们需要一个分层、解耦、高可用的系统架构。

2.1 系统组件概览

我们的Cost Tracking系统将由以下核心组件构成:

  • API Gateway / Load Balancer: 负责流量分发、认证、限流等。
  • Application Server (App Server): 业务逻辑核心,接收用户请求,调用LLM API,处理响应。它是Token计数的“生产者”。
  • Token Counter Service: 一个专门的服务,负责接收Token消费事件,进行精确计数,并更新存储。它是Token计数的“消费者”。
  • Message Queue (MQ): 作为App Server和Token Counter Service之间的缓冲和桥梁,确保事件的可靠传递和异步处理。
  • Data Storage: 存储原始请求日志、用户会话信息、Token消费记录以及聚合账单。
  • Billing Service: 定期从Data Storage中聚合数据,生成用户账单,并可能与支付系统集成。

![Cost Tracking System Architecture Diagram – Text Representation]

+----------------+       +-------------------+       +--------------------+
|   User Client  |<----->|    API Gateway    |<----->| Application Server |
+----------------+       +-------------------+       +--------------------+
                                    |                           |
                                    |                           | (1. User Request)
                                    |                           |
                                    v                           v
                      +-----------------------------+     +----------------+
                      |     External LLM API        |<--->|   LLM Proxy    |
                      | (e.g., OpenAI, Anthropic)   |     | (Optional, for |
                      +-----------------------------+     |  centralized   |
                                                          |  API management)|
                                                          +----------------+
                                                                  |
                                                                  | (2. LLM Response & Token Count)
                                                                  v
                      +-------------------------------------------------+
                      | Message Queue (Kafka/RabbitMQ/Redis Streams)    |
                      | (Token Consumption Events)                      |
                      +-------------------------------------------------+
                                    ^                                   |
                                    | (3. Publish Event)                |
                                    |                                   | (4. Consume Event)
                      +-------------------------------------------------+
                      | Token Counter Service (Worker Pool)             |<--->| Redis (Atomic Counters, Cache) |
                      | (Processes events, updates storage)             |     +--------------------------------+
                      +-------------------------------------------------+
                                    |
                                    | (5. Persist Data)
                                    v
                      +-------------------------------------------------+
                      | Data Storage (PostgreSQL/MongoDB/ClickHouse)    |
                      | (Raw Logs, Aggregated Consumption, User Data)   |
                      +-------------------------------------------------+
                                    |
                                    | (6. Billing Aggregation)
                                    v
                      +-------------------------------------------------+
                      | Billing Service (Generates Invoices, Reports)   |
                      +-------------------------------------------------+

2.2 数据模型设计

清晰的数据模型是确保数据一致性和可审计性的基础。

表2:关键数据模型示例

模型名称 字段 类型 说明
User id UUID 用户唯一标识
username String 用户名
balance Decimal 用户当前余额 (用于预付费或限额)
created_at DateTime
Session id UUID 会话唯一标识 (用于跟踪用户在特定对话中的消耗)
user_id UUID 关联用户ID
start_time DateTime 会话开始时间
end_time DateTime 会话结束时间
total_input_tokens Integer 会话总输入Token数
total_output_tokens Integer 会话总输出Token数
RequestLog id UUID 请求唯一标识 (用于幂等性处理)
session_id UUID 关联会话ID
user_id UUID 关联用户ID
model_name String 使用的LLM模型名称 (e.g., "gpt-4o", "claude-3-sonnet")
request_payload JSONB 原始请求内容 (部分或摘要,用于审计)
response_payload JSONB 原始响应内容 (部分或摘要,用于审计)
input_tokens Integer 本次请求的输入Token数
output_tokens Integer 本次请求的输出Token数
cost Decimal 本次请求的预估或实际成本
status String 请求状态 (SUCCESS, FAILED, PENDING)
request_time DateTime 请求发起时间
response_time DateTime 响应接收时间
TokenConsumption id UUID 消费记录ID
user_id UUID 关联用户ID
session_id UUID 关联会话ID
request_log_id UUID 关联请求日志ID
model_name String 使用的LLM模型名称
token_type Enum INPUT / OUTPUT
tokens_count Integer 消耗的Token数量
unit_price Decimal 每Token的单价
total_cost Decimal 本次消费总成本
consumed_at DateTime 消费发生时间

2.3 高并发下的状态管理

在高并发场景下,直接在应用服务器中维护用户或会话的Token累计状态是不可行的,因为应用服务器通常是无状态的,且请求会被负载均衡到不同的实例上。我们需要一个集中式、高性能的状态存储。

  • Redis:

    • 优点: 内存数据库,读写速度极快,支持原子操作(如INCRBY),非常适合作为短期、高频更新的Token计数器。
    • 用途: 存储实时会话的Token累计数(例如,session:{session_id}:input_tokens),用于实时显示或快速判断是否达到会话限额。也可以用于存储用户余额的临时缓存,并结合CAS(Compare-And-Swap)操作进行扣费。
    • 缺点: 内存存储,数据持久化需要配置,但通常作为最终数据库的补充,不作为唯一数据源。
  • 关系型数据库 (PostgreSQL等):

    • 优点: 事务支持,数据持久化,强一致性,适用于存储最终的、可靠的Token消费记录和聚合数据。
    • 用途: 存储RequestLogTokenConsumptionUserSession等,作为最终账单的源头。
    • 缺点: 写入性能在高并发下可能成为瓶颈,不适合每秒数万次的原子计数。

我们的策略是:Redis用于实时、高性能的临时计数和缓存,关系型数据库用于持久化、可靠的最终记录。 通过消息队列将Redis中的临时计数异步同步到关系型数据库。

Part 3: 精准计算与实时处理

这是整个系统的核心。我们将探讨如何在应用层、代理层以及通过异步机制实现Token的精准计算。

3.1 同步 vs 异步处理

  • 同步处理: 在处理用户请求的同一线程/进程中完成Token计数和记录。

    • 优点: 简单直接,实时性高。
    • 缺点: 阻塞主请求流程,增加延迟;如果计数或存储失败,可能影响用户请求,且难以重试;在高并发下可能成为瓶颈。
    • 适用场景: 简单系统,或对实时性要求极高且失败可接受(例如,仅用于前端展示的预估值)。
  • 异步处理: 将Token计数和记录操作解耦,通过消息队列发送事件,由独立的Worker服务消费处理。

    • 优点:
      • 不阻塞主请求,提升用户响应速度。
      • 削峰填谷,系统更稳定。
      • 易于扩展,增加Worker即可提升处理能力。
      • 通过MQ的持久化和重试机制,确保数据不丢失,提升容错性。
      • 简化App Server逻辑,专注于业务。
    • 缺点: 引入MQ增加系统复杂度;数据可能存在最终一致性延迟(但对于账单系统通常可接受)。
    • 适用场景: 几乎所有高并发、需要高可靠性的Cost Tracking系统。

我们强烈推荐采用异步处理模式。

3.2 Token计数器服务

Token Counter Service是异步处理的核心。它的主要职责是:

  1. 从消息队列中消费Token消费事件。
  2. 解析事件,提取user_id, session_id, model_name, input_tokens, output_tokens等信息。
  3. 执行原子性的Token累加操作(例如,更新Redis或数据库中的累计值)。
  4. 将详细的Token消费记录(TokenConsumption)和请求日志(RequestLog)持久化到数据库。
  5. 处理幂等性,避免重复计数。

3.3 流式输出 Token 计算

这是最具挑战性的部分。LLM流式响应意味着我们无法一次性获取所有输出Token。

方法1: 累积式计数 (在应用服务器/代理层)

这是最常用且可靠的方法。在处理LLM流式响应时,应用服务器或一个专门的LLM代理服务会逐块接收数据。对于每个接收到的数据块,我们进行解码,然后使用相应的LLM分词器计算该块的Token数,并累加到当前请求的输出Token总数中。

import openai
import tiktoken
import uuid
import time
import json
from datetime import datetime
from typing import Generator, Dict, Any

# 假设这是一个简化的LLM代理服务或应用服务器的一部分
class LLMProxyService:
    def __init__(self, message_queue_producer):
        self.message_queue_producer = message_queue_producer
        self.model_tokenizers = {
            "gpt-4o": tiktoken.encoding_for_model("gpt-4o"),
            "gpt-3.5-turbo": tiktoken.encoding_for_model("gpt-3.5-turbo"),
            # 根据需要添加更多模型
        }
        self.model_prices = {
            "gpt-4o": {"input": 0.005 / 1000, "output": 0.015 / 1000},
            "gpt-3.5-turbo": {"input": 0.0005 / 1000, "output": 0.0015 / 1000},
        }

    def _count_tokens(self, text: str, model_name: str) -> int:
        """根据模型名称计算文本的Token数"""
        tokenizer = self.model_tokenizers.get(model_name)
        if not tokenizer:
            print(f"Warning: No tokenizer found for model {model_name}. Using default.")
            tokenizer = tiktoken.get_encoding("cl100k_base") # Fallback
        return len(tokenizer.encode(text))

    def _publish_token_event(self, event_type: str, data: Dict[str, Any]):
        """将Token消费事件发布到消息队列"""
        event = {
            "event_id": str(uuid.uuid4()), # 保证事件唯一性,用于幂等
            "timestamp": datetime.now().isoformat(),
            "event_type": event_type,
            "data": data
        }
        self.message_queue_producer.publish(json.dumps(event), topic="token_consumption_events")
        print(f"Published event: {event_type} for request {data.get('request_log_id')}")

    def call_llm_streaming(
        self,
        user_id: str,
        session_id: str,
        model_name: str,
        prompt: str,
        max_tokens: int = 500
    ) -> Generator[str, None, None]:
        """
        模拟调用LLM并处理流式响应,计算Token。
        这是一个生成器函数,每次yield一个响应块。
        """
        request_log_id = str(uuid.uuid4())
        start_time = time.time()

        # 1. 计算输入Token
        input_tokens = self._count_tokens(prompt, model_name)

        # 预估成本(仅供参考,实际成本以输出Token为准)
        estimated_input_cost = input_tokens * self.model_prices.get(model_name, {}).get("input", 0)

        print(f"[{request_log_id}] Input Tokens: {input_tokens}, Estimated Input Cost: {estimated_input_cost:.4f}")

        # 发布输入Token事件
        self._publish_token_event(
            "input_token_consumed",
            {
                "request_log_id": request_log_id,
                "user_id": user_id,
                "session_id": session_id,
                "model_name": model_name,
                "token_type": "INPUT",
                "tokens_count": input_tokens,
                "unit_price": self.model_prices.get(model_name, {}).get("input", 0),
                "prompt": prompt # 可选择记录部分prompt
            }
        )

        total_output_tokens = 0
        full_response_content = []

        try:
            # 模拟OpenAI流式API调用
            # For a real scenario, replace with actual openai.chat.completions.create
            # Example:
            # stream = openai.chat.completions.create(
            #     model=model_name,
            #     messages=[{"role": "user", "content": prompt}],
            #     max_tokens=max_tokens,
            #     stream=True,
            # )
            # For demonstration, we'll simulate a stream:

            simulated_stream_chunks = [
                {"choices": [{"delta": {"content": "Hello"}}]},
                {"choices": [{"delta": {"content": ", "}}]},
                {"choices": [{"delta": {"content": "this"}}]},
                {"choices": [{"delta": {"content": " is"}}]},
                {"choices": [{"delta": {"content": " a"}}]},
                {"choices": [{"delta": {"content": " streaming"}}]},
                {"choices": [{"delta": {"content": " response."}}]},
                {"choices": [{"delta": {"content": ""}}]}, # End of stream
            ]

            for chunk in simulated_stream_chunks: # Replace with `for chunk in stream:`
                if chunk and chunk.get("choices"):
                    delta_content = chunk["choices"][0].get("delta", {}).get("content")
                    if delta_content:
                        # 累积计算输出Token
                        chunk_tokens = self._count_tokens(delta_content, model_name)
                        total_output_tokens += chunk_tokens
                        full_response_content.append(delta_content)

                        # 实时发布输出Token块事件(可选,粒度较细)
                        # 对于精细计费,可以每收到一个chunk就发布一次
                        # 但通常我们会等整个请求完成后再发布一个总的output_token_consumed事件
                        # 或者在每次`yield`之前发布,并在结束后进行最终确认

                        yield delta_content # 将内容块返回给客户端

                        # 如果需要实时扣费,可以在这里触发一个扣费操作,或者更新Redis中的临时计数
                        # self._publish_token_event(
                        #     "partial_output_token_consumed",
                        #     { ... chunk_tokens ... }
                        # )

                time.sleep(0.05) # 模拟网络延迟

        except Exception as e:
            print(f"Error during LLM streaming for request {request_log_id}: {e}")
            # 处理错误,发布错误事件
            self._publish_token_event(
                "request_failed",
                {
                    "request_log_id": request_log_id,
                    "user_id": user_id,
                    "session_id": session_id,
                    "error_message": str(e),
                    "status": "FAILED"
                }
            )
            raise # 重新抛出异常,让上层处理
        finally:
            end_time = time.time()
            duration = end_time - start_time

            # 2. 完成后发布最终输出Token事件
            if total_output_tokens > 0:
                final_output_cost = total_output_tokens * self.model_prices.get(model_name, {}).get("output", 0)
                print(f"[{request_log_id}] Total Output Tokens: {total_output_tokens}, Final Output Cost: {final_output_cost:.4f}")
                self._publish_token_event(
                    "output_token_consumed",
                    {
                        "request_log_id": request_log_id,
                        "user_id": user_id,
                        "session_id": session_id,
                        "model_name": model_name,
                        "token_type": "OUTPUT",
                        "tokens_count": total_output_tokens,
                        "unit_price": self.model_prices.get(model_name, {}).get("output", 0),
                        "full_response": "".join(full_response_content) # 可选择记录完整响应
                    }
                )

            # 发布请求完成事件,包含所有统计信息
            self._publish_token_event(
                "request_completed",
                {
                    "request_log_id": request_log_id,
                    "user_id": user_id,
                    "session_id": session_id,
                    "model_name": model_name,
                    "input_tokens": input_tokens,
                    "output_tokens": total_output_tokens,
                    "total_cost": estimated_input_cost + (final_output_cost if total_output_tokens > 0 else 0),
                    "duration": duration,
                    "status": "SUCCESS"
                }
            )

# 模拟一个消息队列生产者
class MockMessageQueueProducer:
    def publish(self, message: str, topic: str):
        print(f"Mock MQ: Publishing to topic '{topic}': {message[:100]}...") # 打印前100字符

# 示例用法
if __name__ == "__main__":
    mock_mq = MockMessageQueueProducer()
    llm_service = LLMProxyService(mock_mq)

    user_id = "user_123"
    session_id = "session_abc"
    model = "gpt-4o"
    prompt = "Tell me a short story about a brave knight and a dragon."

    print("n--- Starting LLM Streaming Call ---")
    response_generator = llm_service.call_llm_streaming(user_id, session_id, model, prompt)
    full_response = []
    try:
        for chunk in response_generator:
            print(f"Received chunk: {chunk}", end="")
            full_response.append(chunk)
        print("n--- LLM Streaming Call Finished ---")
    except Exception as e:
        print(f"nStreaming call failed: {e}")

    print(f"nFull Response: {''.join(full_response)}")

    # 模拟另一个请求
    print("n--- Starting Another LLM Streaming Call ---")
    response_generator_2 = llm_service.call_llm_streaming("user_456", "session_def", "gpt-3.5-turbo", "What is the capital of France?")
    full_response_2 = []
    try:
        for chunk in response_generator_2:
            print(f"Received chunk: {chunk}", end="")
            full_response_2.append(chunk)
        print("n--- Another LLM Streaming Call Finished ---")
    except Exception as e:
        print(f"nStreaming call failed: {e}")
    print(f"nFull Response: {''.join(full_response_2)}")

代码示例1解释:

  • LLMProxyService 负责模拟LLM调用和Token计数。
  • _count_tokens 使用 tiktoken 库根据模型名称计算Token。
  • call_llm_streaming 是核心方法。它首先计算输入Token并发布事件。
  • 在模拟的流式响应循环中,每次接收到delta_content(一个文本块),就计算其Token数并累加到total_output_tokens
  • 最后,当流式响应结束时,发布一个output_token_consumed事件,包含最终的输出Token总数。
  • _publish_token_event 将结构化的事件发送到消息队列,由后续的Worker处理。
  • request_log_id 是每个请求的唯一标识,在事件中传递,用于幂等性处理。

方法2: LLM API 返回计数:
一些LLM API(如OpenAI在非流式模式下或流式响应的[DONE]标记中)会直接返回usage对象,其中包含prompt_tokenscompletion_tokens。这是最方便的,因为它避免了我们自己维护分词器和计算逻辑。

  • 优点: 简单,可靠性高,因为是API提供方计算。
  • 缺点: 并非所有API都提供,尤其是在流式响应的中间状态,通常只提供最终计数。对于用户提前中断的情况,可能无法获得准确的中间计数。

结论: 对于流式输出,方法1(累积式计数)是更通用的解决方案,它赋予我们对Token计算的完全控制权,并支持实时或准实时的计量。

3.4 并发安全计数

Token Counter Service在处理来自MQ的事件时,可能会有多个Worker实例同时尝试更新同一个用户或会话的Token总数。这需要并发安全机制。

  1. 数据库的原子操作:
    大多数关系型数据库支持原子更新操作,例如PostgreSQL的UPDATE ... SET total_input_tokens = total_input_tokens + X WHERE id = Y;。这确保了在数据库层面的并发安全。

  2. Redis原子命令:
    Redis的INCRBY命令是原子性的,非常适合在高并发下对计数器进行增量操作。

import redis
import json
import uuid
from datetime import datetime
from typing import Dict, Any

# 假设这是一个简化的Token Counter Service Worker
class TokenCounterWorker:
    def __init__(self, redis_client: redis.Redis, db_connection_pool):
        self.redis_client = redis_client
        self.db_connection_pool = db_connection_pool
        self.processed_events_set = "processed_token_events" # Redis Set for幂等性
        self.model_prices = {
            "gpt-4o": {"input": 0.005 / 1000, "output": 0.015 / 1000},
            "gpt-3.5-turbo": {"input": 0.0005 / 1000, "output": 0.0015 / 1000},
        }

    def _calculate_cost(self, model_name: str, token_type: str, tokens_count: int) -> float:
        """根据模型和Token类型计算成本"""
        price_per_token = self.model_prices.get(model_name, {}).get(token_type.lower(), 0)
        return tokens_count * price_per_token

    def process_event(self, event_json: str):
        """处理从消息队列接收到的Token消费事件"""
        event: Dict[str, Any] = json.loads(event_json)
        event_id = event.get("event_id")
        event_type = event.get("event_type")
        data = event.get("data", {})

        if not event_id:
            print(f"Error: Event missing event_id: {event_json}")
            return

        # 1. 幂等性检查
        # 使用Redis的SETNX或Lua脚本确保原子性
        # 如果event_id已存在于processed_events_set中,则说明已处理过
        if self.redis_client.sismember(self.processed_events_set, event_id):
            print(f"Skipping already processed event: {event_id}")
            return

        # 记录已处理事件,设置过期时间防止集合无限增长
        self.redis_client.sadd(self.processed_events_set, event_id)
        self.redis_client.expire(self.processed_events_set, 7 * 24 * 3600) # 记录7天

        print(f"Processing event [{event_id}] type: {event_type} for user {data.get('user_id')}")

        try:
            user_id = data.get("user_id")
            session_id = data.get("session_id")
            request_log_id = data.get("request_log_id")
            model_name = data.get("model_name")

            if event_type == "input_token_consumed" or event_type == "output_token_consumed":
                token_type = data.get("token_type")
                tokens_count = data.get("tokens_count")

                if not all([user_id, session_id, model_name, token_type, tokens_count is not None]):
                    print(f"Error: Missing essential data for token consumption event: {data}")
                    return

                total_cost = self._calculate_cost(model_name, token_type, tokens_count)

                # 2. 更新Redis中的实时计数 (原子操作)
                # 用户总Token
                self.redis_client.incrby(f"user:{user_id}:total_tokens", tokens_count)
                # 用户分模型Token
                self.redis_client.incrby(f"user:{user_id}:model:{model_name}:{token_type.lower()}_tokens", tokens_count)
                # 会话总Token
                self.redis_client.incrby(f"session:{session_id}:total_{token_type.lower()}_tokens", tokens_count)
                self.redis_client.incrbyfloat(f"user:{user_id}:total_cost", total_cost)

                # 3. 持久化到数据库 (示例使用伪代码,实际需用DB ORM或SQL)
                with self.db_connection_pool.get_connection() as conn:
                    cursor = conn.cursor()

                    # 记录详细的TokenConsumption
                    cursor.execute(
                        """
                        INSERT INTO TokenConsumption (id, user_id, session_id, request_log_id, model_name, token_type, tokens_count, unit_price, total_cost, consumed_at)
                        VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
                        ON CONFLICT (id) DO NOTHING; -- 确保幂等性
                        """,
                        (
                            str(uuid.uuid4()), user_id, session_id, request_log_id, model_name, token_type, 
                            tokens_count, self.model_prices.get(model_name, {}).get(token_type.lower(), 0), 
                            total_cost, datetime.now()
                        )
                    )

                    # 更新Session的汇总Token
                    cursor.execute(
                        f"""
                        UPDATE Session
                        SET total_{token_type.lower()}_tokens = COALESCE(total_{token_type.lower()}_tokens, 0) + %s
                        WHERE id = %s;
                        """,
                        (tokens_count, session_id)
                    )

                    # 更新User的汇总余额 (如果采用预付费模式,这里进行扣减)
                    # cursor.execute(
                    #     """
                    #     UPDATE User
                    //     SET balance = balance - %s
                    //     WHERE id = %s AND balance >= %s;
                    //     """,
                    //     (total_cost, user_id, total_cost)
                    // )
                    # if cursor.rowcount == 0:
                    #     print(f"Warning: User {user_id} balance insufficient for cost {total_cost}")
                    #     # 这里可能需要触发预警或回滚,取决于业务逻辑

                    conn.commit()
                    print(f"DB: Updated DB for request {request_log_id}, {token_type} tokens: {tokens_count}")

            elif event_type == "request_completed" or event_type == "request_failed":
                # 更新RequestLog的状态和最终统计
                with self.db_connection_pool.get_connection() as conn:
                    cursor = conn.cursor()
                    cursor.execute(
                        """
                        INSERT INTO RequestLog (id, user_id, session_id, model_name, input_tokens, output_tokens, total_cost, status, request_time, response_time)
                        VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
                        ON CONFLICT (id) DO UPDATE SET
                            input_tokens = EXCLUDED.input_tokens,
                            output_tokens = EXCLUDED.output_tokens,
                            total_cost = EXCLUDED.total_cost,
                            status = EXCLUDED.status,
                            response_time = EXCLUDED.response_time;
                        """,
                        (
                            request_log_id, user_id, session_id, model_name,
                            data.get("input_tokens"), data.get("output_tokens"), data.get("total_cost"),
                            data.get("status"), data.get("request_time", datetime.now()), data.get("response_time", datetime.now())
                        )
                    )
                    conn.commit()
                    print(f"DB: Updated RequestLog for request {request_log_id} with status {data.get('status')}")

            else:
                print(f"Unknown event type: {event_type}")

        except Exception as e:
            print(f"Error processing event {event_id}: {e}")
            # 将消息重新放回队列或发送到死信队列 (DLQ)
            # 对于实际系统,这里需要更复杂的错误处理和重试逻辑
            # self.message_queue_producer.requeue(event_json, topic="token_consumption_events")
            # print(f"Event {event_id} requeued.")

# 模拟数据库连接池
class MockDBConnectionPool:
    def get_connection(self):
        # 返回一个模拟的数据库连接对象
        class MockCursor:
            def execute(self, query, params=None):
                print(f"Executing DB query: {query} with params {params}")
            def rowcount(self): return 1 # 模拟更新成功
        class MockConnection:
            def cursor(self): return MockCursor()
            def commit(self): print("DB: Committing transaction.")
            def __enter__(self): return self
            def __exit__(self, exc_type, exc_val, exc_tb): pass
        return MockConnection()

# 示例用法
if __name__ == "__main__":
    r = redis.Redis(host='localhost', port=6379, db=0)
    mock_db_pool = MockDBConnectionPool()
    worker = TokenCounterWorker(r, mock_db_pool)

    # 模拟几个事件
    event_1_data = {
        "request_log_id": "req_1", "user_id": "user_123", "session_id": "session_abc", 
        "model_name": "gpt-4o", "token_type": "INPUT", "tokens_count": 100
    }
    event_2_data = {
        "request_log_id": "req_1", "user_id": "user_123", "session_id": "session_abc", 
        "model_name": "gpt-4o", "token_type": "OUTPUT", "tokens_count": 50
    }
    event_3_data = {
        "request_log_id": "req_1", "user_id": "user_123", "session_id": "session_abc", 
        "model_name": "gpt-4o", "input_tokens": 100, "output_tokens": 50, 
        "total_cost": 0.1, "status": "SUCCESS", "request_time": datetime.now().isoformat(),
        "response_time": datetime.now().isoformat()
    }
    event_4_data = { # 另一个请求
        "request_log_id": "req_2", "user_id": "user_456", "session_id": "session_def", 
        "model_name": "gpt-3.5-turbo", "token_type": "INPUT", "tokens_count": 20
    }

    mock_events = [
        {"event_id": str(uuid.uuid4()), "timestamp": datetime.now().isoformat(), "event_type": "input_token_consumed", "data": event_1_data},
        {"event_id": str(uuid.uuid4()), "timestamp": datetime.now().isoformat(), "event_type": "output_token_consumed", "data": event_2_data},
        {"event_id": str(uuid.uuid4()), "timestamp": datetime.now().isoformat(), "event_type": "request_completed", "data": event_3_data},
        {"event_id": str(uuid.uuid4()), "timestamp": datetime.now().isoformat(), "event_type": "input_token_consumed", "data": event_4_data},
    ]

    print("n--- Processing Events with Worker ---")
    for event in mock_events:
        worker.process_event(json.dumps(event))
        time.sleep(0.1) # 模拟处理间隔

    # 检查Redis计数
    print("n--- Checking Redis Counters ---")
    print(f"User user_123 total tokens: {r.get('user:user_123:total_tokens')}")
    print(f"User user_123 GPT-4o input tokens: {r.get('user:user_123:model:gpt-4o:input_tokens')}")
    print(f"Session session_abc total output tokens: {r.get('session:session_abc:total_output_tokens')}")
    print(f"User user_123 total cost: {r.get('user:user_123:total_cost')}")
    print(f"User user_456 total tokens: {r.get('user:user_456:total_tokens')}")

代码示例2解释:

  • TokenCounterWorker 负责从MQ消费事件。
  • 幂等性: 使用processed_events_set(Redis Set)来记录已经处理过的事件ID。sismembersadd操作保证了即使消息重复投递,同一个事件也只会被处理一次。
  • Redis 原子计数: redis_client.incrby()redis_client.incrbyfloat() 用于原子性地增加整数或浮点数计数,避免并发问题。
  • 数据库持久化: 示例中展示了如何将详细的TokenConsumption记录插入数据库,并更新SessionRequestLog的汇总数据。数据库的ON CONFLICT DO NOTHINGON CONFLICT DO UPDATE子句同样有助于实现幂等性。
  • 事务: 实际生产中,数据库操作应包裹在事务中,确保所有相关更新要么全部成功,要么全部失败。

Part 4: 容错、回溯与审计

一个健壮的计费系统必须能够优雅地处理各种异常情况。

4.1 消息队列的应用

消息队列(如Kafka, RabbitMQ, Redis Streams)在异步处理中扮演着核心角色。

  • 解耦: 生产者(App Server)和消费者(Token Counter Service)互不依赖,提高系统弹性。
  • 削峰填谷: 缓冲瞬时高并发流量,防止后端服务过载。
  • 持久化消息: 消息在队列中被持久化,即使消费者服务宕机,消息也不会丢失,待服务恢复后可继续处理。
  • 重试机制: 消费者处理失败的消息可以被重新投递,直到成功处理。
  • 死信队列 (Dead Letter Queue, DLQ): 对于多次重试仍然失败的消息,可以将其发送到DLQ进行人工干预或分析,防止阻塞主队列。

代码示例3: 发布/订阅Token消费事件到Kafka (概念性)

# app_server.py (生产者)
from kafka import KafkaProducer
import json
import uuid

producer = KafkaProducer(
    bootstrap_servers=['kafka:9092'],
    value_serializer=lambda v: json.dumps(v).encode('utf-8')
)

def publish_token_event_to_kafka(event_data: dict):
    event_data["event_id"] = str(uuid.uuid4()) # Ensure unique event ID
    event_data["timestamp"] = datetime.now().isoformat()
    producer.send('token_consumption_events', event_data)
    producer.flush() # 确保消息被发送

# token_counter_worker.py (消费者)
from kafka import KafkaConsumer
import json

consumer = KafkaConsumer(
    'token_consumption_events',
    group_id='token_counter_group',
    bootstrap_servers=['kafka:9092'],
    value_deserializer=lambda m: json.loads(m.decode('utf-8')),
    enable_auto_commit=False # 手动提交,更精细控制
)

def start_worker(worker_instance):
    for message in consumer:
        try:
            event = message.value
            worker_instance.process_event(json.dumps(event))
            consumer.commit() # 处理成功后手动提交offset
        except Exception as e:
            print(f"Failed to process message {message.offset}: {e}")
            # 将消息发送到DLQ或记录错误,等待重试
            # Kafka默认不会自动重试,需要我们自己实现或使用特定库
            # 例如,将错误消息推送到另一个“死信”topic
            producer.send('token_dlq_events', event)
            producer.flush()
            # 不提交offset,以便下次重启时重新处理此消息

4.2 幂等性处理

如前所述,幂等性是分布式系统中的关键。消息队列可能因为各种原因(网络抖动、消费者重启)重复投递消息。我们的消费者必须能识别并跳过重复的消息。

  • 唯一请求ID: 每个Token消费事件都应包含一个全局唯一的event_id(通常是UUID)。
  • 消费者端去重: 消费者处理消息时,首先检查这个event_id是否已经在“已处理事件集合”中。如果存在,则跳过。
    • 实现方式: Redis Set (SADD + SISMEMBER) 是一个高效的选择,但需要注意Set的大小和过期策略。或者,在数据库中维护一张processed_events表,以event_id作为主键,利用数据库的唯一性约束。
# 核心逻辑已在代码示例2中展示:
# self.redis_client.sadd(self.processed_events_set, event_id)
# if self.redis_client.sismember(self.processed_events_set, event_id): return

4.3 数据一致性与最终一致性

在高并发分布式系统中,实现强一致性(所有数据副本实时同步)的代价是巨大的,往往会牺牲性能和可用性。对于Token计费这样的场景,最终一致性通常是可接受的。

  • 最终一致性: 意味着数据在一段时间后会达到一致状态。例如,Redis中的实时计数和数据库中的持久化记录可能在短时间内不一致,但最终会通过异步同步和对账机制达到一致。
  • 对账机制:
    • 实时对账: Redis中的累计值与数据库中的累计值进行定期比对,发现差异时触发告警或自动修复。
    • 离线对账: 定期(例如每天或每周)从原始日志或消息队列的归档中重新计算Token消耗,与当前的数据库数据进行比对,确保没有任何遗漏或错误。这通常是一个批处理作业。

4.4 审计日志

为了确保计费的透明性和可追溯性,详尽的审计日志至关重要。

  • 记录所有关键事件: 每个LLM请求的完整输入、输出(或摘要)、Token计数、成本计算、发生时间、用户ID、会话ID、模型名称等都应被记录下来。
  • 不可篡改性: 审计日志应存储在不可篡改或难以篡改的存储中,例如Append-only日志文件、专用的日志服务或区块链(对于极高安全要求的场景)。
  • 日志分析: 结合ELK Stack (Elasticsearch, Logstash, Kibana) 或类似的日志分析工具,可以方便地查询、分析和可视化Token消费数据,用于问题排查、成本分析和用户行为洞察。

Part 5: 优化与高级策略

5.1 批处理与聚合

如果Token消费事件非常频繁(例如,一个用户在短时间内发送大量小请求),每次都立即写入数据库可能会带来巨大的IO压力。

  • 内存聚合: 可以在Token Counter Service中,将来自同一个用户或会话的小额Token消费先在内存中进行聚合,达到一定数量或时间间隔后,再批量写入数据库。
    • 风险: 服务崩溃可能导致内存中未持久化的数据丢失。
    • 解决方案: 结合WAL (Write-Ahead Log) 或其他持久化队列,将内存中的聚合数据定期写入本地持久化日志,即使崩溃也能恢复。

5.2 多级缓存

  • 模型价格缓存: 模型价格通常是静态的,可以缓存在内存或Redis中,减少数据库查询。
  • 用户余额缓存: 用户余额可以在Redis中进行缓存,用于快速判断是否可发起请求,但最终扣费仍需以数据库为准,或采用乐观锁/CAS操作确保一致性。

5.3 成本预警与限额

基于实时的Token消费数据,可以实现以下功能:

  • 用户限额: 设置每个用户、每个会话或每个API Key的Token使用上限。当达到阈值时,可以阻止进一步的LLM请求。Redis的原子计数在此场景下非常有用。
  • 成本预警: 当用户消费达到预设金额或Token量时,发送通知(邮件、短信、Webhook)。

5.4 云服务集成

利用云厂商提供的托管服务可以大大降低运维复杂性:

  • 托管消息队列: AWS SQS/SNS, Azure Service Bus, Google Cloud Pub/Sub。
  • 托管数据库: AWS RDS, Azure SQL Database, Google Cloud SQL。
  • 托管Redis: AWS ElastiCache, Azure Cache for Redis, Google Cloud Memorystore。
  • 日志服务: AWS CloudWatch, Azure Monitor, Google Cloud Logging。

结语

高并发下LLM Token账单的精准计算是一个系统性工程,它要求我们在架构设计、数据模型、并发控制、容错机制和可观测性等方面都进行深思熟虑。通过构建一个解耦、异步、幂等且可审计的系统,我们可以为用户提供透明、可靠的计费服务,同时确保我们自身的运营成本得到有效管理。这个过程是一个持续演进的旅程,需要在实时性、准确性和系统复杂度之间找到最佳平衡点。

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注