深入 ‘AsyncStream’:如何在图形结构中实现细粒度的部分结果实时推送(Token-level Streaming)?

深入 ‘AsyncStream’:在图形结构中实现细粒度的部分结果实时推送

在现代应用开发中,用户体验的提升往往依赖于对长时间运行操作的即时反馈。当处理复杂的数据结构,特别是具有复杂依赖关系的图形结构(如计算图、依赖图、工作流引擎)时,等待整个计算完成再返回结果是不可接受的。我们需要一种机制,能够实时地、细粒度地推送计算的中间结果,甚至是一个节点计算过程中的“令牌”(token)。

Swift Concurrency 引入的 AsyncStream 类型正是解决此类问题的强大工具。它允许我们以异步序列的方式,将一系列值渐进地发布给消费者,从而实现实时推送和流式处理。本讲座将深入探讨如何在图形结构中利用 AsyncStream 实现这种细粒度的、实时推送到“令牌”级别(Token-level Streaming)的部分结果推送。

1. 实时推送的必要性与图形结构中的挑战

想象一个场景:你正在构建一个智能数据处理平台,用户定义了一个由多个处理步骤(节点)组成的复杂工作流。每个步骤可能涉及数据加载、转换、模型推理等耗时操作,并且这些步骤之间存在明确的依赖关系。

传统方法的局限性:

  1. 高延迟感知: 用户必须等待整个工作流的所有节点计算完毕才能看到任何结果。对于长时间运行的工作流,这会导致用户体验不佳,甚至让用户误以为系统无响应。
  2. 资源浪费: 如果某个下游节点计算失败,只有在所有上游节点都完成后才能发现,导致之前成功的计算结果被浪费。
  3. 调试困难: 难以实时监控每个节点的执行状态和中间输出。

图形结构带来的额外挑战:

  1. 依赖管理: 一个节点可能需要等待其所有前驱节点完成并提供结果后才能开始计算。
  2. 并发执行: 独立的节点或子图可以并行执行,但如何将它们各自的流式输出合并成一个统一的、有序或半有序的全局流?
  3. 细粒度控制: 不仅仅是节点完成时推送结果,而是节点在“思考”或“计算”过程中,每生成一个“令牌”就立即推送。这在大型语言模型 (LLM) 的逐词生成、复杂算法的迭代过程或数据管道的逐条记录处理中尤为重要。
  4. 错误与取消: 如何在流式处理中有效地传播错误,以及如何实现整个图形或特定节点的计算取消?

AsyncStream 作为 Swift 并发模型的核心组件,为我们提供了优雅的解决方案。

2. AsyncStream 核心概念与工作原理

AsyncStreamAsyncSequence 协议的一个具体实现,它允许你以命令式的方式生成一个异步序列。你可以通过 yield 方法向序列中发送元素,并通过 finish()finish(throwing:) 来结束序列。

核心组件:

  1. AsyncSequence 协议: 这是 Swift 并发中异步迭代的基础。任何遵循此协议的类型都可以通过 for await ... in 循环进行遍历。
  2. AsyncStream.makeStream(of:bufferingPolicy:) 这是创建 AsyncStream 的入口点。它返回一个元组 (AsyncStream<Element>, AsyncStream.Continuation)
    • AsyncStream<Element>:这是实际的异步序列,消费者将从这里接收元素。
    • AsyncStream.Continuation:这是生产者用来发送元素、完成序列或抛出错误的句柄。
      • continuation.yield(element):发送一个元素。
      • continuation.finish():正常完成序列。
      • continuation.finish(throwing: Error):异常完成序列。
  3. bufferingPolicy 定义了当生产者发送元素的速度快于消费者处理元素的速度时,如何管理缓冲区。
    • .unbounded: 无限缓冲区,可能导致内存无限增长。
    • .buffered(Int): 有限缓冲区,当缓冲区满时,yield 调用会暂停生产者,直到消费者取出元素(实现背压)。
    • .dropping(Int): 有限缓冲区,当缓冲区满时,新元素会替换旧元素(或被丢弃),不暂停生产者。
    • .overflow(Int): 有限缓冲区,当缓冲区满时,新元素会被丢弃,不暂停生产者。

基本示例:

import Foundation

// 生产者函数:创建一个 AsyncStream,每秒生成一个数字
func createNumberStream() -> AsyncStream<Int> {
    AsyncStream { continuation in
        // 在一个独立的 Task 中执行生产逻辑
        Task {
            for i in 1...5 {
                print("生产者: 正在发送 (i)")
                continuation.yield(i) // 发送元素
                try? await Task.sleep(for: .seconds(1))
            }
            print("生产者: 完成")
            continuation.finish() // 完成序列
        }
    }
}

