什么是 ‘WAL (Write-Ahead Logging) Checkpointing’:在 Go 中实现一个高性能、可故障恢复的崩溃一致性逻辑

WAL Checkpointing:在 Go 中实现高性能、可故障恢复的崩溃一致性逻辑

数据存储系统的核心挑战之一是确保数据在面对系统崩溃、电源故障或硬件错误时仍然保持完整和可用。仅仅将数据写入磁盘是不够的,因为操作系统的缓冲、硬件缓存以及写入操作的非原子性都可能导致部分写入或数据不一致。为了解决这些问题,数据库系统发展出了复杂的机制来保证所谓的“崩溃一致性”(Crash Consistency)和“原子性”(Atomicity)。其中,预写日志 (Write-Ahead Logging, WAL) 是基石,而 WAL Checkpointing 则是其不可或缺的优化和管理机制。

本讲座将深入探讨 WAL Checkpointing 的原理、它在实现高性能和故障恢复中的作用,并结合 Go 语言,演示如何构建一个具备这些特性的存储逻辑。

1. 数据持久性与崩溃一致性的挑战

在深入 WAL 之前,我们首先理解数据持久化所面临的基本困境。当应用程序修改数据时,这些数据通常首先存在于内存中。要使其持久化,需要将其写入非易失性存储(如硬盘、SSD)。这个过程看似简单,实则充满陷阱:

  • 写入粒度不匹配: 应用程序通常以字节或结构体为单位操作数据,而磁盘通常以块(例如 4KB 或 8KB)为单位进行读写。这意味着一个逻辑上的原子操作可能需要修改磁盘上的多个块。
  • 操作系统与硬件缓存: 操作系统通常会缓存写入操作,而不是立即将其刷新到物理磁盘。硬件控制器也有自己的缓存。这意味着 write() 系统调用返回成功并不代表数据已安全写入非易失性存储。
  • 部分写入与撕裂写入: 如果在写入一个多块数据结构的过程中发生崩溃,只有部分块被写入磁盘,导致数据结构处于损坏或不一致的状态,这被称为“部分写入”或“撕裂写入”(Torn Write)。
  • 原子性需求: 数据库事务(Transaction)要求其所有操作要么全部成功提交,要么全部失败回滚,不能出现中间状态。这被称为原子性(Atomicity)。

为了应对这些挑战,我们需要一种机制,能够在系统崩溃后,将数据恢复到一个已知的、一致的状态,并且所有已提交的事务都已持久化,所有未提交的事务都已回滚。这就是 崩溃一致性 的目标。

2. 预写日志 (WAL) 的核心原理

WAL 是实现崩溃一致性和事务原子性的核心技术。其基本原则非常简单但极其强大:在对数据进行任何实际修改之前,必须先将描述这些修改的日志记录写入并刷新到持久化存储。

这个原则确保了两件事:

  1. 原子性 (Atomicity): 如果系统在数据修改过程中崩溃,我们可以通过查看日志来判断一个事务是否已完整提交。如果日志中存在事务的提交记录,那么即使数据本身尚未完全写入磁盘,我们也知道它“应该”是完整的。
  2. 持久性 (Durability): 一旦事务的提交记录被写入并刷新到 WAL,即使随后发生系统崩溃,该事务的修改也能够通过日志重做(Redo)来恢复。

2.1 WAL 的工作流程

  1. 修改数据: 当应用程序请求修改数据(例如,更新一个数据库页面)时,首先在内存中执行修改。
  2. 生成日志记录: 为此修改生成一条日志记录。这条记录包含了足够的信息来重做(Redo)或撤销(Undo)该修改。例如,它可能包含:
    • 日志序列号 (LSN): 唯一标识日志记录的单调递增数字。
    • 记录类型: 例如,页面更新、事务开始、事务提交、事务中止等。
    • 页面 ID: 被修改的数据页的标识符。
    • 旧值和新值(或 Delta): 用于重做或撤销操作所需的数据。
    • 事务 ID: 关联到特定事务。
  3. 写入 WAL: 将日志记录追加到 WAL 文件(通常是内存缓冲区,然后批量刷新)。
  4. 刷新 WAL: 在事务提交时,必须将所有相关日志记录强制刷新(fsync()fdatasync())到持久化存储。这是保证持久性的关键步骤。
  5. 修改数据页: 只有在日志记录被安全写入磁盘之后,被修改的内存中的数据页才可以异步地写入其在数据文件中的实际位置。

2.2 崩溃恢复过程(基于 WAL)

当系统从崩溃中恢复时,WAL 是其唯一的真理来源。恢复过程通常分为两个阶段:

  1. 重做阶段 (Redo Phase):

    • 从 WAL 的开头(或某个已知的一致点)开始扫描。
    • 对于每一条日志记录,将其描述的修改重新应用到数据文件中。
    • 这个阶段的目标是将数据库带到崩溃前那一刻的最新状态,包括所有已提交和未提交事务的修改。
  2. 撤销阶段 (Undo Phase):

    • 在重做阶段结束后,数据库可能包含了未提交事务的部分修改。
    • 撤销阶段识别所有在崩溃时仍处于活动状态(未提交)的事务。
    • 对于这些未提交事务的修改,通过日志记录中记录的旧值或反向操作来撤销它们,将数据恢复到这些事务开始之前的状态。

通过这两个阶段,系统能够确保所有已提交的事务都已持久化,并且所有未提交的事务都已回滚,从而恢复到一个崩溃一致的状态。

3. 纯 WAL 的局限性:日志膨胀

纯粹的 WAL 方案存在一个显著的问题:日志文件会无限增长。 随着时间的推移,日志文件将占用巨大的磁盘空间。更重要的是,恢复时间会变得非常长,因为每次恢复都需要从日志的起点开始重做所有操作。这在生产环境中是不可接受的。

为了解决日志膨胀和恢复时间过长的问题,我们需要一种机制来定期地“清理”WAL,丢弃不再需要的旧日志记录。这就是 Checkpointing 的作用。

4. 引入 Checkpointing:WAL 的优化与管理

Checkpointing 是一种定期创建数据库一致性快照的机制。它的核心思想是:通过将内存中已修改的数据页(脏页)强制写入到其在数据文件中的实际位置,来“固化”一部分数据状态,从而使得在某个时间点之前的 WAL 记录不再需要用于恢复。

4.1 Checkpointing 的目标

  • 限制 WAL 的增长: 允许系统删除旧的、不再需要的日志记录。
  • 缩短恢复时间: 恢复不再需要从 WAL 的起点开始,而是从最近的有效检查点开始。
  • 确保数据一致性: 在崩溃发生时,检查点提供了一个已知的、一致的恢复起点。

4.2 Checkpointing 的类型

主要有两种检查点类型:

  1. 阻塞式检查点 (Blocking Checkpointing):

    • 在检查点开始时,系统会暂停所有新的事务和数据修改操作。
    • 将所有内存中的脏页强制刷新到磁盘。
    • 等待所有当前活动的事务完成。
    • 完成后,写入一个检查点记录,指示在此之前的 WAL 都可以被截断。
    • 优点: 实现相对简单。
    • 缺点: 极大地影响系统性能和可用性,因为它会暂停所有操作。不适用于高并发系统。
  2. 模糊检查点 / 异步检查点 (Fuzzy Checkpointing / Asynchronous Checkpointing):

    • 这是现代高性能数据库系统(如 PostgreSQL, RocksDB, LevelDB)普遍采用的方法。
    • 它允许事务和数据修改在检查点过程中继续进行。
    • 优点: 对系统性能影响最小,可以在后台异步执行。
    • 缺点: 实现更为复杂,恢复逻辑也更复杂。

