Skip to content

Commit

Permalink
Add ScalarType 22 BITS16 support in etdump gen and deserialization
Browse files Browse the repository at this point in the history
Differential Revision: D64812253

Pull Request resolved: pytorch#6504
  • Loading branch information
Olivia-liu authored Oct 29, 2024
1 parent 6b01b91 commit 41a57e6
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 2 deletions.
5 changes: 5 additions & 0 deletions devtools/bundled_program/schema/scalar_type.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@ enum ScalarType : byte {
QINT32 = 14,
QUINT4X2 = 16,
QUINT2X4 = 17,
BITS16 = 22,
// Types currently not implemented.
// COMPLEXHALF = 8,
// COMPLEXFLOAT = 9,
// COMPLEXDOUBLE = 10,
// BFLOAT16 = 15,
// BITS1x8 = 18,
// BITS2x4 = 19,
// BITS4x2 = 20,
// BITS8 = 21,
}
2 changes: 2 additions & 0 deletions devtools/etdump/etdump_flatcc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ executorch_flatbuffer_ScalarType_enum_t get_flatbuffer_scalar_type(
return executorch_flatbuffer_ScalarType_DOUBLE;
case exec_aten::ScalarType::Bool:
return executorch_flatbuffer_ScalarType_BOOL;
case exec_aten::ScalarType::Bits16:
return executorch_flatbuffer_ScalarType_BITS16;
default:
ET_CHECK_MSG(
0,
Expand Down
5 changes: 5 additions & 0 deletions devtools/etdump/scalar_type.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@ enum ScalarType : byte {
QINT32 = 14,
QUINT4X2 = 16,
QUINT2X4 = 17,
BITS16 = 22,
// Types currently not implemented.
// COMPLEXHALF = 8,
// COMPLEXFLOAT = 9,
// COMPLEXDOUBLE = 10,
// BFLOAT16 = 15,
// BITS1x8 = 18,
// BITS2x4 = 19,
// BITS4x2 = 20,
// BITS8 = 21,
}
4 changes: 3 additions & 1 deletion exir/scalar_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from enum import IntEnum


Expand All @@ -26,4 +28,4 @@ class ScalarType(IntEnum):
BFLOAT16 = 15
QUINT4x2 = 16
QUINT2x4 = 17
Bits16 = 22
BITS16 = 22
2 changes: 1 addition & 1 deletion exir/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def memory_format_enum(memory_format: torch.memory_format) -> int:
torch.qint32: ScalarType.QINT32,
torch.bfloat16: ScalarType.BFLOAT16,
torch.quint4x2: ScalarType.QUINT4x2,
torch.uint16: ScalarType.Bits16,
torch.uint16: ScalarType.BITS16,
}


Expand Down
5 changes: 5 additions & 0 deletions schema/scalar_type.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@ enum ScalarType : byte {
QINT32 = 14,
QUINT4X2 = 16,
QUINT2X4 = 17,
BITS16 = 22,
// Types currently not implemented.
// COMPLEXHALF = 8,
// COMPLEXFLOAT = 9,
// COMPLEXDOUBLE = 10,
// BFLOAT16 = 15,
// BITS1x8 = 18,
// BITS2x4 = 19,
// BITS4x2 = 20,
// BITS8 = 21,
}

0 comments on commit 41a57e6

Please sign in to comment.