什么是 ‘Continuous Batching’?在 Go 后端中实现动态合并请求以提升 GPU 利用率的算法

持续批处理 (Continuous Batching) 在 Go 后端中实现动态合并请求以提升 GPU 利用率

尊敬的各位开发者,大家好!

今天我们将深入探讨一个在高性能、低延迟服务,尤其是涉及大量计算密集型任务(如机器学习推理)时至关重要的技术:持续批处理 (Continuous Batching)。随着人工智能,特别是大型语言模型 (LLMs) 的飞速发展,如何高效利用昂贵的 GPU 资源成为了后端服务面临的核心挑战。传统的请求处理模式往往导致 GPU 资源的严重浪费。本讲座将从理论到实践,详细阐述持续批处理的概念、其在 Go 后端中的实现策略、关键数据结构与算法,并辅以丰富的 Go 语言代码示例,以期为大家提供一套提升 GPU 利用率的实战方案。

1. 传统请求处理模式与 GPU 利用率瓶颈

在典型的 Web 服务架构中,客户端发出请求,后端服务接收请求,处理后返回响应。当涉及机器学习推理时,这个过程通常是:

  1. 客户端发送包含输入数据的请求。
  2. 后端服务接收请求,将输入数据转化为模型所需的张量格式。
  3. 后端将张量发送给推理服务(可能是一个独立的微服务,或者通过 CGO/FFI 直接调用本地库)。
  4. 推理服务在 GPU 上执行模型推理。
  5. 推理结果返回给后端服务。
  6. 后端服务将结果格式化后返回给客户端。

这种“请求-响应”模式在处理单个请求时效率很高,但在高并发场景下,尤其是当每个请求的推理任务相对较小或间歇性发生时,会暴露出严重的 GPU 利用率问题。

GPU 利用率瓶颈的主要原因:

  • GPU 启动开销 (Kernel Launch Overhead):每次在 GPU 上执行计算任务(称为 Kernel),都需要一定的启动时间。对于小型任务,这个启动时间可能占据总执行时间的很大一部分,导致实际计算时间占比很低。
  • 数据传输开销 (Data Transfer Overhead):CPU 和 GPU 之间的数据传输(PCIe 带宽限制)也是一个显著的开销。对于小批量数据,传输时间可能超过实际计算时间。
  • 空闲等待 (Idle Waiting):当 GPU 完成一个请求的推理后,如果没有新的请求立即到来,GPU 就会处于空闲状态,等待下一个请求。在高并发但请求到达不规律的场景下,这种空闲等待会频繁发生。
  • 固定批处理大小的局限性 (Fixed Batch Size Limitations):为了提升吞吐量,传统的解决方案是固定批处理大小。例如,每次推理都处理 8 个请求。如果请求数量不足 8 个,推理会等待直到凑齐,这会增加延迟;如果请求数量远超 8 个,则需要排队,同样增加延迟。

这些问题共同导致了 GPU 资源的低效利用,直接影响了服务的吞吐量和成本效益。

2. 什么是持续批处理 (Continuous Batching)?

持续批处理是一种旨在最大化 GPU 利用率的技术,它通过动态地将多个客户端请求合并成一个更大的批次,然后一次性提交给 GPU 进行推理。与传统的固定批处理不同,持续批处理更加灵活和动态,它不会等待一个完整的固定批次,而是在一个预设的时间窗口内或达到某个最大批次大小时,就立即处理当前队列中的所有请求。

更进一步,特别是对于大型语言模型 (LLMs),持续批处理不仅仅是请求级别的合并,更是令牌级别 (Token-level) 的合并。LLM 推理通常分为两个阶段:

  1. 预填充 (Pre-fill):处理用户的输入提示 (prompt),计算初始的 KV 缓存 (Key-Value Cache)。这通常是一个矩阵-向量乘法密集的操作。
  2. 解码 (Decoding):逐个生成新的令牌。在每次生成一个令牌时,模型会根据之前的 KV 缓存和当前生成的令牌进行推理,并更新 KV 缓存。这是一个迭代过程。

在传统的处理方式中,每个请求的预填充和解码过程是独立的。当一个请求正在解码时,GPU 可能因为等待下一个令牌的输入而空闲。持续批处理,特别是对于 LLMs,会同时管理多个正在进行的序列 (sequences)。当一个序列完成预填充并进入解码阶段时,它会与其他正在解码的序列一起被调度到 GPU 上。这样,GPU 可以在单个操作中并行处理多个序列的下一个令牌生成,极大地减少了空闲时间,提高了吞吐量。

持续批处理的核心思想:

  • 动态合并请求:不预设固定的批次大小,而是根据请求到达的速度和系统负载动态调整批次。
  • 最小化 GPU 空闲时间:尽可能确保 GPU 在任何时刻都有任务可执行。
  • 优化数据传输:通过合并请求,减少 CPU-GPU 之间的数据传输次数,并利用更大的传输块。
  • 令牌级调度 (LLMs):对于 LLMs,在解码阶段,将多个序列的令牌生成任务合并到一个 GPU 操作中。

