手写高性能向量索引:基于 Go 的 HNSW 算法实现全过程与性能压测

手写高性能向量索引:基于 Go 的 HNSW 算法实现全过程与性能压测

各位同仁,大家下午好!

今天,我们将深入探讨一个在人工智能和大数据领域日益重要的主题:高性能向量索引。随着深度学习模型的普及,我们每天都在生成和处理海量的向量数据,例如图片、文本或推荐系统中的嵌入向量。如何在这些高维向量空间中进行高效的相似性搜索,即找到与给定查询向量最相似的 K 个向量,成为了一个核心挑战。暴力搜索在数据量庞大时显然不可行,因此,近似最近邻(Approximate Nearest Neighbor, ANN)搜索算法应运而生。

在众多 ANN 算法中,Hierarchical Navigable Small World (HNSW) 算法凭借其卓越的召回率和查询速度,成为了业界的明星。本次讲座,我将带领大家从零开始,使用 Go 语言亲手实现一个 HNSW 向量索引库,并对其进行详尽的性能压测与分析。选择 Go 语言,是因为它天生具备优秀的并发能力、简洁的语法和接近 C/C++ 的执行效率,非常适合构建高性能的基础设施服务。

一、向量索引的基石:为何需要 HNSW?

在进入 HNSW 的具体实现之前,我们先来回顾一下向量索引的必要性以及 HNSW 相较于传统方法的优势。

1.1 向量搜索的挑战

假设我们有一个包含数百万甚至数十亿个高维向量的数据库,现在需要找出与用户查询向量最相似的 K 个向量。

  • 暴力搜索 (Brute Force): 最直观的方法是计算查询向量与数据库中每个向量的距离,然后排序找出 K 个最近的。其时间复杂度为 $O(N cdot D)$,其中 $N$ 是向量数量,$D$ 是向量维度。当 $N$ 很大时,这种方法是不可接受的。
  • 维度灾难 (Curse of Dimensionality): 在高维空间中,点之间的距离会变得越来越稀疏,使得许多传统的索引结构(如 KD 树、R 树)效果大打折扣。
  • 存储与计算成本: 存储高维向量本身就需要大量空间,而频繁的距离计算更是 CPU 密集型操作。

1.2 传统 ANN 方法的局限性

为了应对暴力搜索的挑战,研究者们提出了多种 ANN 算法:

  • 基于树的方法 (Tree-based): 例如 KD 树、VP 树。它们将高维空间递归地划分为更小的区域。在高维空间中性能急剧下降,因为数据分布稀疏,树的深度和分支因子会使得搜索效率降低。
  • 局部敏感哈希 (Locality Sensitive Hashing, LSH): 通过哈希函数将相似的向量映射到相同的哈希桶中。优点是查询速度快,但召回率通常较低,且哈希函数的选择和参数调整复杂。
  • 基于图的方法 (Graph-based): 核心思想是构建一个图,其中节点代表向量,边代表向量之间的相似性。查询时,从图的某个入口点开始,沿着边贪婪地探索,直到找到 K 个最近邻。HNSW 就属于此类。
  • 量化方法 (Quantization-based): 例如乘积量化 (Product Quantization, PQ)。通过将向量分解为子向量并进行量化来压缩存储,减少距离计算量。通常需要权衡精度和速度。

这些方法各有优缺点,但在追求高召回率和低延迟的场景中,HNSW 逐渐脱颖而出。

1.3 HNSW 的崛起与优势

HNSW 是 Navigable Small World (NSW) 算法的改进版。NSW 算法构建了一个图,使得任意两个节点之间可以通过相对较少的“跳跃”到达(小世界特性)。HNSW 在此基础上引入了“分层”思想,构建了一个多层图结构,每一层都是一个 NSW 图,但上层图的连接更稀疏,覆盖范围更广,用于快速定位到查询向量的大致区域;下层图的连接更密集,用于精确搜索。

HNSW 的主要优势体现在:

  • 高召回率: 通过分层搜索和启发式邻居选择,HNSW 能够在保证较高精度的前提下进行近似搜索。
  • 低查询延迟: 在高层快速定位,在低层精确搜索,大大减少了距离计算次数。
  • 可扩展性: 能够有效地处理大规模数据集。
  • 增量更新友好: 新增向量可以在不重建整个索引的情况下插入。
  • 内存效率相对较高: 相较于一些树结构,HNSW 的图结构可以更紧凑。

正是这些优势,使得 HNSW 成为构建高性能向量搜索引擎的首选。

二、HNSW 核心原理深度解析

理解 HNSW 的核心在于把握其分层图结构、构建过程中的邻居选择策略以及查询时的贪婪遍历。

2.1 Navigable Small World (NSW) 基础

想象一个社交网络,你可能通过几个共同朋友就能认识一个陌生人。这就是“小世界”现象。在 NSW 算法中,我们为每个向量构建一个邻居列表,使得图中的任意两点之间都能通过较短的路径连接。

  • 图的构建: 随机选择一些初始节点作为图的入口点。当插入一个新节点时,从入口点开始进行贪婪搜索,找到其 K 个最近邻,并将新节点与这些邻居连接起来。
  • 贪婪搜索: 给定查询向量,从入口点开始,每次选择与查询向量距离最近的邻居作为下一步的搜索点,直到无法找到更近的邻居为止。
  • 局限性: NSW 图是扁平的,这意味着在大型图中,贪婪搜索可能会在局部最优解中陷入困境,导致搜索路径过长,效率不高。

2.2 Hierarchical Navigable Small World (HNSW) 改进