本讲座将重点讨论更复杂但更实用的 模糊检查点

5. WAL Checkpointing 机制深度解析 (模糊检查点)

模糊检查点的核心在于,它允许读写操作在检查点进行时继续。为了实现这一点,检查点记录需要包含更多信息,并且恢复过程也需要更智能地处理。

5.1 检查点记录 (Checkpoint Record)

一个典型的模糊检查点记录(通常是 CheckpointEndRecord)会包含以下关键信息:

| 字段名称 | 描述 min_lsn_for_recovery (LSN): 这个 LSN 是指所有未持久化到数据文件的数据页中,最早的那个 LSN。简单来说,这是在崩溃恢复时,WAL 需要被扫描并重做操作的起点。在此之前的 WAL 记录可以被安全地截断。

  • active_transactions (Transaction IDs): 记录在检查点开始时仍处于活动状态的事务 ID 列表。恢复时需要对这些事务进行撤销操作。
  • dirty_pages_at_ckpt_start (Page IDs): 记录在检查点开始时内存中的所有脏页的 ID 列表。这些页面的最新修改可能尚未写入数据文件,需要确保它们在检查点结束前被刷新。
  • checkpoint_lsn (LSN): 记录 CheckpointEndRecord 本身在 WAL 中的 LSN。

5.2 模糊检查点过程 (Fuzzy Checkpoint Process)

  1. 启动检查点 (Start Checkpoint):

    • 在一个相对空闲的时刻(或定期触发),系统决定启动一个检查点。
    • 获取一个轻量级锁或 Latch,以原子方式捕获当前系统中所有活跃事务的列表 TxList_active_at_ckpt_start,以及所有内存中脏页的列表 DirtyPages_at_ckpt_start
    • 记录当前的 WAL 尾部 LSN,称之为 checkpoint_start_LSN
    • 将一个 CheckpointStartRecord 写入 WAL,标记检查点过程的开始。这个记录通常包含 checkpoint_start_LSN
    • 释放轻量级锁,允许正常数据库操作继续。
  2. 异步刷新脏页 (Asynchronous Flushing of Dirty Pages):

    • 在后台线程中,系统开始将 DirtyPages_at_ckpt_start 列表中的脏页异步地刷新到它们在数据文件中的实际位置。
    • 这个刷新过程是并发的,其他事务仍然可以修改这些页面。如果一个页面在刷新过程中被再次修改,它将保持为脏,并在 PageCache 中记录其新的 lastLSN。检查点只关心在 checkpoint_start_LSN 之前的所有修改都已持久化。
    • 为了确保写入数据文件时的原子性,通常会使用写时复制(Copy-on-Write)或影子分页(Shadow Paging)技术,或者直接对页进行覆盖写入,但需要保证页内写入的原子性(例如,在页面头部写入一个校验和或 LSN)。
  3. 完成检查点 (Complete Checkpoint):

    • 当所有在 checkpoint_start_LSN 时刻为脏的页面都已成功刷新到磁盘时,系统进入完成阶段。
    • 确定 min_recovery_LSN:这是恢复所需的最小 LSN。它通常是 checkpoint_start_LSN 和所有在 TxList_active_at_ckpt_start 中事务的 start_LSN 中的最小值。换句话说,任何早于 min_recovery_LSN 的 WAL 记录都将不再需要。
    • CheckpointEndRecord 写入 WAL,其中包含 min_recovery_LSNTxList_active_at_ckpt_start 等信息,并刷新 WAL。
    • 更新持久化的检查点元数据(通常是一个单独的文件或 WAL 的特定位置),指向最新的 CheckpointEndRecord 的 LSN。
  4. 截断 WAL (Truncate WAL):

    • 一旦 CheckpointEndRecord 被安全写入并持久化,所有早于 min_recovery_LSN 的 WAL 记录都可以被安全地删除或归档。这有效地限制了 WAL 的大小。

5.3 崩溃恢复过程 (基于模糊检查点)

有了检查点,恢复过程被大大简化和加速:

  1. 定位最新检查点:

    • 系统启动时,首先读取持久化的检查点元数据,找到最近一个有效的 CheckpointEndRecord 的 LSN。
    • 从这个 LSN 开始,向后扫描 WAL,查找实际的 CheckpointEndRecord。这个记录包含了 min_recovery_LSNTxList_active_at_ckpt_start
    • 如果找不到任何有效的 CheckpointEndRecord(例如,系统在启动后从未成功完成过检查点),则需要从 WAL 的最开始进行恢复。
  2. 重做阶段 (Redo Phase):

    • CheckpointEndRecord 中指定的 min_recovery_LSN 开始,向前扫描 WAL 直到 WAL 的末尾。
    • 维护一个 活动事务表 (Active Transaction Table, ATT) 和一个 脏页表 (Dirty Page Table, DPT)
    • 对于每个日志记录:
      • 如果它是事务开始记录,将其事务 ID 添加到 ATT。
      • 如果它是页面更新记录,将受影响的页面 ID 添加到 DPT,并更新其 rec_LSN(记录中包含的 LSN)。
      • 如果它是事务提交/中止记录,将其事务 ID 从 ATT 中移除。
      • 重做逻辑: 只有当日志记录的 LSN 大于或等于数据页上记录的 LSN(或者数据页不在 DPT 中)时,才应用该日志记录的修改。这确保了我们不会重做已经存在于数据页上的旧修改。
    • 将所有日志记录的修改应用到内存中的页面上(或者直接写入磁盘,但通常是先在内存中构建状态)。
  3. 撤销阶段 (Undo Phase):

    • 重做阶段结束后,ATT 中仍然包含在崩溃时未提交的事务。
    • 对于 ATT 中的每个事务,从 WAL 中找到其所有相关的日志记录。
    • 以相反的顺序应用这些日志记录的撤销操作(使用日志记录中的旧值或反向操作),将数据恢复到这些事务开始之前的状态。
    • 撤销完成后,将事务 ID 从 ATT 中移除。

一旦 Redo 和 Undo 阶段都完成,数据库就恢复到了一个崩溃一致的状态,所有已提交的事务都已持久化,所有未提交的事务都已回滚。

6. 在 Go 中实现一个高性能、可故障恢复的崩溃一致性逻辑

现在,我们将这些理论概念转化为 Go 语言的实际代码结构。为了简洁和聚焦核心概念,我们将构建一个简化的、概念性的存储引擎骨架。

6.1 核心组件设计

我们将定义以下核心组件:

  1. LogRecord 定义日志记录的接口和具体实现。
  2. WAL 负责日志记录的追加、刷新和读取。
  3. PagePageCache 管理内存中的数据页,追踪脏页。
  4. StorageEngine 协调 WALPageCache,提供数据操作接口。
  5. CheckpointManager 负责触发和执行检查点。
  6. RecoveryManager 负责系统启动时的崩溃恢复。

