WAL Checkpointing:在 Go 中实现高性能、可故障恢复的崩溃一致性逻辑
数据存储系统的核心挑战之一是确保数据在面对系统崩溃、电源故障或硬件错误时仍然保持完整和可用。仅仅将数据写入磁盘是不够的,因为操作系统的缓冲、硬件缓存以及写入操作的非原子性都可能导致部分写入或数据不一致。为了解决这些问题,数据库系统发展出了复杂的机制来保证所谓的“崩溃一致性”(Crash Consistency)和“原子性”(Atomicity)。其中,预写日志 (Write-Ahead Logging, WAL) 是基石,而 WAL Checkpointing 则是其不可或缺的优化和管理机制。
本讲座将深入探讨 WAL Checkpointing 的原理、它在实现高性能和故障恢复中的作用,并结合 Go 语言,演示如何构建一个具备这些特性的存储逻辑。
1. 数据持久性与崩溃一致性的挑战
在深入 WAL 之前,我们首先理解数据持久化所面临的基本困境。当应用程序修改数据时,这些数据通常首先存在于内存中。要使其持久化,需要将其写入非易失性存储(如硬盘、SSD)。这个过程看似简单,实则充满陷阱:
- 写入粒度不匹配: 应用程序通常以字节或结构体为单位操作数据,而磁盘通常以块(例如 4KB 或 8KB)为单位进行读写。这意味着一个逻辑上的原子操作可能需要修改磁盘上的多个块。
- 操作系统与硬件缓存: 操作系统通常会缓存写入操作,而不是立即将其刷新到物理磁盘。硬件控制器也有自己的缓存。这意味着
write()系统调用返回成功并不代表数据已安全写入非易失性存储。 - 部分写入与撕裂写入: 如果在写入一个多块数据结构的过程中发生崩溃,只有部分块被写入磁盘,导致数据结构处于损坏或不一致的状态,这被称为“部分写入”或“撕裂写入”(Torn Write)。
- 原子性需求: 数据库事务(Transaction)要求其所有操作要么全部成功提交,要么全部失败回滚,不能出现中间状态。这被称为原子性(Atomicity)。
为了应对这些挑战,我们需要一种机制,能够在系统崩溃后,将数据恢复到一个已知的、一致的状态,并且所有已提交的事务都已持久化,所有未提交的事务都已回滚。这就是 崩溃一致性 的目标。
2. 预写日志 (WAL) 的核心原理
WAL 是实现崩溃一致性和事务原子性的核心技术。其基本原则非常简单但极其强大:在对数据进行任何实际修改之前,必须先将描述这些修改的日志记录写入并刷新到持久化存储。
这个原则确保了两件事:
- 原子性 (Atomicity): 如果系统在数据修改过程中崩溃,我们可以通过查看日志来判断一个事务是否已完整提交。如果日志中存在事务的提交记录,那么即使数据本身尚未完全写入磁盘,我们也知道它“应该”是完整的。
- 持久性 (Durability): 一旦事务的提交记录被写入并刷新到 WAL,即使随后发生系统崩溃,该事务的修改也能够通过日志重做(Redo)来恢复。
2.1 WAL 的工作流程
- 修改数据: 当应用程序请求修改数据(例如,更新一个数据库页面)时,首先在内存中执行修改。
- 生成日志记录: 为此修改生成一条日志记录。这条记录包含了足够的信息来重做(Redo)或撤销(Undo)该修改。例如,它可能包含:
- 日志序列号 (LSN): 唯一标识日志记录的单调递增数字。
- 记录类型: 例如,页面更新、事务开始、事务提交、事务中止等。
- 页面 ID: 被修改的数据页的标识符。
- 旧值和新值(或 Delta): 用于重做或撤销操作所需的数据。
- 事务 ID: 关联到特定事务。
- 写入 WAL: 将日志记录追加到 WAL 文件(通常是内存缓冲区,然后批量刷新)。
- 刷新 WAL: 在事务提交时,必须将所有相关日志记录强制刷新(
fsync()或fdatasync())到持久化存储。这是保证持久性的关键步骤。 - 修改数据页: 只有在日志记录被安全写入磁盘之后,被修改的内存中的数据页才可以异步地写入其在数据文件中的实际位置。
2.2 崩溃恢复过程(基于 WAL)
当系统从崩溃中恢复时,WAL 是其唯一的真理来源。恢复过程通常分为两个阶段:
-
重做阶段 (Redo Phase):
- 从 WAL 的开头(或某个已知的一致点)开始扫描。
- 对于每一条日志记录,将其描述的修改重新应用到数据文件中。
- 这个阶段的目标是将数据库带到崩溃前那一刻的最新状态,包括所有已提交和未提交事务的修改。
-
撤销阶段 (Undo Phase):
- 在重做阶段结束后,数据库可能包含了未提交事务的部分修改。
- 撤销阶段识别所有在崩溃时仍处于活动状态(未提交)的事务。
- 对于这些未提交事务的修改,通过日志记录中记录的旧值或反向操作来撤销它们,将数据恢复到这些事务开始之前的状态。
通过这两个阶段,系统能够确保所有已提交的事务都已持久化,并且所有未提交的事务都已回滚,从而恢复到一个崩溃一致的状态。
3. 纯 WAL 的局限性:日志膨胀
纯粹的 WAL 方案存在一个显著的问题:日志文件会无限增长。 随着时间的推移,日志文件将占用巨大的磁盘空间。更重要的是,恢复时间会变得非常长,因为每次恢复都需要从日志的起点开始重做所有操作。这在生产环境中是不可接受的。
为了解决日志膨胀和恢复时间过长的问题,我们需要一种机制来定期地“清理”WAL,丢弃不再需要的旧日志记录。这就是 Checkpointing 的作用。
4. 引入 Checkpointing:WAL 的优化与管理
Checkpointing 是一种定期创建数据库一致性快照的机制。它的核心思想是:通过将内存中已修改的数据页(脏页)强制写入到其在数据文件中的实际位置,来“固化”一部分数据状态,从而使得在某个时间点之前的 WAL 记录不再需要用于恢复。
4.1 Checkpointing 的目标
- 限制 WAL 的增长: 允许系统删除旧的、不再需要的日志记录。
- 缩短恢复时间: 恢复不再需要从 WAL 的起点开始,而是从最近的有效检查点开始。
- 确保数据一致性: 在崩溃发生时,检查点提供了一个已知的、一致的恢复起点。
4.2 Checkpointing 的类型
主要有两种检查点类型:
-
阻塞式检查点 (Blocking Checkpointing):
- 在检查点开始时,系统会暂停所有新的事务和数据修改操作。
- 将所有内存中的脏页强制刷新到磁盘。
- 等待所有当前活动的事务完成。
- 完成后,写入一个检查点记录,指示在此之前的 WAL 都可以被截断。
- 优点: 实现相对简单。
- 缺点: 极大地影响系统性能和可用性,因为它会暂停所有操作。不适用于高并发系统。
-
模糊检查点 / 异步检查点 (Fuzzy Checkpointing / Asynchronous Checkpointing):
- 这是现代高性能数据库系统(如 PostgreSQL, RocksDB, LevelDB)普遍采用的方法。
- 它允许事务和数据修改在检查点过程中继续进行。
- 优点: 对系统性能影响最小,可以在后台异步执行。
- 缺点: 实现更为复杂,恢复逻辑也更复杂。
本讲座将重点讨论更复杂但更实用的 模糊检查点。
5. WAL Checkpointing 机制深度解析 (模糊检查点)
模糊检查点的核心在于,它允许读写操作在检查点进行时继续。为了实现这一点,检查点记录需要包含更多信息,并且恢复过程也需要更智能地处理。
5.1 检查点记录 (Checkpoint Record)
一个典型的模糊检查点记录(通常是 CheckpointEndRecord)会包含以下关键信息:
| 字段名称 | 描述 min_lsn_for_recovery (LSN): 这个 LSN 是指所有未持久化到数据文件的数据页中,最早的那个 LSN。简单来说,这是在崩溃恢复时,WAL 需要被扫描并重做操作的起点。在此之前的 WAL 记录可以被安全地截断。
active_transactions(Transaction IDs): 记录在检查点开始时仍处于活动状态的事务 ID 列表。恢复时需要对这些事务进行撤销操作。dirty_pages_at_ckpt_start(Page IDs): 记录在检查点开始时内存中的所有脏页的 ID 列表。这些页面的最新修改可能尚未写入数据文件,需要确保它们在检查点结束前被刷新。checkpoint_lsn(LSN): 记录CheckpointEndRecord本身在 WAL 中的 LSN。
5.2 模糊检查点过程 (Fuzzy Checkpoint Process)
-
启动检查点 (Start Checkpoint):
- 在一个相对空闲的时刻(或定期触发),系统决定启动一个检查点。
- 获取一个轻量级锁或 Latch,以原子方式捕获当前系统中所有活跃事务的列表
TxList_active_at_ckpt_start,以及所有内存中脏页的列表DirtyPages_at_ckpt_start。 - 记录当前的 WAL 尾部 LSN,称之为
checkpoint_start_LSN。 - 将一个
CheckpointStartRecord写入 WAL,标记检查点过程的开始。这个记录通常包含checkpoint_start_LSN。 - 释放轻量级锁,允许正常数据库操作继续。
-
异步刷新脏页 (Asynchronous Flushing of Dirty Pages):
- 在后台线程中,系统开始将
DirtyPages_at_ckpt_start列表中的脏页异步地刷新到它们在数据文件中的实际位置。 - 这个刷新过程是并发的,其他事务仍然可以修改这些页面。如果一个页面在刷新过程中被再次修改,它将保持为脏,并在
PageCache中记录其新的lastLSN。检查点只关心在checkpoint_start_LSN之前的所有修改都已持久化。 - 为了确保写入数据文件时的原子性,通常会使用写时复制(Copy-on-Write)或影子分页(Shadow Paging)技术,或者直接对页进行覆盖写入,但需要保证页内写入的原子性(例如,在页面头部写入一个校验和或 LSN)。
- 在后台线程中,系统开始将
-
完成检查点 (Complete Checkpoint):
- 当所有在
checkpoint_start_LSN时刻为脏的页面都已成功刷新到磁盘时,系统进入完成阶段。 - 确定
min_recovery_LSN:这是恢复所需的最小 LSN。它通常是checkpoint_start_LSN和所有在TxList_active_at_ckpt_start中事务的start_LSN中的最小值。换句话说,任何早于min_recovery_LSN的 WAL 记录都将不再需要。 - 将
CheckpointEndRecord写入 WAL,其中包含min_recovery_LSN、TxList_active_at_ckpt_start等信息,并刷新 WAL。 - 更新持久化的检查点元数据(通常是一个单独的文件或 WAL 的特定位置),指向最新的
CheckpointEndRecord的 LSN。
- 当所有在
-
截断 WAL (Truncate WAL):
- 一旦
CheckpointEndRecord被安全写入并持久化,所有早于min_recovery_LSN的 WAL 记录都可以被安全地删除或归档。这有效地限制了 WAL 的大小。
- 一旦
5.3 崩溃恢复过程 (基于模糊检查点)
有了检查点,恢复过程被大大简化和加速:
-
定位最新检查点:
- 系统启动时,首先读取持久化的检查点元数据,找到最近一个有效的
CheckpointEndRecord的 LSN。 - 从这个 LSN 开始,向后扫描 WAL,查找实际的
CheckpointEndRecord。这个记录包含了min_recovery_LSN和TxList_active_at_ckpt_start。 - 如果找不到任何有效的
CheckpointEndRecord(例如,系统在启动后从未成功完成过检查点),则需要从 WAL 的最开始进行恢复。
- 系统启动时,首先读取持久化的检查点元数据,找到最近一个有效的
-
重做阶段 (Redo Phase):
- 从
CheckpointEndRecord中指定的min_recovery_LSN开始,向前扫描 WAL 直到 WAL 的末尾。 - 维护一个 活动事务表 (Active Transaction Table, ATT) 和一个 脏页表 (Dirty Page Table, DPT)。
- 对于每个日志记录:
- 如果它是事务开始记录,将其事务 ID 添加到 ATT。
- 如果它是页面更新记录,将受影响的页面 ID 添加到 DPT,并更新其
rec_LSN(记录中包含的 LSN)。 - 如果它是事务提交/中止记录,将其事务 ID 从 ATT 中移除。
- 重做逻辑: 只有当日志记录的 LSN 大于或等于数据页上记录的 LSN(或者数据页不在 DPT 中)时,才应用该日志记录的修改。这确保了我们不会重做已经存在于数据页上的旧修改。
- 将所有日志记录的修改应用到内存中的页面上(或者直接写入磁盘,但通常是先在内存中构建状态)。
- 从
-
撤销阶段 (Undo Phase):
- 重做阶段结束后,ATT 中仍然包含在崩溃时未提交的事务。
- 对于 ATT 中的每个事务,从 WAL 中找到其所有相关的日志记录。
- 以相反的顺序应用这些日志记录的撤销操作(使用日志记录中的旧值或反向操作),将数据恢复到这些事务开始之前的状态。
- 撤销完成后,将事务 ID 从 ATT 中移除。
一旦 Redo 和 Undo 阶段都完成,数据库就恢复到了一个崩溃一致的状态,所有已提交的事务都已持久化,所有未提交的事务都已回滚。
6. 在 Go 中实现一个高性能、可故障恢复的崩溃一致性逻辑
现在,我们将这些理论概念转化为 Go 语言的实际代码结构。为了简洁和聚焦核心概念,我们将构建一个简化的、概念性的存储引擎骨架。
6.1 核心组件设计
我们将定义以下核心组件:
LogRecord: 定义日志记录的接口和具体实现。WAL: 负责日志记录的追加、刷新和读取。Page和PageCache: 管理内存中的数据页,追踪脏页。StorageEngine: 协调WAL和PageCache,提供数据操作接口。CheckpointManager: 负责触发和执行检查点。RecoveryManager: 负责系统启动时的崩溃恢复。
6.2 数据结构和接口定义
package main
import (
"bytes"
"encoding/binary"
"encoding/gob"
"fmt"
"io"
"os"
"path/filepath"
"sync"
"time"
)
// LSN (Log Sequence Number)
type LSN uint64
// PageID identifies a data page.
type PageID uint64
// TransactionID identifies a transaction.
type TransactionID uint64
// LogRecordType defines the type of a log record.
type LogRecordType uint8
const (
TypeNoOp LogRecordType = iota // No operation
TypePageUpdate // Update a data page
TypeTransactionStart // Transaction starts
TypeTransactionCommit // Transaction commits
TypeTransactionAbort // Transaction aborts
TypeCheckpointStart // Checkpoint starts
TypeCheckpointEnd // Checkpoint ends
)
// LogRecord interface represents a generic log record.
type LogRecord interface {
GetType() LogRecordType
GetLSN() LSN
SetLSN(lsn LSN)
Serialize() ([]byte, error)
Deserialize([]byte) error
String() string
}
// BaseLogRecord provides common fields for log records.
type BaseLogRecord struct {
LSNVal LSN
TypeVal LogRecordType
}
func (b *BaseLogRecord) GetType() LogRecordType { return b.TypeVal }
func (b *BaseLogRecord) GetLSN() LSN { return b.LSNVal }
func (b *BaseLogRecord) SetLSN(lsn LSN) { b.LSNVal = lsn }
// PageUpdateRecord represents an update to a data page.
type PageUpdateRecord struct {
BaseLogRecord
PageID PageID
Offset uint16
Length uint11 // Max 2048 bytes, assuming 4KB page size
OldData []byte
NewData []byte
TxID TransactionID
}
func NewPageUpdateRecord(pageID PageID, offset uint16, oldData, newData []byte, txID TransactionID) *PageUpdateRecord {
return &PageUpdateRecord{
BaseLogRecord: BaseLogRecord{TypeVal: TypePageUpdate},
PageID: pageID,
Offset: offset,
Length: uint11(len(newData)), // Assuming len(oldData) == len(newData) for simplicity
OldData: oldData,
NewData: newData,
TxID: txID,
}
}
func (r *PageUpdateRecord) Serialize() ([]byte, error) {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err := enc.Encode(r)
return buf.Bytes(), err
}
func (r *PageUpdateRecord) Deserialize(data []byte) error {
dec := gob.NewDecoder(bytes.NewReader(data))
return dec.Decode(r)
}
func (r *PageUpdateRecord) String() string {
return fmt.Sprintf("PageUpdate{LSN:%d, PageID:%d, Offset:%d, Len:%d, TxID:%d}", r.LSNVal, r.PageID, r.Offset, r.Length, r.TxID)
}
// TransactionStartRecord marks the beginning of a transaction.
type TransactionStartRecord struct {
BaseLogRecord
TxID TransactionID
}
func NewTransactionStartRecord(txID TransactionID) *TransactionStartRecord {
return &TransactionStartRecord{
BaseLogRecord: BaseLogRecord{TypeVal: TypeTransactionStart},
TxID: txID,
}
}
func (r *TransactionStartRecord) Serialize() ([]byte, error) {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err := enc.Encode(r)
return buf.Bytes(), err
}
func (r *TransactionStartRecord) Deserialize(data []byte) error {
dec := gob.NewDecoder(bytes.NewReader(data))
return dec.Decode(r)
}
func (r *TransactionStartRecord) String() string {
return fmt.Sprintf("TxStart{LSN:%d, TxID:%d}", r.LSNVal, r.TxID)
}
// TransactionCommitRecord marks the successful end of a transaction.
type TransactionCommitRecord struct {
BaseLogRecord
TxID TransactionID
}
func NewTransactionCommitRecord(txID TransactionID) *TransactionCommitRecord {
return &TransactionCommitRecord{
BaseLogRecord: BaseLogRecord{TypeVal: TypeTransactionCommit},
TxID: txID,
}
}
func (r *TransactionCommitRecord) Serialize() ([]byte, error) {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err := enc.Encode(r)
return buf.Bytes(), err
}
func (r *TransactionCommitRecord) Deserialize(data []byte) error {
dec := gob.NewDecoder(bytes.NewReader(data))
return dec.Decode(r)
}
func (r *TransactionCommitRecord) String() string {
return fmt.Sprintf("TxCommit{LSN:%d, TxID:%d}", r.LSNVal, r.TxID)
}
// TransactionAbortRecord marks the failed end of a transaction.
type TransactionAbortRecord struct {
BaseLogRecord
TxID TransactionID
}
func NewTransactionAbortRecord(txID TransactionID) *TransactionAbortRecord {
return &TransactionAbortRecord{
BaseLogRecord: BaseLogRecord{TypeVal: TypeTransactionAbort},
TxID: txID,
}
}
func (r *TransactionAbortRecord) Serialize() ([]byte, error) {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err := enc.Encode(r)
return buf.Bytes(), err
}
func (r *TransactionAbortRecord) Deserialize(data []byte) error {
dec := gob.NewDecoder(bytes.NewReader(data))
return dec.Decode(r)
}
func (r *TransactionAbortRecord) String() string {
return fmt.Sprintf("TxAbort{LSN:%d, TxID:%d}", r.LSNVal, r.TxID)
}
// CheckpointStartRecord marks the beginning of a checkpoint process.
type CheckpointStartRecord struct {
BaseLogRecord
CheckpointStartLSN LSN // The LSN at which the checkpoint process began
}
func NewCheckpointStartRecord(ckptStartLSN LSN) *CheckpointStartRecord {
return &CheckpointStartRecord{
BaseLogRecord: BaseLogRecord{TypeVal: TypeCheckpointStart},
CheckpointStartLSN: ckptStartLSN,
}
}
func (r *CheckpointStartRecord) Serialize() ([]byte, error) {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err := enc.Encode(r)
return buf.Bytes(), err
}
func (r *CheckpointStartRecord) Deserialize(data []byte) error {
dec := gob.NewDecoder(bytes.NewReader(data))
return dec.Decode(r)
}
func (r *CheckpointStartRecord) String() string {
return fmt.Sprintf("CkptStart{LSN:%d, CkptStartLSN:%d}", r.LSNVal, r.CheckpointStartLSN)
}
// CheckpointEndRecord marks the end of a checkpoint process.
// It contains critical information for recovery.
type CheckpointEndRecord struct {
BaseLogRecord
MinRecoveryLSN LSN // The LSN from which recovery should start
ActiveTransactions []TransactionID // List of transactions active at checkpoint start
DirtyPagesAtCkptStart []PageID // List of dirty pages at checkpoint start (for analysis)
}
func NewCheckpointEndRecord(minLSN LSN, activeTx []TransactionID, dirtyPages []PageID) *CheckpointEndRecord {
return &CheckpointEndRecord{
BaseLogRecord: BaseLogRecord{TypeVal: TypeCheckpointEnd},
MinRecoveryLSN: minLSN,
ActiveTransactions: activeTx,
DirtyPagesAtCkptStart: dirtyPages,
}
}
func (r *CheckpointEndRecord) Serialize() ([]byte, error) {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err := enc.Encode(r)
return buf.Bytes(), err
}
func (r *CheckpointEndRecord) Deserialize(data []byte) error {
dec := gob.NewDecoder(bytes.NewReader(data))
return dec.Decode(r)
}
func (r *CheckpointEndRecord) String() string {
return fmt.Sprintf("CkptEnd{LSN:%d, MinRecoveryLSN:%d, ActiveTx:%v, DirtyPages:%v}",
r.LSNVal, r.MinRecoveryLSN, r.ActiveTransactions, r.DirtyPagesAtCkptStart)
}
// Register gob types for serialization/deserialization
func init() {
gob.Register(&PageUpdateRecord{})
gob.Register(&TransactionStartRecord{})
gob.Register(&TransactionCommitRecord{})
gob.Register(&TransactionAbortRecord{})
gob.Register(&CheckpointStartRecord{})
gob.Register(&CheckpointEndRecord{})
}
// Helper to deserialize a log record from bytes
func DeserializeLogRecord(data []byte) (LogRecord, error) {
var base BaseLogRecord
buf := bytes.NewReader(data)
dec := gob.NewDecoder(buf)
// Temporarily decode BaseLogRecord to get the type
if err := dec.Decode(&base); err != nil {
return nil, fmt.Errorf("failed to decode base log record: %w", err)
}
// Reset reader to decode the full record
buf.Seek(0, io.SeekStart)
dec = gob.NewDecoder(buf)
var record LogRecord
switch base.TypeVal {
case TypePageUpdate:
record = &PageUpdateRecord{}
case TypeTransactionStart:
record = &TransactionStartRecord{}
case TypeTransactionCommit:
record = &TransactionCommitRecord{}
case TypeTransactionAbort:
record = &TransactionAbortRecord{}
case TypeCheckpointStart:
record = &CheckpointStartRecord{}
case TypeCheckpointEnd:
record = &CheckpointEndRecord{}
default:
return nil, fmt.Errorf("unknown log record type: %d", base.TypeVal)
}
if err := dec.Decode(record); err != nil {
return nil, fmt.Errorf("failed to decode concrete log record: %w", err)
}
return record, nil
}
6.3 WAL (Write-Ahead Log) 实现
WAL 将日志记录写入文件。为了高性能,它会缓冲写入并批量刷新。为了管理文件大小,它会使用分段文件(segments)。
const (
WALSegmentSize = 16 * 1024 * 1024 // 16MB per segment
WALFileNamePattern = "wal_%016x.log"
CheckpointMetaFile = "checkpoint.meta"
)
// WALSegment represents a single WAL file segment.
type WALSegment struct {
file *os.File
startLSN LSN
endLSN LSN
size int64
mu sync.Mutex // Protects file operations on this segment
}
func NewWALSegment(dir string, startLSN LSN) (*WALSegment, error) {
filename := filepath.Join(dir, fmt.Sprintf(WALFileNamePattern, startLSN))
file, err := os.OpenFile(filename, os.O_CREATE|os.O_RDWR|os.O_APPEND, 0644)
if err != nil {
return nil, fmt.Errorf("failed to open WAL segment file %s: %w", filename, err)
}
info, err := file.Stat()
if err != nil {
file.Close()
return nil, fmt.Errorf("failed to stat WAL segment file %s: %w", filename, err)
}
return &WALSegment{
file: file,
startLSN: startLSN,
size: info.Size(),
}, nil
}
func (ws *WALSegment) Append(data []byte) (int, error) {
ws.mu.Lock()
defer ws.mu.Unlock()
n, err := ws.file.Write(data)
if err == nil {
ws.size += int64(n)
}
return n, err
}
func (ws *WALSegment) Sync() error {
ws.mu.Lock()
defer ws.mu.Unlock()
return ws.file.Sync()
}
func (ws *WALSegment) Close() error {
ws.mu.Lock()
defer ws.mu.Unlock()
return ws.file.Close()
}
// WAL manages multiple WAL segments.
type WAL struct {
dir string
mu sync.RWMutex
segments []*WALSegment // Ordered list of segments
currentSeg *WALSegment
nextLSN LSN // The next LSN to be assigned
buffer bytes.Buffer // In-memory buffer for batching writes
bufferSize int // Current size of the buffer
maxBufferSize int // Max size before flush
}
func NewWAL(dir string, maxBufferSize int) (*WAL, error) {
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create WAL directory %s: %w", dir, err)
}
wal := &WAL{
dir: dir,
nextLSN: 1, // Start LSN from 1
segments: make([]*WALSegment, 0),
maxBufferSize: maxBufferSize,
}
// Load existing segments or create a new one
err := wal.loadSegments()
if err != nil {
return nil, err
}
if len(wal.segments) == 0 {
if err := wal.createNewSegment(); err != nil {
return nil, err
}
} else {
wal.currentSeg = wal.segments[len(wal.segments)-1]
// Determine nextLSN by scanning the last segment
lastLSN, err := wal.determineLastLSN()
if err != nil {
return nil, err
}
wal.nextLSN = lastLSN + 1
}
return wal, nil
}
func (w *WAL) loadSegments() error {
entries, err := os.ReadDir(w.dir)
if err != nil {
return fmt.Errorf("failed to read WAL directory: %w", err)
}
var segmentFiles []string
for _, entry := range entries {
if !entry.IsDir() && filepath.Ext(entry.Name()) == ".log" && len(entry.Name()) == len(fmt.Sprintf(WALFileNamePattern, 0)) {
segmentFiles = append(segmentFiles, entry.Name())
}
}
// Sort segment files by LSN
// Assuming LSN is encoded in the filename
sort.Slice(segmentFiles, func(i, j int) bool {
var lsnI, lsnJ uint64
fmt.Sscanf(segmentFiles[i], WALFileNamePattern, &lsnI)
fmt.Sscanf(segmentFiles[j], WALFileNamePattern, &lsnJ)
return lsnI < lsnJ
})
for _, filename := range segmentFiles {
var startLSN uint64
fmt.Sscanf(filename, WALFileNamePattern, &startLSN)
seg, err := NewWALSegment(w.dir, LSN(startLSN))
if err != nil {
return fmt.Errorf("failed to load WAL segment %s: %w", filename, err)
}
w.segments = append(w.segments, seg)
}
return nil
}
func (w *WAL) createNewSegment() error {
if w.currentSeg != nil {
if err := w.currentSeg.Close(); err != nil {
return fmt.Errorf("failed to close previous WAL segment: %w", err)
}
}
newSeg, err := NewWALSegment(w.dir, w.nextLSN) // New segment starts with next available LSN
if err != nil {
return fmt.Errorf("failed to create new WAL segment: %w", err)
}
w.segments = append(w.segments, newSeg)
w.currentSeg = newSeg
return nil
}
// determineLastLSN scans the last segment to find the actual last LSN.
// This is crucial for recovery or restarting a WAL.
func (w *WAL) determineLastLSN() (LSN, error) {
if w.currentSeg == nil {
return 0, nil
}
var lastLSN LSN = 0
err := w.ReadRecords(w.currentSeg.startLSN, func(record LogRecord) bool {
lastLSN = record.GetLSN()
return true // Continue reading
})
if err != nil && err != io.EOF {
return 0, fmt.Errorf("failed to scan last WAL segment to determine last LSN: %w", err)
}
return lastLSN, nil
}
// AppendRecord appends a log record to the WAL.
func (w *WAL) AppendRecord(record LogRecord) (LSN, error) {
w.mu.Lock()
defer w.mu.Unlock()
record.SetLSN(w.nextLSN)
serializedData, err := record.Serialize()
if err != nil {
return 0, fmt.Errorf("failed to serialize log record: %w", err)
}
// Format: [length (4 bytes)][record data]
length := uint32(len(serializedData))
lengthBytes := make([]byte, 4)
binary.BigEndian.PutUint32(lengthBytes, length)
// Append to buffer
w.buffer.Write(lengthBytes)
w.buffer.Write(serializedData)
w.bufferSize += 4 + len(serializedData)
currentLSN := w.nextLSN
w.nextLSN++
// Check if buffer needs flushing or segment needs rotation
if w.bufferSize >= w.maxBufferSize || w.currentSeg.size+int64(w.bufferSize) >= WALSegmentSize {
return currentLSN, w.flushBufferAndRotate()
}
return currentLSN, nil
}
// Flush ensures all buffered records are written to disk.
func (w *WAL) Flush() error {
w.mu.Lock()
defer w.mu.Unlock()
return w.flushBufferAndRotate()
}
// flushBufferAndRotate writes the buffer to the current segment and rotates if necessary.
// Must be called with w.mu held.
func (w *WAL) flushBufferAndRotate() error {
if w.bufferSize == 0 {
return nil
}
data := w.buffer.Bytes()
n, err := w.currentSeg.Append(data)
if err != nil {
return fmt.Errorf("failed to write WAL buffer to segment: %w", err)
}
if n != w.bufferSize {
return fmt.Errorf("partial write to WAL segment: wrote %d, expected %d", n, w.bufferSize)
}
if err := w.currentSeg.Sync(); err != nil {
return fmt.Errorf("failed to sync WAL segment: %w", err)
}
w.buffer.Reset()
w.bufferSize = 0
// Rotate segment if it's full
if w.currentSeg.size >= WALSegmentSize {
return w.createNewSegment()
}
return nil
}
// ReadRecords reads log records from a starting LSN.
// The callback function returns true to continue reading, false to stop.
func (w *WAL) ReadRecords(startLSN LSN, callback func(LogRecord) bool) error {
w.mu.RLock()
defer w.mu.RUnlock()
// Find the segment containing startLSN
startIndex := -1
for i, seg := range w.segments {
if startLSN >= seg.startLSN {
startIndex = i
} else {
break
}
}
if startIndex == -1 {
return fmt.Errorf("start LSN %d is too old or invalid", startLSN)
}
for i := startIndex; i < len(w.segments); i++ {
seg := w.segments[i]
currentOffset := int64(0)
if startLSN > seg.startLSN {
// Need to find the exact offset for startLSN within this segment
// This is a simplified approach, a real WAL would use an index for faster lookup
tempFile, err := os.Open(filepath.Join(w.dir, fmt.Sprintf(WALFileNamePattern, seg.startLSN)))
if err != nil {
return fmt.Errorf("failed to open segment for scanning: %w", err)
}
defer tempFile.Close()
reader := bufio.NewReader(tempFile)
var currentRecordLSN LSN
for {
lengthBytes := make([]byte, 4)
_, err := io.ReadFull(reader, lengthBytes)
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("failed to read record length: %w", err)
}
recordLen := binary.BigEndian.Uint32(lengthBytes)
recordData := make([]byte, recordLen)
_, err = io.ReadFull(reader, recordData)
if err != nil {
return fmt.Errorf("failed to read record data: %w", err)
}
record, err := DeserializeLogRecord(recordData)
if err != nil {
return fmt.Errorf("failed to deserialize record during scan: %w", err)
}
currentRecordLSN = record.GetLSN()
if currentRecordLSN >= startLSN {
break // Found the starting record
}
currentOffset += 4 + int64(recordLen) // Move offset past this record
}
// Now currentOffset points to the start of startLSN record
if _, err := tempFile.Seek(currentOffset, io.SeekStart); err != nil {
return fmt.Errorf("failed to seek to start LSN offset: %w", err)
}
reader = bufio.NewReader(tempFile) // Re-create reader after seek
}
// Read from the current segment file from currentOffset
// For simplicity, we'll re-open the file. In a real system, would use a reader.
file, err := os.Open(filepath.Join(w.dir, fmt.Sprintf(WALFileNamePattern, seg.startLSN)))
if err != nil {
return fmt.Errorf("failed to open WAL segment %s for reading: %w", seg.file.Name(), err)
}
defer file.Close()
if _, err := file.Seek(currentOffset, io.SeekStart); err != nil {
return fmt.Errorf("failed to seek in WAL segment %s: %w", seg.file.Name(), err)
}
reader := bufio.NewReader(file)
for {
lengthBytes := make([]byte, 4)
_, err := io.ReadFull(reader, lengthBytes)
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("failed to read record length from WAL: %w", err)
}
recordLen := binary.BigEndian.Uint32(lengthBytes)
recordData := make([]byte, recordLen)
_, err = io.ReadFull(reader, recordData)
if err != nil {
return fmt.Errorf("failed to read record data from WAL: %w", err)
}
record, err := DeserializeLogRecord(recordData)
if err != nil {
return fmt.Errorf("failed to deserialize log record from WAL: %w", err)
}
if !callback(record) {
return nil // Callback requested to stop
}
}
}
return nil
}
// TruncateBefore removes all WAL segments whose end LSN is less than minLSN.
func (w *WAL) TruncateBefore(minLSN LSN) error {
w.mu.Lock()
defer w.mu.Unlock()
var newSegments []*WALSegment
for _, seg := range w.segments {
// If segment's start LSN is before minLSN, it might be truncated.
// A full segment is only truncated if its *entirety* is before minLSN.
// For simplicity, we assume a segment is eligible for deletion if its startLSN < minLSN
// and it's not the current active segment. A more precise check would involve
// parsing records in the segment to find its effective "end LSN".
if seg.startLSN < minLSN && seg != w.currentSeg {
fmt.Printf("Truncating WAL segment %s (start LSN: %d < min LSN: %d)n", seg.file.Name(), seg.startLSN, minLSN)
if err := seg.Close(); err != nil {
fmt.Printf("Warning: failed to close WAL segment %s during truncation: %vn", seg.file.Name(), err)
}
if err := os.Remove(filepath.Join(w.dir, fmt.Sprintf(WALFileNamePattern, seg.startLSN))); err != nil {
fmt.Printf("Warning: failed to remove WAL segment file %s: %vn", seg.file.Name(), err)
}
} else {
newSegments = append(newSegments, seg)
}
}
w.segments = newSegments
return nil
}
func (w *WAL) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
if err := w.flushBufferAndRotate(); err != nil {
return fmt.Errorf("failed to flush WAL buffer on close: %w", err)
}
for _, seg := range w.segments {
if err := seg.Close(); err != nil {
return fmt.Errorf("failed to close WAL segment %s on close: %w", seg.file.Name(), err)
}
}
return nil
}
注意: WAL.ReadRecords 方法在查找 startLSN 时,如果 startLSN 不在某个分段的开头,需要从分段开头扫描以找到确切偏移。这是一个简化的实现,实际生产系统会维护一个 WAL 索引来快速定位 LSN 在文件中的物理偏移。
6.4 Page 和 PageCache
PageCache 负责管理内存中的数据页,并跟踪哪些页是“脏”的(即在内存中被修改但尚未写入磁盘)。
const PageSize = 4096 // 4KB page size
// Page represents a single data page in memory.
type Page struct {
ID PageID
Data []byte
IsDirty bool // True if page has been modified in memory
LastLSN LSN // LSN of the last log record that modified this page
mu sync.RWMutex // Protects Data and IsDirty
}
func NewPage(id PageID) *Page {
return &Page{
ID: id,
Data: make([]byte, PageSize),
}
}
// PageCache manages in-memory data pages (buffer pool).
type PageCache struct {
mu sync.RWMutex
pages map[PageID]*Page // In-memory pages
maxPages int // Max number of pages in cache
dataDir string // Directory for data files
}
func NewPageCache(dataDir string, maxPages int) (*PageCache, error) {
if err := os.MkdirAll(dataDir, 0755); err != nil {
return nil, fmt.Errorf("failed to create data directory %s: %w", dataDir, err)
}
return &PageCache{
pages: make(map[PageID]*Page),
maxPages: maxPages,
dataDir: dataDir,
}, nil
}
// GetPage fetches a page from cache or loads it from disk.
func (pc *PageCache) GetPage(id PageID) (*Page, error) {
pc.mu.RLock()
page, ok := pc.pages[id]
pc.mu.RUnlock()
if ok {
return page, nil
}
// Page not in cache, load from disk
pc.mu.Lock()
defer pc.mu.Unlock()
// Check again, might have been loaded by another goroutine
if page, ok = pc.pages[id]; ok {
return page, nil
}
// Evict if cache is full (simplified LRU/random eviction)
if len(pc.pages) >= pc.maxPages {
pc.evictPage() // Placeholder for eviction logic
}
// Load from disk
page = NewPage(id)
filename := filepath.Join(pc.dataDir, fmt.Sprintf("data_%d.page", id))
file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0644)
if err != nil {
return nil, fmt.Errorf("failed to open data file %s: %w", filename, err)
}
defer file.Close()
n, err := file.Read(page.Data)
if err != nil && err != io.EOF {
return nil, fmt.Errorf("failed to read data page %d: %w", id, err)
}
if n < PageSize {
// Page might be new or partially written, pad with zeros
for i := n; i < PageSize; i++ {
page.Data[i] = 0
}
}
// A real system would also read the last LSN from the page header on disk
pc.pages[id] = page
return page, nil
}
// MarkDirty marks a page as dirty and updates its LastLSN.
func (pc *PageCache) MarkDirty(page *Page, lsn LSN) {
page.mu.Lock()
defer page.mu.Unlock()
page.IsDirty = true
page.LastLSN = lsn
}
// FlushPage writes a dirty page to disk.
func (pc *PageCache) FlushPage(page *Page) error {
page.mu.Lock()
defer page.mu.Unlock()
if !page.IsDirty {
return nil // Not dirty, no need to flush
}
filename := filepath.Join(pc.dataDir, fmt.Sprintf("data_%d.page", page.ID))
file, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE, 0644)
if err != nil {
return fmt.Errorf("failed to open data file %s for writing: %w", filename, err)
}
defer file.Close()
if _, err := file.Write(page.Data); err != nil {
return fmt.Errorf("failed to write data page %d: %w", page.ID, err)
}
if err := file.Sync(); err != nil { // Ensure data is on disk
return fmt.Errorf("failed to sync data page %d: %w", page.ID, err)
}
page.IsDirty = false
return nil
}
// GetDirtyPages returns a snapshot of currently dirty pages.
func (pc *PageCache) GetDirtyPages() map[PageID]*Page {
pc.mu.RLock()
defer pc.mu.RUnlock()
dirtyPages := make(map[PageID]*Page)
for id, page := range pc.pages {
page.mu.RLock()
if page.IsDirty {
dirtyPages[id] = page
}
page.mu.RUnlock()
}
return dirtyPages
}
// evictPage is a placeholder for a page eviction policy (e.g., LRU).
// For simplicity, we just delete a random page.
func (pc *PageCache) evictPage() {
for id, page := range pc.pages {
if page.IsDirty {
// Cannot evict dirty page without flushing it first (or write-ahead logging)
// For a real system, would need to flush or handle differently.
// Here, we'd ideally flush it and then evict.
if err := pc.FlushPage(page); err != nil {
fmt.Printf("Warning: failed to flush page %d during eviction: %vn", id, err)
continue // Try another page
}
}
delete(pc.pages, id)
fmt.Printf("Evicted page %dn", id)
return // Evict one page for now
}
}
// Close flushes all dirty pages before shutting down.
func (pc *PageCache) Close() error {
pc.mu.Lock()
defer pc.mu.Unlock()
var err error
for _, page := range pc.pages {
if page.IsDirty {
if flushErr := pc.FlushPage(page); flushErr != nil {
err = fmt.Errorf("%w; failed to flush page %d: %v", err, page.ID, flushErr)
}
}
}
return err
}
6.5 TransactionManager (简化)
事务管理器负责分配事务 ID,并跟踪活动事务。
// TransactionManager manages active transactions.
type TransactionManager struct {
mu sync.Mutex
nextTxID TransactionID
activeTx map[TransactionID]LSN // Map of TxID to its start LSN
}
func NewTransactionManager() *TransactionManager {
return &TransactionManager{
nextTxID: 1,
activeTx: make(map[TransactionID]LSN),
}
}
func (tm *TransactionManager) StartTransaction(startLSN LSN) TransactionID {
tm.mu.Lock()
defer tm.mu.Unlock()
txID := tm.nextTxID
tm.nextTxID++
tm.activeTx[txID] = startLSN
return txID
}
func (tm *TransactionManager) EndTransaction(txID TransactionID) {
tm.mu.Lock()
defer tm.mu.Unlock()
delete(tm.activeTx, txID)
}
func (tm *TransactionManager) GetActiveTransactions() map[TransactionID]LSN {
tm.mu.Lock() // Need lock to ensure consistent snapshot
defer tm.mu.Unlock()
snapshot := make(map[TransactionID]LSN, len(tm.activeTx))
for txID, lsn := range tm.activeTx {
snapshot[txID] = lsn
}
return snapshot
}
6.6 CheckpointManager
检查点管理器协调检查点过程。
// CheckpointManager handles creating checkpoints.
type CheckpointManager struct {
wal *WAL
pageCache *PageCache
txManager *TransactionManager
lastCheckpointLSN LSN // LSN of the last successfully completed CheckpointEndRecord
metaFile string // File to store lastCheckpointLSN
mu sync.Mutex // Protects checkpoint initiation
}
func NewCheckpointManager(wal *WAL, pc *PageCache, tm *TransactionManager, baseDir string) (*CheckpointManager, error) {
cm := &CheckpointManager{
wal: wal,
pageCache: pc,
txManager: tm,
metaFile: filepath.Join(baseDir, CheckpointMetaFile),
}
// Load last checkpoint LSN from meta file
if err := cm.loadLastCheckpointLSN(); err != nil && !os.IsNotExist(err) {
return nil, fmt.Errorf("failed to load last checkpoint LSN: %w", err)
}
return cm, nil
}
func (cm *CheckpointManager) loadLastCheckpointLSN() error {
data, err := os.ReadFile(cm.metaFile)
if err != nil {
return err // os.IsNotExist is handled by the caller
}
if len(data) != 8 { // LSN is uint64, 8 bytes
return fmt.Errorf("invalid checkpoint meta file format")
}
cm.lastCheckpointLSN = LSN(binary.BigEndian.Uint64(data))
return nil
}
func (cm *CheckpointManager) saveLastCheckpointLSN(lsn LSN) error {
data := make([]byte, 8)
binary.BigEndian.PutUint64(data, uint64(lsn))
return os.WriteFile(cm.metaFile, data, 0644)
}
// PerformCheckpoint executes a fuzzy checkpoint.
func (cm *CheckpointManager) PerformCheckpoint() error {
cm.mu.Lock()
defer cm.mu.Unlock()
fmt.Println("Starting checkpoint...")
// 1. Write CheckpointStartRecord to WAL
checkpointStartLSN, err := cm.wal.AppendRecord(NewCheckpointStartRecord(cm.wal.nextLSN))
if err != nil {
return fmt.Errorf("failed to write CheckpointStartRecord: %w", err)
}
if err := cm.wal.Flush(); err != nil { // Ensure checkpoint start is durable
return fmt.Errorf("failed to flush WAL after CheckpointStartRecord: %w", err)
}
// 2. Capture snapshot of dirty pages and active transactions
dirtyPagesAtCkptStart := cm.pageCache.GetDirtyPages() // Snapshot of dirty pages
activeTxAtCkptStart := cm.txManager.GetActiveTransactions() // Snapshot of active transactions
// 3. Asynchronously flush dirty pages.
// In a real system, this would be done by a background flusher.
// For this example, we'll flush them synchronously but conceptually they are 'at checkpoint start time'.
var pagesToFlush []PageID
for pageID, page := range dirtyPagesAtCkptStart {
pagesToFlush = append(pagesToFlush, pageID)
if err := cm.pageCache.FlushPage(page); err != nil {
fmt.Printf("Warning: failed to flush page %d during checkpoint: %vn", pageID, err)
// A real system would need to track unfleshed pages and retry,
// or ensure recovery can handle them. For simplicity, we proceed.
}
}
fmt.Printf("Flushed %d dirty pages for checkpoint.n", len(pagesToFlush))
// 4. Determine MinRecoveryLSN
minRecoveryLSN := checkpointStartLSN // Start with the LSN of CheckpointStartRecord
for _, txStartLSN := range activeTxAtCkptStart {
if txStartLSN < minRecoveryLSN {
minRecoveryLSN = txStartLSN
}
}
// For any page that was dirty at checkpoint start, its LastLSN might be older than minRecoveryLSN.
// We need to ensure that the oldest LSN of any page that was dirty *at the time the checkpoint started*
// is covered by minRecoveryLSN. A more robust system tracks this more carefully.
// For simplicity, we assume pages are flushed and their changes are covered.
// 5. Write CheckpointEndRecord to WAL
activeTxIDs := make([]TransactionID, 0, len(activeTxAtCkptStart))
for txID := range activeTxAtCkptStart {
activeTxIDs = append(activeTxIDs, txID)
}
ckptEndRecord := NewCheckpointEndRecord(minRecoveryLSN, activeTxIDs, pagesToFlush)
ckptEndLSN, err := cm.wal.AppendRecord(ckptEndRecord)
if err != nil {
return fmt.Errorf("failed to write CheckpointEndRecord: %w", err)
}
if err := cm.wal.Flush(); err != nil { // Crucial: ensure checkpoint end is durable
return fmt.Errorf("failed to flush WAL after CheckpointEndRecord: %w", err)
}
// 6. Persist last checkpoint LSN
if err := cm.saveLastCheckpointLSN(ckptEndLSN); err != nil {
return fmt.Errorf("failed to save last checkpoint LSN to meta file: %w", err)
}
cm.lastCheckpointLSN = ckptEndLSN
// 7. Truncate WAL
if err := cm.wal.TruncateBefore(minRecoveryLSN); err != nil {
fmt.Printf("Warning: failed to truncate WAL before LSN %d: %vn", minRecoveryLSN, err)
}
fmt.Printf("Checkpoint completed. LastCheckpointLSN: %d, MinRecoveryLSN: %dn", ckptEndLSN, minRecoveryLSN)
return nil
}
6.7 RecoveryManager
恢复管理器在系统启动时执行崩溃恢复。
import "bufio" // For WAL ReadRecords
// RecoveryManager handles crash recovery using WAL and checkpoint info.
type RecoveryManager struct {
wal *WAL
pageCache *PageCache
txManager *TransactionManager
checkpointManager *CheckpointManager
}
func NewRecoveryManager(wal *WAL, pc *PageCache, tm *TransactionManager, cm *CheckpointManager) *RecoveryManager {
return &RecoveryManager{
wal: wal,
pageCache: pc,
txManager: tm,
checkpointManager: cm,
}
}
// Recover performs crash recovery.
func (rm *RecoveryManager) Recover() error {
fmt.Println("Starting recovery...")
// 1. Locate the latest CheckpointEndRecord
var lastCkptEndRecord *CheckpointEndRecord
var lastCkptEndLSN LSN
// Scan WAL backwards from the end to find the latest CheckpointEndRecord
// For simplicity, we'll scan forward from the beginning of the WAL.
// A real system would efficiently locate the latest checkpoint metadata.
// Here, we rely on checkpointManager.lastCheckpointLSN.
// If lastCheckpointLSN is 0, no checkpoint was ever completed, need to scan from beginning of WAL (LSN 1).
scanStartLSN := LSN(1)
if rm.checkpointManager.lastCheckpointLSN > 0 {
// Found a previous checkpoint, read it to get minRecoveryLSN
// This requires reading a specific record by LSN, which our simple WAL.ReadRecords doesn't do directly efficiently.
// So we iterate until we find it.
err := rm.wal.ReadRecords(rm.checkpointManager.lastCheckpointLSN, func(record LogRecord) bool {
if ckptEnd, ok := record.(*CheckpointEndRecord); ok {
lastCkptEndRecord = ckptEnd
lastCkptEndLSN = ckptEnd.GetLSN()
return false // Stop after finding
}
return true
})
if err != nil {
fmt.Printf("Warning: Failed to read last checkpoint record at LSN %d, will scan from beginning of WAL: %vn", rm.checkpointManager.lastCheckpointLSN, err)
lastCkptEndRecord = nil // Reset, proceed as if no checkpoint found
}
}
if lastCkptEndRecord != nil {
scanStartLSN = lastCkptEndRecord.MinRecoveryLSN
fmt.Printf("Found latest checkpoint at LSN %d. Recovery will start from MinRecoveryLSN: %dn", lastCkptEndLSN, scanStartLSN)
} else {
fmt.Println("No valid checkpoint found. Recovery will scan WAL from LSN 1.")
}
// Data structures for recovery
redoSet := make(map[PageID]LSN) // Pages that need redo, and their LSN
activeTransactions := make(map[TransactionID]LSN) // TxID -> startLSN
// 2. Redo Phase: Apply all changes from minRecoveryLSN to the end of WAL
fmt.Printf("Starting Redo phase from LSN %d...n", scanStartLSN)
err := rm.wal.ReadRecords(scanStartLSN, func(record LogRecord) bool {
switch r := record.(type) {
case *TransactionStartRecord:
activeTransactions[r.TxID] = r.GetLSN()
case *TransactionCommitRecord, *TransactionAbortRecord:
delete(activeTransactions, r.TxID)
case *PageUpdateRecord:
// Apply update to page cache's view of the page
page, err := rm.pageCache.GetPage(r.PageID)
if err != nil {
fmt.Printf("Error getting page %d for redo: %vn", r.PageID, err)
return false // Abort recovery
}
page.mu.Lock()
// Only redo if the log record is newer than what's currently on the page
// (or what's known to be on disk if page.LastLSN represents disk state)
// For simplicity, we assume page.LastLSN tracks the latest applied LSN from WAL.
if r.GetLSN() > page.LastLSN {
// Apply new data
copy(page.Data[r.Offset:r.Offset+uint16(r.Length)], r.NewData)
page.LastLSN = r.GetLSN()
page.IsDirty = true // Mark as dirty, will be flushed later if needed
redoSet[r.PageID] = r.GetLSN() // Track pages updated during redo
}
page.mu.Unlock()
case *CheckpointEndRecord:
// If we encounter another CheckpointEndRecord, it means the earlier one was valid.
// This scenario is for robust recovery in case the initial search was incomplete.
// For this simplified example, we'd generally use the lastCkptEndRecord found initially.
// If a newer one is found, we should update our recovery start if needed.
// For now, we assume initial search is correct.
}
return true // Continue reading
})
if err != nil && err != io.EOF {
return fmt.Errorf("error during Redo phase: %w", err)
}
fmt.Printf("Redo phase completed. Active transactions: %vn", activeTransactions)
// 3. Undo Phase: Rollback uncommitted transactions
fmt.Println("Starting Undo phase...")
for txID, txStartLSN := range activeTransactions {
fmt.Printf("Undoing transaction %d (started at LSN %d)...n", txID, txStartLSN)
// Scan WAL for records belonging to this transaction in reverse chronological order
// Our WAL.ReadRecords reads forward, so we collect relevant records first, then reverse.
var txRecords []LogRecord
err := rm.wal.ReadRecords(txStartLSN, func(record LogRecord) bool {
if updateRec, ok := record.(*PageUpdateRecord); ok && updateRec.TxID == txID {
txRecords = append(txRecords, updateRec)
} else if startRec, ok := record.(*TransactionStartRecord); ok && startRec.TxID == txID {
txRecords = append(txRecords, startRec)
}
// Stop scanning if we passed the current LSN in a multi-transaction system
// For simplicity, we collect all records for a given TxID up to the end of WAL.
return true
})
if err != nil && err != io.EOF {
return fmt.Errorf("error reading WAL for undo of transaction %d: %w", txID, err)
}
// Apply undo in reverse order
for i := len(txRecords) - 1; i >= 0; i-- {
record := txRecords[i]
if updateRec, ok := record.(*PageUpdateRecord); ok {
page, err := rm.pageCache.GetPage(updateRec.PageID)
if err != nil {
fmt.Printf("Error getting page %d for undo of tx %d: %vn", updateRec.PageID, txID, err)
continue // Try to undo other records
}
page.mu.Lock()
// Only undo if the current state of the page reflects this transaction's change
// (i.e., its LSN matches the update record's LSN or is newer but needs undo)
// ARIES-style recovery is more complex, involving CLRs (Compensation Log Records)
// For simplicity, we just apply OldData.
copy(page.Data[updateRec.Offset:updateRec.Offset+uint16(updateRec.Length)], updateRec.OldData)
page.IsDirty = true // Mark as dirty
page.mu.Unlock()
}
}
// Write an Abort record to WAL to finalize the undo (optional, for explicit history)
if _, err := rm.wal.AppendRecord(NewTransactionAbortRecord(txID)); err != nil {
fmt.Printf("Warning: failed to write abort record for tx %d during undo: %vn", txID, err