HNSW 算法通过引入分层结构解决了 NSW 的局限性,其核心思想类似于跳表 (Skip List):

  • 多层图结构:

    • 图被组织成多层 $L_0, L1, dots, L{max}$。
    • $L_0$ 是包含所有向量的完整图,连接最密集,用于精确搜索。
    • $L1, dots, L{max}$ 是稀疏的图,每层只包含一部分向量。层数越高,图越稀疏,节点之间的距离越远,但覆盖范围越大。
    • 每个向量节点 $v$ 都会被随机分配一个最大层数 $l_v$,表示该节点存在于 $L0, dots, L{l_v}$ 这些层中。层数通常服从指数分布,使得高层节点数量稀少。
    • 上层的节点是下层节点的子集。
  • 构建过程 (Add 操作):

    1. 随机层数分配: 当插入一个新向量 $q$ 时,首先根据指数分布随机为其分配一个最大层数 $l_q$。这意味着 $q$ 将被添加到 $L0, dots, L{l_q}$ 这些层中。
    2. 入口点选择: 如果索引为空,新节点成为入口点。否则,从当前索引的全局入口点 $ep$ 开始,从最高层 $L{max}$ 向下到 $L{l_q+1}$ 进行贪婪搜索。在每一层,找到与 $q$ 距离最近的节点 $w$,并将其作为下一层的搜索入口。这一步的目的是快速定位到 $q$ 所在的大致区域。
    3. 逐层插入与邻居选择: 从 $L_{l_q}$ 层开始,向下到 $L_0$ 层,将 $q$ 插入到每一层中。
      • 在每一层 $L_l$:
        • 从上一步找到的入口点 $w$ 开始,执行 searchLayer 操作,找出 $efConstruction$ 个与 $q$ 距离最近的候选邻居。efConstruction 是一个控制构建时搜索范围的参数,值越大,构建时间越长,但召回率越高。
        • 从这些候选邻居中,使用启发式算法(如 selectNeighbors)选择 $M$ 个最佳邻居与 $q$ 连接。启发式算法旨在选择那些“最具信息量”的邻居,避免连接到过于集中的邻居,从而保持图的连通性和导航效率。
        • 同时,如果某个邻居因为与 $q$ 连接而导致其邻居数量超过了限制($M{max}$ 或 $M{max_0}$),则需要对该邻居的连接进行剪枝,重新选择最佳的 $M$ (或 $M_0$) 个邻居。
    4. 更新入口点: 如果新节点的层数 $lq$ 大于当前索引的全局最大层数 $L{max}$,则更新 $L_{max}$ 并将新节点 $q$ 设置为新的全局入口点。
  • 查询过程 (Search 操作):

    1. 入口点遍历: 从当前索引的全局入口点 $ep$ 开始,从最高层 $L_{max}$ 向下到 $L_1$ (如果查询目标是 $L_0$) 进行贪婪搜索。在每一层 $L_l$,找到与查询向量 $q$ 距离最近的节点 $w$,并将其作为下一层的搜索入口。这一步是快速收敛到查询向量的近似位置。
    2. 精细化搜索: 在 $L_0$ 层(或查询目标层),从上一步找到的入口点 $w$ 开始,执行 searchLayer 操作。
      • searchLayer 使用一个优先级队列(min-heap)来维护待探索的候选节点,以及另一个优先级队列(max-heap)来维护当前已找到的 $efSearch$ 个最佳结果。efSearch 是一个控制查询时搜索范围的参数,值越大,查询时间越长,但召回率越高。
      • 每次从候选队列中取出距离最近的节点进行探索,并将其邻居添加到候选队列(如果未访问过)。同时,将这些邻居与查询向量的距离加入到结果队列,并始终保持结果队列中只有 $efSearch$ 个最佳结果。
    3. 返回结果: 最终,从结果队列中取出距离最近的 K 个节点作为查询结果。

2.3 核心参数

HNSW 的性能和召回率高度依赖于几个核心参数:

参数名称 描述 影响 典型值
M 每个节点在更高层($L_l, l>0$)连接的最大邻居数量。 增加 M 会提高召回率和构建时间,但查询时间不会显著增加。 16
M_max 内部参数,通常等于 M 16
M_max_0 每个节点在 $L_0$ 层连接的最大邻居数量。通常为 2 * M 增加 M_max_0 会提高 $L_0$ 层的连通性,提升召回率,但会增加构建时间和内存占用。 32
efConstruction 构建索引时,搜索邻居的动态列表大小。 影响构建时间、召回率和索引质量。值越大,构建时间越长,但召回率越高。 100-500
efSearch 查询索引时,搜索 K 个最近邻的动态列表大小。 影响查询时间、召回率。值越大,查询时间越长,但召回率越高。 50-200
maxLayers 索引的最大层数。通常根据数据集大小自动计算,如 1/log(M) 的乘数与数据集大小的对数相关。 影响索引的深度。 自动计算
multiplier 用于计算节点最大层数的指数分布参数,通常为 1/log(M) 决定了高层节点的稀疏程度。 1/log(M)
heuristic 是否使用启发式邻居选择算法。 启发式算法通过避免选择过于接近的邻居来提高图的连通性和搜索效率,但会略微增加构建的复杂性。 True

三、Go 语言实现细节与数据结构设计

现在,让我们开始着手使用 Go 语言实现 HNSW。一个高质量的 Go HNSW 库需要精心设计的数据结构和并发策略。

3.1 核心数据结构

package hnsw

import (
    "math"
    "sync"
    "time"
    "fmt"
    "container/heap"
    "sort"
    "math/rand"
)

// Vector interface defines the contract for vector types.
// This allows for different underlying vector implementations (e.g., []float32, []float64).
type Vector interface {
    Dim() int            // Dimension of the vector
    Get(int) float32     // Get the value at a specific index
}

// Float32Vector is a concrete implementation of Vector using []float32.
type Float32Vector []float32

func (v Float32Vector) Dim() int { return len(v) }
func (v Float32Vector) Get(i int) float32 { return v[i] }

// DistanceCalculator defines an interface for calculating distances between vectors.
type DistanceCalculator interface {
    Calculate(v1, v2 Vector) float32
}

// EuclideanDistance implements Euclidean distance calculation.
type EuclideanDistance struct{}

