From 15b1f39f2383d81cf7d70cea613b4e4acd28b720 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Tue, 12 Nov 2024 15:29:27 -0800 Subject: [PATCH] Use std::variant to implement pytree Key (#6792) Pull Request resolved: https://github.com/pytorch/executorch/pull/6701 Key was a struct that should've been a union; std::variant makes using a union much easier. ghstack-source-id: 253128071 @exported-using-ghexport Differential Revision: [D65575184](https://our.internmc.facebook.com/intern/diff/D65575184/) Co-authored-by: Scott Wolchok --- extension/pytree/pytree.h | 42 +++++++++++++-------------------------- 1 file changed, 14 insertions(+), 28 deletions(-) diff --git a/extension/pytree/pytree.h b/extension/pytree/pytree.h index 78e2305fe3..9c7202fb86 100644 --- a/extension/pytree/pytree.h +++ b/extension/pytree/pytree.h @@ -15,6 +15,7 @@ #include #include #include +#include // NB: This is a local, pytree FunctionRef and not from the ExecuTorch runtime. #include @@ -55,29 +56,28 @@ using KeyInt = int32_t; struct Key { enum class Kind : uint8_t { None, Int, Str } kind_; - KeyInt as_int_ = {}; - KeyStr as_str_ = {}; + private: + std::variant repr_; - Key() : kind_(Kind::None) {} - /*implicit*/ Key(KeyInt key) : kind_(Kind::Int), as_int_(std::move(key)) {} - /*implicit*/ Key(KeyStr key) : kind_(Kind::Str), as_str_(std::move(key)) {} + public: + Key() {} + /*implicit*/ Key(KeyInt key) : repr_(key) {} + /*implicit*/ Key(KeyStr key) : repr_(std::move(key)) {} - const Kind& kind() const { - return kind_; + Kind kind() const { + return static_cast(repr_.index()); } - const KeyInt& as_int() const { - pytree_assert(kind_ == Key::Kind::Int); - return as_int_; + KeyInt as_int() const { + return std::get(repr_); } - operator const KeyInt&() const { + operator KeyInt() const { return as_int(); } const KeyStr& as_str() const { - pytree_assert(kind_ == Key::Kind::Str); - return as_str_; + return std::get(repr_); } operator const KeyStr&() const { @@ -85,21 +85,7 @@ struct Key { } bool operator==(const Key& rhs) const { - if (kind_ != rhs.kind_) { - return false; - } - switch (kind_) { - case Kind::Str: { - return as_str_ == rhs.as_str_; - } - case Kind::Int: { - return as_int_ == rhs.as_int_; - } - case Kind::None: { - return true; - } - } - pytree_unreachable(); + return repr_ == rhs.repr_; } bool operator!=(const Key& rhs) const {