3. 持续批处理的关键组成部分与工作流程

为了在 Go 后端实现持续批处理,我们需要设计一个协调机制,它包含以下几个主要组件:

  1. 请求接收器 (Request Receiver):负责接收客户端的原始请求。
  2. 批处理管理器 (Batch Manager):核心组件,维护一个待处理请求队列,并根据策略(时间或大小)触发批处理。
  3. GPU 推理器接口 (GPU Inferencer Interface):负责与实际的 GPU 推理服务(通常是独立的微服务,如使用 Triton Inference Server, vLLM, 或 TensorRT-LLM 搭建的 Python/C++ 服务)进行通信。
  4. 响应分发器 (Response Dispatcher):将批处理推理的结果分解,并与原始请求关联,将结果返回给相应的客户端。

工作流程概览:

  1. 客户端发送请求 Req_A, Req_B, Req_C
  2. 请求接收器收到请求,为每个请求创建一个内部表示 InternalRequest,并将其发送到批处理管理器的队列。
  3. 批处理管理器在一个专门的 Goroutine 中运行。它不断地从队列中取出请求。
    • 当满足批处理条件(例如,队列中的请求数量达到 MaxBatchSize,或距离上次批处理操作已超过 BatchTimeout)时,管理器会将当前队列中的所有请求合并成一个 Batch
  4. 批处理管理器将 Batch 发送给 GPU 推理器接口。
  5. GPU 推理器接口与底层的 GPU 推理服务通信,提交批处理任务。
  6. GPU 推理服务执行推理,并返回一个批处理结果 BatchResult
  7. GPU 推理器接口接收 BatchResult 并将其传回批处理管理器。
  8. 批处理管理器将 BatchResult 分解,根据原始请求的标识符,将结果派发给对应的等待者(通常通过 chansync.Cond 通知)。
  9. 响应分发器(或原始请求的处理 Goroutine)接收到结果,并将其返回给客户端。

针对 LLMs 的令牌级持续批处理:

对于 LLMs,这个流程会更复杂,因为推理是迭代的。批处理管理器需要维护每个序列的内部状态。

  1. 初始请求 (Prompt Processing)
    • 客户端发送请求 (prompt_A, max_tokens_A)
    • 请求被放入批处理管理器的“预填充队列”。
    • 当满足条件时,预填充队列中的请求被合并为一个批次,发送给 GPU 推理器接口进行预填充。
    • GPU 推理服务完成预填充,返回初始的 KV 缓存和第一个生成的令牌。
    • 批处理管理器接收结果,更新每个序列的状态(存储 KV 缓存,记录已生成的令牌),并将序列移至“解码队列”。
  2. 解码阶段 (Token Generation)
    • 批处理管理器定期从“解码队列”中选择一批活动序列(那些尚未完成 max_tokens 或未生成停止符的序列)。
    • 将这些序列的当前状态(包括 KV 缓存、已生成的最后一个令牌)打包成一个批次,发送给 GPU 推理器接口进行解码。
    • GPU 推理服务根据这些状态并行生成下一个令牌。
    • 批处理管理器接收结果,更新每个序列的状态(追加新生成的令牌,更新 KV 缓存)。
    • 如果序列达到 max_tokens 或生成停止符,则将其从活动序列中移除,并派发最终结果给客户端。否则,它会继续留在解码队列中,等待下一轮的解码。

这种令牌级的批处理需要 GPU 推理服务本身支持复杂的序列调度和 KV 缓存管理(例如 vLLM 的 PagedAttention)。Go 后端的作用是作为这些序列的管理者和调度者。

4. Go 后端中的架构设计与数据结构

在 Go 中实现持续批处理,我们需要精心设计数据结构和 Goroutine 之间的协作。

4.1 核心数据结构

1. 客户端请求的内部表示 (InternalRequest)

package main

import (
    "context"
    "fmt"
    "time"
)

// InferenceRequest 代表客户端发送的原始推理请求
type InferenceRequest struct {
    ID        string            // 请求唯一标识符
    InputData string            // 原始输入数据,例如文本
    Params    map[string]string // 推理参数,例如 max_new_tokens, temperature
}

// InferenceResponse 代表推理服务的响应
type InferenceResponse struct {
    ID         string
    OutputData string
    Error      error
}

// InternalRequest 是在批处理系统内部使用的请求表示
// 它包含了客户端请求、上下文以及用于接收结果的通道
type InternalRequest struct {
    InferenceRequest
    ctx      context.Context             // 用于处理请求的上下文,支持超时和取消
    response chan *InferenceResponse     // 用于将推理结果返回给发起者
    createdAt time.Time                   // 请求创建时间,用于计算等待时间
    isLLM    bool                        // 是否是 LLM 推理请求
    sequence *LLMSequenceState           // 如果是 LLM 请求,保存其序列状态
    // ... 其他内部状态 ...
}

