JAVA 如何监控 LLM 成本?记录 Token 消耗的 Billing 拦截器设计

好的,下面我将以讲座模式,深入探讨如何在 Java 环境下监控 LLM (Large Language Model) 成本,并通过设计 Billing 拦截器来记录 Token 消耗。

讲座:Java 环境下 LLM 成本监控与 Billing 拦截器设计

大家好,今天我们来聊聊如何监控 Large Language Model (LLM) 的成本,并设计一个 Billing 拦截器来记录 Token 消耗。随着 LLM 在各种应用场景中的广泛应用,成本控制变得越来越重要。

1. 成本监控的重要性

在使用 LLM 时,成本通常与使用的 Token 数量直接相关。不同的 LLM 提供商(例如 OpenAI, Google Cloud AI 等)会根据输入和输出的 Token 数量收取费用。因此,准确监控 Token 消耗对于控制预算至关重要。主要体现在以下几个方面:

  • 预算控制: 了解每个 LLM 请求的成本,可以帮助我们更好地控制总体预算。
  • 优化提示词: 通过分析 Token 消耗,可以优化提示词,减少不必要的 Token 使用。
  • 性能分析: Token 消耗也可以作为性能指标之一,帮助我们了解 LLM 的响应速度和效率。
  • 成本分摊: 在多用户或多团队环境中,可以根据 Token 消耗进行成本分摊。

2. 监控 LLM 成本的策略

在 Java 环境下监控 LLM 成本,主要有以下几种策略:

  • API 拦截器: 编写一个拦截器,拦截所有与 LLM 提供商 API 的交互,记录 Token 数量。
  • 代理模式: 使用代理模式,在客户端和 LLM 提供商之间创建一个代理层,负责记录 Token 数量。
  • 自定义包装器: 创建一个自定义的 LLM 包装器,封装 LLM API 的调用,并在包装器中记录 Token 数量。
  • 监控工具: 一些 LLM 提供商或第三方公司,会提供专门的监控工具,可以帮助我们监控 LLM 成本。

3. Billing 拦截器设计

我们这里重点讨论 API 拦截器。API 拦截器是一个强大的工具,可以在不修改现有代码的情况下,监控和记录 LLM 的 Token 消耗。

3.1 拦截器架构

Billing 拦截器需要拦截所有与 LLM 提供商 API 的交互。这可以通过使用 Java 的拦截器框架来实现,例如 Spring Interceptor 或 Servlet Filter。

// 使用 Spring Interceptor
public class LLMBillingInterceptor implements HandlerInterceptor {

    private final TokenCounter tokenCounter;

    public LLMBillingInterceptor(TokenCounter tokenCounter) {
        this.tokenCounter = tokenCounter;
    }

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        // 在请求处理之前执行
        return true;
    }

    @Override
    public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {
        // 在请求处理之后执行,但在视图渲染之前执行
    }

    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
        // 在整个请求完成之后执行
        // 在这里记录 Token 消耗
        String requestBody = getRequestBody(request); // 获取请求体
        String responseBody = getResponseBody(response); // 获取响应体

        long inputTokens = tokenCounter.countTokens(requestBody);
        long outputTokens = tokenCounter.countTokens(responseBody);

        tokenCounter.recordTokens(inputTokens, outputTokens);
    }

    // 从 HttpServletRequest 中获取请求体
    private String getRequestBody(HttpServletRequest request) throws IOException {
        StringBuilder stringBuilder = new StringBuilder();
        try (BufferedReader bufferedReader = request.getReader()) {
            char[] charBuffer = new char[128];
            int bytesRead;
            while ((bytesRead = bufferedReader.read(charBuffer)) != -1) {
                stringBuilder.append(charBuffer, 0, bytesRead);
            }
        }
        return stringBuilder.toString();
    }

    // 从 HttpServletResponse 中获取响应体 (需要包装 response)
    private String getResponseBody(HttpServletResponse response) throws IOException {
        if (response instanceof ContentCachingResponseWrapper) {
            ContentCachingResponseWrapper responseWrapper = (ContentCachingResponseWrapper) response;
            byte[] content = responseWrapper.getContentAsByteArray();
            if (content.length > 0) {
                return new String(content, responseWrapper.getCharacterEncoding());
            }
        }
        return "";
    }
}

