好的,让我们开始这场 C++ std::variant
DIY 之旅!今天我们要一起打造一个属于我们自己的、类型安全的联合体,就像超级英雄DIY装备一样,想想就激动!
开场白:联合体的“前世今生”与 std::variant
的诞生
各位,还记得 C 时代的联合体 (union) 吗?它允许我们在同一块内存空间存储不同类型的数据,就像一个神奇的盒子,今天装苹果,明天装香蕉。但是,这盒子有个毛病:它不告诉你里面装的是啥,全靠你自己记住!一不小心就拿香蕉当苹果啃了,程序崩给你看。
union Data {
int i;
float f;
char str[20];
};
int main() {
Data data;
data.i = 10;
std::cout << data.f << std::endl; // 惨不忍睹的输出
return 0;
}
看到了吧?这就是类型不安全的痛苦。为了解决这个问题,C++17 引入了 std::variant
。它就像一个升级版的联合体,不仅能存储不同类型的数据,还能记住自己存储的是哪种类型,让你再也不用担心拿错东西了。
我们的目标:打造一个简化版的 MyVariant
今天,我们不追求完全复制 std::variant
的所有特性(那太复杂了,会累死的)。我们的目标是创建一个简化版的 MyVariant
,包含以下核心功能:
- 存储多种类型: 能够存储预先定义好的几种类型。
- 类型安全: 知道当前存储的是哪种类型,避免类型错误。
- 访问存储的值: 提供一种安全的方式来访问存储的值。
- 支持
index()
方法: 返回当前存储类型的索引。 - 支持
get<T>()
方法: 安全地获取特定类型的值。 - 支持
holds_alternative<T>()
方法: 判断是否存储了特定类型的值。 - 构造函数: 可以用 variant 支持的类型构造。
第一步:定义 MyVariant
的基本结构
首先,我们需要定义 MyVariant
类。为了简单起见,我们使用模板元编程来实现存储不同类型的功能。
template <typename... Types>
class MyVariant {
private:
std::aligned_storage_t<sizeof...(Types), alignof(Types...)> data_; // 用于存储数据的原始内存
std::size_t index_; // 记录当前存储类型的索引
public:
// 构造函数
MyVariant() : index_(0) {
static_assert(sizeof...(Types) > 0, "MyVariant must contain at least one type.");
using FirstType = std::tuple_element_t<0, std::tuple<Types...>>;
new (&data_) FirstType(); // 默认构造第一个类型
}
template <typename T, typename = std::enable_if_t<(... && std::is_constructible_v<T> && std::is_convertible_v<T, T>)>>
MyVariant(T&& value) : index_(find_index<T>()) {
static_assert((... || std::is_same_v<T, Types>), "Type not allowed in MyVariant");
new (&data_) std::decay_t<T>(std::forward<T>(value));
}
// 析构函数
~MyVariant() {
destroy_current();
}
// 赋值运算符
template <typename T, typename = std::enable_if_t<(... && std::is_constructible_v<T> && std::is_convertible_v<T, T>)>>
MyVariant& operator=(T&& value) {
static_assert((... || std::is_same_v<T, Types>), "Type not allowed in MyVariant");
if (index_ != find_index<T>()) {
destroy_current();
index_ = find_index<T>();
new (&data_) std::decay_t<T>(std::forward<T>(value));
} else {
using CurrentType = std::tuple_element_t<index_, std::tuple<Types...>>;
*reinterpret_cast<CurrentType*>(&data_) = std::forward<T>(value);
}
return *this;
}
// index() 方法:返回当前存储类型的索引
std::size_t index() const {
return index_;
}
private:
// 销毁当前存储的对象
void destroy_current() {
using CurrentType = std::tuple_element_t<index_, std::tuple<Types...>>;
reinterpret_cast<CurrentType*>(&data_)->~CurrentType();
}
// 查找类型在 Types... 中的索引
template <typename T, std::size_t N = 0>
constexpr std::size_t find_index() const {
if constexpr (N < sizeof...(Types)) {
if constexpr (std::is_same_v<T, std::tuple_element_t<N, std::tuple<Types...>>>) {
return N;
} else {
return find_index<T, N + 1>();
}
} else {
static_assert(false, "Type not found in MyVariant"); // 编译时错误
}
}
};
这段代码做了什么?
std::aligned_storage_t
:它分配一块足够大的、对齐的内存,可以存储Types...
中最大的类型。这块内存是原始的,未初始化的。index_
:记录当前存储的是Types...
中的哪种类型。- 构造函数:默认构造第一个类型。
find_index()
:一个递归的模板函数,用于查找类型T
在Types...
中的索引。如果找不到,编译时会报错。destroy_current()
:析构当前存储的对象。- 赋值运算符:如果赋值类型和当前类型不一样,销毁当前,然后构造新的。
第二步:实现 get<T>()
方法:安全地获取值
get<T>()
是 std::variant
中最重要的成员函数之一。它用于获取 MyVariant
中存储的特定类型的值。如果 MyVariant
存储的不是 T
类型,get<T>()
应该抛出一个异常,以保证类型安全。
template <typename T>
T& get() {
if (index_ != find_index<T>()) {
throw std::bad_variant_access();
}
return *reinterpret_cast<T*>(&data_);
}
template <typename T>
const T& get() const {
if (index_ != find_index<T>()) {
throw std::bad_variant_access();
}
return *reinterpret_cast<const T*>(&data_);
}
这里,我们使用了 reinterpret_cast
将 data_
的地址转换为 T*
,然后解引用返回。注意,如果类型不匹配,我们会抛出一个 std::bad_variant_access
异常。
第三步:实现 holds_alternative<T>()
方法:类型检查
holds_alternative<T>()
用于检查 MyVariant
当前是否存储了 T
类型的值。
template <typename T>
bool holds_alternative() const {
return index_ == find_index<T>();
}
这个方法非常简单,只需要比较 index_
和 T
的索引即可。
完整代码示例 (MyVariant.h):
#ifndef MYVARIANT_H
#define MYVARIANT_H
#include <iostream>
#include <type_traits>
#include <stdexcept>
#include <tuple>
template <typename... Types>
class MyVariant {
private:
std::aligned_storage_t<sizeof...(Types), alignof(Types...)> data_; // 用于存储数据的原始内存
std::size_t index_; // 记录当前存储类型的索引
public:
// 构造函数
MyVariant() : index_(0) {
static_assert(sizeof...(Types) > 0, "MyVariant must contain at least one type.");
using FirstType = std::tuple_element_t<0, std::tuple<Types...>>;
new (&data_) FirstType(); // 默认构造第一个类型
}
template <typename T, typename = std::enable_if_t<(... && std::is_constructible_v<T> && std::is_convertible_v<T, T>)>>
MyVariant(T&& value) : index_(find_index<T>()) {
static_assert((... || std::is_same_v<T, Types>), "Type not allowed in MyVariant");
new (&data_) std::decay_t<T>(std::forward<T>(value));
}
// 拷贝构造函数
MyVariant(const MyVariant& other) : index_(other.index_) {
switch (index_) {
// 使用宏展开来避免重复代码
#define CASE(idx, Type)
case idx:
new (&data_) Type(*reinterpret_cast<const Type*>(&other.data_));
break;
EXPAND_VARIANT_CASES(Types...)
#undef CASE
default:
throw std::runtime_error("Invalid variant index in copy constructor.");
}
}
// 赋值运算符
template <typename T, typename = std::enable_if_t<(... && std::is_constructible_v<T> && std::is_convertible_v<T, T>)>>
MyVariant& operator=(T&& value) {
static_assert((... || std::is_same_v<T, Types>), "Type not allowed in MyVariant");
if (index_ != find_index<T>()) {
destroy_current();
index_ = find_index<T>();
new (&data_) std::decay_t<T>(std::forward<T>(value));
} else {
using CurrentType = std::tuple_element_t<index_, std::tuple<Types...>>;
*reinterpret_cast<CurrentType*>(&data_) = std::forward<T>(value);
}
return *this;
}
// 拷贝赋值运算符
MyVariant& operator=(const MyVariant& other) {
if (this == &other) {
return *this; // 避免自我赋值
}
if (index_ != other.index_) {
destroy_current();
index_ = other.index_;
switch (index_) {
// 使用宏展开来避免重复代码
#define CASE(idx, Type)
case idx:
new (&data_) Type(*reinterpret_cast<const Type*>(&other.data_));
break;
EXPAND_VARIANT_CASES(Types...)
#undef CASE
default:
throw std::runtime_error("Invalid variant index in copy assignment.");
}
} else {
// 类型相同,直接赋值
using CurrentType = std::tuple_element_t<index_, std::tuple<Types...>>;
*reinterpret_cast<CurrentType*>(&data_) = *reinterpret_cast<const CurrentType*>(&other.data_);
}
return *this;
}
// 析构函数
~MyVariant() {
destroy_current();
}
// index() 方法:返回当前存储类型的索引
std::size_t index() const {
return index_;
}
// get() 方法:获取特定类型的值
template <typename T>
T& get() {
if (index_ != find_index<T>()) {
throw std::bad_variant_access();
}
return *reinterpret_cast<T*>(&data_);
}
template <typename T>
const T& get() const {
if (index_ != find_index<T>()) {
throw std::bad_variant_access();
}
return *reinterpret_cast<const T*>(&data_);
}
// holds_alternative() 方法:检查是否存储了特定类型
template <typename T>
bool holds_alternative() const {
return index_ == find_index<T>();
}
private:
// 销毁当前存储的对象
void destroy_current() {
using CurrentType = std::tuple_element_t<index_, std::tuple<Types...>>;
reinterpret_cast<CurrentType*>(&data_)->~CurrentType();
}
// 查找类型在 Types... 中的索引
template <typename T, std::size_t N = 0>
constexpr std::size_t find_index() const {
if constexpr (N < sizeof...(Types)) {
if constexpr (std::is_same_v<T, std::tuple_element_t<N, std::tuple<Types...>>>) {
return N;
} else {
return find_index<T, N + 1>();
}
} else {
static_assert(false, "Type not found in MyVariant"); // 编译时错误
}
}
// 宏来生成拷贝构造函数/拷贝赋值运算符需要的 switch case
#define EXPAND_VARIANT_CASES(...) EXPAND_VARIANT_CASES_IMPL(__VA_ARGS__)
#define EXPAND_VARIANT_CASES_IMPL(arg, ...) CASE_IMPL(0, arg) EXPAND_VARIANT_CASES_NEXT(__VA_ARGS__)
#define EXPAND_VARIANT_CASES_NEXT(...) EXPAND_VARIANT_CASES_NEXT_IMPL(__VA_ARGS__, END)
#define EXPAND_VARIANT_CASES_NEXT_IMPL(arg, ...) EXPAND_VARIANT_CASES_IMPL(__VA_ARGS__)
#define EXPAND_VARIANT_CASES_NEXT_IMPL(END)
#define CASE_IMPL(idx, Type) CASE(idx, Type)
};
#endif // MYVARIANT_H
测试代码 (main.cpp):
#include "MyVariant.h"
#include <iostream>
int main() {
MyVariant<int, float, std::string> var1; // 默认构造,存储 int (0)
std::cout << "var1 index: " << var1.index() << std::endl; // 输出 0
std::cout << "var1 holds int: " << var1.holds_alternative<int>() << std::endl; // 输出 1
var1 = 3.14f; // 赋值 float
std::cout << "var1 index: " << var1.index() << std::endl; // 输出 1
std::cout << "var1 holds float: " << var1.holds_alternative<float>() << std::endl; // 输出 1
std::cout << "var1 value: " << var1.get<float>() << std::endl; // 输出 3.14
MyVariant<int, float, std::string> var2 = "hello"; // 构造 string
std::cout << "var2 index: " << var2.index() << std::endl; // 输出 2
std::cout << "var2 holds string: " << var2.holds_alternative<std::string>() << std::endl; // 输出 1
std::cout << "var2 value: " << var2.get<std::string>() << std::endl; // 输出 hello
try {
std::cout << var1.get<int>() << std::endl; // 尝试获取 int,会抛出异常
} catch (const std::bad_variant_access& e) {
std::cerr << "Exception: " << e.what() << std::endl; // 输出 "bad variant access"
}
MyVariant<int, float, std::string> var3 = var2; // 拷贝构造
std::cout << "var3 value: " << var3.get<std::string>() << std::endl; // 输出 hello
var3 = var1; // 拷贝赋值
std::cout << "var3 value: " << var3.get<float>() << std::endl; // 输出 3.14
return 0;
}
总结与展望
我们成功地实现了一个简化版的 MyVariant
!它支持存储多种类型、类型安全访问、index()
、get<T>()
和 holds_alternative<T>()
方法。
当然,我们的 MyVariant
还有很多可以改进的地方,比如:
- 支持移动语义: 目前只支持拷贝构造和拷贝赋值,可以添加移动构造和移动赋值。
- 支持
emplace()
方法: 允许就地构造对象,避免拷贝。 - 支持访问者模式: 提供一种更灵活的方式来处理不同类型的值。
- 异常安全性: 确保在异常发生时,
MyVariant
的状态是正确的。 - constexpr支持: 增加编译期计算能力。
希望这次 DIY 之旅能让你对 std::variant
的内部实现有更深入的了解。记住,编程的乐趣在于不断探索和创造!继续加油,打造更强大的工具!