// LLMSequenceState 存储 LLM 推理的动态状态
type LLMSequenceState struct {
    SequenceID        string
    PromptTokens      []int                 // 原始 prompt 的 token ID
    GeneratedTokens   []int                 // 已生成的 token ID
    KVStoreRef        interface{}           // 指向 GPU 端的 KV 缓存引用,具体类型取决于推理引擎接口
    MaxNewTokens      int                   // 用户请求的最大生成 token 数量
    StopReason        string                // 停止生成的原因
    IsFinished        bool                  // 序列是否已完成
    LastTokenOutput   string                // 上一轮生成的 token
    InternalRequest   *InternalRequest      // 关联的原始 InternalRequest
    LastAccessTime    time.Time             // 最后一次被调度的时间,用于调度策略
}

// Batch represents a collection of InternalRequests to be processed together
type Batch struct {
    ID          string
    Requests    []*InternalRequest
    SubmittedAt time.Time
    BatchType   string // "prefill" 或 "decode"
    // ... 其他批次信息 ...
}

// BatchResult represents the combined results from a batched inference operation
type BatchResult struct {
    BatchID string
    Results []*InferenceResponse
    Error   error
    // ... 其他批次结果信息 ...
}

2. 批处理管理器 (Batcher)

批处理管理器是整个系统的核心,它负责协调请求的入队、批次的创建和分发。

// BatcherConfig 批处理器的配置
type BatcherConfig struct {
    MaxBatchSize     int           // 单个批次最大请求数
    BatchTimeout     time.Duration // 批次超时时间,无论请求数量多少,都会触发批处理
    WorkerPoolSize   int           // 后端推理工作 Goroutine 数量
    LLMDecodingInterval time.Duration // LLM 解码批处理的间隔
}

// Batcher 核心批处理管理器
type Batcher struct {
    config BatcherConfig

    requestQueue     chan *InternalRequest       // 接收新到来的请求
    internalReqStore map[string]*InternalRequest // 存储所有活跃的 InternalRequest,用于结果映射

    prefillQueue      []*InternalRequest          // LLM 预填充请求队列
    decodingSequences []*LLMSequenceState         // LLM 正在解码的序列

    batchWg           sync.WaitGroup              // 用于等待所有批次处理完成
    shutdown          chan struct{}               // 关闭信号
    mu                sync.Mutex                  // 保护对 requestQueue 和 internalReqStore 的访问
    nextBatchID       int64                       // 批次 ID 生成器

    inferenceClient GPUInferencer               // GPU 推理客户端接口
}

// GPUInferencer 定义了与 GPU 推理服务交互的接口
type GPUInferencer interface {
    Infer(ctx context.Context, batch *Batch) (*BatchResult, error)
    // 对于 LLM,可能需要更细粒度的接口
    Prefill(ctx context.Context, batch *Batch) (*BatchResult, error)
    Decode(ctx context.Context, decodeBatch *Batch) (*BatchResult, error)
    // ... 其他控制接口,如 KV cache 清理等
}

4.2 批处理管理器 Goroutine

Batcher 内部会启动一个或多个 Goroutine 来执行批处理逻辑。

package main

import (
    "context"
    "fmt"
    "log"
    "sync"
    "time"

    "github.com/google/uuid" // 假设使用 uuid 生成 ID
)

// NewBatcher 创建一个新的 Batcher 实例
func NewBatcher(config BatcherConfig, client GPUInferencer) *Batcher {
    return &Batcher{
        config:           config,
        requestQueue:     make(chan *InternalRequest, config.MaxBatchSize*2), // 缓冲区大小
        internalReqStore: make(map[string]*InternalRequest),
        shutdown:         make(chan struct{}),
        nextBatchID:      0,
        inferenceClient:  client,
        prefillQueue:     make([]*InternalRequest, 0, config.MaxBatchSize),
        decodingSequences: make([]*LLMSequenceState, 0, config.MaxBatchSize),
    }
}

// Start 启动 Batcher 的主 Goroutine
func (b *Batcher) Start() {
    log.Println("Batcher starting...")
    b.batchWg.Add(1)
    go b.run() // 主批处理循环

    // 启动 LLM 解码调度器 (如果需要)
    if b.config.LLMDecodingInterval > 0 {
        b.batchWg.Add(1)
        go b.runLLMDecoder()
    }
}

// Stop 停止 Batcher
func (b *Batcher) Stop() {
    log.Println("Batcher stopping...")
    close(b.shutdown)
    b.batchWg.Wait() // 等待所有批处理 Goroutine 退出
    log.Println("Batcher stopped.")
}

