C++ 实现一个 `std::shared_ptr`:理解引用计数与循环引用解决

哈喽,各位好!今天咱们来聊聊 C++ 里一个非常重要的智能指针:std::shared_ptr。这玩意儿能自动管理内存,避免内存泄漏,简直是现代 C++ 开发的必备良药。但是,shared_ptr 的实现原理,特别是引用计数和循环引用,经常让新手头疼。所以,今天我们就手撸一个简化版的 shared_ptr,彻底搞懂它背后的机制。

1. 为什么需要 shared_ptr

在 C++ 里,内存管理是个老大难问题。如果你用 new 分配了内存,就必须用 delete 来释放,否则就会造成内存泄漏。而手动管理内存很容易出错,比如忘记 delete,或者重复 delete

这时候,智能指针就派上用场了。智能指针本质上是一个类,它封装了原始指针,并在对象生命周期结束时自动释放所管理的内存。std::shared_ptr 是其中一种,它允许多个智能指针共享对同一块内存的所有权。当最后一个 shared_ptr 对象销毁时,它所管理的内存才会被释放。

2. shared_ptr 的核心:引用计数

shared_ptr 的核心机制是引用计数。简单来说,就是每次有一个新的 shared_ptr 指向同一块内存,计数器就加一;每次一个 shared_ptr 对象销毁,计数器就减一。当计数器变为零时,就说明没有任何 shared_ptr 指向这块内存了,这时候就可以安全地释放内存了。

3. 手撸一个简化版的 shared_ptr

为了更好地理解 shared_ptr 的工作原理,我们来自己实现一个简化版的 shared_ptr

#include <iostream>

template <typename T>
class SharedPtr {
private:
  T* ptr;
  size_t* count;

public:
  // 构造函数
  SharedPtr(T* p = nullptr) : ptr(p), count(new size_t(1)) {
    if (!ptr) {
      delete count;
      count = nullptr;
    }
    std::cout << "Constructor: Count = " << *count << std::endl;
  }

  // 拷贝构造函数
  SharedPtr(const SharedPtr& other) : ptr(other.ptr), count(other.count) {
    if (count) {
      (*count)++;
    }
    std::cout << "Copy Constructor: Count = " << (count ? *count : 0) << std::endl;
  }

  // 移动构造函数
  SharedPtr(SharedPtr&& other) : ptr(other.ptr), count(other.count) {
    other.ptr = nullptr;
    other.count = nullptr;
    std::cout << "Move Constructor" << std::endl;
  }

  // 赋值运算符
  SharedPtr& operator=(const SharedPtr& other) {
    // 避免自赋值
    if (this != &other) {
      // 减少旧资源的引用计数
      if (count && (*count > 0)) {
          (*count)--;
          if (*count == 0) {
              delete ptr;
              delete count;
          }
      }

      // 复制新资源
      ptr = other.ptr;
      count = other.count;
      if (count) {
          (*count)++;
      }
    }
    std::cout << "Assignment Operator: Count = " << (count ? *count : 0) << std::endl;
    return *this;
  }

  // 移动赋值运算符
  SharedPtr& operator=(SharedPtr&& other) {
    if (this != &other) {
        // 减少旧资源的引用计数
        if (count && (*count > 0)) {
            (*count)--;
            if (*count == 0) {
                delete ptr;
                delete count;
            }
        }

        // 转移新资源
        ptr = other.ptr;
        count = other.count;

        other.ptr = nullptr;
        other.count = nullptr;
    }
    std::cout << "Move Assignment Operator" << std::endl;
    return *this;
}

  // 解引用运算符
  T& operator*() {
    return *ptr;
  }

  // 箭头运算符
  T* operator->() {
    return ptr;
  }

  // 获取原始指针
  T* get() const {
    return ptr;
  }

  // 获取引用计数
  size_t use_count() const {
    return count ? *count : 0;
  }

  // 析构函数
  ~SharedPtr() {
    if (count && (*count > 0)) {
      (*count)--;
      std::cout << "Destructor: Count = " << *count << std::endl;
      if (*count == 0) {
        std::cout << "Deleting memory" << std::endl;
        delete ptr;
        delete count;
      }
    } else {
      std::cout << "Destructor: Count = 0 (or nullptr)" << std::endl;
    }
    ptr = nullptr;
    count = nullptr;
  }
};

