C++ 编写自己的 `std::variant`:实现类型安全的联合体

好的,让我们开始这场 C++ std::variant DIY 之旅!今天我们要一起打造一个属于我们自己的、类型安全的联合体,就像超级英雄DIY装备一样,想想就激动!

开场白:联合体的“前世今生”与 std::variant 的诞生

各位,还记得 C 时代的联合体 (union) 吗?它允许我们在同一块内存空间存储不同类型的数据,就像一个神奇的盒子,今天装苹果,明天装香蕉。但是,这盒子有个毛病:它不告诉你里面装的是啥,全靠你自己记住!一不小心就拿香蕉当苹果啃了,程序崩给你看。

union Data {
    int i;
    float f;
    char str[20];
};

int main() {
    Data data;
    data.i = 10;
    std::cout << data.f << std::endl; // 惨不忍睹的输出
    return 0;
}

看到了吧?这就是类型不安全的痛苦。为了解决这个问题,C++17 引入了 std::variant。它就像一个升级版的联合体,不仅能存储不同类型的数据,还能记住自己存储的是哪种类型,让你再也不用担心拿错东西了。

我们的目标:打造一个简化版的 MyVariant

今天,我们不追求完全复制 std::variant 的所有特性(那太复杂了,会累死的)。我们的目标是创建一个简化版的 MyVariant,包含以下核心功能:

  • 存储多种类型: 能够存储预先定义好的几种类型。
  • 类型安全: 知道当前存储的是哪种类型,避免类型错误。
  • 访问存储的值: 提供一种安全的方式来访问存储的值。
  • 支持 index() 方法: 返回当前存储类型的索引。
  • 支持 get<T>() 方法: 安全地获取特定类型的值。
  • 支持 holds_alternative<T>() 方法: 判断是否存储了特定类型的值。
  • 构造函数: 可以用 variant 支持的类型构造。

第一步:定义 MyVariant 的基本结构

首先,我们需要定义 MyVariant 类。为了简单起见,我们使用模板元编程来实现存储不同类型的功能。

template <typename... Types>
class MyVariant {
private:
    std::aligned_storage_t<sizeof...(Types), alignof(Types...)> data_; // 用于存储数据的原始内存
    std::size_t index_; // 记录当前存储类型的索引

public:
    // 构造函数
    MyVariant() : index_(0) {
        static_assert(sizeof...(Types) > 0, "MyVariant must contain at least one type.");
        using FirstType = std::tuple_element_t<0, std::tuple<Types...>>;
        new (&data_) FirstType(); // 默认构造第一个类型
    }

    template <typename T, typename = std::enable_if_t<(... && std::is_constructible_v<T> && std::is_convertible_v<T, T>)>>
    MyVariant(T&& value) : index_(find_index<T>()) {
        static_assert((... || std::is_same_v<T, Types>), "Type not allowed in MyVariant");
        new (&data_) std::decay_t<T>(std::forward<T>(value));
    }

    // 析构函数
    ~MyVariant() {
        destroy_current();
    }

    // 赋值运算符
    template <typename T, typename = std::enable_if_t<(... && std::is_constructible_v<T> && std::is_convertible_v<T, T>)>>
    MyVariant& operator=(T&& value) {
        static_assert((... || std::is_same_v<T, Types>), "Type not allowed in MyVariant");
        if (index_ != find_index<T>()) {
            destroy_current();
            index_ = find_index<T>();
            new (&data_) std::decay_t<T>(std::forward<T>(value));
        } else {
            using CurrentType = std::tuple_element_t<index_, std::tuple<Types...>>;
            *reinterpret_cast<CurrentType*>(&data_) = std::forward<T>(value);
        }
        return *this;
    }

    // index() 方法:返回当前存储类型的索引
    std::size_t index() const {
        return index_;
    }

private:
    // 销毁当前存储的对象
    void destroy_current() {
        using CurrentType = std::tuple_element_t<index_, std::tuple<Types...>>;
        reinterpret_cast<CurrentType*>(&data_)->~CurrentType();
    }

    // 查找类型在 Types... 中的索引
    template <typename T, std::size_t N = 0>
    constexpr std::size_t find_index() const {
        if constexpr (N < sizeof...(Types)) {
            if constexpr (std::is_same_v<T, std::tuple_element_t<N, std::tuple<Types...>>>) {
                return N;
            } else {
                return find_index<T, N + 1>();
            }
        } else {
            static_assert(false, "Type not found in MyVariant"); // 编译时错误
        }
    }
};

这段代码做了什么?

  • std::aligned_storage_t:它分配一块足够大的、对齐的内存,可以存储 Types... 中最大的类型。这块内存是原始的,未初始化的。
  • index_:记录当前存储的是 Types... 中的哪种类型。
  • 构造函数:默认构造第一个类型。
  • find_index():一个递归的模板函数,用于查找类型 TTypes... 中的索引。如果找不到,编译时会报错。
  • destroy_current():析构当前存储的对象。
  • 赋值运算符:如果赋值类型和当前类型不一样,销毁当前,然后构造新的。

第二步:实现 get<T>() 方法:安全地获取值

get<T>()std::variant 中最重要的成员函数之一。它用于获取 MyVariant 中存储的特定类型的值。如果 MyVariant 存储的不是 T 类型,get<T>() 应该抛出一个异常,以保证类型安全。

template <typename T>
T& get() {
    if (index_ != find_index<T>()) {
        throw std::bad_variant_access();
    }
    return *reinterpret_cast<T*>(&data_);
}

template <typename T>
const T& get() const {
    if (index_ != find_index<T>()) {
        throw std::bad_variant_access();
    }
    return *reinterpret_cast<const T*>(&data_);
}

这里,我们使用了 reinterpret_castdata_ 的地址转换为 T*,然后解引用返回。注意,如果类型不匹配,我们会抛出一个 std::bad_variant_access 异常。

第三步:实现 holds_alternative<T>() 方法:类型检查

holds_alternative<T>() 用于检查 MyVariant 当前是否存储了 T 类型的值。

template <typename T>
bool holds_alternative() const {
    return index_ == find_index<T>();
}

这个方法非常简单,只需要比较 index_T 的索引即可。

完整代码示例 (MyVariant.h):

#ifndef MYVARIANT_H
#define MYVARIANT_H

#include <iostream>
#include <type_traits>
#include <stdexcept>
#include <tuple>

template <typename... Types>
class MyVariant {
private:
    std::aligned_storage_t<sizeof...(Types), alignof(Types...)> data_; // 用于存储数据的原始内存
    std::size_t index_; // 记录当前存储类型的索引

public:
    // 构造函数
    MyVariant() : index_(0) {
        static_assert(sizeof...(Types) > 0, "MyVariant must contain at least one type.");
        using FirstType = std::tuple_element_t<0, std::tuple<Types...>>;
        new (&data_) FirstType(); // 默认构造第一个类型
    }

    template <typename T, typename = std::enable_if_t<(... && std::is_constructible_v<T> && std::is_convertible_v<T, T>)>>
    MyVariant(T&& value) : index_(find_index<T>()) {
        static_assert((... || std::is_same_v<T, Types>), "Type not allowed in MyVariant");
        new (&data_) std::decay_t<T>(std::forward<T>(value));
    }

    // 拷贝构造函数
    MyVariant(const MyVariant& other) : index_(other.index_) {
        switch (index_) {
            // 使用宏展开来避免重复代码
            #define CASE(idx, Type)                                               
                case idx:                                                        
                    new (&data_) Type(*reinterpret_cast<const Type*>(&other.data_)); 
                    break;

            EXPAND_VARIANT_CASES(Types...)

            #undef CASE
            default:
                throw std::runtime_error("Invalid variant index in copy constructor.");
        }
    }

    // 赋值运算符
    template <typename T, typename = std::enable_if_t<(... && std::is_constructible_v<T> && std::is_convertible_v<T, T>)>>
    MyVariant& operator=(T&& value) {
        static_assert((... || std::is_same_v<T, Types>), "Type not allowed in MyVariant");
        if (index_ != find_index<T>()) {
            destroy_current();
            index_ = find_index<T>();
            new (&data_) std::decay_t<T>(std::forward<T>(value));
        } else {
            using CurrentType = std::tuple_element_t<index_, std::tuple<Types...>>;
            *reinterpret_cast<CurrentType*>(&data_) = std::forward<T>(value);
        }
        return *this;
    }

    // 拷贝赋值运算符
    MyVariant& operator=(const MyVariant& other) {
        if (this == &other) {
            return *this; // 避免自我赋值
        }

        if (index_ != other.index_) {
            destroy_current();
            index_ = other.index_;

            switch (index_) {
                // 使用宏展开来避免重复代码
                #define CASE(idx, Type)                                               
                    case idx:                                                        
                        new (&data_) Type(*reinterpret_cast<const Type*>(&other.data_)); 
                        break;

                EXPAND_VARIANT_CASES(Types...)

                #undef CASE

                default:
                    throw std::runtime_error("Invalid variant index in copy assignment.");
            }
        } else {
            // 类型相同,直接赋值
            using CurrentType = std::tuple_element_t<index_, std::tuple<Types...>>;
            *reinterpret_cast<CurrentType*>(&data_) = *reinterpret_cast<const CurrentType*>(&other.data_);
        }
        return *this;
    }

    // 析构函数
    ~MyVariant() {
        destroy_current();
    }

    // index() 方法:返回当前存储类型的索引
    std::size_t index() const {
        return index_;
    }

    // get() 方法:获取特定类型的值
    template <typename T>
    T& get() {
        if (index_ != find_index<T>()) {
            throw std::bad_variant_access();
        }
        return *reinterpret_cast<T*>(&data_);
    }

    template <typename T>
    const T& get() const {
        if (index_ != find_index<T>()) {
            throw std::bad_variant_access();
        }
        return *reinterpret_cast<const T*>(&data_);
    }

    // holds_alternative() 方法:检查是否存储了特定类型
    template <typename T>
    bool holds_alternative() const {
        return index_ == find_index<T>();
    }

private:
    // 销毁当前存储的对象
    void destroy_current() {
        using CurrentType = std::tuple_element_t<index_, std::tuple<Types...>>;
        reinterpret_cast<CurrentType*>(&data_)->~CurrentType();
    }

    // 查找类型在 Types... 中的索引
    template <typename T, std::size_t N = 0>
    constexpr std::size_t find_index() const {
        if constexpr (N < sizeof...(Types)) {
            if constexpr (std::is_same_v<T, std::tuple_element_t<N, std::tuple<Types...>>>) {
                return N;
            } else {
                return find_index<T, N + 1>();
            }
        } else {
            static_assert(false, "Type not found in MyVariant"); // 编译时错误
        }
    }

    // 宏来生成拷贝构造函数/拷贝赋值运算符需要的 switch case
    #define EXPAND_VARIANT_CASES(...) EXPAND_VARIANT_CASES_IMPL(__VA_ARGS__)
    #define EXPAND_VARIANT_CASES_IMPL(arg, ...) CASE_IMPL(0, arg) EXPAND_VARIANT_CASES_NEXT(__VA_ARGS__)
    #define EXPAND_VARIANT_CASES_NEXT(...) EXPAND_VARIANT_CASES_NEXT_IMPL(__VA_ARGS__, END)
    #define EXPAND_VARIANT_CASES_NEXT_IMPL(arg, ...) EXPAND_VARIANT_CASES_IMPL(__VA_ARGS__)
    #define EXPAND_VARIANT_CASES_NEXT_IMPL(END)
    #define CASE_IMPL(idx, Type) CASE(idx, Type)

};

#endif // MYVARIANT_H

测试代码 (main.cpp):

#include "MyVariant.h"
#include <iostream>

int main() {
    MyVariant<int, float, std::string> var1; // 默认构造,存储 int (0)
    std::cout << "var1 index: " << var1.index() << std::endl; // 输出 0
    std::cout << "var1 holds int: " << var1.holds_alternative<int>() << std::endl; // 输出 1

    var1 = 3.14f; // 赋值 float
    std::cout << "var1 index: " << var1.index() << std::endl; // 输出 1
    std::cout << "var1 holds float: " << var1.holds_alternative<float>() << std::endl; // 输出 1
    std::cout << "var1 value: " << var1.get<float>() << std::endl; // 输出 3.14

    MyVariant<int, float, std::string> var2 = "hello"; // 构造 string
    std::cout << "var2 index: " << var2.index() << std::endl; // 输出 2
    std::cout << "var2 holds string: " << var2.holds_alternative<std::string>() << std::endl; // 输出 1
    std::cout << "var2 value: " << var2.get<std::string>() << std::endl; // 输出 hello

    try {
        std::cout << var1.get<int>() << std::endl; // 尝试获取 int,会抛出异常
    } catch (const std::bad_variant_access& e) {
        std::cerr << "Exception: " << e.what() << std::endl; // 输出 "bad variant access"
    }

    MyVariant<int, float, std::string> var3 = var2; // 拷贝构造
    std::cout << "var3 value: " << var3.get<std::string>() << std::endl; // 输出 hello

    var3 = var1; // 拷贝赋值
    std::cout << "var3 value: " << var3.get<float>() << std::endl; // 输出 3.14

    return 0;
}

总结与展望

我们成功地实现了一个简化版的 MyVariant!它支持存储多种类型、类型安全访问、index()get<T>()holds_alternative<T>() 方法。

当然,我们的 MyVariant 还有很多可以改进的地方,比如:

  • 支持移动语义: 目前只支持拷贝构造和拷贝赋值,可以添加移动构造和移动赋值。
  • 支持 emplace() 方法: 允许就地构造对象,避免拷贝。
  • 支持访问者模式: 提供一种更灵活的方式来处理不同类型的值。
  • 异常安全性: 确保在异常发生时,MyVariant 的状态是正确的。
  • constexpr支持: 增加编译期计算能力。

希望这次 DIY 之旅能让你对 std::variant 的内部实现有更深入的了解。记住,编程的乐趣在于不断探索和创造!继续加油,打造更强大的工具!

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注