// SubmitRequest 提交一个客户端请求到批处理器
func (b *Batcher) SubmitRequest(ctx context.Context, req InferenceRequest, isLLM bool) (*InferenceResponse, error) {
    internalReq := &InternalRequest{
        InferenceRequest: req,
        ctx:              ctx,
        response:         make(chan *InferenceResponse, 1), // 缓冲通道,避免阻塞发送
        createdAt:        time.Now(),
        isLLM:            isLLM,
    }

    b.mu.Lock()
    b.internalReqStore[req.ID] = internalReq // 存储请求以便后续结果映射
    b.mu.Unlock()

    select {
    case b.requestQueue <- internalReq:
        // 请求成功入队
    case <-ctx.Done():
        b.mu.Lock()
        delete(b.internalReqStore, req.ID) // 请求被取消,从存储中移除
        b.mu.Unlock()
        return nil, ctx.Err()
    case <-b.shutdown:
        b.mu.Lock()
        delete(b.internalReqStore, req.ID)
        b.mu.Unlock()
        return nil, fmt.Errorf("batcher is shutting down")
    }

    // 等待推理结果
    select {
    case res := <-internalReq.response:
        b.mu.Lock()
        delete(b.internalReqStore, req.ID) // 结果已返回,从存储中移除
        b.mu.Unlock()
        return res, res.Error
    case <-ctx.Done():
        b.mu.Lock()
        delete(b.internalReqStore, req.ID)
        b.mu.Unlock()
        return nil, ctx.Err()
    case <-b.shutdown:
        b.mu.Lock()
        delete(b.internalReqStore, req.ID)
        b.mu.Unlock()
        return nil, fmt.Errorf("batcher is shutting down")
    }
}

// run 是 Batcher 的主循环,负责收集请求和触发批处理
func (b *Batcher) run() {
    defer b.batchWg.Done()

    ticker := time.NewTicker(b.config.BatchTimeout)
    defer ticker.Stop()

    // 缓冲区用于暂存请求,以便在触发批处理时一次性取出
    tempBuffer := make([]*InternalRequest, 0, b.config.MaxBatchSize)

    for {
        select {
        case req := <-b.requestQueue:
            if req.isLLM {
                b.prefillQueue = append(b.prefillQueue, req)
            } else {
                tempBuffer = append(tempBuffer, req)
            }
            // 检查是否达到最大批次大小,如果达到则立即触发处理
            b.tryProcessBatch(&tempBuffer, "general_inference")
            b.tryProcessLLMPrefillBatch()

        case <-ticker.C:
            // 批次超时,无论请求数量多少,都触发处理
            b.tryProcessBatch(&tempBuffer, "general_inference")
            b.tryProcessLLMPrefillBatch()

        case <-b.shutdown:
            log.Println("Batcher run goroutine shutting down.")
            // 在关闭前处理剩余的请求
            b.tryProcessBatch(&tempBuffer, "general_inference")
            b.tryProcessLLMPrefillBatch()
            b.processRemainingLLMDecodingSequences(true) // 强制完成所有 LLM 序列
            return
        }
    }
}

// tryProcessBatch 尝试处理一个通用推理批次
func (b *Batcher) tryProcessBatch(buffer *[]*InternalRequest, batchType string) {
    if len(*buffer) == 0 {
        return
    }

    // 如果达到最大批次大小,立即处理
    if len(*buffer) >= b.config.MaxBatchSize {
        b.processBatch(*buffer, batchType)
        *buffer = (*buffer)[:0] // 清空 buffer
    }
    // 注意:这里需要一个外部的定时器来触发超时批处理,或者在 `run` 循环中处理
}

// processBatch 实际发送批次到 GPU 推理服务
func (b *Batcher) processBatch(requests []*InternalRequest, batchType string) {
    if len(requests) == 0 {
        return
    }

    batchID := fmt.Sprintf("batch-%d-%s", b.nextBatchID, uuid.New().String())
    b.nextBatchID++

    batch := &Batch{
        ID:          batchID,
        Requests:    requests,
        SubmittedAt: time.Now(),
        BatchType:   batchType,
    }

    log.Printf("Processing batch %s with %d requests (type: %s).", batch.ID, len(batch.Requests), batch.BatchType)

    b.batchWg.Add(1)
    go func() {
        defer b.batchWg.Done()
        ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) // 推理超时
        defer cancel()

        var batchResult *BatchResult
        var err error

        switch batch.BatchType {
        case "general_inference":
            batchResult, err = b.inferenceClient.Infer(ctx, batch)
        case "prefill":
            batchResult, err = b.inferenceClient.Prefill(ctx, batch)
        case "decode":
            batchResult, err = b.inferenceClient.Decode(ctx, batch)
        default:
            err = fmt.Errorf("unknown batch type: %s", batch.BatchType)
        }

        if err != nil {
            log.Printf("Batch %s inference failed: %v", batch.ID, err)
            for _, req := range requests {
                // 通知所有请求失败
                select {
                case req.response <- &InferenceResponse{ID: req.ID, Error: err}:
                case <-time.After(time.Millisecond * 100): // 避免阻塞
                    log.Printf("Failed to send error response for request %s (channel blocked)", req.ID)
                }
            }
            return
        }

        if batchResult == nil {
            log.Printf("Batch %s inference returned nil result.", batch.ID)
            err = fmt.Errorf("inference service returned nil result")
            for _, req := range requests {
                select {
                case req.response <- &InferenceResponse{ID: req.ID, Error: err}:
                case <-time.After(time.Millisecond * 100):
                    log.Printf("Failed to send error response for request %s (channel blocked)", req.ID)
                }
            }
            return
        }

        b.distributeResults(batch, batchResult)
    }()
}