// 消费者函数
func consumeNumberStream() async {
    let stream = createNumberStream()
    print("消费者: 开始接收")
    for await number in stream { // 使用 for await ... in 循环消费
        print("消费者: 接收到 (number)")
        // 模拟消费者处理延迟
        try? await Task.sleep(for: .milliseconds(500))
    }
    print("消费者: 流已结束")
}

// 运行示例
// Task { await consumeNumberStream() }

在这个例子中,AsyncStream 充当了生产者和消费者之间的桥梁。生产者在独立的 Task 中异步地生成数据,并通过 continuation.yield() 发送。消费者通过 for await ... in 循环异步地接收这些数据。

3. 图形结构的建模与表示

在深入实现细节之前,我们需要定义如何表示我们的图形结构。一个典型的图形结构由节点(Node)和边(Edge)组成,边表示节点之间的依赖关系。

import Foundation

// 节点的唯一标识符
typealias NodeID = String

// 定义不同类型的节点操作
enum GraphOperation: Codable, Sendable {
    case generateTokens(text: String, delayPerTokenMs: Int) // 模拟生成文本令牌
    case sum(inputIDs: [NodeID]) // 求和操作
    case multiply(inputIDs: [NodeID]) // 乘法操作
    case identity(inputID: NodeID) // 传递操作
    case custom(name: String, params: [String: String]) // 自定义操作
}

// 图中的一个节点
struct GraphNode: Identifiable, Hashable, Codable, Sendable {
    let id: NodeID
    let name: String
    let operation: GraphOperation
    let dependencies: Set<NodeID> // 此节点依赖的其他节点的ID

    var identifier: String { id } // Conforming to Identifiable
}

// 节点计算结果的类型
enum NodeResult: Codable, Sendable, CustomStringConvertible {
    case string(String)
    case integer(Int)
    case arrayOfString([String])
    case arrayOfInt([Int])
    case void

    var description: String {
        switch self {
        case .string(let s): return "String("(s)")"
        case .integer(let i): return "Int((i))"
        case .arrayOfString(let arr): return "Array<String>((arr.count) items)"
        case .arrayOfInt(let arr): return "Array<Int>((arr.count) items)"
        case .void: return "Void"
        }
    }
}

// 整个依赖图
struct DependencyGraph: Codable, Sendable {
    var nodes: [NodeID: GraphNode]
    var adjacencyList: [NodeID: Set<NodeID>] // 存储每个节点的所有直接依赖

    init(nodes: [GraphNode]) {
        self.nodes = [:]
        self.adjacencyList = [:]
        for node in nodes {
            self.nodes[node.id] = node
            self.adjacencyList[node.id] = node.dependencies
        }
    }

    // 辅助方法:获取一个节点的所有前驱节点(直接依赖)
    func getDependencies(for nodeID: NodeID) -> Set<NodeID> {
        adjacencyList[nodeID] ?? []
    }

    // 辅助方法:获取图中所有没有依赖的起始节点
    func getSourceNodes() -> Set<NodeID> {
        var allDependentNodes: Set<NodeID> = []
        for (_, dependencies) in adjacencyList {
            allDependentNodes.formUnion(dependencies)
        }
        let allNodes = Set(nodes.keys)
        return allNodes.subtracting(allDependentNodes)
    }

    // 简单的拓扑排序(用于调度参考,实际并发处理中可能不需要严格的线性排序)
    func topologicalSort() -> [NodeID]? {
        var visited: Set<NodeID> = []
        var visiting: Set<NodeID> = [] // 用于检测循环依赖
        var sortedNodes: [NodeID] = []

        func dfs(nodeID: NodeID) -> Bool {
            visiting.insert(nodeID)

            if let dependencies = adjacencyList[nodeID] {
                for depID in dependencies {
                    if visiting.contains(depID) {
                        return false // 发现循环依赖
                    }
                    if !visited.contains(depID) {
                        if !dfs(nodeID: depID) { return false }
                    }
                }
            }
            visiting.remove(nodeID)
            visited.insert(nodeID)
            sortedNodes.append(nodeID) // 后序遍历添加,结果是逆序拓扑排序
            return true
        }

        for nodeID in nodes.keys {
            if !visited.contains(nodeID) {
                if !dfs(nodeID: nodeID) { return nil } // 存在循环依赖
            }
        }
        return sortedNodes.reversed() // 返回正序拓扑排序
    }
}