6.2 数据结构和接口定义

package main

import (
    "bytes"
    "encoding/binary"
    "encoding/gob"
    "fmt"
    "io"
    "os"
    "path/filepath"
    "sync"
    "time"
)

// LSN (Log Sequence Number)
type LSN uint64

// PageID identifies a data page.
type PageID uint64

// TransactionID identifies a transaction.
type TransactionID uint64

// LogRecordType defines the type of a log record.
type LogRecordType uint8

const (
    TypeNoOp LogRecordType = iota // No operation
    TypePageUpdate               // Update a data page
    TypeTransactionStart         // Transaction starts
    TypeTransactionCommit        // Transaction commits
    TypeTransactionAbort         // Transaction aborts
    TypeCheckpointStart          // Checkpoint starts
    TypeCheckpointEnd            // Checkpoint ends
)

// LogRecord interface represents a generic log record.
type LogRecord interface {
    GetType() LogRecordType
    GetLSN() LSN
    SetLSN(lsn LSN)
    Serialize() ([]byte, error)
    Deserialize([]byte) error
    String() string
}

// BaseLogRecord provides common fields for log records.
type BaseLogRecord struct {
    LSNVal LSN
    TypeVal LogRecordType
}

func (b *BaseLogRecord) GetType() LogRecordType { return b.TypeVal }
func (b *BaseLogRecord) GetLSN() LSN           { return b.LSNVal }
func (b *BaseLogRecord) SetLSN(lsn LSN)        { b.LSNVal = lsn }

// PageUpdateRecord represents an update to a data page.
type PageUpdateRecord struct {
    BaseLogRecord
    PageID PageID
    Offset uint16
    Length uint11 // Max 2048 bytes, assuming 4KB page size
    OldData []byte
    NewData []byte
    TxID TransactionID
}

func NewPageUpdateRecord(pageID PageID, offset uint16, oldData, newData []byte, txID TransactionID) *PageUpdateRecord {
    return &PageUpdateRecord{
        BaseLogRecord: BaseLogRecord{TypeVal: TypePageUpdate},
        PageID: pageID,
        Offset: offset,
        Length: uint11(len(newData)), // Assuming len(oldData) == len(newData) for simplicity
        OldData: oldData,
        NewData: newData,
        TxID: txID,
    }
}

func (r *PageUpdateRecord) Serialize() ([]byte, error) {
    var buf bytes.Buffer
    enc := gob.NewEncoder(&buf)
    err := enc.Encode(r)
    return buf.Bytes(), err
}

func (r *PageUpdateRecord) Deserialize(data []byte) error {
    dec := gob.NewDecoder(bytes.NewReader(data))
    return dec.Decode(r)
}

func (r *PageUpdateRecord) String() string {
    return fmt.Sprintf("PageUpdate{LSN:%d, PageID:%d, Offset:%d, Len:%d, TxID:%d}", r.LSNVal, r.PageID, r.Offset, r.Length, r.TxID)
}

// TransactionStartRecord marks the beginning of a transaction.
type TransactionStartRecord struct {
    BaseLogRecord
    TxID TransactionID
}

func NewTransactionStartRecord(txID TransactionID) *TransactionStartRecord {
    return &TransactionStartRecord{
        BaseLogRecord: BaseLogRecord{TypeVal: TypeTransactionStart},
        TxID: txID,
    }
}

func (r *TransactionStartRecord) Serialize() ([]byte, error) {
    var buf bytes.Buffer
    enc := gob.NewEncoder(&buf)
    err := enc.Encode(r)
    return buf.Bytes(), err
}

func (r *TransactionStartRecord) Deserialize(data []byte) error {
    dec := gob.NewDecoder(bytes.NewReader(data))
    return dec.Decode(r)
}

func (r *TransactionStartRecord) String() string {
    return fmt.Sprintf("TxStart{LSN:%d, TxID:%d}", r.LSNVal, r.TxID)
}

// TransactionCommitRecord marks the successful end of a transaction.
type TransactionCommitRecord struct {
    BaseLogRecord
    TxID TransactionID
}

func NewTransactionCommitRecord(txID TransactionID) *TransactionCommitRecord {
    return &TransactionCommitRecord{
        BaseLogRecord: BaseLogRecord{TypeVal: TypeTransactionCommit},
        TxID: txID,
    }
}

func (r *TransactionCommitRecord) Serialize() ([]byte, error) {
    var buf bytes.Buffer
    enc := gob.NewEncoder(&buf)
    err := enc.Encode(r)
    return buf.Bytes(), err
}

func (r *TransactionCommitRecord) Deserialize(data []byte) error {
    dec := gob.NewDecoder(bytes.NewReader(data))
    return dec.Decode(r)
}

func (r *TransactionCommitRecord) String() string {
    return fmt.Sprintf("TxCommit{LSN:%d, TxID:%d}", r.LSNVal, r.TxID)
}

// TransactionAbortRecord marks the failed end of a transaction.
type TransactionAbortRecord struct {
    BaseLogRecord
    TxID TransactionID
}

func NewTransactionAbortRecord(txID TransactionID) *TransactionAbortRecord {
    return &TransactionAbortRecord{
        BaseLogRecord: BaseLogRecord{TypeVal: TypeTransactionAbort},
        TxID: txID,
    }
}

func (r *TransactionAbortRecord) Serialize() ([]byte, error) {
    var buf bytes.Buffer
    enc := gob.NewEncoder(&buf)
    err := enc.Encode(r)
    return buf.Bytes(), err
}

func (r *TransactionAbortRecord) Deserialize(data []byte) error {
    dec := gob.NewDecoder(bytes.NewReader(data))
    return dec.Decode(r)
}

func (r *TransactionAbortRecord) String() string {
    return fmt.Sprintf("TxAbort{LSN:%d, TxID:%d}", r.LSNVal, r.TxID)
}

// CheckpointStartRecord marks the beginning of a checkpoint process.
type CheckpointStartRecord struct {
    BaseLogRecord
    CheckpointStartLSN LSN // The LSN at which the checkpoint process began
}

func NewCheckpointStartRecord(ckptStartLSN LSN) *CheckpointStartRecord {
    return &CheckpointStartRecord{
        BaseLogRecord: BaseLogRecord{TypeVal: TypeCheckpointStart},
        CheckpointStartLSN: ckptStartLSN,
    }
}

func (r *CheckpointStartRecord) Serialize() ([]byte, error) {
    var buf bytes.Buffer
    enc := gob.NewEncoder(&buf)
    err := enc.Encode(r)
    return buf.Bytes(), err
}

func (r *CheckpointStartRecord) Deserialize(data []byte) error {
    dec := gob.NewDecoder(bytes.NewReader(data))
    return dec.Decode(r)
}

func (r *CheckpointStartRecord) String() string {
    return fmt.Sprintf("CkptStart{LSN:%d, CkptStartLSN:%d}", r.LSNVal, r.CheckpointStartLSN)
}

// CheckpointEndRecord marks the end of a checkpoint process.
// It contains critical information for recovery.
type CheckpointEndRecord struct {
    BaseLogRecord
    MinRecoveryLSN LSN           // The LSN from which recovery should start
    ActiveTransactions []TransactionID // List of transactions active at checkpoint start
    DirtyPagesAtCkptStart []PageID      // List of dirty pages at checkpoint start (for analysis)
}