// distributeResults 将批处理结果分发给原始请求
func (b *Batcher) distributeResults(batch *Batch, batchResult *BatchResult) {
    // 假设 BatchResult.Results 数组的顺序与 Batch.Requests 相同
    // 或者 BatchResult 内部包含映射关系 (例如 map[string]*InferenceResponse)
    // 这里我们假设是按顺序的,实际情况可能需要更复杂的匹配逻辑

    if len(batch.Requests) != len(batchResult.Results) && batchResult.Error == nil {
        log.Printf("Warning: Batch %s request count mismatch with result count. Requests: %d, Results: %d",
            batch.ID, len(batch.Requests), len(batchResult.Results))
        // 如果结果数量不匹配,视为整个批次失败或部分失败,需要更复杂的处理
        // 简单处理:将错误分发给所有请求
        err := fmt.Errorf("result count mismatch for batch %s", batch.ID)
        for _, req := range batch.Requests {
            select {
            case req.response <- &InferenceResponse{ID: req.ID, Error: err}:
            case <-time.After(time.Millisecond * 100):
                log.Printf("Failed to send error response for request %s (channel blocked)", req.ID)
            }
        }
        return
    }

    for i, req := range batch.Requests {
        var res *InferenceResponse
        if batchResult.Error != nil {
            // 整个批次失败
            res = &InferenceResponse{ID: req.ID, Error: batchResult.Error}
        } else {
            // 正常分发单个结果
            res = batchResult.Results[i]
            if res.ID == "" { // 补充 ID
                res.ID = req.ID
            }
        }

        // 对于 LLM 预填充请求,需要更新其序列状态并将其移至解码队列
        if req.isLLM && batch.BatchType == "prefill" && res.Error == nil {
            // 假设 res.OutputData 包含第一个生成的 token 和 KV 缓存引用
            // 实际中这部分逻辑会更复杂,可能需要从 BatchResult 中解析出 KV 缓存等信息
            seqState := &LLMSequenceState{
                SequenceID:        uuid.New().String(), // 每个序列一个 ID
                PromptTokens:      nil, // 假设已在推理服务中处理
                GeneratedTokens:   []int{123}, // 假设第一个 token ID
                KVStoreRef:        fmt.Sprintf("kv_ref_%s", uuid.New().String()), // 模拟 KV 缓存引用
                MaxNewTokens:      100, // 从 req.Params 中解析
                IsFinished:        false,
                LastTokenOutput:   "first_token",
                InternalRequest:   req,
                LastAccessTime:    time.Now(),
            }
            // 更新原始请求的 sequence 字段
            req.sequence = seqState

            b.mu.Lock()
            b.decodingSequences = append(b.decodingSequences, seqState)
            b.mu.Unlock()
            // 此时不向客户端返回结果,而是等待整个序列完成
            log.Printf("LLM prefill for request %s completed. Sequence %s moved to decoding queue.", req.ID, seqState.SequenceID)
            continue // 跳过对 req.response 的发送,等待最终结果
        }

        // 将结果发送回原始请求的通道
        select {
        case req.response <- res:
            // 成功发送
        case <-req.ctx.Done():
            // 客户端已取消请求
            log.Printf("Client for request %s cancelled while sending response.", req.ID)
        case <-time.After(time.Millisecond * 100):
            // 通道被阻塞,可能客户端已退出或发生其他问题
            log.Printf("Failed to send response for request %s (channel blocked)", req.ID)
        }
    }
}

4.3 LLM 令牌级持续批处理的实现细节

上面通用 Batcher 已经包含了 prefillQueuedecodingSequences。现在我们来具体实现 LLM 特有的批处理逻辑。

1. 预填充批次处理 (Prefill Batch)

prefillQueue 中有足够多的请求时,或者达到超时,就触发预填充批次。

// tryProcessLLMPrefillBatch 尝试处理 LLM 预填充批次
func (b *Batcher) tryProcessLLMPrefillBatch() {
    b.mu.Lock()
    defer b.mu.Unlock()

    if len(b.prefillQueue) == 0 {
        return
    }

    // 达到最大批次大小,立即处理
    if len(b.prefillQueue) >= b.config.MaxBatchSize {
        // 取出当前所有预填充请求
        batchRequests := b.prefillQueue
        b.prefillQueue = make([]*InternalRequest, 0, b.config.MaxBatchSize) // 清空队列
        b.processBatch(batchRequests, "prefill")
    }
    // 注意:预填充的超时触发逻辑与通用推理批次相同,由 `run` 循环中的 ticker 触发
}

2. 解码批次调度 (Decoding Batch)

解码批次需要一个独立的 Goroutine 定期运行,因为它是一个迭代过程。

// runLLMDecoder 负责 LLM 解码阶段的批处理调度
func (b *Batcher) runLLMDecoder() {
    defer b.batchWg.Done()

    ticker := time.NewTicker(b.config.LLMDecodingInterval)
    defer ticker.Stop()

    for {
        select {
        case <-ticker.C:
            b.processLLMDecodingBatch()
        case <-b.shutdown:
            log.Println("LLM decoder goroutine shutting down.")
            b.processRemainingLLMDecodingSequences(false) // 尝试处理剩余序列,但不强制完成
            return
        }
    }
}