关键点:

  • NodeID: 使用 String 作为节点ID,方便调试和可读性。
  • GraphOperation: 定义了节点可以执行的不同类型的操作,这使得我们的图处理器具有通用性。
  • GraphNode: 包含ID、名称、具体操作和依赖节点集。Sendable 协议确保在并发环境中安全传递。
  • NodeResult: 节点计算完成后返回的结果类型。
  • DependencyGraph: 存储所有节点及其邻接列表,方便查询依赖关系。拓扑排序可以帮助我们理解依赖顺序,但在并发执行中,我们更多地依赖于 await 来强制执行依赖。

4. 核心策略:依赖管理与并发执行下的流合并

要实现图形结构中的细粒度实时推送,我们需要一个能够管理节点状态、协调并发执行并合并所有节点输出流的组件。Actor 是实现这一目标的理想选择,因为它提供了隔离的、同步的状态访问,从而避免了并发修改问题。

我们将创建一个 GraphProcessor Actor,它将负责:

  1. 启动和管理图中所有节点的计算任务。
  2. 跟踪每个节点的完成状态和结果。
  3. 确保节点只在其所有依赖项完成后才开始计算。
  4. 将所有节点产生的“令牌”合并到一个统一的 AsyncStream 中,并推送给外部消费者。
  5. 处理错误和取消。

为了提供更丰富的实时反馈,我们定义一个 GraphStreamEvent 类型,它不仅包含计算的“令牌”,还包含节点的状态变化信息。

// 流推送的事件类型,包含节点ID和具体内容
enum GraphStreamEvent: Sendable {
    case graphStarted(graphID: String)
    case nodeStarted(nodeID: NodeID, name: String)
    case nodeProgress(nodeID: NodeID, token: String) // 细粒度令牌
    case nodeFinished(nodeID: NodeID, result: NodeResult)
    case nodeFailed(nodeID: NodeID, error: String)
    case graphFinished(graphID: String)
    case graphFailed(graphID: String, error: String)
}

// 节点计算的内部状态和结果
struct NodeComputationState: Sendable {
    var status: NodeStatus
    var result: Result<NodeResult, Error>?

    enum NodeStatus: Sendable {
        case pending
        case running
        case completed
        case failed
        case cancelled
    }
}