// 需要使用 ContentCachingRequestWrapper 和 ContentCachingResponseWrapper
@Configuration
public class WebConfig implements WebMvcConfigurer {

    private final TokenCounter tokenCounter;

    public WebConfig(TokenCounter tokenCounter) {
        this.tokenCounter = tokenCounter;
    }

    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        registry.addInterceptor(new LLMBillingInterceptor(tokenCounter));
    }

    @Bean
    public FilterRegistrationBean<ContentCachingFilter> contentCachingFilter() {
        FilterRegistrationBean<ContentCachingFilter> registrationBean = new FilterRegistrationBean<>();
        registrationBean.setFilter(new ContentCachingFilter());
        registrationBean.addUrlPatterns("/*"); // 拦截所有请求
        return registrationBean;
    }
}

3.2 TokenCounter 类

TokenCounter 类负责统计 Token 数量和记录 Token 消耗。可以使用不同的 Tokenizer 库,例如 Hugging Face Tokenizers 或 OpenAI 的 Tiktoken。

// TokenCounter 接口
public interface TokenCounter {
    long countTokens(String text);
    void recordTokens(long inputTokens, long outputTokens);
}

// 使用 Tiktoken 实现 TokenCounter
@Component
public class TiktokenTokenCounter implements TokenCounter {

    private final Encoding enc;

    public TiktokenTokenCounter() {
        try {
            enc = Encoding.of("cl100k_base"); // 使用 OpenAI 的 cl100k_base 编码
        } catch (IOException e) {
            throw new RuntimeException("Failed to initialize Tiktoken encoding", e);
        }
    }

    @Override
    public long countTokens(String text) {
        return enc.encode(text).size();
    }

    @Override
    public void recordTokens(long inputTokens, long outputTokens) {
        // 将 Token 数量记录到数据库或日志文件中
        System.out.println("Input Tokens: " + inputTokens + ", Output Tokens: " + outputTokens);
        // 也可以将数据保存到数据库
        // billingRepository.save(new BillingRecord(inputTokens, outputTokens, new Date()));
    }
}

// BillingRecord 实体类 (示例)
@Entity
public class BillingRecord {
    @Id
    @GeneratedValue(strategy = GenerationType.IDENTITY)
    private Long id;

    private long inputTokens;
    private long outputTokens;
    private Date timestamp;

    // Getters and setters
    public BillingRecord(){}

    public BillingRecord(long inputTokens, long outputTokens, Date timestamp) {
        this.inputTokens = inputTokens;
        this.outputTokens = outputTokens;
        this.timestamp = timestamp;
    }

    public Long getId() {
        return id;
    }

    public void setId(Long id) {
        this.id = id;
    }

    public long getInputTokens() {
        return inputTokens;
    }

    public void setInputTokens(long inputTokens) {
        this.inputTokens = inputTokens;
    }

    public long getOutputTokens() {
        return outputTokens;
    }

    public void setOutputTokens(long outputTokens) {
        this.outputTokens = outputTokens;
    }

    public Date getTimestamp() {
        return timestamp;
    }

    public void setTimestamp(Date timestamp) {
        this.timestamp = timestamp;
    }
}

// BillingRepository 接口 (示例)
public interface BillingRepository extends JpaRepository<BillingRecord, Long> {
}

// Tiktoken 的 Maven 依赖
<!-- https://mvnrepository.com/artifact/com.knuddels/tiktoken -->
<dependency>
    <groupId>com.knuddels</groupId>
    <artifactId>tiktoken</artifactId>
    <version>0.5.1</version>
</dependency>

3.3 获取请求和响应内容