// processLLMDecodingBatch 收集正在解码的序列并进行批处理
func (b *Batcher) processLLMDecodingBatch() {
    b.mu.Lock()
    defer b.mu.Unlock()

    if len(b.decodingSequences) == 0 {
        return
    }

    // 收集所有未完成的序列,准备进行下一轮解码
    activeSequences := make([]*LLMSequenceState, 0, len(b.decodingSequences))
    pendingRequests := make([]*InternalRequest, 0, len(b.decodingSequences))

    // 过滤掉已完成的序列,并准备批处理输入
    for _, seq := range b.decodingSequences {
        if !seq.IsFinished {
            activeSequences = append(activeSequences, seq)
            pendingRequests = append(pendingRequests, seq.InternalRequest)
        }
    }

    if len(activeSequences) == 0 {
        b.decodingSequences = b.decodingSequences[:0] // 清空已完成的序列
        return
    }

    // 创建一个特殊的批次,包含所有活动序列的信息
    // 注意:这里需要将 LLMSequenceState 转换为 GPU 推理服务所需的批处理输入格式
    // 例如,包含所有序列的 KV 缓存引用、最后一个生成的 token ID 等
    // 为了简化,我们假设 processBatch 可以处理这种批次类型,并从 InternalRequest 中提取必要信息
    // 实际中,可能需要一个专门的 `buildLLMDecodeBatch` 函数

    decodeBatchID := fmt.Sprintf("decode-batch-%d-%s", b.nextBatchID, uuid.New().String())
    b.nextBatchID++

    decodeBatch := &Batch{
        ID:          decodeBatchID,
        Requests:    pendingRequests, // 这里 Requests 实际上是持有 LLMSequenceState 的 InternalRequest
        SubmittedAt: time.Now(),
        BatchType:   "decode",
    }

    log.Printf("Processing LLM decode batch %s with %d active sequences.", decodeBatch.ID, len(activeSequences))

    // 提交解码批次到推理服务
    b.batchWg.Add(1)
    go func(batch *Batch, sequences []*LLMSequenceState) {
        defer b.batchWg.Done()
        ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) // 解码超时
        defer cancel()

        batchResult, err := b.inferenceClient.Decode(ctx, batch) // 调用专门的解码接口

        if err != nil {
            log.Printf("LLM decode batch %s inference failed: %v", batch.ID, err)
            // 通知所有相关请求失败
            for _, req := range batch.Requests {
                if req.sequence != nil {
                    req.sequence.IsFinished = true // 标记为完成,避免再次调度
                }
                select {
                case req.response <- &InferenceResponse{ID: req.ID, Error: err}:
                case <-time.After(time.Millisecond * 100):
                    log.Printf("Failed to send error response for request %s (channel blocked)", req.ID)
                }
            }
            return
        }

        if batchResult == nil {
            log.Printf("LLM decode batch %s inference returned nil result.", batch.ID)
            err = fmt.Errorf("inference service returned nil result for decode batch")
            for _, req := range batch.Requests {
                if req.sequence != nil {
                    req.sequence.IsFinished = true
                }
                select {
                case req.response <- &InferenceResponse{ID: req.ID, Error: err}:
                case <-time.After(time.Millisecond * 100):
                    log.Printf("Failed to send error response for request %s (channel blocked)", req.ID)
                }
            }
            return
        }

        // 处理解码结果,更新序列状态
        b.handleLLMDecodeResults(batch, batchResult, sequences)

    }(decodeBatch, activeSequences)
}

