C++实现自定义Smart Pointer:实现特有资源管理、引用计数与线程安全

C++自定义智能指针:特有资源管理、引用计数与线程安全

各位听众,大家好!今天我们来深入探讨C++中自定义智能指针的实现,重点关注如何管理特有资源、实现引用计数以及保证线程安全。智能指针是C++中管理动态分配内存的重要工具,可以有效避免内存泄漏等问题。虽然标准库提供了std::unique_ptrstd::shared_ptrstd::weak_ptr,但在某些特定场景下,我们需要自定义智能指针以满足更复杂的需求,例如管理文件句柄、数据库连接等非内存资源,或者需要更细粒度的线程安全控制。

1. 理解智能指针的核心概念

在开始实现之前,我们先回顾一下智能指针的核心概念:

  • 资源获取即初始化 (RAII): 智能指针是RAII原则的典型应用。RAII的核心思想是将资源的生命周期与对象的生命周期绑定,在对象构造时获取资源,在对象析构时释放资源。智能指针通过析构函数自动释放所管理的资源,从而避免手动释放资源可能导致的错误。
  • 所有权: 智能指针负责管理所拥有的资源。不同类型的智能指针采用不同的所有权模型:
    • unique_ptr: 独占所有权,一个资源只能被一个unique_ptr拥有。
    • shared_ptr: 共享所有权,多个shared_ptr可以共同拥有一个资源,通过引用计数来跟踪资源的生命周期。
    • weak_ptr: 弱引用,不拥有资源的所有权,可以观察shared_ptr所管理的资源是否仍然有效。
  • 引用计数: 用于shared_ptr,记录有多少个shared_ptr指向同一个资源。当引用计数变为0时,资源被释放。
  • 线程安全: 在多线程环境下,智能指针的引用计数操作需要保证线程安全,以避免数据竞争。

2. 自定义智能指针的设计与实现

接下来,我们将通过一个具体的例子,来演示如何自定义一个支持特有资源管理、引用计数和线程安全的智能指针。假设我们需要管理一个文件句柄,并确保在多个线程中安全地共享该句柄。

2.1 定义资源管理类

首先,我们需要定义一个类来封装文件句柄的管理逻辑。这个类负责打开、关闭文件句柄,并提供访问句柄的方法。

#include <iostream>
#include <fstream>
#include <mutex>

class FileHandle {
public:
    FileHandle(const std::string& filename) : filename_(filename), handle_(nullptr) {
        handle_ = fopen(filename_.c_str(), "r+"); // 以读写模式打开文件
        if (handle_ == nullptr) {
            throw std::runtime_error("Failed to open file: " + filename_);
        }
        std::cout << "File opened: " << filename_ << std::endl;
    }

    ~FileHandle() {
        if (handle_ != nullptr) {
            fclose(handle_);
            std::cout << "File closed: " << filename_ << std::endl;
        }
    }

    FILE* get() const {
        return handle_;
    }

    std::string getFileName() const {
        return filename_;
    }

private:
    std::string filename_;
    FILE* handle_;
};

2.2 实现自定义智能指针类

现在,我们可以实现自定义智能指针类MySharedPtr,它将管理FileHandle对象。

template <typename T>
class MySharedPtr {
public:
    // 构造函数
    MySharedPtr(T* ptr = nullptr) : ptr_(ptr), count_(nullptr) {
        if (ptr_) {
            count_ = new std::atomic<int>(1); // 初始化引用计数为1
        }
    }

    // 拷贝构造函数
    MySharedPtr(const MySharedPtr& other) : ptr_(other.ptr_), count_(other.count_) {
        if (count_) {
            (*count_)++; // 增加引用计数
        }
    }

    // 移动构造函数
    MySharedPtr(MySharedPtr&& other) noexcept : ptr_(other.ptr_), count_(other.count_) {
        other.ptr_ = nullptr;
        other.count_ = nullptr;
    }

    // 赋值运算符
    MySharedPtr& operator=(const MySharedPtr& other) {
        if (this != &other) {
            // 减少当前对象的引用计数
            decrementCount();

            // 复制其他对象的引用计数
            ptr_ = other.ptr_;
            count_ = other.count_;
            if (count_) {
                (*count_)++;
            }
        }
        return *this;
    }

    // 移动赋值运算符
    MySharedPtr& operator=(MySharedPtr&& other) noexcept {
        if (this != &other) {
            decrementCount();

            ptr_ = other.ptr_;
            count_ = other.count_;

            other.ptr_ = nullptr;
            other.count_ = nullptr;
        }
        return *this;
    }

    // 解引用运算符
    T& operator*() const {
        if (ptr_) {
            return *ptr_;
        }
        throw std::runtime_error("Dereferencing a null pointer.");
    }