func (e EuclideanDistance) Calculate(v1, v2 Vector) float32 {
    if v1.Dim() != v2.Dim() {
        panic("vector dimensions mismatch")
    }
    var sum float32
    for i := 0; i < v1.Dim(); i++ {
        diff := v1.Get(i) - v2.Get(i)
        sum += diff * diff
    }
    return sum // Return squared Euclidean distance for performance, sqrt not needed for comparison
}

// CosineDistance implements Cosine distance calculation.
type CosineDistance struct{}

func (c CosineDistance) Calculate(v1, v2 Vector) float32 {
    if v1.Dim() != v2.Dim() {
        panic("vector dimensions mismatch")
    }
    var dotProduct float32
    var normA float32
    var normB float32
    for i := 0; i < v1.Dim(); i++ {
        dotProduct += v1.Get(i) * v2.Get(i)
        normA += v1.Get(i) * v1.Get(i)
        normB += v2.Get(i) * v2.Get(i)
    }
    if normA == 0 || normB == 0 {
        return 1.0 // Or handle as an error, or return 0 if both are zero vectors
    }
    // Cosine similarity ranges from -1 to 1. Distance is 1 - similarity.
    return 1.0 - (dotProduct / (float32(math.Sqrt(float64(normA))) * float32(math.Sqrt(float64(normB)))))
}

// Node represents a single data point in the HNSW graph.
type Node struct {
    ID        uint64 // Unique identifier for the node
    Vector    Vector
    MaxLayer  int           // The maximum layer this node participates in
    Neighbors [][]uint64    // Neighbors at each layer [layer_idx][neighbor_id]
    mu        sync.RWMutex  // Protects Neighbors for concurrent updates
}

// NewNode creates a new Node instance.
func NewNode(id uint64, vector Vector, maxLayer int, M_max int) *Node {
    n := &Node{
        ID:        id,
        Vector:    vector,
        MaxLayer:  maxLayer,
        Neighbors: make([][]uint64, maxLayer+1),
    }
    for i := 0; i <= maxLayer; i++ {
        // Pre-allocate capacity for efficiency, M_max_0 for layer 0, M_max for others
        if i == 0 {
            n.Neighbors[i] = make([]uint64, 0, M_max*2) // M_max_0 is usually 2*M
        } else {
            n.Neighbors[i] = make([]uint64, 0, M_max)
        }
    }
    return n
}

// Config holds the configuration parameters for the HNSW index.
type Config struct {
    M             int     // Number of neighbors to connect to during graph construction (higher layers)
    MaxM          int     // Max neighbors per node at higher layers (usually M)
    MaxM0         int     // Max neighbors per node at layer 0 (usually 2*M)
    EfConstruction int     // Size of the dynamic list for nearest neighbors during construction
    EfSearch      int     // Size of the dynamic list for nearest neighbors during search
    Heuristic     bool    // Whether to use the heuristic neighbor selection algorithm
    DistanceType  string  // "euclidean" or "cosine"
    MaxLayers     int     // Max possible layers (optional, usually dynamic based on data size)
    Multiplier    float64 // Layer multiplier (usually 1/log(M))
}

// DefaultConfig provides sensible defaults for HNSW.
func DefaultConfig() Config {
    m := 16
    return Config{
        M:             m,
        MaxM:          m,
        MaxM0:         m * 2,
        EfConstruction: 200,
        EfSearch:      50,
        Heuristic:     true,
        DistanceType:  "euclidean",
        Multiplier:    1 / math.Log(float64(m)), // Typical value 1/ln(M)
    }
}

// HNSWIndex represents the Hierarchical Navigable Small World index.
type HNSWIndex struct {
    mu          sync.RWMutex // Protects overall index structure (entryPointID, maxLayer, nodes, nextID)
    config      Config
    distCalc    DistanceCalculator
    nodes       map[uint64]*Node // Map from node ID to Node object
    entryPointID uint64         // ID of the current entry point for search
    maxLayer    int            // Max layer of the entire graph
    hnswRand    *hnswRand      // Thread-safe random number generator
    nextID      uint64         // For assigning new node IDs
    initialised bool           // Flag to indicate if the index has been initialised with at least one node
}

// NewHNSWIndex creates and initializes a new HNSW index.
func NewHNSWIndex(cfg Config) (*HNSWIndex, error) {
    if cfg.M <= 0 || cfg.EfConstruction <= 0 || cfg.EfSearch <= 0 {
        return nil, fmt.Errorf("HNSW config parameters M, efConstruction, efSearch must be positive")
    }
    if cfg.MaxM == 0 {
        cfg.MaxM = cfg.M
    }
    if cfg.MaxM0 == 0 {
        cfg.MaxM0 = cfg.M * 2
    }
    if cfg.Multiplier == 0 {
        cfg.Multiplier = 1 / math.Log(float64(cfg.M))
    }

    var distCalc DistanceCalculator
    switch cfg.DistanceType {
    case "euclidean":
        distCalc = EuclideanDistance{}
    case "cosine":
        distCalc = CosineDistance{}
    default:
        return nil, fmt.Errorf("unsupported distance type: %s", cfg.DistanceType)
    }

    return &HNSWIndex{
        config:      cfg,
        distCalc:    distCalc,
        nodes:       make(map[uint64]*Node),
        hnswRand:    newHnswRand(),
        nextID:      1, // Start IDs from 1, 0 can represent uninitialised entry point
        initialised: false,
    }, nil
}

代码解析:

  • Vector 接口:提供了通用的向量操作,方便扩展不同类型的向量。
  • Float32Vector:一个简单的 []float32 实现。
  • DistanceCalculator 接口:定义了距离计算方法,支持欧氏距离和余弦距离,方便切换。
  • Node 结构体:存储向量数据、最大层数以及各层的邻居列表。sync.RWMutex 用于保护 Neighbors 列表,实现并发安全的读写。
  • Config 结构体:封装了 HNSW 的所有可配置参数,DefaultConfig 提供了一组合理的默认值。
  • HNSWIndex 结构体:HNSW 索引的主结构,包含配置、距离计算器、所有节点的映射、入口点 ID、当前最大层数、随机数生成器和下一个可用 ID。sync.RWMutex 保护索引的全局状态。

