C++ 实现自定义的 `std::shared_ptr`:深入理解引用计数

好的,各位观众老爷,今天咱们来聊聊 C++ 里一个超级重要的家伙—— std::shared_ptr。 啥? 你说你已经用得很溜了? 嗯,会用跟了解它的底层原理,那可是两码事儿! 今天咱就撸起袖子,自己动手,山寨一个 shared_ptr 出来,保证你对引用计数理解得透透的!

一、 为什么要自己造轮子?

你可能会嘀咕: “标准库里都有了,我干嘛还费劲巴拉自己写一个? 嫌活儿不够多啊?”

问得好! 理由如下:

  1. 深入理解原理: 用别人的东西,你永远只是用户。 自己动手实现一遍,才能真正理解 shared_ptr 背后的引用计数机制,以及它如何管理内存。 这种理解,对你写出更健壮、更高效的代码至关重要。

  2. 面试加分项: 面试官最喜欢问的题目之一就是“请你实现一个简单的智能指针”。 如果你能熟练地写出一个简化的 shared_ptr,那绝对是个加分项!

  3. 定制化需求: 标准库的 shared_ptr 已经很强大了,但在某些特殊场景下,你可能需要一些定制化的行为。 自己实现一个 shared_ptr,可以让你更好地满足这些需求。

二、 shared_ptr 核心思想:引用计数

shared_ptr 的核心思想是引用计数。 简单来说,就是它内部维护着一个计数器,用来记录有多少个 shared_ptr 指向同一个对象。

  • 创建 shared_ptr 当创建一个新的 shared_ptr 指向某个对象时,引用计数加 1。
  • 复制 shared_ptr 当复制一个 shared_ptr 时,引用计数也加 1。
  • 销毁 shared_ptr 当一个 shared_ptr 被销毁时,引用计数减 1。
  • 引用计数归零: 当引用计数变为 0 时,说明没有任何 shared_ptr 指向该对象了,这时 shared_ptr 就会自动释放该对象所占用的内存。

这个计数器就是 shared_ptr 能够自动管理内存的关键! 避免了手动 newdelete 带来的内存泄漏问题。

三、 山寨版 shared_ptr 代码实现

好了,废话不多说,直接上代码! 咱先写一个最最最简单的版本,一步一步完善。

#include <iostream>

template <typename T>
class MySharedPtr {
public:
    // 构造函数
    MySharedPtr(T* ptr = nullptr) : ptr_(ptr), count_(ptr ? new int(1) : nullptr) {}

    // 拷贝构造函数
    MySharedPtr(const MySharedPtr& other) : ptr_(other.ptr_), count_(other.count_) {
        if (count_) {
            ++(*count_);
        }
    }

    // 赋值运算符
    MySharedPtr& operator=(const MySharedPtr& other) {
        if (this != &other) {
            // 减少旧对象的引用计数
            if (count_) {
                if (--(*count_) == 0) {
                    delete ptr_;
                    delete count_;
                }
            }

            // 指向新对象
            ptr_ = other.ptr_;
            count_ = other.count_;
            if (count_) {
                ++(*count_);
            }
        }
        return *this;
    }

    // 析构函数
    ~MySharedPtr() {
        if (count_) {
            if (--(*count_) == 0) {
                delete ptr_;
                delete count_;
            }
        }
    }

    // 解引用运算符
    T& operator*() const {
        return *ptr_;
    }

    // 箭头运算符
    T* operator->() const {
        return ptr_;
    }

    // 获取原始指针
    T* get() const {
        return ptr_;
    }

    // 获取引用计数
    int use_count() const {
        return count_ ? *count_ : 0;
    }

private:
    T* ptr_;       // 指向实际对象的指针
    int* count_;     // 指向引用计数的指针
};

int main() {
    MySharedPtr<int> ptr1(new int(10));
    std::cout << "ptr1 use_count: " << ptr1.use_count() << std::endl; // 输出 1

    MySharedPtr<int> ptr2 = ptr1;
    std::cout << "ptr1 use_count: " << ptr1.use_count() << std::endl; // 输出 2
    std::cout << "ptr2 use_count: " << ptr2.use_count() << std::endl; // 输出 2

    *ptr1 = 20;
    std::cout << "*ptr2: " << *ptr2 << std::endl; // 输出 20

    MySharedPtr<int> ptr3;
    ptr3 = ptr1;
    std::cout << "ptr1 use_count: " << ptr1.use_count() << std::endl; // 输出 3
    std::cout << "ptr2 use_count: " << ptr2.use_count() << std::endl; // 输出 3
    std::cout << "ptr3 use_count: " << ptr3.use_count() << std::endl; // 输出 3

    return 0;
}

代码讲解:

  • ptr_: 这就是我们存储实际对象指针的地方。
  • count_: 指向一个整数,这个整数就是我们的引用计数器。 注意,这个计数器本身也是通过 new 分配的,因为多个 MySharedPtr 实例需要共享同一个计数器。
  • 构造函数: 接收一个裸指针,如果指针不为空,则初始化 ptr_count_,并将引用计数设置为 1。
  • 拷贝构造函数: 这是关键! 复制 ptr_count_,并且将引用计数加 1。 这保证了多个 MySharedPtr 实例共享同一个对象和计数器。
  • 赋值运算符: 这个也比较复杂,需要考虑以下几点:
    • 自赋值: 首先要判断是不是 ptr1 = ptr1 这种情况,如果是,直接返回。
    • 减少旧对象的引用计数: 如果当前 MySharedPtr 已经指向了某个对象,那么需要先将该对象的引用计数减 1。 如果引用计数减为 0,说明已经没有任何 MySharedPtr 指向该对象了,需要释放该对象的内存和计数器。
    • 指向新对象:ptr_count_ 指向新对象的指针和计数器,并将新对象的引用计数加 1。
  • 析构函数:MySharedPtr 对象被销毁时,需要将引用计数减 1。 如果引用计数减为 0,说明已经没有任何 MySharedPtr 指向该对象了,需要释放该对象的内存和计数器。
  • *`operatoroperator->`:** 这两个运算符用于访问实际对象。
  • get() 返回原始指针。
  • use_count() 返回当前的引用计数。