actor GraphProcessor {
    let graph: DependencyGraph
    private var nodeStates: [NodeID: NodeComputationState] // 存储每个节点的计算状态
    private var nodeResults: [NodeID: Result<NodeResult, Error>] // 存储已完成节点的最终结果

    // 用于管理主 AsyncStream 的 continuation
    private var streamContinuation: AsyncStream<GraphStreamEvent>.Continuation?

    init(graph: DependencyGraph) {
        self.graph = graph
        self.nodeStates = [:]
        self.nodeResults = [:]
        for nodeID in graph.nodes.keys {
            nodeStates[nodeID] = NodeComputationState(status: .pending)
        }
    }

    // 主处理方法,返回一个 AsyncStream<GraphStreamEvent>
    func processGraph() -> AsyncStream<GraphStreamEvent> {
        AsyncStream { continuation in
            self.streamContinuation = continuation

            // 启动一个 Task 来执行实际的图处理逻辑
            Task { [weak self] in
                guard let self = self else { return }

                self.streamContinuation?.yield(.graphStarted(graphID: "main_graph"))

                do {
                    // 使用 TaskGroup 来并发处理节点
                    // TaskGroup 会等待所有子 Task 完成或取消
                    try await withTaskGroup(of: Void.self) { group in
                        // 为每个节点创建一个 Task
                        for nodeID in self.graph.nodes.keys {
                            group.addTask {
                                // 每个节点的处理逻辑在一个独立的 Task 中执行
                                await self.processSingleNode(nodeID: nodeID)
                            }
                        }

                        // 等待所有节点任务完成
                        // 注意:这里不会阻塞,因为 TaskGroup 会在内部管理其子任务的生命周期
                        // 当所有子任务完成时,group.next() 将返回 nil
                        // 我们不需要显式地迭代 group.next(),因为我们只关心所有任务的完成
                    }
                    // 所有节点处理完毕,图处理完成
                    self.streamContinuation?.yield(.graphFinished(graphID: "main_graph"))
                    self.streamContinuation?.finish()
                } catch {
                    // 图处理过程中发生错误
                    self.streamContinuation?.yield(.graphFailed(graphID: "main_graph", error: error.localizedDescription))
                    self.streamContinuation?.finish(throwing: error)
                }
            }

            // 处理取消回调
            continuation.onTermination = { [weak self] @Sendable status in
                print("GraphProcessor: Stream terminated with status: (status)")
                // 可以在这里取消所有正在运行的节点 Task
                // Swift Concurrency 的 Task 默认支持协作式取消
                // 当外部取消 `processGraph` 返回的 `AsyncStream` 时,
                // `TaskGroup` 中的 `Task` 会收到取消信号
            }
        }
    }

    // 处理单个节点的逻辑
    private func processSingleNode(nodeID: NodeID) async {
        guard let node = graph.nodes[nodeID] else {
            self.streamContinuation?.yield(.nodeFailed(nodeID: nodeID, error: "Node (nodeID) not found."))
            return
        }

        // 检查是否已取消
        guard !Task.isCancelled else {
            await updateNodeState(nodeID: nodeID, status: .cancelled)
            self.streamContinuation?.yield(.nodeFailed(nodeID: nodeID, error: "Node (nodeID) cancelled."))
            return
        }

        // 等待所有依赖节点完成
        await waitForDependencies(for: node)

        // 如果在等待过程中被取消,则退出
        guard !Task.isCancelled else {
            await updateNodeState(nodeID: nodeID, status: .cancelled)
            self.streamContinuation?.yield(.nodeFailed(nodeID: nodeID, error: "Node (nodeID) cancelled during dependency wait."))
            return
        }

        // 更新节点状态为运行中
        await updateNodeState(nodeID: nodeID, status: .running)
        self.streamContinuation?.yield(.nodeStarted(nodeID: nodeID, name: node.name))

        do {
            let result: NodeResult
            switch node.operation {
            case .generateTokens(let text, let delayMs):
                result = await simulateTokenGeneration(nodeID: nodeID, text: text, delayPerTokenMs: delayMs)
            case .sum(let inputIDs):
                result = try await calculateSum(nodeID: nodeID, inputIDs: inputIDs)
            case .multiply(let inputIDs):
                result = try await calculateMultiply(nodeID: nodeID, inputIDs: inputIDs)
            case .identity(let inputID):
                result = try await identityOperation(nodeID: nodeID, inputID: inputID)
            case .custom(let name, _):
                // 模拟自定义操作
                print("Node (nodeID): Executing custom operation '(name)'")
                try await Task.sleep(for: .seconds(0.5))
                self.streamContinuation?.yield(.nodeProgress(nodeID: nodeID, token: "Custom operation '(name)' processed."))
                result = .string("Custom operation (name) completed.")
            }

            await updateNodeState(nodeID: nodeID, status: .completed, result: .success(result))
            self.streamContinuation?.yield(.nodeFinished(nodeID: nodeID, result: result))

        } catch {
            await updateNodeState(nodeID: nodeID, status: .failed, result: .failure(error))
            self.streamContinuation?.yield(.nodeFailed(nodeID: nodeID, error: error.localizedDescription))
        }
    }

    // MARK: - 辅助方法 (Actor Isolation)

    // 更新节点状态 (Actor隔离)
    private func updateNodeState(nodeID: NodeID, status: NodeComputationState.NodeStatus, result: Result<NodeResult, Error>? = nil) {
        var state = nodeStates[nodeID] ?? NodeComputationState(status: .pending)
        state.status = status
        if let result = result {
            state.result = result
            nodeResults[nodeID] = result // 存储最终结果
        }
        nodeStates[nodeID] = state
        // print("Node (nodeID) state updated to: (status)")
    }

    // 等待依赖节点完成 (Actor隔离)
    private func waitForDependencies(for node: GraphNode) async {
        guard !node.dependencies.isEmpty else { return }

        // 使用 withThrowingTaskGroup 来等待所有依赖完成
        // 如果任何一个依赖失败,则此组也会失败
        try? await withThrowingTaskGroup(of: NodeResult.self) { group in
            for depID in node.dependencies {
                group.addTask {
                    // 等待直到依赖节点完成
                    // 这里是一个自旋锁的简化形式,实际应用中可能需要更复杂的条件变量或 AsyncChannel
                    while await self.nodeStates[depID]?.status != .completed {
                        // 检查依赖是否失败或被取消
                        if await self.nodeStates[depID]?.status == .failed {
                            throw GraphProcessingError.dependencyFailed(depID)
                        }
                        if await self.nodeStates[depID]?.status == .cancelled {
                            throw GraphProcessingError.dependencyCancelled(depID)
                        }
                        try? await Task.sleep(for: .milliseconds(50)) // 短暂等待
                        guard !Task.isCancelled else { // 检查当前Task是否被取消
                            throw GraphProcessingError.taskCancelled("Node (node.id) cancelled while waiting for dependency (depID)")
                        }
                    }
                    // 依赖节点已完成,获取其结果
                    guard let result = await self.nodeResults[depID] else {
                        throw GraphProcessingError.missingDependencyResult(depID)
                    }
                    return try result.get() // 如果依赖结果是失败,这里会抛出
                }
            }
            // 收集所有依赖的结果,确保它们都成功完成
            // 实际上,我们只是等待它们完成,并不会在这里使用这些结果
            // 节点计算时会再次从 nodeResults 中获取
            for try await _ in group {}
        } catch let error as GraphProcessingError {
            // 记录依赖失败,但允许当前节点 Task 继续,以便它能标记自己为失败
            print("Node (node.id) dependency failed: (error.localizedDescription)")
            // 可以在这里设置当前节点的 Task 状态为失败,或者抛出错误
            throw error // 重新抛出,让 processSingleNode 捕获
        } catch {
            print("Node (node.id) dependency check encountered unknown error: (error.localizedDescription)")
            throw error
        }
    }

    // MARK: - 节点操作的具体实现

    private func simulateTokenGeneration(nodeID: NodeID, text: String, delayPerTokenMs: Int) async -> NodeResult {
        var generatedTokens: [String] = []
        for (index, char) in text.enumerated() {
            guard !Task.isCancelled else { // 检查任务是否被取消
                print("Node (nodeID): Token generation cancelled.")
                return .arrayOfString(generatedTokens) // 返回已生成的令牌
            }
            let token = String(char)
            self.streamContinuation?.yield(.nodeProgress(nodeID: nodeID, token: token))
            generatedTokens.append(token)
            try? await Task.sleep(for: .milliseconds(Double(delayPerTokenMs)))
        }
        return .string(text) // 最终结果可以是你想要的任何形式,这里是完整字符串
    }

    private func getIntegerInputs(inputIDs: [NodeID]) async throws -> [Int] {
        var inputs: [Int] = []
        for id in inputIDs {
            guard let result = await self.nodeResults[id] else {
                throw GraphProcessingError.missingDependencyResult(id)
            }
            switch try result.get() {
            case .integer(let val): inputs.append(val)
            default: throw GraphProcessingError.invalidInputType(id, "Expected Int")
            }
        }
        return inputs
    }

    private func calculateSum(nodeID: NodeID, inputIDs: [NodeID]) async throws -> NodeResult {
        print("Node (nodeID): Calculating sum...")
        try await Task.sleep(for: .milliseconds(200)) // 模拟计算延迟
        let inputs = try await getIntegerInputs(inputIDs: inputIDs)
        let sum = inputs.reduce(0, +)
        self.streamContinuation?.yield(.nodeProgress(nodeID: nodeID, token: "Sum calculated: (sum)"))
        return .integer(sum)
    }

    private func calculateMultiply(nodeID: NodeID, inputIDs: [NodeID]) async throws -> NodeResult {
        print("Node (nodeID): Calculating product...")
        try await Task.sleep(for: .milliseconds(200)) // 模拟计算延迟
        let inputs = try await getIntegerInputs(inputIDs: inputIDs)
        let product = inputs.reduce(1, *)
        self.streamContinuation?.yield(.nodeProgress(nodeID: nodeID, token: "Product calculated: (product)"))
        return .integer(product)
    }

    private func identityOperation(nodeID: NodeID, inputID: NodeID) async throws -> NodeResult {
        print("Node (nodeID): Performing identity operation...")
        try await Task.sleep(for: .milliseconds(100)) // 模拟延迟
        guard let result = await self.nodeResults[inputID] else {
            throw GraphProcessingError.missingDependencyResult(inputID)
        }
        self.streamContinuation?.yield(.nodeProgress(nodeID: nodeID, token: "Identity passed: (result.description)"))
        return try result.get()
    }
}