afterCompletion 方法中,我们需要获取请求和响应的内容,以便计算 Token 数量。可以使用 HttpServletRequestHttpServletResponse 对象来获取请求和响应的内容。但是,默认情况下,HttpServletRequestHttpServletResponse 的输入流只能读取一次。为了解决这个问题,可以使用 ContentCachingRequestWrapperContentCachingResponseWrapper 来缓存请求和响应的内容。

import org.springframework.web.util.ContentCachingRequestWrapper;
import org.springframework.web.util.ContentCachingResponseWrapper;

// 在拦截器中使用 ContentCachingRequestWrapper 和 ContentCachingResponseWrapper
@Override
public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
    ContentCachingRequestWrapper requestWrapper = (ContentCachingRequestWrapper) request;
    ContentCachingResponseWrapper responseWrapper = (ContentCachingResponseWrapper) response;

    String requestBody = new String(requestWrapper.getContentAsByteArray(), requestWrapper.getCharacterEncoding());
    String responseBody = new String(responseWrapper.getContentAsByteArray(), responseWrapper.getCharacterEncoding());

    long inputTokens = tokenCounter.countTokens(requestBody);
    long outputTokens = tokenCounter.countTokens(responseBody);

    tokenCounter.recordTokens(inputTokens, outputTokens);

    // 重要:需要复制响应内容,以便后续处理
    responseWrapper.copyBodyToResponse();
}

3.4 成本计算

根据 Token 数量和 LLM 提供商的定价,可以计算出每个请求的成本。

// 假设 OpenAI 的定价为 $0.0001 / 1K tokens
private static final double PRICE_PER_1K_TOKENS = 0.0001;

public double calculateCost(long inputTokens, long outputTokens) {
    double totalTokens = (inputTokens + outputTokens) / 1000.0;
    return totalTokens * PRICE_PER_1K_TOKENS;
}

4. 代码示例:集成 OpenAI API

为了更具体地说明如何使用拦截器,我们将展示一个与 OpenAI API 集成的示例。

// OpenAI API 客户端 (简化版)
@Service
public class OpenAIService {

    private final RestTemplate restTemplate;
    private final String apiKey;

    @Value("${openai.api.key}")
    public void setApiKey(String apiKey) {
        this.apiKey = apiKey;
    }

    private static final String COMPLETIONS_URL = "https://api.openai.com/v1/completions";

    public OpenAIService(RestTemplateBuilder restTemplateBuilder) {
        this.restTemplate = restTemplateBuilder.build();
    }

    public String getCompletion(String prompt) {
        HttpHeaders headers = new HttpHeaders();
        headers.setContentType(MediaType.APPLICATION_JSON);
        headers.set("Authorization", "Bearer " + apiKey);

        Map<String, Object> requestBody = new HashMap<>();
        requestBody.put("model", "text-davinci-003");
        requestBody.put("prompt", prompt);
        requestBody.put("max_tokens", 150);

        HttpEntity<Map<String, Object>> requestEntity = new HttpEntity<>(requestBody, headers);

        try {
            ResponseEntity<Map> response = restTemplate.postForEntity(COMPLETIONS_URL, requestEntity, Map.class);
            if (response.getStatusCode() == HttpStatus.OK && response.getBody() != null) {
                List<Map<String, String>> choices = (List<Map<String, String>>) response.getBody().get("choices");
                if (choices != null && !choices.isEmpty()) {
                    return choices.get(0).get("text");
                }
            }
            return "Error: Could not retrieve completion.";
        } catch (Exception e) {
            e.printStackTrace();
            return "Error: " + e.getMessage();
        }
    }
}

5. 拦截器配置

在 Spring Boot 应用中,需要配置拦截器才能生效。

@Configuration
public class AppConfig implements WebMvcConfigurer {

    private final LLMBillingInterceptor billingInterceptor;