    // 箭头运算符
    T* operator->() const {
        if (ptr_) {
            return ptr_;
        }
        throw std::runtime_error("Dereferencing a null pointer.");
    }

    // 获取原始指针
    T* get() const {
        return ptr_;
    }

    // 获取引用计数
    int use_count() const {
        if (count_) {
            return *count_;
        }
        return 0;
    }

    // 析构函数
    ~MySharedPtr() {
        decrementCount();
    }

private:
    T* ptr_;
    std::atomic<int>* count_; // 使用atomic保证线程安全

    void decrementCount() {
        if (count_) {
            if ((*count_)-- == 1) { // 原子递减引用计数
                delete ptr_;
                delete count_;
                ptr_ = nullptr;
                count_ = nullptr;
            }
        }
    }
};

代码解释:

  • ptr_: 指向所管理的资源的原始指针。
  • count_: 指向引用计数的原子指针。使用std::atomic<int>保证引用计数的线程安全。
  • 构造函数:当传入原始指针时,分配一个新的原子计数器并初始化为1。
  • 拷贝构造函数:复制原始指针和原子计数器指针,并将原子计数器加1。
  • 移动构造函数:转移原始指针和原子计数器指针,并将源对象的指针置为nullptr
  • 赋值运算符:先减少当前对象的引用计数,然后复制其他对象的指针和计数器,并增加计数器。
  • 移动赋值运算符:先减少当前对象的引用计数,然后转移其他对象的指针和计数器,并将源对象置空。
  • decrementCount(): 原子递减引用计数。如果计数变为0,则释放资源和计数器。
  • 解引用运算符(*)和箭头运算符(->):提供访问所管理资源的方式。
  • get(): 返回原始指针。
  • use_count(): 返回当前的引用计数。

2.3 线程安全性的实现

在上面的代码中,我们使用std::atomic<int>来存储引用计数,保证了引用计数的原子操作。这意味着多个线程可以同时增加或减少引用计数,而不会发生数据竞争。

2.4 使用示例

下面是一个使用MySharedPtr的示例:

#include <iostream>
#include <thread>

void thread_function(MySharedPtr<FileHandle> ptr) {
    std::cout << "Thread ID: " << std::this_thread::get_id() << ", File name: " << ptr->getFileName() << ", Use count: " << ptr.use_count() << std::endl;
    // 在线程中使用文件句柄
    FILE* file = ptr->get();
    if (file) {
        // 读写文件操作
        char buffer[100];
        if (fgets(buffer, sizeof(buffer), file)) {
            std::cout << "Thread ID: " << std::this_thread::get_id() << ", Read from file: " << buffer << std::endl;
        }
    }
}

int main() {
    try {
        MySharedPtr<FileHandle> file_ptr(new FileHandle("test.txt")); // 创建一个 FileHandle 对象
        std::cout << "Main Thread: File Name: " << file_ptr->getFileName() << ", Use count: " << file_ptr.use_count() << std::endl;
        std::thread t1(thread_function, file_ptr); // 创建一个线程,并将智能指针传递给它
        std::thread t2(thread_function, file_ptr); // 创建另一个线程,并将智能指针传递给它

        t1.join();
        t2.join();

        std::cout << "Main Thread: File Name: " << file_ptr->getFileName() << ", Use count: " << file_ptr.use_count() << std::endl;

    } catch (const std::exception& e) {
        std::cerr << "Exception: " << e.what() << std::endl;
    }

    return 0;
}

代码解释:

  • main()函数中,我们创建了一个MySharedPtr<FileHandle>对象,并将其传递给两个线程。
  • 每个线程都可以安全地访问和操作文件句柄,因为MySharedPtr保证了线程安全。
  • 当所有MySharedPtr对象都超出作用域时,文件句柄会被自动关闭。

2.5 完整代码

#include <iostream>
#include <fstream>
#include <mutex>
#include <thread>
#include <atomic>
#include <stdexcept>

class FileHandle {
public:
    FileHandle(const std::string& filename) : filename_(filename), handle_(nullptr) {
        handle_ = fopen(filename_.c_str(), "r+"); // 以读写模式打开文件
        if (handle_ == nullptr) {
            throw std::runtime_error("Failed to open file: " + filename_);
        }
        std::cout << "File opened: " << filename_ << std::endl;
    }

    ~FileHandle() {
        if (handle_ != nullptr) {
            fclose(handle_);
            std::cout << "File closed: " << filename_ << std::endl;
        }
    }

    FILE* get() const {
        return handle_;
    }

    std::string getFileName() const {
        return filename_;
    }

private:
    std::string filename_;
    FILE* handle_;
};

template <typename T>
class MySharedPtr {
public:
    // 构造函数
    MySharedPtr(T* ptr = nullptr) : ptr_(ptr), count_(nullptr) {
        if (ptr_) {
            count_ = new std::atomic<int>(1); // 初始化引用计数为1
        }
    }