// 定义自定义错误类型
enum GraphProcessingError: LocalizedError, Sendable {
    case dependencyFailed(NodeID)
    case dependencyCancelled(NodeID)
    case missingDependencyResult(NodeID)
    case invalidInputType(NodeID, String)
    case taskCancelled(String)
    case generalError(String)

    var errorDescription: String? {
        switch self {
        case .dependencyFailed(let id): return "Dependency node '(id)' failed."
        case .dependencyCancelled(let id): return "Dependency node '(id)' was cancelled."
        case .missingDependencyResult(let id): return "Missing result for dependency node '(id)'."
        case .invalidInputType(let id, let expected): return "Node '(id)' received invalid input type. (expected)."
        case .taskCancelled(let msg): return "Task cancelled: (msg)"
        case .generalError(let msg): return "Graph processing error: (msg)"
        }
    }
}

代码解释与核心逻辑:

  1. GraphStreamEvent: 定义了不同粒度的事件类型。nodeProgress(nodeID:token:) 是实现“令牌级”推送的关键,它允许节点在计算过程中发送中间状态。
  2. GraphProcessor Actor:
    • 状态管理: nodeStates (记录每个节点的运行状态) 和 nodeResults (存储已完成节点的最终结果) 都被 actor 隔离保护,确保并发安全。
    • streamContinuation: 这是 AsyncStream 的核心。GraphProcessorprocessGraph() 方法中创建 AsyncStream 时捕获其 continuation。所有节点在生成 GraphStreamEvent 时,都通过这个 continuation 将事件 yield 到主 AsyncStream
    • processGraph(): 这是整个图处理的入口点。
      • 它返回一个 AsyncStream<GraphStreamEvent>,消费者将 for await 这个流。
      • 内部在一个新的 Task 中执行实际的图处理逻辑,以避免阻塞调用者。
      • 使用 withTaskGroup(of: Void.self) 来并发启动所有节点的 processSingleNode 任务。TaskGroup 会自动等待所有子任务完成,或者在外部取消时通知所有子任务。
    • processSingleNode(nodeID:):
      • 这是每个节点计算的独立任务。
      • 取消检查: 在关键位置(开始、等待依赖、模拟计算中)检查 Task.isCancelled,实现协作式取消。
      • waitForDependencies(for:): 这是处理依赖的核心。它会循环检查所有依赖节点的状态。
        • 由于 nodeStatesnodeResults 都是 GraphProcessor Actor 的内部状态,访问它们需要 await
        • 如果依赖节点尚未完成,它会短暂 sleep 然后重试,这是一种简单的自旋锁等待。
        • 如果依赖节点失败或被取消,当前节点也会抛出错误。
        • withThrowingTaskGroup 用于并行等待所有依赖,如果其中任何一个依赖失败,整个组都会失败。
      • 操作执行: 根据 node.operation 调用相应的模拟计算函数。这些函数会通过 self.streamContinuation?.yield() 实时推送 nodeProgress 事件。
    • updateNodeState(...): Actor 隔离的辅助方法,用于安全地更新 nodeStatesnodeResults
    • 模拟操作函数: simulateTokenGeneration, calculateSum, calculateMultiply 等函数模拟了实际的耗时操作,并通过 yield(.nodeProgress(...)) 实现了令牌级推送。
  3. 错误处理: 使用 do-catch 块捕获节点计算中的错误,并通过 yield(.nodeFailed(...)) 推送错误事件。streamContinuation?.finish(throwing: error) 用于终止整个流并传播错误。
  4. onTermination: 当消费者取消 AsyncStream 时,会触发 continuation.onTermination 闭包。我们可以在这里执行清理工作,例如取消所有正在运行的节点任务(尽管 TaskGroup 在外部任务取消时会自动传播取消)。