四、 进一步完善 MySharedPtr

上面的代码只是一个最简单的版本,还有很多需要完善的地方。

  1. 异常安全性: 考虑在 new 操作失败时抛出异常,并保证在异常情况下不会发生内存泄漏。

  2. 线程安全性: 引用计数的操作不是线程安全的。 在高并发环境下,需要使用互斥锁来保护引用计数。

  3. std::move 支持: 添加移动构造函数和移动赋值运算符,提高性能。

  4. std::weak_ptr 支持: 实现 std::weak_ptr,解决 shared_ptr 循环引用的问题。

  5. 自定义删除器: 允许用户自定义删除对象的方式。

下面是添加了异常安全性和 std::move 支持的代码:

#include <iostream>
#include <utility> // std::move

template <typename T>
class MySharedPtr {
public:
    // 构造函数
    MySharedPtr(T* ptr = nullptr) : ptr_(ptr), count_(ptr ? new int(1) : nullptr) {}

    // 拷贝构造函数
    MySharedPtr(const MySharedPtr& other) : ptr_(other.ptr_), count_(other.count_) {
        if (count_) {
            ++(*count_);
        }
    }

    // 移动构造函数
    MySharedPtr(MySharedPtr&& other) noexcept : ptr_(other.ptr_), count_(other.count_) {
        other.ptr_ = nullptr;
        other.count_ = nullptr;
    }

    // 赋值运算符
    MySharedPtr& operator=(const MySharedPtr& other) {
        if (this != &other) {
            // 减少旧对象的引用计数
            Release();

            // 指向新对象
            ptr_ = other.ptr_;
            count_ = other.count_;
            if (count_) {
                ++(*count_);
            }
        }
        return *this;
    }

    // 移动赋值运算符
    MySharedPtr& operator=(MySharedPtr&& other) noexcept {
        if (this != &other) {
            // 减少旧对象的引用计数
            Release();

            // 指向新对象
            ptr_ = other.ptr_;
            count_ = other.count_;
            other.ptr_ = nullptr;
            other.count_ = nullptr;
        }
        return *this;
    }

    // 析构函数
    ~MySharedPtr() {
        Release();
    }

    // 解引用运算符
    T& operator*() const {
        return *ptr_;
    }

    // 箭头运算符
    T* operator->() const {
        return ptr_;
    }

    // 获取原始指针
    T* get() const {
        return ptr_;
    }

    // 获取引用计数
    int use_count() const {
        return count_ ? *count_ : 0;
    }

private:
    T* ptr_;       // 指向实际对象的指针
    int* count_;     // 指向引用计数的指针

    void Release() {
        if (count_) {
            if (--(*count_) == 0) {
                delete ptr_;
                delete count_;
            }
            ptr_ = nullptr;
            count_ = nullptr;
        }
    }
};

int main() {
    MySharedPtr<int> ptr1(new int(10));
    std::cout << "ptr1 use_count: " << ptr1.use_count() << std::endl;

    MySharedPtr<int> ptr2 = std::move(ptr1); // 使用移动构造函数
    std::cout << "ptr1 use_count: " << ptr1.use_count() << std::endl; // 输出 0
    std::cout << "ptr2 use_count: " << ptr2.use_count() << std::endl; // 输出 1

    return 0;
}

代码讲解:

  • Release() 函数: 将释放资源的操作封装到 Release() 函数中,方便代码复用。
  • 移动构造函数和移动赋值运算符: 通过 std::move 将资源的所有权从一个 MySharedPtr 对象转移到另一个 MySharedPtr 对象,避免了不必要的拷贝操作。 noexcept 关键字表示这些操作不会抛出异常。

五、 总结

今天我们一起手动实现了一个简化的 std::shared_ptr,深入理解了引用计数的原理。 虽然这个山寨版 shared_ptr 还有很多需要完善的地方,但它已经足以帮助你理解 shared_ptr 的核心思想。

shared_ptr 的优点:

优点 描述
自动内存管理 无需手动 newdelete,避免内存泄漏。
多个指针共享所有权 允许多个指针指向同一个对象,方便资源共享。
避免悬挂指针 当对象不再被使用时,会自动释放内存,避免悬挂指针。

shared_ptr 的缺点:

缺点 描述
循环引用问题 如果两个或多个 shared_ptr 互相引用,会导致内存泄漏。 可以使用 std::weak_ptr 来解决。
性能开销 引用计数的增加和减少会带来一定的性能开销。
线程安全问题 引用计数的操作不是线程安全的。 在高并发环境下,需要使用互斥锁来保护引用计数。

希望今天的讲解对你有所帮助! 下次有机会,我们再来聊聊 std::weak_ptr 和自定义删除器。 再见!

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注