    // 拷贝构造函数
    MySharedPtr(const MySharedPtr& other) : ptr_(other.ptr_), count_(other.count_) {
        if (count_) {
            (*count_)++; // 增加引用计数
        }
    }

    // 移动构造函数
    MySharedPtr(MySharedPtr&& other) noexcept : ptr_(other.ptr_), count_(other.count_) {
        other.ptr_ = nullptr;
        other.count_ = nullptr;
    }

    // 赋值运算符
    MySharedPtr& operator=(const MySharedPtr& other) {
        if (this != &other) {
            // 减少当前对象的引用计数
            decrementCount();

            // 复制其他对象的引用计数
            ptr_ = other.ptr_;
            count_ = other.count_;
            if (count_) {
                (*count_)++;
            }
        }
        return *this;
    }

    // 移动赋值运算符
    MySharedPtr& operator=(MySharedPtr&& other) noexcept {
        if (this != &other) {
            decrementCount();

            ptr_ = other.ptr_;
            count_ = other.count_;

            other.ptr_ = nullptr;
            other.count_ = nullptr;
        }
        return *this;
    }

    // 解引用运算符
    T& operator*() const {
        if (ptr_) {
            return *ptr_;
        }
        throw std::runtime_error("Dereferencing a null pointer.");
    }

    // 箭头运算符
    T* operator->() const {
        if (ptr_) {
            return ptr_;
        }
        throw std::runtime_error("Dereferencing a null pointer.");
    }

    // 获取原始指针
    T* get() const {
        return ptr_;
    }

    // 获取引用计数
    int use_count() const {
        if (count_) {
            return *count_;
        }
        return 0;
    }

    // 析构函数
    ~MySharedPtr() {
        decrementCount();
    }

private:
    T* ptr_;
    std::atomic<int>* count_; // 使用atomic保证线程安全

    void decrementCount() {
        if (count_) {
            if ((*count_)-- == 1) { // 原子递减引用计数
                delete ptr_;
                delete count_;
                ptr_ = nullptr;
                count_ = nullptr;
            }
        }
    }
};

void thread_function(MySharedPtr<FileHandle> ptr) {
    std::cout << "Thread ID: " << std::this_thread::get_id() << ", File name: " << ptr->getFileName() << ", Use count: " << ptr.use_count() << std::endl;
    // 在线程中使用文件句柄
    FILE* file = ptr->get();
    if (file) {
        // 读写文件操作
        char buffer[100];
        if (fgets(buffer, sizeof(buffer), file)) {
            std::cout << "Thread ID: " << std::this_thread::get_id() << ", Read from file: " << buffer << std::endl;
        }
    }
}

int main() {
    try {
        MySharedPtr<FileHandle> file_ptr(new FileHandle("test.txt")); // 创建一个 FileHandle 对象
        std::cout << "Main Thread: File Name: " << file_ptr->getFileName() << ", Use count: " << file_ptr.use_count() << std::endl;
        std::thread t1(thread_function, file_ptr); // 创建一个线程,并将智能指针传递给它
        std::thread t2(thread_function, file_ptr); // 创建另一个线程,并将智能指针传递给它

        t1.join();
        t2.join();

        std::cout << "Main Thread: File Name: " << file_ptr->getFileName() << ", Use count: " << file_ptr.use_count() << std::endl;

    } catch (const std::exception& e) {
        std::cerr << "Exception: " << e.what() << std::endl;
    }

    return 0;
}

3. 扩展与改进

上述代码只是一个基本的示例,我们可以根据实际需求进行扩展和改进。

  • 自定义删除器 (Custom Deleter): MySharedPtr的模板可以接受一个删除器作为参数,允许用户自定义资源的释放方式。这对于管理非内存资源非常有用。
  • weak_ptr支持: 可以实现一个MyWeakPtr类,用于观察MySharedPtr所管理的资源,但不拥有所有权。
  • 更细粒度的线程安全控制: 可以使用读写锁 (Read-Write Lock) 来提高并发性能。读写锁允许多个线程同时读取资源,但只允许一个线程写入资源。

4. 总结

通过自定义智能指针,我们可以更好地管理特有资源,并根据实际需求定制线程安全策略。MySharedPtr的实现展示了如何使用引用计数和原子操作来保证资源在多线程环境下的安全释放。理解RAII原则和所有权模型是实现自定义智能指针的关键。

5. 关键点回顾

自定义智能指针的关键在于理解RAII原则和所有权模型。线程安全可以通过原子操作和读写锁等机制来实现。自定义删除器可以让我们灵活地管理各种类型的资源。

更多IT精英技术系列讲座,到智猿学院

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注