3.2 优先级队列 (Min-Heap / Max-Heap)

Go 语言标准库 container/heap 提供了一个通用的堆接口。我们需要实现 heap.Interface 才能使用它。在 HNSW 中,我们既需要 min-heap 来探索距离最近的候选节点,也需要 max-heap 来维护 K 个最佳结果(这样最大的距离总在堆顶,方便替换)。

// Item represents an item in the priority queue.
type Item struct {
    ID       uint64
    Distance float32
    index    int // The index is needed by update and is maintained by the heap.Interface methods.
}

// A priorityQueue implements heap.Interface and holds Items.
type priorityQueue []*Item

func (pq priorityQueue) Len() int { return len(pq) }

func (pq priorityQueue) Swap(i, j int) {
    pq[i], pq[j] = pq[j], pq[i]
    pq[i].index = i
    pq[j].index = j
}

func (pq *priorityQueue) Push(x interface{}) {
    n := len(*pq)
    item := x.(*Item)
    item.index = n
    *pq = append(*pq, item)
}

func (pq *priorityQueue) Pop() interface{} {
    old := *pq
    n := len(old)
    item := old[n-1]
    old[n-1] = nil // avoid memory leak
    item.index = -1 // for safety
    *pq = old[0 : n-1]
    return item
}

// minHeap is a min-priority queue for candidate nodes (smallest distance is highest priority).
type minHeap struct {
    priorityQueue
}

func NewMinHeap() *minHeap {
    return &minHeap{make(priorityQueue, 0)}
}

func (pq minHeap) Less(i, j int) bool {
    // We want a min-heap for distances (smaller distance is higher priority).
    return pq[i].Distance < pq[j].Distance
}

func (pq *minHeap) Peek() *Item {
    if pq.Len() == 0 {
        return nil
    }
    return (*pq).priorityQueue[0]
}

// maxHeap is a max-priority queue for result nodes (largest distance is highest priority, to be replaced).
type maxHeap struct {
    priorityQueue
}

func NewMaxHeap() *maxHeap {
    return &maxHeap{make(priorityQueue, 0)}
}

func (pq maxHeap) Less(i, j int) bool {
    // For a max-heap, we want the *largest* distance to be considered "less"
    // so it bubbles up to be replaced when full.
    return pq[i].Distance > pq[j].Distance
}

func (pq *maxHeap) Peek() *Item {
    if pq.Len() == 0 {
        return nil
    }
    return (*pq).priorityQueue[0]
}

3.3 线程安全的随机数生成器

HNSW 算法在为新节点分配层数时需要随机数。Go 的 math/rand 包默认不是并发安全的,因此我们需要封装一个线程安全的随机数生成器。

// hnswRand provides a thread-safe random number generator.
type hnswRand struct {
    src rand.Source
    mu  sync.Mutex
}

func newHnswRand() *hnswRand {
    return &hnswRand{
        src: rand.NewSource(time.Now().UnixNano()),
    }
}

// ExpRandom generates a random layer level for a new node based on an exponential distribution.
func (r *hnswRand) ExpRandom(mL float64) int {
    r.mu.Lock()
    defer r.mu.Unlock()
    return int(math.Floor(-math.Log(r.src.Float64()) * mL))
}

3.4 核心方法:Add 和 Search

接下来是 HNSW 索引的核心逻辑:Add (构建) 和 Search (查询)。

3.4.1 Add 方法 (插入新向量)

Add 方法将一个新向量插入到 HNSW 索引中。为了简化代码,这里的 selectNeighbors 仅实现了最简单的距离排序选择。在实际应用中,应实现更复杂的启发式选择(如 "extend-candidates" 策略),以确保图的质量。

// Add inserts a new vector into the HNSW index.
func (h *HNSWIndex) Add(vector Vector) (uint64, error) {
    // Assign a new ID and protect global index state
    h.mu.Lock()
    id := h.nextID
    h.nextID++
    currentEntryPointID := h.entryPointID
    currentMaxLayer := h.maxLayer
    h.mu.Unlock()

    // Determine the max layer for the new node
    newNodeMaxLayer := h.hnswRand.ExpRandom(h.config.Multiplier)
    if h.config.MaxLayers > 0 && newNodeMaxLayer > h.config.MaxLayers {
        newNodeMaxLayer = h.config.MaxLayers
    }

    newNode := NewNode(id, vector, newNodeMaxLayer, h.config.MaxM)
    h.nodes[id] = newNode // Add new node to map early for neighbor lookup

    // Handle the very first node
    if !h.initialised {
        h.mu.Lock()
        h.entryPointID = id
        h.maxLayer = newNodeMaxLayer
        h.initialised = true
        h.mu.Unlock()
        return id, nil // First node, no connections yet
    }

    // Step 1: Find a good entry point for the new node
    // Traverse down from the current maxLayer of the graph to newNodeMaxLayer + 1
    // or 0 if newNodeMaxLayer is already the highest
    ep := h.nodes[currentEntryPointID]
    var entryPointForSearch *Node
    var entryPointForSearchDist float32

    if newNodeMaxLayer < currentMaxLayer {
        // Find the nearest neighbor to the new node at layers higher than its own max layer
        currObj := ep // Start from the global entry point
        currDist := h.distCalc.Calculate(vector, currObj.Vector)

        for L := currentMaxLayer; L > newNodeMaxLayer; L-- {
            changed := true
            for changed {
                changed = false
                currObj.mu.RLock() // Read lock on current search node
                neighbors := currObj.Neighbors[L]
                currObj.mu.RUnlock()

                for _, neighborID := range neighbors {
                    neighbor := h.nodes[neighborID] // Assuming neighbor exists
                    dist := h.distCalc.Calculate(vector, neighbor.Vector)
                    if dist < currDist {
                        currDist = dist
                        currObj = neighbor
                        changed = true
                    }
                }
            }
        }
        entryPointForSearch = currObj
        entryPointForSearchDist = currDist
    } else {
        // New node is at or above current max layer, start search from global entry point
        entryPointForSearch = ep
        entryPointForSearchDist = h.distCalc.Calculate(vector, ep.Vector)
    }

    // Step 2: Insert the new node into layers from newNodeMaxLayer down to 0
    for L := newNodeMaxLayer; L >= 0; L-- {
        // Find efConstruction nearest neighbors in layer L
        // The search starts from entryPointForSearch, but for layers higher than newNode's max layer,
        // we already found the 'currObj' from the previous loop.
        // For L <= newNodeMaxLayer, we reuse the result from searchLayer.
        var candidates []uint64
        if L == newNodeMaxLayer && entryPointForSearch != nil {
            // For the first layer of insertion, use the entry point found earlier
            candidates = h.searchLayer(vector, entryPointForSearch, L, h.config.EfConstruction)
        } else if L < newNodeMaxLayer && entryPointForSearch != nil {
            // For subsequent layers, use the closest node from the previous layer's result as entry
            // (simplified; a more robust approach would use efConstruction candidates from higher layers)
            // Here, we just pick the closest from the previous layer's candidates to start.
            if len(candidates) > 0 {
                bestCandID := candidates[0] // Assuming searchLayer returns sorted by distance
                entryPointForSearch = h.nodes[bestCandID]
            }
            candidates = h.searchLayer(vector, entryPointForSearch, L, h.config.EfConstruction)
        } else { // No entry point, likely the first node or an error
            return 0, fmt.Errorf("could not find an entry point for layer %d", L)
        }

        // Select M (or M0) neighbors from the candidates using heuristic
        selectedNeighborsIDs := h.selectNeighbors(newNode, candidates, L, h.config.Heuristic)

        // Connect new node to selected neighbors
        newNode.mu.Lock()
        newNode.Neighbors[L] = append(newNode.Neighbors[L], selectedNeighborsIDs...)
        newNode.mu.Unlock()

        // Connect selected neighbors back to new node (bidirectional) and potentially prune
        maxConnections := h.config.MaxM
        if L == 0 {
            maxConnections = h.config.MaxM0
        }

        for _, neighborID := range selectedNeighborsIDs {
            neighborNode := h.nodes[neighborID]
            neighborNode.mu.Lock()
            neighborNode.Neighbors[L] = append(neighborNode.Neighbors[L], id)
            if len(neighborNode.Neighbors[L]) > maxConnections {
                // Prune if over capacity
                neighborNode.Neighbors[L] = h.selectNeighbors(neighborNode, neighborNode.Neighbors[L], L, h.config.Heuristic)
            }
            neighborNode.mu.Unlock()
        }

        // Ensure new node's connections are also pruned if over capacity (after adding neighbors' connections)
        newNode.mu.Lock()
        if len(newNode.Neighbors[L]) > maxConnections {
            newNode.Neighbors[L] = h.selectNeighbors(newNode, newNode.Neighbors[L], L, h.config.Heuristic)
        }
        newNode.mu.Unlock()
    }

    // Step 3: Update global entry point if new node is on a higher layer
    if newNodeMaxLayer > currentMaxLayer {
        h.mu.Lock()
        h.entryPointID = id
        h.maxLayer = newNodeMaxLayer
        h.mu.Unlock()
    }

    return id, nil
}

// selectNeighbors selects M (or M0) neighbors from a list of candidates using a heuristic.
// The heuristic aims to select diverse neighbors, avoiding overly clustered connections.
func (h *HNSWIndex) selectNeighbors(targetNode *Node, candidates []uint64, layer int, heuristic bool) []uint64 {
    maxConnections := h.config.MaxM
    if layer == 0 {
        maxConnections = h.config.MaxM0
    }

    if len(candidates) == 0 {
        return []uint64{}
    }

    if !heuristic {
        // Non-heuristic: simply take the closest 'maxConnections' neighbors.
        type tempNeighbor struct {
            ID       uint64
            Distance float32
        }
        tempNeighbors := make([]tempNeighbor, len(candidates))
        for i, candID := range candidates {
            tempNeighbors[i] = tempNeighbor{ID: candID, Distance: h.distCalc.Calculate(targetNode.Vector, h.nodes[candID].Vector)}
        }
        sort.Slice(tempNeighbors, func(i, j int) bool {
            return tempNeighbors[i].Distance < tempNeighbors[j].Distance
        })

        selected := make([]uint64, 0, maxConnections)
        for i := 0; i < len(tempNeighbors) && len(selected) < maxConnections; i++ {
            selected = append(selected, tempNeighbors[i].ID)
        }
        return selected
    }

    // Heuristic selection (simplified for demonstration, full implementation is more involved):
    // This heuristic (referred to as "extend-candidates" or "select-neighbors-heuristic" in papers)
    // aims to select neighbors that are not only close to the target node but also
    // not too close to each other, promoting graph connectivity and diversity.
    //
    // A common approach involves:
    // 1. Maintain a min-heap of `candidates` (nodes to explore, ordered by distance to target).
    // 2. Maintain a max-heap of `selected` nodes (best `maxConnections` neighbors, ordered by distance to target).
    // 3. Iteratively pop the closest node `c` from `candidates`.
    // 4. If `c` is not "too close" to any node already in `selected` (e.g., its distance to existing selected nodes
    //    is greater than `dist(target, c)`), then add `c` to `selected`.
    // 5. If `selected` exceeds `maxConnections`, remove the farthest node from `selected`.
    // 6. This process continues until `candidates` is empty or `maxConnections` neighbors are chosen.
    //
    // For this lecture, we'll demonstrate a simplified heuristic by just taking the closest candidates
    // and then applying a basic check to prevent adding extremely redundant neighbors (not fully robust).
    // For a production-grade implementation, refer to the original HNSW paper's `selectNeighbors` algorithm.

    // For now, fall back to non-heuristic for simplicity, or implement a basic version.
    // A basic heuristic might involve:
    // 1. Sort candidates by distance to target.
    // 2. Add closest candidate to selected.
    // 3. For subsequent candidates, only add if its distance to target is less than the max distance in selected AND
    //    its distance to existing selected elements is not too small (to avoid redundancy).
    // This is often implemented with a temporary max-heap (W_max) for selected candidates and a min-heap for unvisited candidates.

    // A very basic approximation of the heuristic for the sake of brevity:
    // Take the closest ones, then slightly prefer those which are further from already selected neighbors.
    // This is NOT the full HNSW heuristic.

    type scoredNeighbor struct {
        ID uint64
        Distance float32 // Distance to targetNode
    }

    scoredCandidates := make([]scoredNeighbor, len(candidates))
    for i, candID := range candidates {
        scoredCandidates[i] = scoredNeighbor{ID: candID, Distance: h.distCalc.Calculate(targetNode.Vector, h.nodes[candID].Vector)}
    }

    sort.Slice(scoredCandidates, func(i, j int) bool {
        return scoredCandidates[i].Distance < scoredCandidates[j].Distance
    })

    selected := make([]uint64, 0, maxConnections)
    selectedSet := make(map[uint64]struct{}) // To quickly check if a node is already selected

    for _, cand := range scoredCandidates {
        if len(selected) >= maxConnections {
            break
        }

        // Check if this candidate is too close to any already selected node.
        // This is a simplified redundancy check.
        isRedundant := false
        for _, sID := range selected {
            if h.distCalc.Calculate(h.nodes[sID].Vector, h.nodes[cand.ID].Vector) < cand.Distance * 0.5 { // Arbitrary threshold
                isRedundant = true
                break
            }
        }

        if !isRedundant {
            selected = append(selected, cand.ID)
            selectedSet[cand.ID] = struct{}{}
        }
    }

    // If after this simplified heuristic, we still don't have enough neighbors,
    // just fill up with the closest remaining ones.
    if len(selected) < maxConnections {
        for _, cand := range scoredCandidates {
            if len(selected) >= maxConnections {
                break
            }
            if _, exists := selectedSet[cand.ID]; !exists {
                selected = append(selected, cand.ID)
                selectedSet[cand.ID] = struct{}{}
            }
        }
    }

    return selected
}

Add 方法解析:

  • ID 和层数分配: 为新节点分配唯一 ID 和随机层数。
  • 全局入口点处理: 如果是第一个节点,直接成为入口点。如果新节点的层数高于当前最大层数,它也将成为新的入口点。
  • 逐层搜索入口: 从全局入口点开始,从最高层向下遍历到 newNodeMaxLayer+1,在每层找到一个与新向量最近的节点,作为下一层搜索的起点。
  • 逐层插入:newNodeMaxLayer 向下到 L0,在每一层执行:
    • 调用 searchLayer 找到 efConstruction 个候选邻居。
    • 调用 selectNeighbors 从候选集中选择 M (或 M_max_0) 个最佳邻居。
    • 双向连接:将新节点与选定的邻居连接,并确保邻居也连接回新节点。
    • 剪枝:如果任何节点的邻居数量超过 M_max (或 M_max_0),则进行剪枝,保留最佳的连接。
  • 并发控制: HNSWIndex.mu 保护全局状态,Node.mu 保护单个节点的邻居列表。这允许不同节点在插入过程中并发更新其邻居。
3.4.2 Search 方法 (查询 K 个最近邻)

Search 方法用于查询与给定向量最相似的 K 个向量。

// Search finds the k nearest neighbors to the query vector.
func (h *HNSWIndex) Search(query Vector, k int) ([]uint64, error) {
    h.mu.RLock() // Read lock for overall index structure
    currentEntryPointID := h.entryPointID
    currentMaxLayer := h.maxLayer
    initialised := h.initialised
    h.mu.RUnlock()

    if !initialised || currentEntryPointID == 0 { // Empty index
        return []uint64{}, nil
    }

    ep := h.nodes[currentEntryPointID]
    if ep == nil {
        return []uint64{}, fmt.Errorf("entry point node %d not found", currentEntryPointID)
    }

    // Step 1: Find a good entry point for the target layer (layer 0)
    // Traverse down from maxLayer to layer 1 (or 0 if maxLayer is 0)
    currObj := ep // Start from the global entry point
    currDist := h.distCalc.Calculate(query, currObj.Vector)

    for L := currentMaxLayer; L > 0; L-- {
        changed := true
        for changed {
            changed = false
            currObj.mu.RLock()
            neighbors := currObj.Neighbors[L]
            currObj.mu.RUnlock()

            for _, neighborID := range neighbors {
                neighbor := h.nodes[neighborID]
                dist := h.distCalc.Calculate(query, neighbor.Vector)
                if dist < currDist {
                    currDist = dist
                    currObj = neighbor
                    changed = true
                }
            }
        }
    }
    // After this loop, currObj is the best entry point found at layer 0.

    // Step 2: Perform the actual K-NN search at layer 0 using efSearch
    // This will return a list of efSearch candidates.
    candidates := h.searchLayer(query, currObj, 0, h.config.EfSearch)

    // Step 3: From candidates, select the top k results
    // Use a max-heap to keep track of k best results, largest distance at top to be replaced.
    resultsHeap := NewMaxHeap()

    for _, candID := range candidates {
        candNode := h.nodes[candID]
        if candNode == nil { // Should not happen if candidates are valid IDs
            continue
        }
        dist := h.distCalc.Calculate(query, candNode.Vector)
        if resultsHeap.Len() < k {
            heap.Push(resultsHeap, &Item{ID: candID, Distance: dist})
        } else if dist < resultsHeap.Peek().Distance {
            heap.Pop(resultsHeap) // Remove the current farthest
            heap.Push(resultsHeap, &Item{ID: candID, Distance: dist})
        }
    }

    // Extract final K-NN results in ascending order of distance
    finalKnn := make([]uint64, resultsHeap.Len())
    for i := resultsHeap.Len() - 1; i >= 0; i-- {
        finalKnn[i] = heap.Pop(resultsHeap).(*Item).ID
    }
    return finalKnn, nil
}

// searchLayer finds ef nearest neighbors to 'query' starting from 'entryPoint' at 'layer'.
// Returns a slice of node IDs, sorted by distance (closest first).
func (h *HNSWIndex) searchLayer(query Vector, entryPoint *Node, layer int, ef int) []uint64 {
    // min-heap for visited candidates (explore nodes with smallest distance first)
    candidateQueue := NewMinHeap()
    heap.Push(candidateQueue, &Item{ID: entryPoint.ID, Distance: h.distCalc.Calculate(query, entryPoint.Vector)})

    // max-heap to store the 'ef' best results found so far (largest distance at top to be replaced)
    resultQueue := NewMaxHeap()
    heap.Push(resultQueue, &Item{ID: entryPoint.ID, Distance: h.distCalc.Calculate(query, entryPoint.Vector)})

    visited := make(map[uint64]struct{})
    visited[entryPoint.ID] = struct{}{}

    for candidateQueue.Len() > 0 {
        curr := heap.Pop(candidateQueue).(*Item)

        // Optimization: if the current best result is further than the worst candidate, stop.
        // This can prune the search early but might slightly reduce recall in specific, sparse cases.
        if resultQueue.Len() > 0 && curr.Distance > resultQueue.Peek().Distance && resultQueue.Len() == ef {
            break
        }

        currNode := h.nodes[curr.ID]
        if currNode == nil {
            continue // Should not happen if ID is valid
        }

        currNode.mu.RLock() // Read lock for current node's neighbors
        neighbors := currNode.Neighbors[layer]
        currNode.mu.RUnlock()

        for _, neighborID := range neighbors {
            if _, ok := visited[neighborID]; !ok {
                visited[neighborID] = struct{}{}
                neighborNode := h.nodes[neighborID]
                if neighborNode == nil {
                    continue // Should not happen if ID is valid
                }
                dist := h.distCalc.Calculate(query, neighborNode.Vector)

                if resultQueue.Len() < ef || dist < resultQueue.Peek().Distance {
                    heap.Push(candidateQueue, &Item{ID: neighborID, Distance: dist})
                    heap.Push(resultQueue, &Item{ID: neighborID, Distance: dist})
                    for resultQueue.Len() > ef {
                        heap.Pop(resultQueue) // Maintain ef best results
                    }
                }
            }
        }
    }

    // Extract results from the max-heap (which are the ef best results)
    finalResults := make([]uint64, resultQueue.Len())
    // Pop all elements, they will be in reverse order of distance (farthest first)
    // We want them closest first, so we collect and then reverse.
    temp := make([]*Item, 0, resultQueue.Len())
    for resultQueue.Len() > 0 {
        temp = append(temp, heap.Pop(resultQueue).(*Item))
    }
    // Reverse to get closest first
    for i := len(temp) - 1; i >= 0; i-- {
        finalResults[len(temp)-1-i] = temp[i].ID
    }
    return finalResults
}

Search 方法解析:

  • 入口点遍历: 从全局入口点开始,从最高层向下遍历到 L1。在每一层,通过贪婪搜索找到与查询向量最近的节点,作为下一层搜索的起点。这个过程不维护 K 个结果,只是为了快速定位。
  • L0 层精细搜索:L0 层(或目标查询层),从上一步找到的入口点开始,执行 searchLayer 操作。
    • searchLayer 使用 minHeap (candidateQueue) 来探索邻居,确保总是优先探索距离查询向量最近的节点。
    • 同时使用 maxHeap (resultQueue) 来维护 efSearch 个最佳结果。efSearch 是一个关键参数,它决定了搜索的广度,越大则召回率越高但查询时间越长。
    • visited map 用于防止重复访问节点,避免死循环。
    • 优化:当 candidateQueue 中距离最远的候选节点比 resultQueue 中距离最近的节点还要远时,可以提前终止搜索,因为再探索下去也不会找到更好的结果。
  • 结果排序: searchLayer 返回的是 efSearch 个候选 ID,Search 方法再从这些候选者中筛选出 k 个最佳结果并返回。

四、Go 并发优化与性能考量

Go 语言在并发方面有着天然的优势。在 HNSW 的实现中,我们可以利用 Goroutine 和 Channel 进行优化。

4.1 索引构建的并发

HNSW 的 Add 操作是 CPU 密集型的,特别是距离计算和邻居选择。虽然单个 Add 操作内部有锁,但如果能批量的并发添加多个向量,将显著提升构建速度。

// AddBatch inserts multiple vectors into the HNSW index concurrently.
func (h *HNSWIndex) AddBatch(vectors []Vector) ([]uint64, error) {
    var wg sync.WaitGroup
    ids := make([]uint64, len(vectors))
    errs := make(chan error, len(vectors))

    for i, vec := range vectors {
        wg.Add(1)
        go func(idx int, vectorToAdd Vector) {
            defer wg.Done()
            id, err := h.Add(vectorToAdd)
            if err != nil {
                errs <- fmt.Errorf("failed to add vector %d: %w", idx, err)
                return
            }
            ids[idx] = id
        }(i, vec)
    }

    wg.Wait()
    close(errs)

    for err := range errs {
        if err != nil {
            return nil, err
        }
    }
    return ids, nil
}

解析:

  • AddBatch 方法使用 sync.WaitGroup 来等待所有 Goroutine 完成。
  • 每个向量的插入操作在一个独立的 Goroutine 中执行,调用 h.Add(vectorToAdd)
  • h.Add 内部的 Node.mu 锁保证了对单个节点邻居列表的并发安全,HNSWIndex.mu 保护了全局的入口点和最大层数。这使得多个 Add 操作可以安全地并行。
  • errs channel 用于收集 Goroutine 中可能发生的错误。

