From 41a57e6b0ad05d63d84448747251bc0559d6a826 Mon Sep 17 00:00:00 2001 From: Peixuan Liu Date: Mon, 28 Oct 2024 21:53:04 -0700 Subject: [PATCH] Add ScalarType 22 `BITS16` support in etdump gen and deserialization Differential Revision: D64812253 Pull Request resolved: https://github.com/pytorch/executorch/pull/6504 --- devtools/bundled_program/schema/scalar_type.fbs | 5 +++++ devtools/etdump/etdump_flatcc.cpp | 2 ++ devtools/etdump/scalar_type.fbs | 5 +++++ exir/scalar_type.py | 4 +++- exir/tensor.py | 2 +- schema/scalar_type.fbs | 5 +++++ 6 files changed, 21 insertions(+), 2 deletions(-) diff --git a/devtools/bundled_program/schema/scalar_type.fbs b/devtools/bundled_program/schema/scalar_type.fbs index a8da080c67..fc299ac691 100644 --- a/devtools/bundled_program/schema/scalar_type.fbs +++ b/devtools/bundled_program/schema/scalar_type.fbs @@ -24,9 +24,14 @@ enum ScalarType : byte { QINT32 = 14, QUINT4X2 = 16, QUINT2X4 = 17, + BITS16 = 22, // Types currently not implemented. // COMPLEXHALF = 8, // COMPLEXFLOAT = 9, // COMPLEXDOUBLE = 10, // BFLOAT16 = 15, + // BITS1x8 = 18, + // BITS2x4 = 19, + // BITS4x2 = 20, + // BITS8 = 21, } diff --git a/devtools/etdump/etdump_flatcc.cpp b/devtools/etdump/etdump_flatcc.cpp index 4c05bb5ace..cfd1d2ae14 100644 --- a/devtools/etdump/etdump_flatcc.cpp +++ b/devtools/etdump/etdump_flatcc.cpp @@ -55,6 +55,8 @@ executorch_flatbuffer_ScalarType_enum_t get_flatbuffer_scalar_type( return executorch_flatbuffer_ScalarType_DOUBLE; case exec_aten::ScalarType::Bool: return executorch_flatbuffer_ScalarType_BOOL; + case exec_aten::ScalarType::Bits16: + return executorch_flatbuffer_ScalarType_BITS16; default: ET_CHECK_MSG( 0, diff --git a/devtools/etdump/scalar_type.fbs b/devtools/etdump/scalar_type.fbs index a8da080c67..fc299ac691 100644 --- a/devtools/etdump/scalar_type.fbs +++ b/devtools/etdump/scalar_type.fbs @@ -24,9 +24,14 @@ enum ScalarType : byte { QINT32 = 14, QUINT4X2 = 16, QUINT2X4 = 17, + BITS16 = 22, // Types currently not implemented. // COMPLEXHALF = 8, // COMPLEXFLOAT = 9, // COMPLEXDOUBLE = 10, // BFLOAT16 = 15, + // BITS1x8 = 18, + // BITS2x4 = 19, + // BITS4x2 = 20, + // BITS8 = 21, } diff --git a/exir/scalar_type.py b/exir/scalar_type.py index b789a09f3a..5d41038610 100644 --- a/exir/scalar_type.py +++ b/exir/scalar_type.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + from enum import IntEnum @@ -26,4 +28,4 @@ class ScalarType(IntEnum): BFLOAT16 = 15 QUINT4x2 = 16 QUINT2x4 = 17 - Bits16 = 22 + BITS16 = 22 diff --git a/exir/tensor.py b/exir/tensor.py index d63ed5d262..a40bef4e5e 100644 --- a/exir/tensor.py +++ b/exir/tensor.py @@ -262,7 +262,7 @@ def memory_format_enum(memory_format: torch.memory_format) -> int: torch.qint32: ScalarType.QINT32, torch.bfloat16: ScalarType.BFLOAT16, torch.quint4x2: ScalarType.QUINT4x2, - torch.uint16: ScalarType.Bits16, + torch.uint16: ScalarType.BITS16, } diff --git a/schema/scalar_type.fbs b/schema/scalar_type.fbs index a8da080c67..fc299ac691 100644 --- a/schema/scalar_type.fbs +++ b/schema/scalar_type.fbs @@ -24,9 +24,14 @@ enum ScalarType : byte { QINT32 = 14, QUINT4X2 = 16, QUINT2X4 = 17, + BITS16 = 22, // Types currently not implemented. // COMPLEXHALF = 8, // COMPLEXFLOAT = 9, // COMPLEXDOUBLE = 10, // BFLOAT16 = 15, + // BITS1x8 = 18, + // BITS2x4 = 19, + // BITS4x2 = 20, + // BITS8 = 21, }