func NewCheckpointEndRecord(minLSN LSN, activeTx []TransactionID, dirtyPages []PageID) *CheckpointEndRecord {
    return &CheckpointEndRecord{
        BaseLogRecord: BaseLogRecord{TypeVal: TypeCheckpointEnd},
        MinRecoveryLSN: minLSN,
        ActiveTransactions: activeTx,
        DirtyPagesAtCkptStart: dirtyPages,
    }
}

func (r *CheckpointEndRecord) Serialize() ([]byte, error) {
    var buf bytes.Buffer
    enc := gob.NewEncoder(&buf)
    err := enc.Encode(r)
    return buf.Bytes(), err
}

func (r *CheckpointEndRecord) Deserialize(data []byte) error {
    dec := gob.NewDecoder(bytes.NewReader(data))
    return dec.Decode(r)
}

func (r *CheckpointEndRecord) String() string {
    return fmt.Sprintf("CkptEnd{LSN:%d, MinRecoveryLSN:%d, ActiveTx:%v, DirtyPages:%v}",
        r.LSNVal, r.MinRecoveryLSN, r.ActiveTransactions, r.DirtyPagesAtCkptStart)
}

// Register gob types for serialization/deserialization
func init() {
    gob.Register(&PageUpdateRecord{})
    gob.Register(&TransactionStartRecord{})
    gob.Register(&TransactionCommitRecord{})
    gob.Register(&TransactionAbortRecord{})
    gob.Register(&CheckpointStartRecord{})
    gob.Register(&CheckpointEndRecord{})
}

// Helper to deserialize a log record from bytes
func DeserializeLogRecord(data []byte) (LogRecord, error) {
    var base BaseLogRecord
    buf := bytes.NewReader(data)
    dec := gob.NewDecoder(buf)

    // Temporarily decode BaseLogRecord to get the type
    if err := dec.Decode(&base); err != nil {
        return nil, fmt.Errorf("failed to decode base log record: %w", err)
    }

    // Reset reader to decode the full record
    buf.Seek(0, io.SeekStart)
    dec = gob.NewDecoder(buf)

    var record LogRecord
    switch base.TypeVal {
    case TypePageUpdate:
        record = &PageUpdateRecord{}
    case TypeTransactionStart:
        record = &TransactionStartRecord{}
    case TypeTransactionCommit:
        record = &TransactionCommitRecord{}
    case TypeTransactionAbort:
        record = &TransactionAbortRecord{}
    case TypeCheckpointStart:
        record = &CheckpointStartRecord{}
    case TypeCheckpointEnd:
        record = &CheckpointEndRecord{}
    default:
        return nil, fmt.Errorf("unknown log record type: %d", base.TypeVal)
    }

    if err := dec.Decode(record); err != nil {
        return nil, fmt.Errorf("failed to decode concrete log record: %w", err)
    }
    return record, nil
}

6.3 WAL (Write-Ahead Log) 实现

WAL 将日志记录写入文件。为了高性能,它会缓冲写入并批量刷新。为了管理文件大小,它会使用分段文件(segments)。

const (
    WALSegmentSize     = 16 * 1024 * 1024 // 16MB per segment
    WALFileNamePattern = "wal_%016x.log"
    CheckpointMetaFile = "checkpoint.meta"
)

// WALSegment represents a single WAL file segment.
type WALSegment struct {
    file   *os.File
    startLSN LSN
    endLSN   LSN
    size   int64
    mu     sync.Mutex // Protects file operations on this segment
}

func NewWALSegment(dir string, startLSN LSN) (*WALSegment, error) {
    filename := filepath.Join(dir, fmt.Sprintf(WALFileNamePattern, startLSN))
    file, err := os.OpenFile(filename, os.O_CREATE|os.O_RDWR|os.O_APPEND, 0644)
    if err != nil {
        return nil, fmt.Errorf("failed to open WAL segment file %s: %w", filename, err)
    }
    info, err := file.Stat()
    if err != nil {
        file.Close()
        return nil, fmt.Errorf("failed to stat WAL segment file %s: %w", filename, err)
    }
    return &WALSegment{
        file:   file,
        startLSN: startLSN,
        size:   info.Size(),
    }, nil
}

func (ws *WALSegment) Append(data []byte) (int, error) {
    ws.mu.Lock()
    defer ws.mu.Unlock()
    n, err := ws.file.Write(data)
    if err == nil {
        ws.size += int64(n)
    }
    return n, err
}

func (ws *WALSegment) Sync() error {
    ws.mu.Lock()
    defer ws.mu.Unlock()
    return ws.file.Sync()
}

func (ws *WALSegment) Close() error {
    ws.mu.Lock()
    defer ws.mu.Unlock()
    return ws.file.Close()
}

// WAL manages multiple WAL segments.
type WAL struct {
    dir        string
    mu         sync.RWMutex
    segments   []*WALSegment // Ordered list of segments
    currentSeg *WALSegment
    nextLSN    LSN // The next LSN to be assigned

    buffer     bytes.Buffer // In-memory buffer for batching writes
    bufferSize int          // Current size of the buffer
    maxBufferSize int // Max size before flush
}

func NewWAL(dir string, maxBufferSize int) (*WAL, error) {
    if err := os.MkdirAll(dir, 0755); err != nil {
        return nil, fmt.Errorf("failed to create WAL directory %s: %w", dir, err)
    }

    wal := &WAL{
        dir:         dir,
        nextLSN:     1, // Start LSN from 1
        segments:    make([]*WALSegment, 0),
        maxBufferSize: maxBufferSize,
    }

    // Load existing segments or create a new one
    err := wal.loadSegments()
    if err != nil {
        return nil, err
    }
    if len(wal.segments) == 0 {
        if err := wal.createNewSegment(); err != nil {
            return nil, err
        }
    } else {
        wal.currentSeg = wal.segments[len(wal.segments)-1]
        // Determine nextLSN by scanning the last segment
        lastLSN, err := wal.determineLastLSN()
        if err != nil {
            return nil, err
        }
        wal.nextLSN = lastLSN + 1
    }

    return wal, nil
}

func (w *WAL) loadSegments() error {
    entries, err := os.ReadDir(w.dir)
    if err != nil {
        return fmt.Errorf("failed to read WAL directory: %w", err)
    }

    var segmentFiles []string
    for _, entry := range entries {
        if !entry.IsDir() && filepath.Ext(entry.Name()) == ".log" && len(entry.Name()) == len(fmt.Sprintf(WALFileNamePattern, 0)) {
            segmentFiles = append(segmentFiles, entry.Name())
        }
    }

    // Sort segment files by LSN
    // Assuming LSN is encoded in the filename
    sort.Slice(segmentFiles, func(i, j int) bool {
        var lsnI, lsnJ uint64
        fmt.Sscanf(segmentFiles[i], WALFileNamePattern, &lsnI)
        fmt.Sscanf(segmentFiles[j], WALFileNamePattern, &lsnJ)
        return lsnI < lsnJ
    })

    for _, filename := range segmentFiles {
        var startLSN uint64
        fmt.Sscanf(filename, WALFileNamePattern, &startLSN)
        seg, err := NewWALSegment(w.dir, LSN(startLSN))
        if err != nil {
            return fmt.Errorf("failed to load WAL segment %s: %w", filename, err)
        }
        w.segments = append(w.segments, seg)
    }
    return nil
}

