From ca1ea7cff8c0415fad5fbb10f7281bd961b01fb2 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Wed, 2 Oct 2024 12:32:58 -0700 Subject: [PATCH] PR #16585: Add support for float8_e4m3 and float8_e3m4 types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Imported from GitHub PR https://github.com/openxla/xla/pull/16585 This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler). ### `f8E4M3` type follows IEEE 754 convention. ```c f8E4M3 (IEEE 754) - Exponent bias: 7 - Maximum stored exponent value: 14 (binary 1110) - Maximum unbiased exponent value: 14 - 7 = 7 - Minimum stored exponent value: 1 (binary 0001) - Minimum unbiased exponent value: 1 − 7 = −6 - Precision specifies the total number of bits used for the significand (mantisa), including implicit leading integer bit = 3 + 1 = 4 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 7 - Min exp (unbiased): -6 - Infinities (+/-): S.1111.000 - Zeros (+/-): S.0000.000 - NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111} - Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240 - Min normal number: S.0001.000 = +/-2^(-6) - Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7 - Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9) ``` ### `f8E3M4` type follows IEEE 754 convention ```c f8E3M4 (IEEE 754) - Exponent bias: 3 - Maximum stored exponent value: 6 (binary 110) - Maximum unbiased exponent value: 6 - 3 = 3 - Minimum stored exponent value: 1 (binary 001) - Minimum unbiased exponent value: 1 − 3 = −2 - Precision specifies the total number of bits used for the significand (mantissa), including implicit leading integer bit = 4 + 1 = 5 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 3 - Min exp (unbiased): -2 - Infinities (+/-): S.111.0000 - Zeros (+/-): S.000.0000 - NaNs: S.111.{0,1}⁴ except S.111.0000 - Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5 - Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2) - Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6) - Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-2) x 2^(-4) = +/-2^(-6) ``` ### Testing: ``` bazel test \ //xla:array2d_test \ //xla:fp_util_test \ //xla:literal_comparison_test \ //xla:literal_test \ //xla/mlir/utils:type_util_test \ //xla:primitive_util_test \ //xla/python/ifrt:dtype_test \ //xla/python:xla_client_test \ //xla/service:elemental_ir_emitter_test \ //xla/service:float_normalization_test \ //xla/service/gpu/tests:float_conversions_test \ //xla/tests:array_elementwise_ops_test \ //xla/tests:constants_test \ //xla/tests:convert_test \ //xla/tests:float8_test \ //xla:util_test bazel test \ //xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \ //xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test ``` ### Related PRs: - LLVM [PR-97179](https://github.com/llvm/llvm-project/pull/97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged) - LLVM [PR-97118](https://github.com/llvm/llvm-project/pull/97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged) - LLVM [PR-99698](https://github.com/llvm/llvm-project/pull/99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged) - LLVM [PR-101230](https://github.com/llvm/llvm-project/pull/101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged) - StableHLO [PR-2486](https://github.com/openxla/stablehlo/pull/2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged) - StableHLO [PR-2482](https://github.com/openxla/stablehlo/pull/2482) Add f8E4M3 and f8E3M4 types support (Merged) - ml_dtypes [PR-161](https://github.com/jax-ml/ml_dtypes/pull/161) Add float8_e4m3 (Merged) - ml_dtypes [PR-171](https://github.com/jax-ml/ml_dtypes/pull/171/) Add float8_e3m4 (Merged) - XLA [PR-17075](https://github.com/openxla/xla/pull/17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved) - XLA [PR-3200](https://github.com/openxla/xla/pull/3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template) - JAX [PR-23585](https://github.com/google/jax/pull/23585) Add float8_e4m3 type support (in Review) Copybara import of the project: -- ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov : Add support for float8_e4m3 and float8_e3m4 types Merging this change closes #16585 PiperOrigin-RevId: 681551979 --- tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir | 14 ++++++++++++++ tests/Dialect/mhlo/ops.mlir | 14 ++++++++++++++ tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir | 14 ++++++++++++++ 3 files changed, 42 insertions(+) diff --git a/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index 8baa3e0d3..59618001c 100644 --- a/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -1805,6 +1805,20 @@ func.func @type_ui64(%arg0: tensor, %arg1: tensor) -> tensor { func.return %0 : tensor } +// CHECK-LABEL: "type_f8E3M4" +func.func @type_f8E3M4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "stablehlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor + %0 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3" +func.func @type_f8E4M3(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "stablehlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor + %0 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + // CHECK-LABEL: "type_f8E4M3FN" func.func @type_f8E4M3FN(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: "stablehlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor diff --git a/tests/Dialect/mhlo/ops.mlir b/tests/Dialect/mhlo/ops.mlir index 65594f55f..03b6a21e0 100644 --- a/tests/Dialect/mhlo/ops.mlir +++ b/tests/Dialect/mhlo/ops.mlir @@ -6832,6 +6832,20 @@ func.func @invalid_dimension_attr(%arg0: tensor) -> tensor { + %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +func.func @f8e4m3(%arg0: tensor) -> tensor { + %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + func.func @f8e4m3fn(%arg0: tensor) -> tensor { %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor func.return %0 : tensor diff --git a/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir index 0f2e1b108..66c388b9e 100644 --- a/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir +++ b/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir @@ -1787,6 +1787,20 @@ func.func @type_ui64(%arg0: tensor, %arg1: tensor) -> tensor { func.return %0 : tensor } +// CHECK-LABEL: "type_f8E3M4" +func.func @type_f8E3M4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "mhlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3" +func.func @type_f8E4M3(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "mhlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + // CHECK-LABEL: "type_f8E4M3FN" func.func @type_f8E4M3FN(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: "mhlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor