C++ 构建一个微型 `std::variant`:理解类型安全联合体

哈喽,各位好!今天咱们来聊聊一个在 C++ 世界里既神秘又实用的家伙—— std::variant 的微型版。 别害怕,我们不搞火箭科学,而是用一种轻松幽默的方式,一起拆解它,看看类型安全的联合体到底是怎么工作的。

开场白:联合体的爱恨情仇

在 C++ 的江湖里,联合体(union)一直是个颇具争议的角色。 它允许你在同一块内存空间里存储不同类型的数据,这在某些场景下非常高效。 但同时,它的类型安全性却让人头疼:编译器不会帮你检查你到底存的是什么类型,取的时候是不是取的也是这个类型。 一旦取错,那可就惨了,轻则数据错误,重则程序崩溃。

std::variant 的出现,就是为了解决这个问题。 它提供了一种类型安全的联合体,让你可以放心地使用联合体的效率,而不用担心类型错误。 今天,咱们就来自己动手,打造一个微型的 std::variant,深入理解它的原理。

我们的目标:MiniVariant

我们的目标是创建一个名为 MiniVariant 的类,它应该具备以下功能:

  • 可以存储多种不同类型的数据。
  • 在编译时检查类型安全性。
  • 提供一种方式来确定当前存储的类型。
  • 提供一种方式来访问存储的数据。

Step 1: 基本结构

首先,我们需要一个能够存储数据的底层存储空间。 这里我们直接用 std::aligned_storage 来分配一块足够大的内存,以容纳所有可能的类型。

#include <type_traits>
#include <memory>
#include <new> // for placement new
#include <utility> // for std::forward

template <typename... Types>
class MiniVariant {
private:
    using StorageType = std::aligned_storage_t<
        std::max({sizeof(Types)...}),
        std::max({alignof(Types)...})
    >;
    StorageType storage;
    std::size_t index = 0; // 记录当前存储的类型索引
    using TypeList = std::tuple<Types...>;

public:
    // ... 构造函数、析构函数、访问函数等
};
  • std::aligned_storage_t: 它可以分配一块原始内存,并且保证这块内存的对齐方式适合任何类型。 我们使用 std::max({sizeof(Types)...})std::max({alignof(Types)...}) 来确定需要分配的内存大小和对齐方式,确保它能容纳所有可能的类型。
  • index: 这个变量用来记录当前 MiniVariant 中存储的是哪种类型。 它的值对应于 Types... 中的类型顺序。
  • TypeList: 使用 std::tuple 将所有可能的类型打包在一起,方便后续的操作。

Step 2: 构造函数

我们需要提供多个构造函数,允许用户使用不同的类型来初始化 MiniVariant

    template <typename T,
              typename = std::enable_if_t<
                  (std::is_constructible_v<T> &&
                   (std::is_same_v<T, Types> || ...))>>
    MiniVariant(T&& value) {
        using U = std::decay_t<T>; // 去掉引用和 const/volatile 修饰符
        size_t i = 0;
        bool found = false;
        std::apply([&](auto... args){
            ( [&](){
                using ArgType = std::decay_t<decltype(args)>;
                if constexpr (std::is_same_v<U, ArgType>) {
                    new (&storage) ArgType(std::forward<T>(value));
                    index = i;
                    found = true;
                    return;
                }
                i++;
            }(), ... );
        }, TypeList{});

        if (!found) {
            throw std::bad_variant_access(); // 如果类型不在 Types 中,抛出异常
        }
    }
  • std::enable_if_t:这个模板元编程的利器,用于在编译时控制构造函数的可用性。 只有当 T 可以构造,并且 TTypes... 中的一种类型时,这个构造函数才可用。
  • std::decay_t:用于去掉 T 的引用和 const/volatile 修饰符,方便后续的类型比较。
  • placement new:这是一个特殊的 new 操作符,它不会分配内存,而是在已有的内存空间上构造对象。 我们使用 new (&storage) ArgType(std::forward<T>(value))storage 上构造对象。
  • std::forward:用于完美转发,保持 value 的左右值属性。
  • std::apply: 用于遍历 TypeList 中的所有类型,并找到匹配的类型。 这里使用了一个 lambda 表达式和一个 fold expression 来实现遍历。

Step 3: 析构函数

MiniVariant 对象销毁时,我们需要手动调用当前存储对象的析构函数,释放资源。

    ~MiniVariant() {
        std::apply([&](auto... args){
            size_t i = 0;
            ( [&](){
                using ArgType = std::decay_t<decltype(args)>;
                if (i == index) {
                    using DestructorType = void (*) (ArgType*);
                    DestructorType destructor = [](ArgType* ptr){ ptr->~ArgType(); };
                    destructor(reinterpret_cast<ArgType*>(&storage));
                    return;
                }
                i++;
            }(), ... );
        }, TypeList{});
    }
  • 这里同样使用了 std::apply 和一个 fold expression 来遍历 TypeList
  • reinterpret_cast: 由于 storage 是原始内存,我们需要使用 reinterpret_cast 将它转换为正确的类型指针。
  • 显式调用析构函数: ptr->~ArgType() 显式调用对象的析构函数。

Step 4: 拷贝构造函数和拷贝赋值运算符

为了保证拷贝的正确性,我们需要自定义拷贝构造函数和拷贝赋值运算符。

    MiniVariant(const MiniVariant& other) {
        index = other.index;
        std::apply([&](auto... args){
            size_t i = 0;
            ( [&](){
                using ArgType = std::decay_t<decltype(args)>;
                if (i == index) {
                    const ArgType* other_ptr = reinterpret_cast<const ArgType*>(&other.storage);
                    new (&storage) ArgType(*other_ptr);
                    return;
                }
                i++;
            }(), ... );
        }, TypeList{});
    }

    MiniVariant& operator=(const MiniVariant& other) {
        if (this == &other) {
            return *this;
        }

        // 销毁当前对象
        std::apply([&](auto... args){
            size_t i = 0;
            ( [&](){
                using ArgType = std::decay_t<decltype(args)>;
                if (i == index) {
                    using DestructorType = void (*) (ArgType*);
                    DestructorType destructor = [](ArgType* ptr){ ptr->~ArgType(); };
                    destructor(reinterpret_cast<ArgType*>(&storage));
                    return;
                }
                i++;
            }(), ... );
        }, TypeList{});

        // 拷贝新对象
        index = other.index;
        std::apply([&](auto... args){
            size_t i = 0;
            ( [&](){
                using ArgType = std::decay_t<decltype(args)>;
                if (i == index) {
                    const ArgType* other_ptr = reinterpret_cast<const ArgType*>(&other.storage);
                    new (&storage) ArgType(*other_ptr);
                    return;
                }
                i++;
            }(), ... );
        }, TypeList{});

        return *this;
    }
  • 拷贝构造函数: 在新的 MiniVariant 对象中,拷贝 otherindex,然后在 storage 上构造一个与 other 相同类型的对象。
  • 拷贝赋值运算符: 首先检查是否是自赋值,如果是,则直接返回。 否则,销毁当前对象,然后拷贝 otherindex,并在 storage 上构造一个与 other 相同类型的对象。

Step 5: 移动构造函数和移动赋值运算符

同样,为了提高效率,我们需要自定义移动构造函数和移动赋值运算符。

    MiniVariant(MiniVariant&& other) noexcept {
        index = other.index;
        std::apply([&](auto... args){
            size_t i = 0;
            ( [&](){
                using ArgType = std::decay_t<decltype(args)>;
                if (i == index) {
                    ArgType* other_ptr = reinterpret_cast<ArgType*>(&other.storage);
                    new (&storage) ArgType(std::move(*other_ptr));
                    return;
                }
                i++;
            }(), ... );
        }, TypeList{});
        other.index = 0; // 将 other 的 index 设置为 0,表示不再持有任何值
    }

    MiniVariant& operator=(MiniVariant&& other) noexcept {
        if (this == &other) {
            return *this;
        }

        // 销毁当前对象
        std::apply([&](auto... args){
            size_t i = 0;
            ( [&](){
                using ArgType = std::decay_t<decltype(args)>;
                if (i == index) {
                    using DestructorType = void (*) (ArgType*);
                    DestructorType destructor = [](ArgType* ptr){ ptr->~ArgType(); };
                    destructor(reinterpret_cast<ArgType*>(&storage));
                    return;
                }
                i++;
            }(), ... );
        }, TypeList{});

        // 移动新对象
        index = other.index;
        std::apply([&](auto... args){
            size_t i = 0;
            ( [&](){
                using ArgType = std::decay_t<decltype(args)>;
                if (i == index) {
                    ArgType* other_ptr = reinterpret_cast<ArgType*>(&other.storage);
                    new (&storage) ArgType(std::move(*other_ptr));
                    return;
                }
                i++;
            }(), ... );
        }, TypeList{});
        other.index = 0; // 将 other 的 index 设置为 0,表示不再持有任何值

        return *this;
    }
  • 移动构造函数: 在新的 MiniVariant 对象中,拷贝 otherindex,然后在 storage 上使用 std::move 构造一个与 other 相同类型的对象。 同时,将 otherindex 设置为 0,表示 other 不再持有任何值。
  • 移动赋值运算符: 首先检查是否是自赋值,如果是,则直接返回。 否则,销毁当前对象,然后拷贝 otherindex,并在 storage 上使用 std::move 构造一个与 other 相同类型的对象。 同时,将 otherindex 设置为 0,表示 other 不再持有任何值。

Step 6: 访问数据

我们需要提供一种方式来访问 MiniVariant 中存储的数据。 这里我们使用 std::get 函数模板。

    template <typename T>
    T& get() {
        using U = std::decay_t<T>;
        size_t i = 0;
        bool found = false;
        std::apply([&](auto... args){
            ( [&](){
                using ArgType = std::decay_t<decltype(args)>;
                if constexpr (std::is_same_v<U, ArgType>) {
                    if (i == index) {
                        found = true;
                        return;
                    } else {
                        throw std::bad_variant_access();
                    }
                }
                i++;
            }(), ... );
        }, TypeList{});
        if (!found) {
            throw std::bad_variant_access();
        }
        return *reinterpret_cast<T*>(&storage);
    }

    template <typename T>
    const T& get() const {
        using U = std::decay_t<T>;
        size_t i = 0;
        bool found = false;
        std::apply([&](auto... args){
            ( [&](){
                using ArgType = std::decay_t<decltype(args)>;
                if constexpr (std::is_same_v<U, ArgType>) {
                    if (i == index) {
                        found = true;
                        return;
                    } else {
                        throw std::bad_variant_access();
                    }
                }
                i++;
            }(), ... );
        }, TypeList{});
        if (!found) {
            throw std::bad_variant_access();
        }
        return *reinterpret_cast<const T*>(&storage);
    }
  • std::get 函数模板接受一个类型 T 作为参数,并返回 MiniVariant 中存储的 T 类型的引用。
  • 如果 MiniVariant 中存储的不是 T 类型,则抛出 std::bad_variant_access 异常。
  • 这里同样使用了 std::apply 和一个 fold expression 来遍历 TypeList

Step 7: 获取当前类型索引

我们需要提供一种方式来获取当前 MiniVariant 中存储的类型的索引。

    std::size_t index_of() const {
        return index;
    }

Step 8: 测试代码

现在,我们可以编写一些测试代码来验证 MiniVariant 的功能。

#include <iostream>
#include <string>

int main() {
    MiniVariant<int, double, std::string> mv(10);
    std::cout << "Value: " << mv.get<int>() << std::endl;

    mv = 3.14;
    std::cout << "Value: " << mv.get<double>() << std::endl;

    mv = "hello";
    std::cout << "Value: " << mv.get<std::string>() << std::endl;

    try {
        std::cout << "Value: " << mv.get<int>() << std::endl; // 抛出异常
    } catch (const std::bad_variant_access& e) {
        std::cerr << "Error: " << e.what() << std::endl;
    }

    MiniVariant<int, double, std::string> mv2 = mv; // Copy constructor
    std::cout << "Value in mv2: " << mv2.get<std::string>() << std::endl;

    MiniVariant<int, double, std::string> mv3(std::move(mv)); // Move constructor
    std::cout << "Value in mv3: " << mv3.get<std::string>() << std::endl;
    //std::cout << "Value in mv: " << mv.get<std::string>() << std::endl; //  mv is now in valid but unspecified state

    return 0;
}

完整代码

#include <type_traits>
#include <memory>
#include <new> // for placement new
#include <utility> // for std::forward
#include <stdexcept> // for std::bad_variant_access
#include <iostream>
#include <string>

template <typename... Types>
class MiniVariant {
private:
    using StorageType = std::aligned_storage_t<
        std::max({sizeof(Types)...}),
        std::max({alignof(Types)...})
    >;
    StorageType storage;
    std::size_t index = 0; // 记录当前存储的类型索引
    using TypeList = std::tuple<Types...>;

public:
    template <typename T,
              typename = std::enable_if_t<
                  (std::is_constructible_v<T> &&
                   (std::is_same_v<T, Types> || ...))>>
    MiniVariant(T&& value) {
        using U = std::decay_t<T>; // 去掉引用和 const/volatile 修饰符
        size_t i = 0;
        bool found = false;
        std::apply([&](auto... args){
            ( [&](){
                using ArgType = std::decay_t<decltype(args)>;
                if constexpr (std::is_same_v<U, ArgType>) {
                    new (&storage) ArgType(std::forward<T>(value));
                    index = i;
                    found = true;
                    return;
                }
                i++;
            }(), ... );
        }, TypeList{});

        if (!found) {
            throw std::bad_variant_access(); // 如果类型不在 Types 中,抛出异常
        }
    }

    ~MiniVariant() {
        std::apply([&](auto... args){
            size_t i = 0;
            ( [&](){
                using ArgType = std::decay_t<decltype(args)>;
                if (i == index) {
                    using DestructorType = void (*) (ArgType*);
                    DestructorType destructor = [](ArgType* ptr){ ptr->~ArgType(); };
                    destructor(reinterpret_cast<ArgType*>(&storage));
                    return;
                }
                i++;
            }(), ... );
        }, TypeList{});
    }

    MiniVariant(const MiniVariant& other) {
        index = other.index;
        std::apply([&](auto... args){
            size_t i = 0;
            ( [&](){
                using ArgType = std::decay_t<decltype(args)>;
                if (i == index) {
                    const ArgType* other_ptr = reinterpret_cast<const ArgType*>(&other.storage);
                    new (&storage) ArgType(*other_ptr);
                    return;
                }
                i++;
            }(), ... );
        }, TypeList{});
    }

    MiniVariant& operator=(const MiniVariant& other) {
        if (this == &other) {
            return *this;
        }

        // 销毁当前对象
        std::apply([&](auto... args){
            size_t i = 0;
            ( [&](){
                using ArgType = std::decay_t<decltype(args)>;
                if (i == index) {
                    using DestructorType = void (*) (ArgType*);
                    DestructorType destructor = [](ArgType* ptr){ ptr->~ArgType(); };
                    destructor(reinterpret_cast<ArgType*>(&storage));
                    return;
                }
                i++;
            }(), ... );
        }, TypeList{});

        // 拷贝新对象
        index = other.index;
        std::apply([&](auto... args){
            size_t i = 0;
            ( [&](){
                using ArgType = std::decay_t<decltype(args)>;
                if (i == index) {
                    const ArgType* other_ptr = reinterpret_cast<const ArgType*>(&other.storage);
                    new (&storage) ArgType(*other_ptr);
                    return;
                }
                i++;
            }(), ... );
        }, TypeList{});

        return *this;
    }

    MiniVariant(MiniVariant&& other) noexcept {
        index = other.index;
        std::apply([&](auto... args){
            size_t i = 0;
            ( [&](){
                using ArgType = std::decay_t<decltype(args)>;
                if (i == index) {
                    ArgType* other_ptr = reinterpret_cast<ArgType*>(&other.storage);
                    new (&storage) ArgType(std::move(*other_ptr));
                    return;
                }
                i++;
            }(), ... );
        }, TypeList{});
        other.index = 0; // 将 other 的 index 设置为 0,表示不再持有任何值
    }

    MiniVariant& operator=(MiniVariant&& other) noexcept {
        if (this == &other) {
            return *this;
        }

        // 销毁当前对象
        std::apply([&](auto... args){
            size_t i = 0;
            ( [&](){
                using ArgType = std::decay_t<decltype(args)>;
                if (i == index) {
                    using DestructorType = void (*) (ArgType*);
                    DestructorType destructor = [](ArgType* ptr){ ptr->~ArgType(); };
                    destructor(reinterpret_cast<ArgType*>(&storage));
                    return;
                }
                i++;
            }(), ... );
        }, TypeList{});

        // 移动新对象
        index = other.index;
        std::apply([&](auto... args){
            size_t i = 0;
            ( [&](){
                using ArgType = std::decay_t<decltype(args)>;
                if (i == index) {
                    ArgType* other_ptr = reinterpret_cast<ArgType*>(&other.storage);
                    new (&storage) ArgType(std::move(*other_ptr));
                    return;
                }
                i++;
            }(), ... );
        }, TypeList{});
        other.index = 0; // 将 other 的 index 设置为 0,表示不再持有任何值

        return *this;
    }

    template <typename T>
    T& get() {
        using U = std::decay_t<T>;
        size_t i = 0;
        bool found = false;
        std::apply([&](auto... args){
            ( [&](){
                using ArgType = std::decay_t<decltype(args)>;
                if constexpr (std::is_same_v<U, ArgType>) {
                    if (i == index) {
                        found = true;
                        return;
                    } else {
                        throw std::bad_variant_access();
                    }
                }
                i++;
            }(), ... );
        }, TypeList{});
        if (!found) {
            throw std::bad_variant_access();
        }
        return *reinterpret_cast<T*>(&storage);
    }

    template <typename T>
    const T& get() const {
        using U = std::decay_t<T>;
        size_t i = 0;
        bool found = false;
        std::apply([&](auto... args){
            ( [&](){
                using ArgType = std::decay_t<decltype(args)>;
                if constexpr (std::is_same_v<U, ArgType>) {
                    if (i == index) {
                        found = true;
                        return;
                    } else {
                        throw std::bad_variant_access();
                    }
                }
                i++;
            }(), ... );
        }, TypeList{});
        if (!found) {
            throw std::bad_variant_access();
        }
        return *reinterpret_cast<const T*>(&storage);
    }

    std::size_t index_of() const {
        return index;
    }
};

int main() {
    MiniVariant<int, double, std::string> mv(10);
    std::cout << "Value: " << mv.get<int>() << std::endl;

    mv = 3.14;
    std::cout << "Value: " << mv.get<double>() << std::endl;

    mv = "hello";
    std::cout << "Value: " << mv.get<std::string>() << std::endl;

    try {
        std::cout << "Value: " << mv.get<int>() << std::endl; // 抛出异常
    } catch (const std::bad_variant_access& e) {
        std::cerr << "Error: " << e.what() << std::endl;
    }

    MiniVariant<int, double, std::string> mv2 = mv; // Copy constructor
    std::cout << "Value in mv2: " << mv2.get<std::string>() << std::endl;

    MiniVariant<int, double, std::string> mv3(std::move(mv)); // Move constructor
    std::cout << "Value in mv3: " << mv3.get<std::string>() << std::endl;
    //std::cout << "Value in mv: " << mv.get<std::string>() << std::endl; //  mv is now in valid but unspecified state

    return 0;
}

总结

通过自己动手实现一个微型的 std::variant,我们深入理解了类型安全联合体的原理。 关键点在于:

  • 使用 std::aligned_storage 分配原始内存,保证对齐。
  • 使用 index 记录当前存储的类型索引。
  • 使用 placement new 在原始内存上构造对象。
  • 显式调用析构函数释放资源。
  • 使用 std::enable_if_t 进行编译时类型检查。

当然,我们的 MiniVariant 只是一个简化版本,它没有实现 std::variant 的所有功能,比如 std::visitvalueless_by_exception 等。 但是,它足以帮助我们理解 std::variant 的核心思想。

更进一步:std::visit

std::visitstd::variant 中一个非常强大的功能,它允许你根据 variant 中存储的类型,调用不同的函数。 实现 std::visit 需要用到一些高级的模板元编程技巧,比如函数对象、重载、类型推导等。 这部分内容比较复杂,我们这里就不展开讲解了,感兴趣的同学可以自己研究一下。

希望今天的讲解能够帮助大家更好地理解 std::variant,并在实际开发中灵活运用它。 感谢大家的收听!

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注