int main() {
  SharedPtr<int> ptr1(new int(10));
  std::cout << "ptr1 use_count: " << ptr1.use_count() << std::endl;

  SharedPtr<int> ptr2 = ptr1;
  std::cout << "ptr1 use_count: " << ptr1.use_count() << std::endl;
  std::cout << "ptr2 use_count: " << ptr2.use_count() << std::endl;

  SharedPtr<int> ptr3;
  ptr3 = ptr1;
  std::cout << "ptr1 use_count: " << ptr1.use_count() << std::endl;
  std::cout << "ptr2 use_count: " << ptr2.use_count() << std::endl;
  std::cout << "ptr3 use_count: " << ptr3.use_count() << std::endl;

  {
      SharedPtr<int> ptr4(new int(20));
      std::cout << "ptr4 use_count: " << ptr4.use_count() << std::endl;
      ptr3 = ptr4;
      std::cout << "ptr1 use_count: " << ptr1.use_count() << std::endl;
      std::cout << "ptr2 use_count: " << ptr2.use_count() << std::endl;
      std::cout << "ptr3 use_count: " << ptr3.use_count() << std::endl;
      std::cout << "ptr4 use_count: " << ptr4.use_count() << std::endl;
  }

  std::cout << "ptr1 use_count: " << ptr1.use_count() << std::endl;
  std::cout << "ptr2 use_count: " << ptr2.use_count() << std::endl;
  std::cout << "ptr3 use_count: " << ptr3.use_count() << std::endl;
  return 0;
}

代码解释:

  • ptr: 这是指向实际数据的原始指针。
  • count: 这是指向引用计数器的指针。引用计数器是一个 size_t 类型的变量,用来记录有多少个 shared_ptr 指向同一块内存。
  • 构造函数: 构造函数负责初始化 ptrcount。如果传入了原始指针,就创建一个新的引用计数器,并将计数器初始化为 1。
  • 拷贝构造函数: 拷贝构造函数负责创建一个新的 shared_ptr 对象,指向与原始 shared_ptr 对象相同的内存,并将引用计数器加一。
  • 赋值运算符: 赋值运算符负责将一个 shared_ptr 对象赋值给另一个 shared_ptr 对象。首先,它会减少左侧 shared_ptr 对象原来指向的内存的引用计数。如果引用计数变为零,就释放该内存。然后,它会将左侧 shared_ptr 对象指向与右侧 shared_ptr 对象相同的内存,并将引用计数器加一。注意避免自赋值的情况。
  • 析构函数: 析构函数负责在 shared_ptr 对象销毁时,将引用计数器减一。如果引用计数变为零,就释放所管理的内存。
  • use_count(): 返回当前的引用计数。

4. 循环引用:shared_ptr 的阿喀琉斯之踵

虽然 shared_ptr 很好用,但是它也有一个致命的弱点:循环引用。循环引用是指两个或多个对象互相持有对方的 shared_ptr,导致它们的引用计数永远不会变为零,从而造成内存泄漏。

举个例子:

#include <iostream>
#include <memory>

class A; // 前置声明

class B {
public:
    std::shared_ptr<A> a_ptr;

    ~B() {
        std::cout << "B destructor" << std::endl;
    }
};

class A {
public:
    std::shared_ptr<B> b_ptr;

    ~A() {
        std::cout << "A destructor" << std::endl;
    }
};

int main() {
    std::shared_ptr<A> a = std::make_shared<A>();
    std::shared_ptr<B> b = std::make_shared<B>();

    a->b_ptr = b;
    b->a_ptr = a;

    std::cout << "Program ends" << std::endl;
    return 0;
}

在这个例子中,A 对象持有一个指向 B 对象的 shared_ptr,而 B 对象又持有一个指向 A 对象的 shared_ptr。当 main 函数结束时,ab 两个 shared_ptr 对象都会被销毁,它们的引用计数都会减一,但是由于它们互相持有对方的 shared_ptr,它们的引用计数永远不会变为零,所以 AB 对象的析构函数永远不会被调用,从而造成内存泄漏。

5. 解决循环引用:std::weak_ptr

为了解决循环引用问题,C++ 提供了 std::weak_ptrweak_ptr 是一种弱引用,它不会增加引用计数。weak_ptr 可以用来观察 shared_ptr 所管理的对象,但是它不会阻止对象的销毁。

