持续批处理 (Continuous Batching) 在 Go 后端中实现动态合并请求以提升 GPU 利用率
尊敬的各位开发者,大家好!
今天我们将深入探讨一个在高性能、低延迟服务,尤其是涉及大量计算密集型任务(如机器学习推理)时至关重要的技术:持续批处理 (Continuous Batching)。随着人工智能,特别是大型语言模型 (LLMs) 的飞速发展,如何高效利用昂贵的 GPU 资源成为了后端服务面临的核心挑战。传统的请求处理模式往往导致 GPU 资源的严重浪费。本讲座将从理论到实践,详细阐述持续批处理的概念、其在 Go 后端中的实现策略、关键数据结构与算法,并辅以丰富的 Go 语言代码示例,以期为大家提供一套提升 GPU 利用率的实战方案。
1. 传统请求处理模式与 GPU 利用率瓶颈
在典型的 Web 服务架构中,客户端发出请求,后端服务接收请求,处理后返回响应。当涉及机器学习推理时,这个过程通常是:
- 客户端发送包含输入数据的请求。
- 后端服务接收请求,将输入数据转化为模型所需的张量格式。
- 后端将张量发送给推理服务(可能是一个独立的微服务,或者通过 CGO/FFI 直接调用本地库)。
- 推理服务在 GPU 上执行模型推理。
- 推理结果返回给后端服务。
- 后端服务将结果格式化后返回给客户端。
这种“请求-响应”模式在处理单个请求时效率很高,但在高并发场景下,尤其是当每个请求的推理任务相对较小或间歇性发生时,会暴露出严重的 GPU 利用率问题。
GPU 利用率瓶颈的主要原因:
- GPU 启动开销 (Kernel Launch Overhead):每次在 GPU 上执行计算任务(称为 Kernel),都需要一定的启动时间。对于小型任务,这个启动时间可能占据总执行时间的很大一部分,导致实际计算时间占比很低。
- 数据传输开销 (Data Transfer Overhead):CPU 和 GPU 之间的数据传输(PCIe 带宽限制)也是一个显著的开销。对于小批量数据,传输时间可能超过实际计算时间。
- 空闲等待 (Idle Waiting):当 GPU 完成一个请求的推理后,如果没有新的请求立即到来,GPU 就会处于空闲状态,等待下一个请求。在高并发但请求到达不规律的场景下,这种空闲等待会频繁发生。
- 固定批处理大小的局限性 (Fixed Batch Size Limitations):为了提升吞吐量,传统的解决方案是固定批处理大小。例如,每次推理都处理 8 个请求。如果请求数量不足 8 个,推理会等待直到凑齐,这会增加延迟;如果请求数量远超 8 个,则需要排队,同样增加延迟。
这些问题共同导致了 GPU 资源的低效利用,直接影响了服务的吞吐量和成本效益。
2. 什么是持续批处理 (Continuous Batching)?
持续批处理是一种旨在最大化 GPU 利用率的技术,它通过动态地将多个客户端请求合并成一个更大的批次,然后一次性提交给 GPU 进行推理。与传统的固定批处理不同,持续批处理更加灵活和动态,它不会等待一个完整的固定批次,而是在一个预设的时间窗口内或达到某个最大批次大小时,就立即处理当前队列中的所有请求。
更进一步,特别是对于大型语言模型 (LLMs),持续批处理不仅仅是请求级别的合并,更是令牌级别 (Token-level) 的合并。LLM 推理通常分为两个阶段:
- 预填充 (Pre-fill):处理用户的输入提示 (prompt),计算初始的 KV 缓存 (Key-Value Cache)。这通常是一个矩阵-向量乘法密集的操作。
- 解码 (Decoding):逐个生成新的令牌。在每次生成一个令牌时,模型会根据之前的 KV 缓存和当前生成的令牌进行推理,并更新 KV 缓存。这是一个迭代过程。
在传统的处理方式中,每个请求的预填充和解码过程是独立的。当一个请求正在解码时,GPU 可能因为等待下一个令牌的输入而空闲。持续批处理,特别是对于 LLMs,会同时管理多个正在进行的序列 (sequences)。当一个序列完成预填充并进入解码阶段时,它会与其他正在解码的序列一起被调度到 GPU 上。这样,GPU 可以在单个操作中并行处理多个序列的下一个令牌生成,极大地减少了空闲时间,提高了吞吐量。
持续批处理的核心思想:
- 动态合并请求:不预设固定的批次大小,而是根据请求到达的速度和系统负载动态调整批次。
- 最小化 GPU 空闲时间:尽可能确保 GPU 在任何时刻都有任务可执行。
- 优化数据传输:通过合并请求,减少 CPU-GPU 之间的数据传输次数,并利用更大的传输块。
- 令牌级调度 (LLMs):对于 LLMs,在解码阶段,将多个序列的令牌生成任务合并到一个 GPU 操作中。
3. 持续批处理的关键组成部分与工作流程
为了在 Go 后端实现持续批处理,我们需要设计一个协调机制,它包含以下几个主要组件:
- 请求接收器 (Request Receiver):负责接收客户端的原始请求。
- 批处理管理器 (Batch Manager):核心组件,维护一个待处理请求队列,并根据策略(时间或大小)触发批处理。
- GPU 推理器接口 (GPU Inferencer Interface):负责与实际的 GPU 推理服务(通常是独立的微服务,如使用 Triton Inference Server, vLLM, 或 TensorRT-LLM 搭建的 Python/C++ 服务)进行通信。
- 响应分发器 (Response Dispatcher):将批处理推理的结果分解,并与原始请求关联,将结果返回给相应的客户端。
工作流程概览:
- 客户端发送请求
Req_A,Req_B,Req_C… - 请求接收器收到请求,为每个请求创建一个内部表示
InternalRequest,并将其发送到批处理管理器的队列。 - 批处理管理器在一个专门的 Goroutine 中运行。它不断地从队列中取出请求。
- 当满足批处理条件(例如,队列中的请求数量达到
MaxBatchSize,或距离上次批处理操作已超过BatchTimeout)时,管理器会将当前队列中的所有请求合并成一个Batch。
- 当满足批处理条件(例如,队列中的请求数量达到
- 批处理管理器将
Batch发送给 GPU 推理器接口。 - GPU 推理器接口与底层的 GPU 推理服务通信,提交批处理任务。
- GPU 推理服务执行推理,并返回一个批处理结果
BatchResult。 - GPU 推理器接口接收
BatchResult并将其传回批处理管理器。 - 批处理管理器将
BatchResult分解,根据原始请求的标识符,将结果派发给对应的等待者(通常通过chan或sync.Cond通知)。 - 响应分发器(或原始请求的处理 Goroutine)接收到结果,并将其返回给客户端。
针对 LLMs 的令牌级持续批处理:
对于 LLMs,这个流程会更复杂,因为推理是迭代的。批处理管理器需要维护每个序列的内部状态。
- 初始请求 (Prompt Processing):
- 客户端发送请求
(prompt_A, max_tokens_A)。 - 请求被放入批处理管理器的“预填充队列”。
- 当满足条件时,预填充队列中的请求被合并为一个批次,发送给 GPU 推理器接口进行预填充。
- GPU 推理服务完成预填充,返回初始的 KV 缓存和第一个生成的令牌。
- 批处理管理器接收结果,更新每个序列的状态(存储 KV 缓存,记录已生成的令牌),并将序列移至“解码队列”。
- 客户端发送请求
- 解码阶段 (Token Generation):
- 批处理管理器定期从“解码队列”中选择一批活动序列(那些尚未完成
max_tokens或未生成停止符的序列)。 - 将这些序列的当前状态(包括 KV 缓存、已生成的最后一个令牌)打包成一个批次,发送给 GPU 推理器接口进行解码。
- GPU 推理服务根据这些状态并行生成下一个令牌。
- 批处理管理器接收结果,更新每个序列的状态(追加新生成的令牌,更新 KV 缓存)。
- 如果序列达到
max_tokens或生成停止符,则将其从活动序列中移除,并派发最终结果给客户端。否则,它会继续留在解码队列中,等待下一轮的解码。
- 批处理管理器定期从“解码队列”中选择一批活动序列(那些尚未完成
这种令牌级的批处理需要 GPU 推理服务本身支持复杂的序列调度和 KV 缓存管理(例如 vLLM 的 PagedAttention)。Go 后端的作用是作为这些序列的管理者和调度者。
4. Go 后端中的架构设计与数据结构
在 Go 中实现持续批处理,我们需要精心设计数据结构和 Goroutine 之间的协作。
4.1 核心数据结构
1. 客户端请求的内部表示 (InternalRequest)
package main
import (
"context"
"fmt"
"time"
)
// InferenceRequest 代表客户端发送的原始推理请求
type InferenceRequest struct {
ID string // 请求唯一标识符
InputData string // 原始输入数据,例如文本
Params map[string]string // 推理参数,例如 max_new_tokens, temperature
}
// InferenceResponse 代表推理服务的响应
type InferenceResponse struct {
ID string
OutputData string
Error error
}
// InternalRequest 是在批处理系统内部使用的请求表示
// 它包含了客户端请求、上下文以及用于接收结果的通道
type InternalRequest struct {
InferenceRequest
ctx context.Context // 用于处理请求的上下文,支持超时和取消
response chan *InferenceResponse // 用于将推理结果返回给发起者
createdAt time.Time // 请求创建时间,用于计算等待时间
isLLM bool // 是否是 LLM 推理请求
sequence *LLMSequenceState // 如果是 LLM 请求,保存其序列状态
// ... 其他内部状态 ...
}
// LLMSequenceState 存储 LLM 推理的动态状态
type LLMSequenceState struct {
SequenceID string
PromptTokens []int // 原始 prompt 的 token ID
GeneratedTokens []int // 已生成的 token ID
KVStoreRef interface{} // 指向 GPU 端的 KV 缓存引用,具体类型取决于推理引擎接口
MaxNewTokens int // 用户请求的最大生成 token 数量
StopReason string // 停止生成的原因
IsFinished bool // 序列是否已完成
LastTokenOutput string // 上一轮生成的 token
InternalRequest *InternalRequest // 关联的原始 InternalRequest
LastAccessTime time.Time // 最后一次被调度的时间,用于调度策略
}
// Batch represents a collection of InternalRequests to be processed together
type Batch struct {
ID string
Requests []*InternalRequest
SubmittedAt time.Time
BatchType string // "prefill" 或 "decode"
// ... 其他批次信息 ...
}
// BatchResult represents the combined results from a batched inference operation
type BatchResult struct {
BatchID string
Results []*InferenceResponse
Error error
// ... 其他批次结果信息 ...
}
2. 批处理管理器 (Batcher)
批处理管理器是整个系统的核心,它负责协调请求的入队、批次的创建和分发。
// BatcherConfig 批处理器的配置
type BatcherConfig struct {
MaxBatchSize int // 单个批次最大请求数
BatchTimeout time.Duration // 批次超时时间,无论请求数量多少,都会触发批处理
WorkerPoolSize int // 后端推理工作 Goroutine 数量
LLMDecodingInterval time.Duration // LLM 解码批处理的间隔
}
// Batcher 核心批处理管理器
type Batcher struct {
config BatcherConfig
requestQueue chan *InternalRequest // 接收新到来的请求
internalReqStore map[string]*InternalRequest // 存储所有活跃的 InternalRequest,用于结果映射
prefillQueue []*InternalRequest // LLM 预填充请求队列
decodingSequences []*LLMSequenceState // LLM 正在解码的序列
batchWg sync.WaitGroup // 用于等待所有批次处理完成
shutdown chan struct{} // 关闭信号
mu sync.Mutex // 保护对 requestQueue 和 internalReqStore 的访问
nextBatchID int64 // 批次 ID 生成器
inferenceClient GPUInferencer // GPU 推理客户端接口
}
// GPUInferencer 定义了与 GPU 推理服务交互的接口
type GPUInferencer interface {
Infer(ctx context.Context, batch *Batch) (*BatchResult, error)
// 对于 LLM,可能需要更细粒度的接口
Prefill(ctx context.Context, batch *Batch) (*BatchResult, error)
Decode(ctx context.Context, decodeBatch *Batch) (*BatchResult, error)
// ... 其他控制接口,如 KV cache 清理等
}
4.2 批处理管理器 Goroutine
Batcher 内部会启动一个或多个 Goroutine 来执行批处理逻辑。
package main
import (
"context"
"fmt"
"log"
"sync"
"time"
"github.com/google/uuid" // 假设使用 uuid 生成 ID
)
// NewBatcher 创建一个新的 Batcher 实例
func NewBatcher(config BatcherConfig, client GPUInferencer) *Batcher {
return &Batcher{
config: config,
requestQueue: make(chan *InternalRequest, config.MaxBatchSize*2), // 缓冲区大小
internalReqStore: make(map[string]*InternalRequest),
shutdown: make(chan struct{}),
nextBatchID: 0,
inferenceClient: client,
prefillQueue: make([]*InternalRequest, 0, config.MaxBatchSize),
decodingSequences: make([]*LLMSequenceState, 0, config.MaxBatchSize),
}
}
// Start 启动 Batcher 的主 Goroutine
func (b *Batcher) Start() {
log.Println("Batcher starting...")
b.batchWg.Add(1)
go b.run() // 主批处理循环
// 启动 LLM 解码调度器 (如果需要)
if b.config.LLMDecodingInterval > 0 {
b.batchWg.Add(1)
go b.runLLMDecoder()
}
}
// Stop 停止 Batcher
func (b *Batcher) Stop() {
log.Println("Batcher stopping...")
close(b.shutdown)
b.batchWg.Wait() // 等待所有批处理 Goroutine 退出
log.Println("Batcher stopped.")
}
// SubmitRequest 提交一个客户端请求到批处理器
func (b *Batcher) SubmitRequest(ctx context.Context, req InferenceRequest, isLLM bool) (*InferenceResponse, error) {
internalReq := &InternalRequest{
InferenceRequest: req,
ctx: ctx,
response: make(chan *InferenceResponse, 1), // 缓冲通道,避免阻塞发送
createdAt: time.Now(),
isLLM: isLLM,
}
b.mu.Lock()
b.internalReqStore[req.ID] = internalReq // 存储请求以便后续结果映射
b.mu.Unlock()
select {
case b.requestQueue <- internalReq:
// 请求成功入队
case <-ctx.Done():
b.mu.Lock()
delete(b.internalReqStore, req.ID) // 请求被取消,从存储中移除
b.mu.Unlock()
return nil, ctx.Err()
case <-b.shutdown:
b.mu.Lock()
delete(b.internalReqStore, req.ID)
b.mu.Unlock()
return nil, fmt.Errorf("batcher is shutting down")
}
// 等待推理结果
select {
case res := <-internalReq.response:
b.mu.Lock()
delete(b.internalReqStore, req.ID) // 结果已返回,从存储中移除
b.mu.Unlock()
return res, res.Error
case <-ctx.Done():
b.mu.Lock()
delete(b.internalReqStore, req.ID)
b.mu.Unlock()
return nil, ctx.Err()
case <-b.shutdown:
b.mu.Lock()
delete(b.internalReqStore, req.ID)
b.mu.Unlock()
return nil, fmt.Errorf("batcher is shutting down")
}
}
// run 是 Batcher 的主循环,负责收集请求和触发批处理
func (b *Batcher) run() {
defer b.batchWg.Done()
ticker := time.NewTicker(b.config.BatchTimeout)
defer ticker.Stop()
// 缓冲区用于暂存请求,以便在触发批处理时一次性取出
tempBuffer := make([]*InternalRequest, 0, b.config.MaxBatchSize)
for {
select {
case req := <-b.requestQueue:
if req.isLLM {
b.prefillQueue = append(b.prefillQueue, req)
} else {
tempBuffer = append(tempBuffer, req)
}
// 检查是否达到最大批次大小,如果达到则立即触发处理
b.tryProcessBatch(&tempBuffer, "general_inference")
b.tryProcessLLMPrefillBatch()
case <-ticker.C:
// 批次超时,无论请求数量多少,都触发处理
b.tryProcessBatch(&tempBuffer, "general_inference")
b.tryProcessLLMPrefillBatch()
case <-b.shutdown:
log.Println("Batcher run goroutine shutting down.")
// 在关闭前处理剩余的请求
b.tryProcessBatch(&tempBuffer, "general_inference")
b.tryProcessLLMPrefillBatch()
b.processRemainingLLMDecodingSequences(true) // 强制完成所有 LLM 序列
return
}
}
}
// tryProcessBatch 尝试处理一个通用推理批次
func (b *Batcher) tryProcessBatch(buffer *[]*InternalRequest, batchType string) {
if len(*buffer) == 0 {
return
}
// 如果达到最大批次大小,立即处理
if len(*buffer) >= b.config.MaxBatchSize {
b.processBatch(*buffer, batchType)
*buffer = (*buffer)[:0] // 清空 buffer
}
// 注意:这里需要一个外部的定时器来触发超时批处理,或者在 `run` 循环中处理
}
// processBatch 实际发送批次到 GPU 推理服务
func (b *Batcher) processBatch(requests []*InternalRequest, batchType string) {
if len(requests) == 0 {
return
}
batchID := fmt.Sprintf("batch-%d-%s", b.nextBatchID, uuid.New().String())
b.nextBatchID++
batch := &Batch{
ID: batchID,
Requests: requests,
SubmittedAt: time.Now(),
BatchType: batchType,
}
log.Printf("Processing batch %s with %d requests (type: %s).", batch.ID, len(batch.Requests), batch.BatchType)
b.batchWg.Add(1)
go func() {
defer b.batchWg.Done()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) // 推理超时
defer cancel()
var batchResult *BatchResult
var err error
switch batch.BatchType {
case "general_inference":
batchResult, err = b.inferenceClient.Infer(ctx, batch)
case "prefill":
batchResult, err = b.inferenceClient.Prefill(ctx, batch)
case "decode":
batchResult, err = b.inferenceClient.Decode(ctx, batch)
default:
err = fmt.Errorf("unknown batch type: %s", batch.BatchType)
}
if err != nil {
log.Printf("Batch %s inference failed: %v", batch.ID, err)
for _, req := range requests {
// 通知所有请求失败
select {
case req.response <- &InferenceResponse{ID: req.ID, Error: err}:
case <-time.After(time.Millisecond * 100): // 避免阻塞
log.Printf("Failed to send error response for request %s (channel blocked)", req.ID)
}
}
return
}
if batchResult == nil {
log.Printf("Batch %s inference returned nil result.", batch.ID)
err = fmt.Errorf("inference service returned nil result")
for _, req := range requests {
select {
case req.response <- &InferenceResponse{ID: req.ID, Error: err}:
case <-time.After(time.Millisecond * 100):
log.Printf("Failed to send error response for request %s (channel blocked)", req.ID)
}
}
return
}
b.distributeResults(batch, batchResult)
}()
}
// distributeResults 将批处理结果分发给原始请求
func (b *Batcher) distributeResults(batch *Batch, batchResult *BatchResult) {
// 假设 BatchResult.Results 数组的顺序与 Batch.Requests 相同
// 或者 BatchResult 内部包含映射关系 (例如 map[string]*InferenceResponse)
// 这里我们假设是按顺序的,实际情况可能需要更复杂的匹配逻辑
if len(batch.Requests) != len(batchResult.Results) && batchResult.Error == nil {
log.Printf("Warning: Batch %s request count mismatch with result count. Requests: %d, Results: %d",
batch.ID, len(batch.Requests), len(batchResult.Results))
// 如果结果数量不匹配,视为整个批次失败或部分失败,需要更复杂的处理
// 简单处理:将错误分发给所有请求
err := fmt.Errorf("result count mismatch for batch %s", batch.ID)
for _, req := range batch.Requests {
select {
case req.response <- &InferenceResponse{ID: req.ID, Error: err}:
case <-time.After(time.Millisecond * 100):
log.Printf("Failed to send error response for request %s (channel blocked)", req.ID)
}
}
return
}
for i, req := range batch.Requests {
var res *InferenceResponse
if batchResult.Error != nil {
// 整个批次失败
res = &InferenceResponse{ID: req.ID, Error: batchResult.Error}
} else {
// 正常分发单个结果
res = batchResult.Results[i]
if res.ID == "" { // 补充 ID
res.ID = req.ID
}
}
// 对于 LLM 预填充请求,需要更新其序列状态并将其移至解码队列
if req.isLLM && batch.BatchType == "prefill" && res.Error == nil {
// 假设 res.OutputData 包含第一个生成的 token 和 KV 缓存引用
// 实际中这部分逻辑会更复杂,可能需要从 BatchResult 中解析出 KV 缓存等信息
seqState := &LLMSequenceState{
SequenceID: uuid.New().String(), // 每个序列一个 ID
PromptTokens: nil, // 假设已在推理服务中处理
GeneratedTokens: []int{123}, // 假设第一个 token ID
KVStoreRef: fmt.Sprintf("kv_ref_%s", uuid.New().String()), // 模拟 KV 缓存引用
MaxNewTokens: 100, // 从 req.Params 中解析
IsFinished: false,
LastTokenOutput: "first_token",
InternalRequest: req,
LastAccessTime: time.Now(),
}
// 更新原始请求的 sequence 字段
req.sequence = seqState
b.mu.Lock()
b.decodingSequences = append(b.decodingSequences, seqState)
b.mu.Unlock()
// 此时不向客户端返回结果,而是等待整个序列完成
log.Printf("LLM prefill for request %s completed. Sequence %s moved to decoding queue.", req.ID, seqState.SequenceID)
continue // 跳过对 req.response 的发送,等待最终结果
}
// 将结果发送回原始请求的通道
select {
case req.response <- res:
// 成功发送
case <-req.ctx.Done():
// 客户端已取消请求
log.Printf("Client for request %s cancelled while sending response.", req.ID)
case <-time.After(time.Millisecond * 100):
// 通道被阻塞,可能客户端已退出或发生其他问题
log.Printf("Failed to send response for request %s (channel blocked)", req.ID)
}
}
}
4.3 LLM 令牌级持续批处理的实现细节
上面通用 Batcher 已经包含了 prefillQueue 和 decodingSequences。现在我们来具体实现 LLM 特有的批处理逻辑。
1. 预填充批次处理 (Prefill Batch)
当 prefillQueue 中有足够多的请求时,或者达到超时,就触发预填充批次。
// tryProcessLLMPrefillBatch 尝试处理 LLM 预填充批次
func (b *Batcher) tryProcessLLMPrefillBatch() {
b.mu.Lock()
defer b.mu.Unlock()
if len(b.prefillQueue) == 0 {
return
}
// 达到最大批次大小,立即处理
if len(b.prefillQueue) >= b.config.MaxBatchSize {
// 取出当前所有预填充请求
batchRequests := b.prefillQueue
b.prefillQueue = make([]*InternalRequest, 0, b.config.MaxBatchSize) // 清空队列
b.processBatch(batchRequests, "prefill")
}
// 注意:预填充的超时触发逻辑与通用推理批次相同,由 `run` 循环中的 ticker 触发
}
2. 解码批次调度 (Decoding Batch)
解码批次需要一个独立的 Goroutine 定期运行,因为它是一个迭代过程。
// runLLMDecoder 负责 LLM 解码阶段的批处理调度
func (b *Batcher) runLLMDecoder() {
defer b.batchWg.Done()
ticker := time.NewTicker(b.config.LLMDecodingInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
b.processLLMDecodingBatch()
case <-b.shutdown:
log.Println("LLM decoder goroutine shutting down.")
b.processRemainingLLMDecodingSequences(false) // 尝试处理剩余序列,但不强制完成
return
}
}
}
// processLLMDecodingBatch 收集正在解码的序列并进行批处理
func (b *Batcher) processLLMDecodingBatch() {
b.mu.Lock()
defer b.mu.Unlock()
if len(b.decodingSequences) == 0 {
return
}
// 收集所有未完成的序列,准备进行下一轮解码
activeSequences := make([]*LLMSequenceState, 0, len(b.decodingSequences))
pendingRequests := make([]*InternalRequest, 0, len(b.decodingSequences))
// 过滤掉已完成的序列,并准备批处理输入
for _, seq := range b.decodingSequences {
if !seq.IsFinished {
activeSequences = append(activeSequences, seq)
pendingRequests = append(pendingRequests, seq.InternalRequest)
}
}
if len(activeSequences) == 0 {
b.decodingSequences = b.decodingSequences[:0] // 清空已完成的序列
return
}
// 创建一个特殊的批次,包含所有活动序列的信息
// 注意:这里需要将 LLMSequenceState 转换为 GPU 推理服务所需的批处理输入格式
// 例如,包含所有序列的 KV 缓存引用、最后一个生成的 token ID 等
// 为了简化,我们假设 processBatch 可以处理这种批次类型,并从 InternalRequest 中提取必要信息
// 实际中,可能需要一个专门的 `buildLLMDecodeBatch` 函数
decodeBatchID := fmt.Sprintf("decode-batch-%d-%s", b.nextBatchID, uuid.New().String())
b.nextBatchID++
decodeBatch := &Batch{
ID: decodeBatchID,
Requests: pendingRequests, // 这里 Requests 实际上是持有 LLMSequenceState 的 InternalRequest
SubmittedAt: time.Now(),
BatchType: "decode",
}
log.Printf("Processing LLM decode batch %s with %d active sequences.", decodeBatch.ID, len(activeSequences))
// 提交解码批次到推理服务
b.batchWg.Add(1)
go func(batch *Batch, sequences []*LLMSequenceState) {
defer b.batchWg.Done()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) // 解码超时
defer cancel()
batchResult, err := b.inferenceClient.Decode(ctx, batch) // 调用专门的解码接口
if err != nil {
log.Printf("LLM decode batch %s inference failed: %v", batch.ID, err)
// 通知所有相关请求失败
for _, req := range batch.Requests {
if req.sequence != nil {
req.sequence.IsFinished = true // 标记为完成,避免再次调度
}
select {
case req.response <- &InferenceResponse{ID: req.ID, Error: err}:
case <-time.After(time.Millisecond * 100):
log.Printf("Failed to send error response for request %s (channel blocked)", req.ID)
}
}
return
}
if batchResult == nil {
log.Printf("LLM decode batch %s inference returned nil result.", batch.ID)
err = fmt.Errorf("inference service returned nil result for decode batch")
for _, req := range batch.Requests {
if req.sequence != nil {
req.sequence.IsFinished = true
}
select {
case req.response <- &InferenceResponse{ID: req.ID, Error: err}:
case <-time.After(time.Millisecond * 100):
log.Printf("Failed to send error response for request %s (channel blocked)", req.ID)
}
}
return
}
// 处理解码结果,更新序列状态
b.handleLLMDecodeResults(batch, batchResult, sequences)
}(decodeBatch, activeSequences)
}
// handleLLMDecodeResults 处理 LLM 解码批次的结果
func (b *Batcher) handleLLMDecodeResults(batch *Batch, batchResult *BatchResult, activeSequences []*LLMSequenceState) {
b.mu.Lock()
defer b.mu.Unlock()
// 假设 batchResult.Results 也是与 activeSequences 顺序对应的
// 或者包含 SequenceID 到结果的映射
if len(batchResult.Results) != len(activeSequences) {
log.Printf("Warning: LLM decode batch %s result count mismatch. Active sequences: %d, Results: %d",
batch.ID, len(activeSequences), len(batchResult.Results))
// 处理错误,标记所有序列为完成并发送错误
for _, seq := range activeSequences {
seq.IsFinished = true
select {
case seq.InternalRequest.response <- &InferenceResponse{ID: seq.InternalRequest.ID, Error: fmt.Errorf("decode result mismatch")}:
case <-time.After(time.Millisecond * 100):
log.Printf("Failed to send error response for request %s", seq.InternalRequest.ID)
}
}
return
}
newActiveSequences := make([]*LLMSequenceState, 0, len(b.decodingSequences))
for i, seq := range activeSequences {
res := batchResult.Results[i] // 对应当前序列的结果
if res.Error != nil {
log.Printf("Sequence %s failed during decode: %v", seq.SequenceID, res.Error)
seq.IsFinished = true
select {
case seq.InternalRequest.response <- res:
case <-time.After(time.Millisecond * 100):
log.Printf("Failed to send error response for request %s", seq.InternalRequest.ID)
}
continue
}
// 更新序列状态
// 实际中,res.OutputData 会是新生成的 token,需要解析并添加到 GeneratedTokens
// 并且可能包含更新后的 KV 缓存引用
seq.GeneratedTokens = append(seq.GeneratedTokens, 456) // 模拟新 token ID
seq.LastTokenOutput = res.OutputData // 假设 OutputData 是新生成的 token 文本
seq.LastAccessTime = time.Now()
// 检查是否达到停止条件 (max_tokens 或生成停止符)
// 假设 max_new_tokens 是从 InferenceRequest.Params 中解析出来的
currentMaxTokens, _ := seq.InternalRequest.Params["max_new_tokens"] // 示例
maxNewTokens := 100
fmt.Sscanf(currentMaxTokens, "%d", &maxNewTokens)
if len(seq.GeneratedTokens) >= maxNewTokens || res.OutputData == "<EOS>" { // 假设 "<EOS>" 是停止符
seq.IsFinished = true
seq.StopReason = "max_tokens_reached" // 或 "stop_token_generated"
log.Printf("Sequence %s finished. Total tokens: %d, reason: %s", seq.SequenceID, len(seq.GeneratedTokens), seq.StopReason)
// 序列完成,发送最终结果给客户端
finalRes := &InferenceResponse{
ID: seq.InternalRequest.ID,
OutputData: seq.InternalRequest.InputData + " " + seq.LastTokenOutput, // 拼接所有生成的 token
Error: nil,
}
select {
case seq.InternalRequest.response <- finalRes:
case <-time.After(time.Millisecond * 100):
log.Printf("Failed to send final response for request %s", seq.InternalRequest.ID)
}
// KV 缓存清理 (如果推理服务支持)
// b.inferenceClient.ReleaseKV(seq.KVStoreRef)
} else {
newActiveSequences = append(newActiveSequences, seq) // 仍然活跃,保留在队列中
}
}
b.decodingSequences = newActiveSequences // 更新活动序列列表
}
// processRemainingLLMDecodingSequences 在 Batcher 关闭时处理剩余的 LLM 序列
func (b *Batcher) processRemainingLLMDecodingSequences(forceComplete bool) {
b.mu.Lock()
defer b.mu.Unlock()
for _, seq := range b.decodingSequences {
if !seq.IsFinished {
err := fmt.Errorf("batcher shutting down, sequence %s incomplete", seq.SequenceID)
if forceComplete {
err = fmt.Errorf("batcher shut down, sequence %s forced to complete", seq.SequenceID)
// 强制返回部分结果或错误
finalRes := &InferenceResponse{
ID: seq.InternalRequest.ID,
OutputData: seq.InternalRequest.InputData + " " + seq.LastTokenOutput, // 返回部分结果
Error: err,
}
select {
case seq.InternalRequest.response <- finalRes:
case <-time.After(time.Millisecond * 100):
log.Printf("Failed to send final response for request %s", seq.InternalRequest.ID)
}
} else {
// 仅发送错误
select {
case seq.InternalRequest.response <- &InferenceResponse{ID: seq.InternalRequest.ID, Error: err}:
case <-time.After(time.Millisecond * 100):
log.Printf("Failed to send error response for request %s", seq.InternalRequest.ID)
}
}
seq.IsFinished = true
}
}
b.decodingSequences = b.decodingSequences[:0] // 清空所有序列
}
4.4 模拟 GPU 推理客户端
为了测试上述 Batcher,我们需要一个 GPUInferencer 的模拟实现。
// MockGPUInferencer 模拟 GPU 推理客户端
type MockGPUInferencer struct{}
func (m *MockGPUInferencer) Infer(ctx context.Context, batch *Batch) (*BatchResult, error) {
log.Printf("[MockInferencer] Inferring batch %s with %d requests...", batch.ID, len(batch.Requests))
time.Sleep(time.Millisecond * time.Duration(100 + len(batch.Requests)*5)) // 模拟推理时间
results := make([]*InferenceResponse, len(batch.Requests))
for i, req := range batch.Requests {
results[i] = &InferenceResponse{
ID: req.ID,
OutputData: fmt.Sprintf("Processed: %s", req.InputData),
}
}
return &BatchResult{BatchID: batch.ID, Results: results}, nil
}
func (m *MockGPUInferencer) Prefill(ctx context.Context, batch *Batch) (*BatchResult, error) {
log.Printf("[MockInferencer] Prefilling LLM batch %s with %d prompts...", batch.ID, len(batch.Requests))
time.Sleep(time.Millisecond * time.Duration(150 + len(batch.Requests)*10)) // 模拟预填充时间
results := make([]*InferenceResponse, len(batch.Requests))
for i, req := range batch.Requests {
results[i] = &InferenceResponse{
ID: req.ID,
OutputData: fmt.Sprintf("LLM Prefill Output for: %s (first token)", req.InputData), // 返回第一个 token
}
}
return &BatchResult{BatchID: batch.ID, Results: results}, nil
}
func (m *MockGPUInferencer) Decode(ctx context.Context, batch *Batch) (*BatchResult, error) {
log.Printf("[MockInferencer] Decoding LLM batch %s with %d sequences...", batch.ID, len(batch.Requests))
time.Sleep(time.Millisecond * time.Duration(50 + len(batch.Requests)*2)) // 模拟解码时间
results := make([]*InferenceResponse, len(batch.Requests))
for i, req := range batch.Requests {
// 模拟生成下一个 token
// 实际中这里会根据 seq.KVStoreRef 和 seq.LastTokenOutput 进行推理
results[i] = &InferenceResponse{
ID: req.ID,
OutputData: fmt.Sprintf("LLM Decode Output (next token) for: %s", req.InputData),
}
}
return &BatchResult{BatchID: batch.ID, Results: results}, nil
}
5. 性能考量与权衡
持续批处理显著提升 GPU 利用率的同时,也引入了一些权衡和挑战:
| 特性 | 传统请求-响应模式 | 持续批处理模式 |
|---|---|---|
| GPU 利用率 | 低,存在大量空闲时间 | 高,持续有任务在 GPU 上执行 |
| 吞吐量 | 受限于单个请求的 GPU 启动/数据传输开销,总体吞吐量较低 | 高,通过并行处理多个请求大幅提升 |
| 平均延迟 | 对于单个请求可能较低(无批处理等待),但高并发下整体延迟高 | 对于部分请求可能因等待批次形成而略有增加,但整体系统平均延迟降低 |
| 尾部延迟 | 在高并发下可能很高 | 批次超时设置不当可能导致尾部延迟增加,但整体更可控 |
| 实现复杂度 | 简单 | 较高,需要复杂的批处理逻辑、状态管理和调度 |
| 资源消耗 (CPU) | 较低 | 较高,批处理逻辑本身需要 CPU 资源进行请求合并、调度和结果分发 |
| 内存消耗 | 较低 | 较高,需要维护请求队列、批次数据和 LLM 序列状态 |
| 动态性 | 差 | 强,能根据请求负载动态调整批次大小 |
主要权衡点:
- 延迟 vs. 吞吐量:通过批处理,我们牺牲了单个请求的最小可能延迟(因为有批处理等待时间),以换取系统整体吞吐量的显著提升。通过调整
BatchTimeout可以平衡这两者。更小的BatchTimeout会降低延迟但可能导致批次过小,降低 GPU 利用率;更大的BatchTimeout会提升 GPU 利用率但可能增加延迟。 - 复杂度 vs. 效率:持续批处理的实现比简单地处理单个请求复杂得多。需要仔细设计并发模型、数据结构、错误处理和状态管理。
- 资源消耗:批处理管理器本身会消耗一定的 CPU 和内存资源。特别是在 LLM 的令牌级批处理中,维护大量序列的 KV 缓存引用和状态需要精心管理。
6. 监控与度量
为了有效地运行和优化持续批处理系统,完善的监控至关重要:
- 队列深度 (Queue Depth):监控
requestQueue、prefillQueue和decodingSequences的当前长度。队列过长可能表明 GPU 推理服务处理能力不足。 - 批次大小 (Batch Size):记录每个批次的请求数量。理想情况下,批次大小应尽可能接近
MaxBatchSize。 - 批处理延迟 (Batching Latency):从请求进入队列到批次提交给 GPU 推理服务的时间。这直接影响了客户端的感知延迟。
- GPU 利用率 (GPU Utilization):通过 GPU 监控工具(如
nvidia-smi)获取 GPU 核心、内存利用率。 - 推理时间 (Inference Time):记录 GPU 推理服务处理一个批次所需的时间。
- 端到端延迟 (End-to-End Latency):从客户端发送请求到接收到响应的总时间。
- 吞吐量 (Throughput):每秒处理的请求数量或每秒生成的令牌数量。
- 错误率 (Error Rate):批处理或推理过程中发生的错误数量。
7. 总结与展望
持续批处理是提升 GPU 利用率和机器学习服务吞吐量的强大技术。在 Go 后端中实现它,需要深入理解并发编程范式、细致的数据结构设计以及与外部 GPU 推理服务的有效协作。特别是对于大型语言模型,令牌级的持续批处理已成为行业标准,它通过精细的序列调度和 KV 缓存管理,实现了前所未有的 GPU 效率。
虽然实现复杂度较高,但通过合理的架构设计和 Go 语言的并发特性,我们可以构建出高性能、可扩展的持续批处理系统。未来的发展方向包括更智能的调度算法、动态调整批次大小以适应实时负载、以及更紧密的与底层推理框架(如 vLLM, TensorRT-LLM)的集成,以进一步榨取 GPU 性能。