diff --git a/extension/pytree/pytree.h b/extension/pytree/pytree.h index 78e2305fe3..9c7202fb86 100644 --- a/extension/pytree/pytree.h +++ b/extension/pytree/pytree.h @@ -15,6 +15,7 @@ #include #include #include +#include // NB: This is a local, pytree FunctionRef and not from the ExecuTorch runtime. #include @@ -55,29 +56,28 @@ using KeyInt = int32_t; struct Key { enum class Kind : uint8_t { None, Int, Str } kind_; - KeyInt as_int_ = {}; - KeyStr as_str_ = {}; + private: + std::variant repr_; - Key() : kind_(Kind::None) {} - /*implicit*/ Key(KeyInt key) : kind_(Kind::Int), as_int_(std::move(key)) {} - /*implicit*/ Key(KeyStr key) : kind_(Kind::Str), as_str_(std::move(key)) {} + public: + Key() {} + /*implicit*/ Key(KeyInt key) : repr_(key) {} + /*implicit*/ Key(KeyStr key) : repr_(std::move(key)) {} - const Kind& kind() const { - return kind_; + Kind kind() const { + return static_cast(repr_.index()); } - const KeyInt& as_int() const { - pytree_assert(kind_ == Key::Kind::Int); - return as_int_; + KeyInt as_int() const { + return std::get(repr_); } - operator const KeyInt&() const { + operator KeyInt() const { return as_int(); } const KeyStr& as_str() const { - pytree_assert(kind_ == Key::Kind::Str); - return as_str_; + return std::get(repr_); } operator const KeyStr&() const { @@ -85,21 +85,7 @@ struct Key { } bool operator==(const Key& rhs) const { - if (kind_ != rhs.kind_) { - return false; - } - switch (kind_) { - case Kind::Str: { - return as_str_ == rhs.as_str_; - } - case Kind::Int: { - return as_int_ == rhs.as_int_; - } - case Kind::None: { - return true; - } - } - pytree_unreachable(); + return repr_ == rhs.repr_; } bool operator!=(const Key& rhs) const {