5. 精细化控制:错误处理与取消机制

在流式处理中,错误和取消是至关重要的方面,它们直接影响系统的健壮性和用户体验。

错误处理策略:

  1. 节点级错误:
    • 当单个节点计算失败时,我们通过 streamContinuation?.yield(.nodeFailed(nodeID: error:)) 发送一个特定的错误事件,通知消费者哪个节点出了问题。
    • 该节点的 nodeStates 会被标记为 .failed,其结果在 nodeResults 中存储为 .failure(error)
    • 依赖于此失败节点的其他节点在 waitForDependencies 中会检测到此状态,并相应地抛出 GraphProcessingError.dependencyFailed,从而终止其自身的计算。
  2. 图级错误:
    • 如果 processGraphTaskGroup 中有任何任务抛出未捕获的错误,或者在 processGraph 自身的主逻辑中发生错误,我们会通过 streamContinuation?.yield(.graphFailed(graphID: error:)) 发送一个图级错误事件。
    • 最终,streamContinuation?.finish(throwing: error) 会终止整个 AsyncStream,并向上游消费者传播错误。消费者在 for await 循环中可以通过 do-catch 捕获到这个错误。

取消机制:

  1. 协作式取消: Swift Concurrency 中的 Task 默认支持协作式取消。这意味着当一个 Task 被取消时,它不会立即停止执行,而是在下一个 await 点或显式检查 Task.isCancelled 时响应取消。
  2. Task.isCancelledprocessSingleNodesimulateTokenGeneration 等耗时操作中,我们定期检查 Task.isCancelled。一旦检测到取消,任务应立即停止当前工作并优雅退出。
  3. continuation.onTerminationAsyncStream 的消费者停止迭代(例如,因为它被取消或不再需要更多数据)时,onTermination 闭包会被调用。这提供了一个清理资源的机会。在我们的例子中,TaskGroup 已经很好地处理了子任务的取消传播,所以这里可能不需要额外的取消逻辑,但对于其他类型的资源(如文件句柄、网络连接),这里是释放它们的理想位置。
  4. withTaskCancellationHandler 对于更复杂的取消逻辑,例如需要在取消时执行特定清理操作,可以使用 withTaskCancellationHandler

