Skip to content

Commit

Permalink
PR #16585: Add support for float8_e4m3 and float8_e3m4 types
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

PiperOrigin-RevId: 681551979
  • Loading branch information
apivovarov authored and TensorFlow MLIR Team committed Oct 2, 2024
1 parent 9f778be commit ca1ea7c
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1805,6 +1805,20 @@ func.func @type_ui64(%arg0: tensor<ui64>, %arg1: tensor<ui64>) -> tensor<ui64> {
func.return %0 : tensor<ui64>
}

// CHECK-LABEL: "type_f8E3M4"
func.func @type_f8E3M4(%arg0: tensor<f8E3M4>, %arg1: tensor<f8E3M4>) -> tensor<f8E3M4> {
// CHECK: "stablehlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor<f8E3M4>, tensor<f8E3M4>) -> tensor<f8E3M4>
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<f8E3M4>, tensor<f8E3M4>) -> tensor<f8E3M4>
func.return %0 : tensor<f8E3M4>
}

// CHECK-LABEL: "type_f8E4M3"
func.func @type_f8E4M3(%arg0: tensor<f8E4M3>, %arg1: tensor<f8E4M3>) -> tensor<f8E4M3> {
// CHECK: "stablehlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor<f8E4M3>, tensor<f8E4M3>) -> tensor<f8E4M3>
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<f8E4M3>, tensor<f8E4M3>) -> tensor<f8E4M3>
func.return %0 : tensor<f8E4M3>
}

// CHECK-LABEL: "type_f8E4M3FN"
func.func @type_f8E4M3FN(%arg0: tensor<f8E4M3FN>, %arg1: tensor<f8E4M3FN>) -> tensor<f8E4M3FN> {
// CHECK: "stablehlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor<f8E4M3FN>, tensor<f8E4M3FN>) -> tensor<f8E4M3FN>
Expand Down
14 changes: 14 additions & 0 deletions tests/Dialect/mhlo/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6832,6 +6832,20 @@ func.func @invalid_dimension_attr(%arg0: tensor<?x?xf32, #mhlo.type_extensions<b

// -----

func.func @f8e3m4(%arg0: tensor<f16>) -> tensor<f8E3M4> {
%0 = "mhlo.convert"(%arg0) : (tensor<f16>) -> tensor<f8E3M4>
func.return %0 : tensor<f8E3M4>
}

// -----

func.func @f8e4m3(%arg0: tensor<f16>) -> tensor<f8E4M3> {
%0 = "mhlo.convert"(%arg0) : (tensor<f16>) -> tensor<f8E4M3>
func.return %0 : tensor<f8E4M3>
}

// -----

func.func @f8e4m3fn(%arg0: tensor<f16>) -> tensor<f8E4M3FN> {
%0 = "mhlo.convert"(%arg0) : (tensor<f16>) -> tensor<f8E4M3FN>
func.return %0 : tensor<f8E4M3FN>
Expand Down
14 changes: 14 additions & 0 deletions tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1787,6 +1787,20 @@ func.func @type_ui64(%arg0: tensor<ui64>, %arg1: tensor<ui64>) -> tensor<ui64> {
func.return %0 : tensor<ui64>
}

// CHECK-LABEL: "type_f8E3M4"
func.func @type_f8E3M4(%arg0: tensor<f8E3M4>, %arg1: tensor<f8E3M4>) -> tensor<f8E3M4> {
// CHECK: "mhlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor<f8E3M4>, tensor<f8E3M4>) -> tensor<f8E3M4>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<f8E3M4>, tensor<f8E3M4>) -> tensor<f8E3M4>
func.return %0 : tensor<f8E3M4>
}

// CHECK-LABEL: "type_f8E4M3"
func.func @type_f8E4M3(%arg0: tensor<f8E4M3>, %arg1: tensor<f8E4M3>) -> tensor<f8E4M3> {
// CHECK: "mhlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor<f8E4M3>, tensor<f8E4M3>) -> tensor<f8E4M3>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<f8E4M3>, tensor<f8E4M3>) -> tensor<f8E4M3>
func.return %0 : tensor<f8E4M3>
}

// CHECK-LABEL: "type_f8E4M3FN"
func.func @type_f8E4M3FN(%arg0: tensor<f8E4M3FN>, %arg1: tensor<f8E4M3FN>) -> tensor<f8E4M3FN> {
// CHECK: "mhlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor<f8E4M3FN>, tensor<f8E4M3FN>) -> tensor<f8E4M3FN>
Expand Down

0 comments on commit ca1ea7c

Please sign in to comment.