// handleLLMDecodeResults 处理 LLM 解码批次的结果
func (b *Batcher) handleLLMDecodeResults(batch *Batch, batchResult *BatchResult, activeSequences []*LLMSequenceState) {
    b.mu.Lock()
    defer b.mu.Unlock()

    // 假设 batchResult.Results 也是与 activeSequences 顺序对应的
    // 或者包含 SequenceID 到结果的映射
    if len(batchResult.Results) != len(activeSequences) {
        log.Printf("Warning: LLM decode batch %s result count mismatch. Active sequences: %d, Results: %d",
            batch.ID, len(activeSequences), len(batchResult.Results))
        // 处理错误,标记所有序列为完成并发送错误
        for _, seq := range activeSequences {
            seq.IsFinished = true
            select {
            case seq.InternalRequest.response <- &InferenceResponse{ID: seq.InternalRequest.ID, Error: fmt.Errorf("decode result mismatch")}:
            case <-time.After(time.Millisecond * 100):
                log.Printf("Failed to send error response for request %s", seq.InternalRequest.ID)
            }
        }
        return
    }

    newActiveSequences := make([]*LLMSequenceState, 0, len(b.decodingSequences))
    for i, seq := range activeSequences {
        res := batchResult.Results[i] // 对应当前序列的结果

        if res.Error != nil {
            log.Printf("Sequence %s failed during decode: %v", seq.SequenceID, res.Error)
            seq.IsFinished = true
            select {
            case seq.InternalRequest.response <- res:
            case <-time.After(time.Millisecond * 100):
                log.Printf("Failed to send error response for request %s", seq.InternalRequest.ID)
            }
            continue
        }

        // 更新序列状态
        // 实际中,res.OutputData 会是新生成的 token,需要解析并添加到 GeneratedTokens
        // 并且可能包含更新后的 KV 缓存引用
        seq.GeneratedTokens = append(seq.GeneratedTokens, 456) // 模拟新 token ID
        seq.LastTokenOutput = res.OutputData // 假设 OutputData 是新生成的 token 文本
        seq.LastAccessTime = time.Now()

        // 检查是否达到停止条件 (max_tokens 或生成停止符)
        // 假设 max_new_tokens 是从 InferenceRequest.Params 中解析出来的
        currentMaxTokens, _ := seq.InternalRequest.Params["max_new_tokens"] // 示例
        maxNewTokens := 100
        fmt.Sscanf(currentMaxTokens, "%d", &maxNewTokens)

        if len(seq.GeneratedTokens) >= maxNewTokens || res.OutputData == "<EOS>" { // 假设 "<EOS>" 是停止符
            seq.IsFinished = true
            seq.StopReason = "max_tokens_reached" // 或 "stop_token_generated"
            log.Printf("Sequence %s finished. Total tokens: %d, reason: %s", seq.SequenceID, len(seq.GeneratedTokens), seq.StopReason)

            // 序列完成,发送最终结果给客户端
            finalRes := &InferenceResponse{
                ID:         seq.InternalRequest.ID,
                OutputData: seq.InternalRequest.InputData + " " + seq.LastTokenOutput, // 拼接所有生成的 token
                Error:      nil,
            }
            select {
            case seq.InternalRequest.response <- finalRes:
            case <-time.After(time.Millisecond * 100):
                log.Printf("Failed to send final response for request %s", seq.InternalRequest.ID)
            }
            // KV 缓存清理 (如果推理服务支持)
            // b.inferenceClient.ReleaseKV(seq.KVStoreRef)
        } else {
            newActiveSequences = append(newActiveSequences, seq) // 仍然活跃,保留在队列中
        }
    }
    b.decodingSequences = newActiveSequences // 更新活动序列列表
}

// processRemainingLLMDecodingSequences 在 Batcher 关闭时处理剩余的 LLM 序列
func (b *Batcher) processRemainingLLMDecodingSequences(forceComplete bool) {
    b.mu.Lock()
    defer b.mu.Unlock()

    for _, seq := range b.decodingSequences {
        if !seq.IsFinished {
            err := fmt.Errorf("batcher shutting down, sequence %s incomplete", seq.SequenceID)
            if forceComplete {
                err = fmt.Errorf("batcher shut down, sequence %s forced to complete", seq.SequenceID)
                // 强制返回部分结果或错误
                finalRes := &InferenceResponse{
                    ID:         seq.InternalRequest.ID,
                    OutputData: seq.InternalRequest.InputData + " " + seq.LastTokenOutput, // 返回部分结果
                    Error:      err,
                }
                select {
                case seq.InternalRequest.response <- finalRes:
                case <-time.After(time.Millisecond * 100):
                    log.Printf("Failed to send final response for request %s", seq.InternalRequest.ID)
                }
            } else {
                // 仅发送错误
                select {
                case seq.InternalRequest.response <- &InferenceResponse{ID: seq.InternalRequest.ID, Error: err}:
                case <-time.After(time.Millisecond * 100):
                    log.Printf("Failed to send error response for request %s", seq.InternalRequest.ID)
                }
            }
            seq.IsFinished = true
        }
    }
    b.decodingSequences = b.decodingSequences[:0] // 清空所有序列
}

4.4 模拟 GPU 推理客户端

为了测试上述 Batcher,我们需要一个 GPUInferencer 的模拟实现。

// MockGPUInferencer 模拟 GPU 推理客户端
type MockGPUInferencer struct{}

func (m *MockGPUInferencer) Infer(ctx context.Context, batch *Batch) (*BatchResult, error) {
    log.Printf("[MockInferencer] Inferring batch %s with %d requests...", batch.ID, len(batch.Requests))
    time.Sleep(time.Millisecond * time.Duration(100 + len(batch.Requests)*5)) // 模拟推理时间
    results := make([]*InferenceResponse, len(batch.Requests))
    for i, req := range batch.Requests {
        results[i] = &InferenceResponse{
            ID:         req.ID,
            OutputData: fmt.Sprintf("Processed: %s", req.InputData),
        }
    }
    return &BatchResult{BatchID: batch.ID, Results: results}, nil
}