表格:错误与取消事件处理

事件类型 触发时机 生产者行为 消费者感知
nodeFailed 单个节点计算中发生错误 yield(.nodeFailed(...)),更新 nodeStates 接收到 GraphStreamEvent.nodeFailed 事件
dependencyFailed 节点等待的依赖项失败 waitForDependencies 抛出错误,当前节点标记失败 依赖节点已报告 nodeFailed,当前节点也会报告 nodeFailed
graphFailed 整个图处理过程中发生致命错误 yield(.graphFailed(...))finish(throwing:) 接收到 GraphStreamEvent.graphFailedfor await 循环抛出错误
nodeProgress (取消) 节点在生成令牌时检测到 Task.isCancelled 停止 yield,返回部分结果或空结果 停止接收该节点的 nodeProgress 事件
streamTermination 消费者取消 AsyncStream continuation.onTermination 被调用,TaskGroup 子任务收到取消信号 for await 循环终止或抛出 CancellationError

6. 综合示例:一个完整的图形计算流推送系统

让我们构建一个简单的图,并演示如何使用 GraphProcessor 进行处理和消费。

// 1. 定义图节点
let node1 = GraphNode(id: "node_A", name: "Generate Tokens for Hello", operation: .generateTokens(text: "Hello", delayPerTokenMs: 100), dependencies: [])
let node2 = GraphNode(id: "node_B", name: "Generate Tokens for World", operation: .generateTokens(text: "World", delayPerTokenMs: 150), dependencies: [])
let node3 = GraphNode(id: "node_C", name: "Pass Hello Tokens", operation: .identity(inputID: node1.id), dependencies: [node1.id])
let node4 = GraphNode(id: "node_D", name: "Sum Node A & C Lengths", operation: .sum(inputIDs: [node1.id, node3.id]), dependencies: [node1.id, node3.id]) // 假设 sum 接受 string 并计算长度
let node5 = GraphNode(id: "node_E", name: "Multiply Node B & D Results", operation: .multiply(inputIDs: [node2.id, node4.id]), dependencies: [node2.id, node4.id]) // 假设 multiply 接受 string 并计算长度

// 2. 创建依赖图
let graph = DependencyGraph(nodes: [node1, node2, node3, node4, node5])

// 3. 创建 GraphProcessor
let processor = GraphProcessor(graph: graph)

// 4. 消费者 Task
Task {
    print("n--- 开始处理图 ---")
    do {
        for await event in processor.processGraph() {
            switch event {
            case .graphStarted(let graphID):
                print("[Graph] (graphID) 启动")
            case .nodeStarted(let nodeID, let name):
                print("  [Node] (nodeID) ((name)) 启动")
            case .nodeProgress(let nodeID, let token):
                print("    [Token] (nodeID): '(token)'")
            case .nodeFinished(let nodeID, let result):
                print("  [Node] (nodeID) 完成,结果: (result)")
            case .nodeFailed(let nodeID, let error):
                print("  [Node] (nodeID) 失败,错误: (error)")
            case .graphFinished(let graphID):
                print("[Graph] (graphID) 完成")
            case .graphFailed(let graphID, let error):
                print("[Graph] (graphID) 失败,错误: (error)")
            }
        }
        print("--- 图处理流已终止 ---")
    } catch {
        print("--- 图处理过程中发生致命错误: (error.localizedDescription) ---")
    }
}

// 示例:在一定时间后取消 Task (可选)
// Task {
//     try? await Task.sleep(for: .seconds(3))
//     print("n--- 尝试取消图处理 ---")
//     processor.streamContinuation?.finish(throwing: CancellationError()) // 外部取消
// }

运行输出示例 (部分):