4.2 内存管理

  • 预分配切片容量:NewNode 中,我们为 Neighbors 切片预先分配了 M_max (或 M_max_0) 的容量,可以减少后续 append 操作时的内存重新分配开销。
  • *`map[uint64]Node:** Go 的 map 是一种高效的数据结构,但存储大量指针仍会产生一定的内存开销。对于非常大的数据集,可以考虑将Node结构体中的Vector` 字段优化为只存储向量的索引或偏移量,实际向量数据存储在一个大的连续数组中,以提高缓存局部性。

4.3 距离计算优化

距离计算是 HNSW 的主要性能瓶颈之一。

  • 平方距离: 欧氏距离的平方距离与欧氏距离本身在比较大小时是等价的,因此在很多地方我们直接计算平方距离,避免了昂贵的 math.Sqrt 操作。
  • SIMD 指令: 对于高性能的向量计算,可以考虑使用 Go 的 go:asmgolang.org/x/sys/cpu 包来利用 CPU 的 SIMD (Single Instruction, Multiple Data) 指令集(如 AVX2, AVX512)进行并行计算。这通常需要手写汇编或使用专门的库,但能带来数倍的性能提升。

五、性能压测与结果分析

为了验证我们实现的 HNSW 库的性能,我们需要进行严格的压测。

5.1 压测数据集

  • SIFT1M: 经典的计算机视觉数据集,包含 100 万个 128 维浮点向量。通常用于评估 ANN 算法。
  • GloVe: 文本嵌入向量,例如 100 维或 300 维。
  • Synthetic Data: 可以生成随机的高维浮点向量,用于控制数据规模和维度,方便参数调优。

5.2 压测指标

  • 构建时间 (Build Time): 插入所有向量并构建索引所需的时间。
  • 查询时间 (Query Time): 单次查询 K 个最近邻的平均时间。通常会计算 QPS (Queries Per Second) 或 P95/P99 延迟。
  • 召回率 (Recall@K): 衡量 ANN 算法精度。计算方式为:Recall@K = (ANN 算法找到的 K 个最近邻中与真实 K 个最近邻重叠的数量) / K。通常会比较不同 efSearch 参数下的召回率。
  • 内存占用 (Memory Usage): 索引在内存中占用的总空间。

5.3 参数调优示例

以下是一个简单的压测框架示例,以及如何调整参数进行性能分析:


package main

import (
    "fmt"
    "log"
    "math/rand"
    "time"

    "your_module_path/hnsw" // Replace with your actual module path
)

// generateRandomVectors generates N random D-dimensional float32 vectors.
func generateRandomVectors(numVectors, dim int) []hnsw.Vector {
    vectors := make([]hnsw.Vector, numVectors)
    src := rand.NewSource(time.Now().UnixNano())
    r := rand.New(src)
    for i := 0; i < numVectors; i++ {
        vec := make(hnsw.Float32Vector, dim)
        for j := 0; j < dim; j++ {
            vec[j] = r.Float32() * 100 // Scale to a reasonable range
        }
        vectors[i] = vec
    }
    return vectors
}

func main() {
    numVectors := 100000 // 100k vectors
    dim := 128           // 128 dimensions
    k := 10              // Find 10 nearest neighbors
    numQueries := 1000   // Number of queries to run

    // Generate synthetic data
    fmt.Printf("Generating %d vectors with %d dimensions...n", numVectors, dim)
    dataVectors := generateRandomVectors(numVectors, dim)
    queryVectors := generateRandomVectors(numQueries, dim)
    fmt.Println("Data generation complete.")

    // --- HNSW Configuration and Benchmarking ---
    configs := []struct {
        Name string
        Config hnsw.Config
    }{
        {
            Name: "DefaultConfig",
            Config: hnsw.DefaultConfig(),
        },
        {
            Name: "HighRecallLowSpeed", // Higher efConstruction, efSearch for better recall
            Config: hnsw.Config{
                M:             24,
                MaxM:          24,
                MaxM0:         48,
                EfConstruction: 400,
                EfSearch:      100,
                Heuristic:     true,
                DistanceType:  "euclidean",
                Multiplier:    1 / math.Log(24.0),
            },
        },
        {
            Name: "LowRecallHighSpeed", // Lower efConstruction, efSearch for faster queries
            Config: hnsw.Config{
                M:             12,
                MaxM:          12,
                MaxM0:         24,
                EfConstruction: 100,
                EfSearch:      20,
                Heuristic:     true,
                DistanceType:  "euclidean",
                Multiplier:    1 / math.Log(12.0),
            },
        },
    }

    for _, cfgEntry := range configs {
        fmt.Printf("n--- Benchmarking with %s (%+v) ---n", cfgEntry.Name, cfgEntry.Config)

        hnswIndex, err := hnsw.NewHNSWIndex(cfgEntry.Config)
        if err != nil {
            log.Fatalf("Failed to create HNSW index: %v", err)
        }

        // --- Build Time ---
        buildStartTime := time.Now()
        // Using AddBatch for concurrent insertion
        _, err = hnswIndex.AddBatch(dataVectors)
        if err != nil {
            log.Fatalf("Failed to add batch: %v", err)
        }
        buildDuration := time.Since(buildStartTime)
        fmt.Printf("Build Time: %sn", buildDuration)

        // --- Query Time and Recall (simplified, true recall needs ground truth) ---
        // For a full recall test, you'd need a dataset with pre-computed ground truth nearest neighbors.
        // Here, we just measure query time.
        queryStartTime := time.Now()
        for i := 0; i < numQueries; i++ {
            _, err := hnswIndex.Search(queryVectors[i], k)
            if err != nil {
                log.Printf("Query %d failed: %v", i, err)
            }
        }
        queryDuration := time.Since(queryStartTime)
        avgQueryTime := float64(queryDuration.Nanoseconds()) / float64(numQueries) / 1e6 // ms
        fmt.Printf("Average Query Time (k=%d): %.3f msn", k, avgQueryTime)
        fmt.Printf("Queries Per Second (QPS): %.2fn", float64(numQueries)/queryDuration.Seconds())

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注