func (w *WAL) createNewSegment() error {
    if w.currentSeg != nil {
        if err := w.currentSeg.Close(); err != nil {
            return fmt.Errorf("failed to close previous WAL segment: %w", err)
        }
    }
    newSeg, err := NewWALSegment(w.dir, w.nextLSN) // New segment starts with next available LSN
    if err != nil {
        return fmt.Errorf("failed to create new WAL segment: %w", err)
    }
    w.segments = append(w.segments, newSeg)
    w.currentSeg = newSeg
    return nil
}

// determineLastLSN scans the last segment to find the actual last LSN.
// This is crucial for recovery or restarting a WAL.
func (w *WAL) determineLastLSN() (LSN, error) {
    if w.currentSeg == nil {
        return 0, nil
    }

    var lastLSN LSN = 0
    err := w.ReadRecords(w.currentSeg.startLSN, func(record LogRecord) bool {
        lastLSN = record.GetLSN()
        return true // Continue reading
    })
    if err != nil && err != io.EOF {
        return 0, fmt.Errorf("failed to scan last WAL segment to determine last LSN: %w", err)
    }
    return lastLSN, nil
}

// AppendRecord appends a log record to the WAL.
func (w *WAL) AppendRecord(record LogRecord) (LSN, error) {
    w.mu.Lock()
    defer w.mu.Unlock()

    record.SetLSN(w.nextLSN)
    serializedData, err := record.Serialize()
    if err != nil {
        return 0, fmt.Errorf("failed to serialize log record: %w", err)
    }

    // Format: [length (4 bytes)][record data]
    length := uint32(len(serializedData))
    lengthBytes := make([]byte, 4)
    binary.BigEndian.PutUint32(lengthBytes, length)

    // Append to buffer
    w.buffer.Write(lengthBytes)
    w.buffer.Write(serializedData)
    w.bufferSize += 4 + len(serializedData)

    currentLSN := w.nextLSN
    w.nextLSN++

    // Check if buffer needs flushing or segment needs rotation
    if w.bufferSize >= w.maxBufferSize || w.currentSeg.size+int64(w.bufferSize) >= WALSegmentSize {
        return currentLSN, w.flushBufferAndRotate()
    }

    return currentLSN, nil
}

// Flush ensures all buffered records are written to disk.
func (w *WAL) Flush() error {
    w.mu.Lock()
    defer w.mu.Unlock()
    return w.flushBufferAndRotate()
}

// flushBufferAndRotate writes the buffer to the current segment and rotates if necessary.
// Must be called with w.mu held.
func (w *WAL) flushBufferAndRotate() error {
    if w.bufferSize == 0 {
        return nil
    }

    data := w.buffer.Bytes()
    n, err := w.currentSeg.Append(data)
    if err != nil {
        return fmt.Errorf("failed to write WAL buffer to segment: %w", err)
    }
    if n != w.bufferSize {
        return fmt.Errorf("partial write to WAL segment: wrote %d, expected %d", n, w.bufferSize)
    }

    if err := w.currentSeg.Sync(); err != nil {
        return fmt.Errorf("failed to sync WAL segment: %w", err)
    }

    w.buffer.Reset()
    w.bufferSize = 0

    // Rotate segment if it's full
    if w.currentSeg.size >= WALSegmentSize {
        return w.createNewSegment()
    }
    return nil
}

// ReadRecords reads log records from a starting LSN.
// The callback function returns true to continue reading, false to stop.
func (w *WAL) ReadRecords(startLSN LSN, callback func(LogRecord) bool) error {
    w.mu.RLock()
    defer w.mu.RUnlock()

    // Find the segment containing startLSN
    startIndex := -1
    for i, seg := range w.segments {
        if startLSN >= seg.startLSN {
            startIndex = i
        } else {
            break
        }
    }

    if startIndex == -1 {
        return fmt.Errorf("start LSN %d is too old or invalid", startLSN)
    }

    for i := startIndex; i < len(w.segments); i++ {
        seg := w.segments[i]
        currentOffset := int64(0)
        if startLSN > seg.startLSN {
            // Need to find the exact offset for startLSN within this segment
            // This is a simplified approach, a real WAL would use an index for faster lookup
            tempFile, err := os.Open(filepath.Join(w.dir, fmt.Sprintf(WALFileNamePattern, seg.startLSN)))
            if err != nil {
                return fmt.Errorf("failed to open segment for scanning: %w", err)
            }
            defer tempFile.Close()

            reader := bufio.NewReader(tempFile)
            var currentRecordLSN LSN
            for {
                lengthBytes := make([]byte, 4)
                _, err := io.ReadFull(reader, lengthBytes)
                if err == io.EOF {
                    break
                }
                if err != nil {
                    return fmt.Errorf("failed to read record length: %w", err)
                }
                recordLen := binary.BigEndian.Uint32(lengthBytes)

                recordData := make([]byte, recordLen)
                _, err = io.ReadFull(reader, recordData)
                if err != nil {
                    return fmt.Errorf("failed to read record data: %w", err)
                }

                record, err := DeserializeLogRecord(recordData)
                if err != nil {
                    return fmt.Errorf("failed to deserialize record during scan: %w", err)
                }
                currentRecordLSN = record.GetLSN()
                if currentRecordLSN >= startLSN {
                    break // Found the starting record
                }
                currentOffset += 4 + int64(recordLen) // Move offset past this record
            }
            // Now currentOffset points to the start of startLSN record
            if _, err := tempFile.Seek(currentOffset, io.SeekStart); err != nil {
                return fmt.Errorf("failed to seek to start LSN offset: %w", err)
            }
            reader = bufio.NewReader(tempFile) // Re-create reader after seek
        }

        // Read from the current segment file from currentOffset
        // For simplicity, we'll re-open the file. In a real system, would use a reader.
        file, err := os.Open(filepath.Join(w.dir, fmt.Sprintf(WALFileNamePattern, seg.startLSN)))
        if err != nil {
            return fmt.Errorf("failed to open WAL segment %s for reading: %w", seg.file.Name(), err)
        }
        defer file.Close()

        if _, err := file.Seek(currentOffset, io.SeekStart); err != nil {
            return fmt.Errorf("failed to seek in WAL segment %s: %w", seg.file.Name(), err)
        }

        reader := bufio.NewReader(file)
        for {
            lengthBytes := make([]byte, 4)
            _, err := io.ReadFull(reader, lengthBytes)
            if err == io.EOF {
                break
            }
            if err != nil {
                return fmt.Errorf("failed to read record length from WAL: %w", err)
            }
            recordLen := binary.BigEndian.Uint32(lengthBytes)

            recordData := make([]byte, recordLen)
            _, err = io.ReadFull(reader, recordData)
            if err != nil {
                return fmt.Errorf("failed to read record data from WAL: %w", err)
            }

            record, err := DeserializeLogRecord(recordData)
            if err != nil {
                return fmt.Errorf("failed to deserialize log record from WAL: %w", err)
            }

            if !callback(record) {
                return nil // Callback requested to stop
            }
        }
    }
    return nil
}