func (m *MockGPUInferencer) Prefill(ctx context.Context, batch *Batch) (*BatchResult, error) {
    log.Printf("[MockInferencer] Prefilling LLM batch %s with %d prompts...", batch.ID, len(batch.Requests))
    time.Sleep(time.Millisecond * time.Duration(150 + len(batch.Requests)*10)) // 模拟预填充时间
    results := make([]*InferenceResponse, len(batch.Requests))
    for i, req := range batch.Requests {
        results[i] = &InferenceResponse{
            ID:         req.ID,
            OutputData: fmt.Sprintf("LLM Prefill Output for: %s (first token)", req.InputData), // 返回第一个 token
        }
    }
    return &BatchResult{BatchID: batch.ID, Results: results}, nil
}

func (m *MockGPUInferencer) Decode(ctx context.Context, batch *Batch) (*BatchResult, error) {
    log.Printf("[MockInferencer] Decoding LLM batch %s with %d sequences...", batch.ID, len(batch.Requests))
    time.Sleep(time.Millisecond * time.Duration(50 + len(batch.Requests)*2)) // 模拟解码时间
    results := make([]*InferenceResponse, len(batch.Requests))
    for i, req := range batch.Requests {
        // 模拟生成下一个 token
        // 实际中这里会根据 seq.KVStoreRef 和 seq.LastTokenOutput 进行推理
        results[i] = &InferenceResponse{
            ID:         req.ID,
            OutputData: fmt.Sprintf("LLM Decode Output (next token) for: %s", req.InputData),
        }
    }
    return &BatchResult{BatchID: batch.ID, Results: results}, nil
}

5. 性能考量与权衡

持续批处理显著提升 GPU 利用率的同时,也引入了一些权衡和挑战:

特性 传统请求-响应模式 持续批处理模式
GPU 利用率 低,存在大量空闲时间 高,持续有任务在 GPU 上执行
吞吐量 受限于单个请求的 GPU 启动/数据传输开销,总体吞吐量较低 高,通过并行处理多个请求大幅提升
平均延迟 对于单个请求可能较低(无批处理等待),但高并发下整体延迟高 对于部分请求可能因等待批次形成而略有增加,但整体系统平均延迟降低
尾部延迟 在高并发下可能很高 批次超时设置不当可能导致尾部延迟增加,但整体更可控
实现复杂度 简单 较高,需要复杂的批处理逻辑、状态管理和调度
资源消耗 (CPU) 较低 较高,批处理逻辑本身需要 CPU 资源进行请求合并、调度和结果分发
内存消耗 较低 较高,需要维护请求队列、批次数据和 LLM 序列状态
动态性 强,能根据请求负载动态调整批次大小

主要权衡点:

  • 延迟 vs. 吞吐量:通过批处理,我们牺牲了单个请求的最小可能延迟(因为有批处理等待时间),以换取系统整体吞吐量的显著提升。通过调整 BatchTimeout 可以平衡这两者。更小的 BatchTimeout 会降低延迟但可能导致批次过小,降低 GPU 利用率;更大的 BatchTimeout 会提升 GPU 利用率但可能增加延迟。
  • 复杂度 vs. 效率:持续批处理的实现比简单地处理单个请求复杂得多。需要仔细设计并发模型、数据结构、错误处理和状态管理。
  • 资源消耗:批处理管理器本身会消耗一定的 CPU 和内存资源。特别是在 LLM 的令牌级批处理中,维护大量序列的 KV 缓存引用和状态需要精心管理。

6. 监控与度量

为了有效地运行和优化持续批处理系统,完善的监控至关重要:

  • 队列深度 (Queue Depth):监控 requestQueueprefillQueuedecodingSequences 的当前长度。队列过长可能表明 GPU 推理服务处理能力不足。
  • 批次大小 (Batch Size):记录每个批次的请求数量。理想情况下,批次大小应尽可能接近 MaxBatchSize
  • 批处理延迟 (Batching Latency):从请求进入队列到批次提交给 GPU 推理服务的时间。这直接影响了客户端的感知延迟。
  • GPU 利用率 (GPU Utilization):通过 GPU 监控工具(如 nvidia-smi)获取 GPU 核心、内存利用率。
  • 推理时间 (Inference Time):记录 GPU 推理服务处理一个批次所需的时间。
  • 端到端延迟 (End-to-End Latency):从客户端发送请求到接收到响应的总时间。
  • 吞吐量 (Throughput):每秒处理的请求数量或每秒生成的令牌数量。
  • 错误率 (Error Rate):批处理或推理过程中发生的错误数量。

7. 总结与展望

持续批处理是提升 GPU 利用率和机器学习服务吞吐量的强大技术。在 Go 后端中实现它,需要深入理解并发编程范式、细致的数据结构设计以及与外部 GPU 推理服务的有效协作。特别是对于大型语言模型,令牌级的持续批处理已成为行业标准,它通过精细的序列调度和 KV 缓存管理,实现了前所未有的 GPU 效率。

虽然实现复杂度较高,但通过合理的架构设计和 Go 语言的并发特性,我们可以构建出高性能、可扩展的持续批处理系统。未来的发展方向包括更智能的调度算法、动态调整批次大小以适应实时负载、以及更紧密的与底层推理框架(如 vLLM, TensorRT-LLM)的集成,以进一步榨取 GPU 性能。

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注