我们可以将上面的代码修改如下:

#include <iostream>
#include <memory>

class A; // 前置声明

class B {
public:
    std::weak_ptr<A> a_ptr; // 使用 weak_ptr

    ~B() {
        std::cout << "B destructor" << std::endl;
    }
};

class A {
public:
    std::shared_ptr<B> b_ptr;

    ~A() {
        std::cout << "A destructor" << std::endl;
    }
};

int main() {
    std::shared_ptr<A> a = std::make_shared<A>();
    std::shared_ptr<B> b = std::make_shared<B>();

    a->b_ptr = b;
    b->a_ptr = a; // 使用 weak_ptr 不会增加引用计数

    std::cout << "Program ends" << std::endl;
    return 0;
}

在这个修改后的例子中,B 对象持有一个指向 A 对象的 weak_ptr。当 main 函数结束时,ab 两个 shared_ptr 对象都会被销毁,它们的引用计数都会减一。由于 B 对象持有的是一个 weak_ptr,它不会增加 A 对象的引用计数,所以 A 对象的引用计数会变为零,A 对象的析构函数会被调用,从而释放 A 对象所占用的内存。然后,B 对象的析构函数也会被调用,从而释放 B 对象所占用的内存。

weak_ptr 的使用:

weak_ptr 不能直接访问所指向的对象,需要先调用 lock() 方法将其转换为 shared_ptr,才能访问对象。如果对象已经被销毁,lock() 方法会返回一个空的 shared_ptr

#include <iostream>
#include <memory>

int main() {
    std::shared_ptr<int> sp = std::make_shared<int>(10);
    std::weak_ptr<int> wp = sp;

    // 使用 weak_ptr 之前需要先 lock()
    if (auto shared_ptr = wp.lock()) {
        std::cout << "Value: " << *shared_ptr << std::endl;
    } else {
        std::cout << "Object has been destroyed" << std::endl;
    }

    sp.reset(); // 释放 shared_ptr

    // 再次尝试 lock()
    if (auto shared_ptr = wp.lock()) {
        std::cout << "Value: " << *shared_ptr << std::endl;
    } else {
        std::cout << "Object has been destroyed" << std::endl;
    }

    return 0;
}

6. shared_ptr 的一些注意事项:

  • 不要将原始指针直接赋值给多个 shared_ptr: 如果这样做,会导致多个 shared_ptr 对象管理同一块内存,每个 shared_ptr 对象都会认为自己是唯一的所有者,从而导致重复释放内存。

    int* raw_ptr = new int(10);
    std::shared_ptr<int> sp1(raw_ptr); // 正确
    //std::shared_ptr<int> sp2(raw_ptr); // 错误!重复释放
    std::shared_ptr<int> sp2 = sp1; // 正确,共享所有权
  • 使用 std::make_shared 创建 shared_ptr: std::make_shared 可以一次性分配对象和引用计数器的内存,避免了两次内存分配,提高了效率。

    // 推荐使用 make_shared
    std::shared_ptr<int> sp = std::make_shared<int>(10);
    
    // 不推荐
    std::shared_ptr<int> sp2(new int(10));

7. 总结

std::shared_ptr 是 C++ 中一个非常重要的智能指针,它可以自动管理内存,避免内存泄漏。理解 shared_ptr 的引用计数机制和循环引用问题,并掌握 std::weak_ptr 的使用,可以帮助你更好地使用 shared_ptr,编写更安全、更可靠的 C++ 代码。

表格总结:

特性 std::shared_ptr std::weak_ptr
所有权 拥有所有权,增加引用计数 不拥有所有权,不增加引用计数
作用 自动管理对象生命周期,防止内存泄漏 观察 shared_ptr 所管理的对象,解决循环引用问题
使用方法 可以直接访问所指向的对象 需要先 lock() 转换为 shared_ptr 才能访问
适用场景 需要多个对象共享所有权时 需要观察对象,但不希望影响对象生命周期时
循环引用解决方案 造成循环引用,导致内存泄漏 用于打破循环引用

希望今天的讲解能够帮助大家更好地理解 std::shared_ptr。记住,理解原理才能更好地应用! 祝大家编程愉快!

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注