// TruncateBefore removes all WAL segments whose end LSN is less than minLSN.
func (w *WAL) TruncateBefore(minLSN LSN) error {
    w.mu.Lock()
    defer w.mu.Unlock()

    var newSegments []*WALSegment
    for _, seg := range w.segments {
        // If segment's start LSN is before minLSN, it might be truncated.
        // A full segment is only truncated if its *entirety* is before minLSN.
        // For simplicity, we assume a segment is eligible for deletion if its startLSN < minLSN
        // and it's not the current active segment. A more precise check would involve
        // parsing records in the segment to find its effective "end LSN".
        if seg.startLSN < minLSN && seg != w.currentSeg {
            fmt.Printf("Truncating WAL segment %s (start LSN: %d < min LSN: %d)n", seg.file.Name(), seg.startLSN, minLSN)
            if err := seg.Close(); err != nil {
                fmt.Printf("Warning: failed to close WAL segment %s during truncation: %vn", seg.file.Name(), err)
            }
            if err := os.Remove(filepath.Join(w.dir, fmt.Sprintf(WALFileNamePattern, seg.startLSN))); err != nil {
                fmt.Printf("Warning: failed to remove WAL segment file %s: %vn", seg.file.Name(), err)
            }
        } else {
            newSegments = append(newSegments, seg)
        }
    }
    w.segments = newSegments
    return nil
}

func (w *WAL) Close() error {
    w.mu.Lock()
    defer w.mu.Unlock()
    if err := w.flushBufferAndRotate(); err != nil {
        return fmt.Errorf("failed to flush WAL buffer on close: %w", err)
    }
    for _, seg := range w.segments {
        if err := seg.Close(); err != nil {
            return fmt.Errorf("failed to close WAL segment %s on close: %w", seg.file.Name(), err)
        }
    }
    return nil
}

注意: WAL.ReadRecords 方法在查找 startLSN 时,如果 startLSN 不在某个分段的开头,需要从分段开头扫描以找到确切偏移。这是一个简化的实现,实际生产系统会维护一个 WAL 索引来快速定位 LSN 在文件中的物理偏移。

6.4 Page 和 PageCache

PageCache 负责管理内存中的数据页,并跟踪哪些页是“脏”的(即在内存中被修改但尚未写入磁盘)。

const PageSize = 4096 // 4KB page size

// Page represents a single data page in memory.
type Page struct {
    ID        PageID
    Data      []byte
    IsDirty   bool // True if page has been modified in memory
    LastLSN   LSN  // LSN of the last log record that modified this page
    mu        sync.RWMutex // Protects Data and IsDirty
}

func NewPage(id PageID) *Page {
    return &Page{
        ID:   id,
        Data: make([]byte, PageSize),
    }
}

// PageCache manages in-memory data pages (buffer pool).
type PageCache struct {
    mu        sync.RWMutex
    pages     map[PageID]*Page // In-memory pages
    maxPages  int              // Max number of pages in cache
    dataDir   string           // Directory for data files
}

func NewPageCache(dataDir string, maxPages int) (*PageCache, error) {
    if err := os.MkdirAll(dataDir, 0755); err != nil {
        return nil, fmt.Errorf("failed to create data directory %s: %w", dataDir, err)
    }
    return &PageCache{
        pages:    make(map[PageID]*Page),
        maxPages: maxPages,
        dataDir:  dataDir,
    }, nil
}

// GetPage fetches a page from cache or loads it from disk.
func (pc *PageCache) GetPage(id PageID) (*Page, error) {
    pc.mu.RLock()
    page, ok := pc.pages[id]
    pc.mu.RUnlock()

    if ok {
        return page, nil
    }

    // Page not in cache, load from disk
    pc.mu.Lock()
    defer pc.mu.Unlock()

    // Check again, might have been loaded by another goroutine
    if page, ok = pc.pages[id]; ok {
        return page, nil
    }

    // Evict if cache is full (simplified LRU/random eviction)
    if len(pc.pages) >= pc.maxPages {
        pc.evictPage() // Placeholder for eviction logic
    }

    // Load from disk
    page = NewPage(id)
    filename := filepath.Join(pc.dataDir, fmt.Sprintf("data_%d.page", id))
    file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0644)
    if err != nil {
        return nil, fmt.Errorf("failed to open data file %s: %w", filename, err)
    }
    defer file.Close()

    n, err := file.Read(page.Data)
    if err != nil && err != io.EOF {
        return nil, fmt.Errorf("failed to read data page %d: %w", id, err)
    }
    if n < PageSize {
        // Page might be new or partially written, pad with zeros
        for i := n; i < PageSize; i++ {
            page.Data[i] = 0
        }
    }
    // A real system would also read the last LSN from the page header on disk

    pc.pages[id] = page
    return page, nil
}

// MarkDirty marks a page as dirty and updates its LastLSN.
func (pc *PageCache) MarkDirty(page *Page, lsn LSN) {
    page.mu.Lock()
    defer page.mu.Unlock()
    page.IsDirty = true
    page.LastLSN = lsn
}

// FlushPage writes a dirty page to disk.
func (pc *PageCache) FlushPage(page *Page) error {
    page.mu.Lock()
    defer page.mu.Unlock()

    if !page.IsDirty {
        return nil // Not dirty, no need to flush
    }

    filename := filepath.Join(pc.dataDir, fmt.Sprintf("data_%d.page", page.ID))
    file, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE, 0644)
    if err != nil {
        return fmt.Errorf("failed to open data file %s for writing: %w", filename, err)
    }
    defer file.Close()

    if _, err := file.Write(page.Data); err != nil {
        return fmt.Errorf("failed to write data page %d: %w", page.ID, err)
    }
    if err := file.Sync(); err != nil { // Ensure data is on disk
        return fmt.Errorf("failed to sync data page %d: %w", page.ID, err)
    }

    page.IsDirty = false
    return nil
}

// GetDirtyPages returns a snapshot of currently dirty pages.
func (pc *PageCache) GetDirtyPages() map[PageID]*Page {
    pc.mu.RLock()
    defer pc.mu.RUnlock()
    dirtyPages := make(map[PageID]*Page)
    for id, page := range pc.pages {
        page.mu.RLock()
        if page.IsDirty {
            dirtyPages[id] = page
        }
        page.mu.RUnlock()
    }
    return dirtyPages
}

// evictPage is a placeholder for a page eviction policy (e.g., LRU).
// For simplicity, we just delete a random page.
func (pc *PageCache) evictPage() {
    for id, page := range pc.pages {
        if page.IsDirty {
            // Cannot evict dirty page without flushing it first (or write-ahead logging)
            // For a real system, would need to flush or handle differently.
            // Here, we'd ideally flush it and then evict.
            if err := pc.FlushPage(page); err != nil {
                fmt.Printf("Warning: failed to flush page %d during eviction: %vn", id, err)
                continue // Try another page
            }
        }
        delete(pc.pages, id)
        fmt.Printf("Evicted page %dn", id)
        return // Evict one page for now
    }
}