--- 开始处理图 ---
[Graph] main_graph 启动
  [Node] node_A (Generate Tokens for Hello) 启动
  [Node] node_B (Generate Tokens for World) 启动
    [Token] node_A: 'H'
    [Token] node_B: 'W'
    [Token] node_A: 'e'
    [Token] node_B: 'o'
    [Token] node_A: 'l'
    [Token] node_B: 'r'
    [Token] node_A: 'l'
    [Token] node_B: 'l'
    [Token] node_A: 'o'
    [Token] node_B: 'd'
  [Node] node_A 完成,结果: String("Hello")
  [Node] node_B 完成,结果: String("World")
  [Node] node_C (Pass Hello Tokens) 启动
    [Token] node_C: 'Identity passed: String("Hello")'
  [Node] node_C 完成,结果: String("Hello")
  [Node] node_D (Sum Node A & C Lengths) 启动
Node node_D: Calculating sum...
    [Token] node_D: 'Sum calculated: 10'
  [Node] node_D 完成,结果: Int(10)
  [Node] node_E (Multiply Node B & D Results) 启动
Node node_E: Calculating product...
    [Token] node_E: 'Product calculated: 50'
  [Node] node_E 完成,结果: Int(50)
[Graph] main_graph 完成
--- 图处理流已终止 ---

这个输出清晰地展示了:

  • 图和节点启动的消息。
  • node_Anode_B 并发地生成令牌,它们的 nodeProgress 事件交错出现。
  • node_Cnode_D 在其依赖 node_Anode_B 完成后才启动。
  • 所有节点完成,最终图完成。

7. 性能优化与最佳实践

  1. 选择合适的 bufferingPolicy
    • 对于大多数实时推送场景,.buffered(_:) 是一个不错的选择,它提供背压机制,防止生产者速度过快导致内存溢出。
    • 如果数据丢失可以接受,且不希望生产者被阻塞,可以考虑 .dropping(_:).overflow(_:)
    • .unbounded 应该谨慎使用,只在你能保证消费者总能跟上生产者速度或数据量有限的情况下使用。
  2. Actor 的正确使用:
    • 将所有共享的可变状态(如 nodeStates, nodeResults)封装在 Actor 中,确保并发访问安全。
    • 避免在 Actor 内部执行长时间的同步操作,否则会阻塞 Actor 的并发性。耗时操作应转移到 Task 中,并在 await 时机回到 Actor。
  3. 细粒度 TaskTaskGroup
    • 将每个节点的计算封装为独立的 Task,并由 TaskGroup 管理,可以最大化并发。
    • TaskGroup 会自动处理子任务的生命周期和取消传播,简化了管理。
  4. 协作式取消:
    • 在长时间运行的计算和 await 调用之间,定期检查 Task.isCancelled
    • 确保你的自定义 yield 方法或模拟操作能响应取消。
  5. 减少不必要的 await
    • 尽管 await 是安全的,但频繁的 await 会带来上下文切换开销。在 Actor 内部,如果可以连续执行多步操作而无需等待外部资源,则应尽量减少 await self. 调用。
    • 例如,在 updateNodeState 这样的辅助函数中,它只修改 Actor 的隔离状态,无需 await 自身。
  6. 结果类型与 Sendable
    • 确保所有在并发任务之间传递的类型(如 GraphStreamEvent, NodeResult, GraphNode)都符合 Sendable 协议,以保证类型安全和数据隔离。
  7. 避免循环依赖:
    • 虽然我们的 waitForDependencies 机制可以检测到循环依赖(导致死锁或无限等待),但最佳实践是在图构建阶段就通过拓扑排序或其他验证手段避免它们。
  8. 考虑 AsyncChannel (Swift Async Algorithms):
    • 对于更复杂的流合并场景,或者当你有多个独立的生产者需要向一个统一的流发送数据时,swift-async-algorithms 库中的 AsyncChannel 是一个非常有用的工具。它可以作为一个线程安全的、支持背压的通道,用于在不同 Task 之间传递数据。
    • 在我们的 GraphProcessor 例子中,通过直接传递 streamContinuation 已经实现了类似的效果,但 AsyncChannel 提供了更抽象和通用的解决方案。

8. 展望与总结

通过深入理解 AsyncStream 的工作原理,并将其与 Swift Concurrency 的 ActorTaskGroup 结合,我们成功地在复杂的图形结构中构建了一个细粒度的实时结果推送系统。这不仅大幅提升了用户体验,使得用户能够即时看到每个计算步骤的进展和中间“令牌”输出,还增强了系统的可观测性和调试能力。这种模式在需要处理复杂依赖、并行计算并提供即时反馈的场景中,如数据管道、AI 推理工作流、持续集成/部署流程等,都具有广泛的应用前景。

这种设计模式使得我们能够创建高度响应、可伸缩且易于维护的异步系统,充分利用现代多核处理器的能力,同时保持代码的清晰和并发安全。

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注