    public AppConfig(LLMBillingInterceptor billingInterceptor) {
        this.billingInterceptor = billingInterceptor;
    }

    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        registry.addInterceptor(billingInterceptor).addPathPatterns("/api/llm/*");
    }
}

6. 关键代码片段总结

  • Token 计数器 (TokenCounter):
    public interface TokenCounter {
        long countTokens(String text);
        void recordTokens(long inputTokens, long outputTokens);
    }
  • Billing 拦截器 (LLMBillingInterceptor):
    public class LLMBillingInterceptor implements HandlerInterceptor {
        // ...
        @Override
        public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
            // ...
        }
    }
  • 配置类 (WebConfig):

    @Configuration
    public class WebConfig implements WebMvcConfigurer {
        // ...
        @Override
        public void addInterceptors(InterceptorRegistry registry) {
            // ...
        }
    
        @Bean
        public FilterRegistrationBean<ContentCachingFilter> contentCachingFilter() {
            // ...
        }
    }

7. 优化方向

  • 异步记录: 将 Token 消耗记录到数据库或日志文件中,可以使用异步方式,避免阻塞请求处理。
  • 缓存 Token 数量: 可以缓存 Token 数量,避免重复计算。
  • 动态配置: 可以使用配置中心,动态配置 LLM 提供商的定价。
  • 数据可视化: 可以使用数据可视化工具,将 Token 消耗和成本以图表的形式展示出来。
  • 更精细的监控: 可以根据用户,项目等维度进行更精细的监控.
  • 告警机制: 设定成本阈值,当成本超过阈值时,发送告警通知。

8. 遇到的问题与解决方案

问题 解决方案
无法多次读取 Request Body 使用 ContentCachingRequestWrapper 包装 HttpServletRequest,允许重复读取请求内容。
无法多次读取 Response Body 使用 ContentCachingResponseWrapper 包装 HttpServletResponse,允许重复读取响应内容。 并且需要调用responseWrapper.copyBodyToResponse(); 将缓存的内容写回。
Tokenizer 初始化失败 确保 Tokenizer 库已正确安装,并且已正确配置相关参数。 检查网络连接,确保可以下载必要的模型文件。
性能问题 使用异步方式记录 Token 消耗,避免阻塞请求处理。 缓存 Token 数量,避免重复计算。
成本计算不准确 确保 LLM 提供商的定价已正确配置。 定期检查 LLM 提供商的定价,并及时更新配置。
无法区分不同用户的 Token 消耗 在拦截器中获取用户信息,并将用户信息与 Token 消耗记录关联起来。 可以使用 Spring Security 或其他身份验证框架来获取用户信息。
无法监控流式 API 的 Token 消耗 对于流式 API,需要在每次接收到数据时,增量计算 Token 数量。 可以使用 ResponseBodyEmitterStreamingResponseBody 来处理流式响应。
如何处理图片等非文本数据的 Token 计算 不同的 LLM 对非文本数据的 Token 计算方式不同。 需要根据具体的 LLM 提供商的文档,选择合适的处理方式。 可以使用图像处理库,将图片转换为文本描述,然后计算 Token 数量。 也可以使用多模态 LLM,直接处理图片数据。
如何应对 LLM 厂商 API 变更 抽象 LLM 客户端接口,将 LLM 厂商的 API 调用封装在接口后面。 当 LLM 厂商 API 变更时,只需要修改接口的实现,而不需要修改其他代码。 使用适配器模式,将 LLM 厂商的 API 响应转换为统一的格式。

9. LLM 成本监控是长期工作

总而言之,LLM 成本监控是一个持续的过程,需要根据实际情况进行调整和优化。 通过设计 Billing 拦截器,可以有效地监控 LLM 的 Token 消耗,从而更好地控制成本。

希望今天的讲座对大家有所帮助。谢谢!

下一步可以做的事情

  • 部署拦截器到生产环境,持续观察和优化.
  • 结合实际业务场景,定制更精细的监控策略.
  • 探索更多 LLM 成本优化方法,例如模型压缩,知识蒸馏等。

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注