// Close flushes all dirty pages before shutting down.
func (pc *PageCache) Close() error {
    pc.mu.Lock()
    defer pc.mu.Unlock()
    var err error
    for _, page := range pc.pages {
        if page.IsDirty {
            if flushErr := pc.FlushPage(page); flushErr != nil {
                err = fmt.Errorf("%w; failed to flush page %d: %v", err, page.ID, flushErr)
            }
        }
    }
    return err
}

6.5 TransactionManager (简化)

事务管理器负责分配事务 ID,并跟踪活动事务。

// TransactionManager manages active transactions.
type TransactionManager struct {
    mu         sync.Mutex
    nextTxID   TransactionID
    activeTx   map[TransactionID]LSN // Map of TxID to its start LSN
}

func NewTransactionManager() *TransactionManager {
    return &TransactionManager{
        nextTxID: 1,
        activeTx: make(map[TransactionID]LSN),
    }
}

func (tm *TransactionManager) StartTransaction(startLSN LSN) TransactionID {
    tm.mu.Lock()
    defer tm.mu.Unlock()
    txID := tm.nextTxID
    tm.nextTxID++
    tm.activeTx[txID] = startLSN
    return txID
}

func (tm *TransactionManager) EndTransaction(txID TransactionID) {
    tm.mu.Lock()
    defer tm.mu.Unlock()
    delete(tm.activeTx, txID)
}

func (tm *TransactionManager) GetActiveTransactions() map[TransactionID]LSN {
    tm.mu.Lock() // Need lock to ensure consistent snapshot
    defer tm.mu.Unlock()

    snapshot := make(map[TransactionID]LSN, len(tm.activeTx))
    for txID, lsn := range tm.activeTx {
        snapshot[txID] = lsn
    }
    return snapshot
}

6.6 CheckpointManager

检查点管理器协调检查点过程。

// CheckpointManager handles creating checkpoints.
type CheckpointManager struct {
    wal         *WAL
    pageCache   *PageCache
    txManager   *TransactionManager
    lastCheckpointLSN LSN // LSN of the last successfully completed CheckpointEndRecord
    metaFile    string // File to store lastCheckpointLSN
    mu          sync.Mutex // Protects checkpoint initiation
}

func NewCheckpointManager(wal *WAL, pc *PageCache, tm *TransactionManager, baseDir string) (*CheckpointManager, error) {
    cm := &CheckpointManager{
        wal:       wal,
        pageCache: pc,
        txManager: tm,
        metaFile:  filepath.Join(baseDir, CheckpointMetaFile),
    }

    // Load last checkpoint LSN from meta file
    if err := cm.loadLastCheckpointLSN(); err != nil && !os.IsNotExist(err) {
        return nil, fmt.Errorf("failed to load last checkpoint LSN: %w", err)
    }
    return cm, nil
}

func (cm *CheckpointManager) loadLastCheckpointLSN() error {
    data, err := os.ReadFile(cm.metaFile)
    if err != nil {
        return err // os.IsNotExist is handled by the caller
    }
    if len(data) != 8 { // LSN is uint64, 8 bytes
        return fmt.Errorf("invalid checkpoint meta file format")
    }
    cm.lastCheckpointLSN = LSN(binary.BigEndian.Uint64(data))
    return nil
}

func (cm *CheckpointManager) saveLastCheckpointLSN(lsn LSN) error {
    data := make([]byte, 8)
    binary.BigEndian.PutUint64(data, uint64(lsn))
    return os.WriteFile(cm.metaFile, data, 0644)
}

// PerformCheckpoint executes a fuzzy checkpoint.
func (cm *CheckpointManager) PerformCheckpoint() error {
    cm.mu.Lock()
    defer cm.mu.Unlock()

    fmt.Println("Starting checkpoint...")

    // 1. Write CheckpointStartRecord to WAL
    checkpointStartLSN, err := cm.wal.AppendRecord(NewCheckpointStartRecord(cm.wal.nextLSN))
    if err != nil {
        return fmt.Errorf("failed to write CheckpointStartRecord: %w", err)
    }
    if err := cm.wal.Flush(); err != nil { // Ensure checkpoint start is durable
        return fmt.Errorf("failed to flush WAL after CheckpointStartRecord: %w", err)
    }

    // 2. Capture snapshot of dirty pages and active transactions
    dirtyPagesAtCkptStart := cm.pageCache.GetDirtyPages() // Snapshot of dirty pages
    activeTxAtCkptStart := cm.txManager.GetActiveTransactions() // Snapshot of active transactions

    // 3. Asynchronously flush dirty pages.
    // In a real system, this would be done by a background flusher.
    // For this example, we'll flush them synchronously but conceptually they are 'at checkpoint start time'.
    var pagesToFlush []PageID
    for pageID, page := range dirtyPagesAtCkptStart {
        pagesToFlush = append(pagesToFlush, pageID)
        if err := cm.pageCache.FlushPage(page); err != nil {
            fmt.Printf("Warning: failed to flush page %d during checkpoint: %vn", pageID, err)
            // A real system would need to track unfleshed pages and retry,
            // or ensure recovery can handle them. For simplicity, we proceed.
        }
    }
    fmt.Printf("Flushed %d dirty pages for checkpoint.n", len(pagesToFlush))

    // 4. Determine MinRecoveryLSN
    minRecoveryLSN := checkpointStartLSN // Start with the LSN of CheckpointStartRecord
    for _, txStartLSN := range activeTxAtCkptStart {
        if txStartLSN < minRecoveryLSN {
            minRecoveryLSN = txStartLSN
        }
    }
    // For any page that was dirty at checkpoint start, its LastLSN might be older than minRecoveryLSN.
    // We need to ensure that the oldest LSN of any page that was dirty *at the time the checkpoint started*
    // is covered by minRecoveryLSN. A more robust system tracks this more carefully.
    // For simplicity, we assume pages are flushed and their changes are covered.

    // 5. Write CheckpointEndRecord to WAL
    activeTxIDs := make([]TransactionID, 0, len(activeTxAtCkptStart))
    for txID := range activeTxAtCkptStart {
        activeTxIDs = append(activeTxIDs, txID)
    }

    ckptEndRecord := NewCheckpointEndRecord(minRecoveryLSN, activeTxIDs, pagesToFlush)
    ckptEndLSN, err := cm.wal.AppendRecord(ckptEndRecord)
    if err != nil {
        return fmt.Errorf("failed to write CheckpointEndRecord: %w", err)
    }
    if err := cm.wal.Flush(); err != nil { // Crucial: ensure checkpoint end is durable
        return fmt.Errorf("failed to flush WAL after CheckpointEndRecord: %w", err)
    }

    // 6. Persist last checkpoint LSN
    if err := cm.saveLastCheckpointLSN(ckptEndLSN); err != nil {
        return fmt.Errorf("failed to save last checkpoint LSN to meta file: %w", err)
    }
    cm.lastCheckpointLSN = ckptEndLSN

    // 7. Truncate WAL
    if err := cm.wal.TruncateBefore(minRecoveryLSN); err != nil {
        fmt.Printf("Warning: failed to truncate WAL before LSN %d: %vn", minRecoveryLSN, err)
    }

    fmt.Printf("Checkpoint completed. LastCheckpointLSN: %d, MinRecoveryLSN: %dn", ckptEndLSN, minRecoveryLSN)
    return nil
}

6.7 RecoveryManager

恢复管理器在系统启动时执行崩溃恢复。


import "bufio" // For WAL ReadRecords

// RecoveryManager handles crash recovery using WAL and checkpoint info.
type RecoveryManager struct {
    wal       *WAL
    pageCache *PageCache
    txManager *TransactionManager
    checkpointManager *CheckpointManager
}

func NewRecoveryManager(wal *WAL, pc *PageCache, tm *TransactionManager, cm *CheckpointManager) *RecoveryManager {
    return &RecoveryManager{
        wal:       wal,
        pageCache: pc,
        txManager: tm,
        checkpointManager: cm,
    }
}

// Recover performs crash recovery.
func (rm *RecoveryManager) Recover() error {
    fmt.Println("Starting recovery...")

    // 1. Locate the latest CheckpointEndRecord
    var lastCkptEndRecord *CheckpointEndRecord
    var lastCkptEndLSN LSN

    // Scan WAL backwards from the end to find the latest CheckpointEndRecord
    // For simplicity, we'll scan forward from the beginning of the WAL.
    // A real system would efficiently locate the latest checkpoint metadata.
    // Here, we rely on checkpointManager.lastCheckpointLSN.

    // If lastCheckpointLSN is 0, no checkpoint was ever completed, need to scan from beginning of WAL (LSN 1).
    scanStartLSN := LSN(1) 
    if rm.checkpointManager.lastCheckpointLSN > 0 {
        // Found a previous checkpoint, read it to get minRecoveryLSN
        // This requires reading a specific record by LSN, which our simple WAL.ReadRecords doesn't do directly efficiently.
        // So we iterate until we find it.
        err := rm.wal.ReadRecords(rm.checkpointManager.lastCheckpointLSN, func(record LogRecord) bool {
            if ckptEnd, ok := record.(*CheckpointEndRecord); ok {
                lastCkptEndRecord = ckptEnd
                lastCkptEndLSN = ckptEnd.GetLSN()
                return false // Stop after finding
            }
            return true
        })
        if err != nil {
            fmt.Printf("Warning: Failed to read last checkpoint record at LSN %d, will scan from beginning of WAL: %vn", rm.checkpointManager.lastCheckpointLSN, err)
            lastCkptEndRecord = nil // Reset, proceed as if no checkpoint found
        }
    }

    if lastCkptEndRecord != nil {
        scanStartLSN = lastCkptEndRecord.MinRecoveryLSN
        fmt.Printf("Found latest checkpoint at LSN %d. Recovery will start from MinRecoveryLSN: %dn", lastCkptEndLSN, scanStartLSN)
    } else {
        fmt.Println("No valid checkpoint found. Recovery will scan WAL from LSN 1.")
    }

    // Data structures for recovery
    redoSet := make(map[PageID]LSN) // Pages that need redo, and their LSN
    activeTransactions := make(map[TransactionID]LSN) // TxID -> startLSN

    // 2. Redo Phase: Apply all changes from minRecoveryLSN to the end of WAL
    fmt.Printf("Starting Redo phase from LSN %d...n", scanStartLSN)
    err := rm.wal.ReadRecords(scanStartLSN, func(record LogRecord) bool {
        switch r := record.(type) {
        case *TransactionStartRecord:
            activeTransactions[r.TxID] = r.GetLSN()
        case *TransactionCommitRecord, *TransactionAbortRecord:
            delete(activeTransactions, r.TxID)
        case *PageUpdateRecord:
            // Apply update to page cache's view of the page
            page, err := rm.pageCache.GetPage(r.PageID)
            if err != nil {
                fmt.Printf("Error getting page %d for redo: %vn", r.PageID, err)
                return false // Abort recovery
            }
            page.mu.Lock()
            // Only redo if the log record is newer than what's currently on the page
            // (or what's known to be on disk if page.LastLSN represents disk state)
            // For simplicity, we assume page.LastLSN tracks the latest applied LSN from WAL.
            if r.GetLSN() > page.LastLSN {
                // Apply new data
                copy(page.Data[r.Offset:r.Offset+uint16(r.Length)], r.NewData)
                page.LastLSN = r.GetLSN()
                page.IsDirty = true // Mark as dirty, will be flushed later if needed
                redoSet[r.PageID] = r.GetLSN() // Track pages updated during redo
            }
            page.mu.Unlock()
        case *CheckpointEndRecord:
            // If we encounter another CheckpointEndRecord, it means the earlier one was valid.
            // This scenario is for robust recovery in case the initial search was incomplete.
            // For this simplified example, we'd generally use the lastCkptEndRecord found initially.
            // If a newer one is found, we should update our recovery start if needed.
            // For now, we assume initial search is correct.
        }
        return true // Continue reading
    })

    if err != nil && err != io.EOF {
        return fmt.Errorf("error during Redo phase: %w", err)
    }
    fmt.Printf("Redo phase completed. Active transactions: %vn", activeTransactions)

    // 3. Undo Phase: Rollback uncommitted transactions
    fmt.Println("Starting Undo phase...")
    for txID, txStartLSN := range activeTransactions {
        fmt.Printf("Undoing transaction %d (started at LSN %d)...n", txID, txStartLSN)

        // Scan WAL for records belonging to this transaction in reverse chronological order
        // Our WAL.ReadRecords reads forward, so we collect relevant records first, then reverse.
        var txRecords []LogRecord
        err := rm.wal.ReadRecords(txStartLSN, func(record LogRecord) bool {
            if updateRec, ok := record.(*PageUpdateRecord); ok && updateRec.TxID == txID {
                txRecords = append(txRecords, updateRec)
            } else if startRec, ok := record.(*TransactionStartRecord); ok && startRec.TxID == txID {
                txRecords = append(txRecords, startRec)
            }
            // Stop scanning if we passed the current LSN in a multi-transaction system
            // For simplicity, we collect all records for a given TxID up to the end of WAL.
            return true
        })
        if err != nil && err != io.EOF {
            return fmt.Errorf("error reading WAL for undo of transaction %d: %w", txID, err)
        }

        // Apply undo in reverse order
        for i := len(txRecords) - 1; i >= 0; i-- {
            record := txRecords[i]
            if updateRec, ok := record.(*PageUpdateRecord); ok {
                page, err := rm.pageCache.GetPage(updateRec.PageID)
                if err != nil {
                    fmt.Printf("Error getting page %d for undo of tx %d: %vn", updateRec.PageID, txID, err)
                    continue // Try to undo other records
                }
                page.mu.Lock()
                // Only undo if the current state of the page reflects this transaction's change
                // (i.e., its LSN matches the update record's LSN or is newer but needs undo)
                // ARIES-style recovery is more complex, involving CLRs (Compensation Log Records)
                // For simplicity, we just apply OldData.
                copy(page.Data[updateRec.Offset:updateRec.Offset+uint16(updateRec.Length)], updateRec.OldData)
                page.IsDirty = true // Mark as dirty
                page.mu.Unlock()
            }
        }
        // Write an Abort record to WAL to finalize the undo (optional, for explicit history)
        if _, err := rm.wal.AppendRecord(NewTransactionAbortRecord(txID)); err != nil {
            fmt.Printf("Warning: failed to write abort record for tx %d during undo: %vn", txID, err

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注