From 4a03c71a1511c8f1350655128885a0decdbdd950 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Wed, 21 Aug 2024 22:03:27 +0000 Subject: [PATCH] Add float8_e3m4 --- CHANGELOG.md | 4 +- README.md | 5 + ml_dtypes/__init__.py | 3 + ml_dtypes/_finfo.py | 66 ++++++++++++- ml_dtypes/_src/dtypes.cc | 29 ++++++ ml_dtypes/include/float8.h | 126 ++++++++++++++++++++++++ ml_dtypes/tests/custom_float_test.py | 22 ++++- ml_dtypes/tests/finfo_test.py | 1 + ml_dtypes/tests/float8_test.cc | 138 +++++++++++++++++++++++++-- 9 files changed, 377 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e6c0eb72..48788cdf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,8 +23,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): ## [Unreleased] -* Added new 8-bit float type following IEEE 754 convention: - `ml_dtypes.float8_e4m3`. +* Added new 8-bit float types following IEEE 754 convention: + `ml_dtypes.float8_e4m3` and `ml_dtypes.float8_e3m4`. * Fix outputs of float `divmod` and `floor_divide` when denominator is zero. ## [0.4.0] - 2024-04-1 diff --git a/README.md b/README.md index 4921b49b..45a18bd1 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ an alternative to the standard [`float16`](https://en.wikipedia.org/wiki/Half-precision_floating-point_format) format - `float8_*`: several experimental 8-bit floating point representations including: + * `float8_e3m4` * `float8_e4m3` * `float8_e4m3b11fnuz` * `float8_e4m3fn` @@ -65,6 +66,10 @@ A `bfloat16` number is a single-precision float truncated at 16 bits. Exponent: 8, Mantissa: 7, exponent bias: 127. IEEE 754, with NaN and inf. +### `float8_e3m4` + +Exponent: 3, Mantissa: 4, bias: 3. IEEE 754, with NaN and inf. + ### `float8_e4m3` Exponent: 4, Mantissa: 3, bias: 7. IEEE 754, with NaN and inf. diff --git a/ml_dtypes/__init__.py b/ml_dtypes/__init__.py index fe0b1891..3942db9d 100644 --- a/ml_dtypes/__init__.py +++ b/ml_dtypes/__init__.py @@ -17,6 +17,7 @@ "__version__", "bfloat16", "finfo", + "float8_e3m4", "float8_e4m3", "float8_e4m3b11fnuz", "float8_e4m3fn", @@ -35,6 +36,7 @@ from ml_dtypes._finfo import finfo from ml_dtypes._iinfo import iinfo from ml_dtypes._ml_dtypes_ext import bfloat16 +from ml_dtypes._ml_dtypes_ext import float8_e3m4 from ml_dtypes._ml_dtypes_ext import float8_e4m3 from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz from ml_dtypes._ml_dtypes_ext import float8_e4m3fn @@ -48,6 +50,7 @@ import numpy as np bfloat16: Type[np.generic] +float8_e3m4: Type[np.generic] float8_e4m3: Type[np.generic] float8_e4m3b11fnuz: Type[np.generic] float8_e4m3fn: Type[np.generic] diff --git a/ml_dtypes/_finfo.py b/ml_dtypes/_finfo.py index 3f7aa48d..9d62e3a2 100644 --- a/ml_dtypes/_finfo.py +++ b/ml_dtypes/_finfo.py @@ -17,6 +17,7 @@ from typing import Dict from ml_dtypes._ml_dtypes_ext import bfloat16 +from ml_dtypes._ml_dtypes_ext import float8_e3m4 from ml_dtypes._ml_dtypes_ext import float8_e4m3 from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz from ml_dtypes._ml_dtypes_ext import float8_e4m3fn @@ -26,6 +27,7 @@ import numpy as np _bfloat16_dtype = np.dtype(bfloat16) +_float8_e3m4_dtype = np.dtype(float8_e3m4) _float8_e4m3_dtype = np.dtype(float8_e4m3) _float8_e4m3b11fnuz_dtype = np.dtype(float8_e4m3b11fnuz) _float8_e4m3fn_dtype = np.dtype(float8_e4m3fn) @@ -43,12 +45,21 @@ def __init__(self): self.smallest_subnormal = bfloat16(smallest_subnormal) +class _Float8E3m4MachArLike: + + def __init__(self): + smallest_normal = float.fromhex("0x1p-2") + self.smallest_normal = float8_e3m4(smallest_normal) + smallest_subnormal = float.fromhex("0x0.1p-2") + self.smallest_subnormal = float8_e3m4(smallest_subnormal) + + class _Float8E4m3MachArLike: def __init__(self): smallest_normal = float.fromhex("0x1p-6") self.smallest_normal = float8_e4m3(smallest_normal) - smallest_subnormal = float.fromhex("0x1p-9") + smallest_subnormal = float.fromhex("0x0.2p-6") self.smallest_subnormal = float8_e4m3(smallest_subnormal) @@ -146,6 +157,51 @@ def float_to_str(f): # pylint: enable=protected-access return obj + @staticmethod + def _float8_e3m4_finfo(): + def float_to_str(f): + return "%6.2e" % float(f) + + tiny = float.fromhex("0x1p-2") # 1/4 min normal + resolution = 0.1 + eps = float.fromhex("0x1p-4") # 1/16 + epsneg = float.fromhex("0x1p-5") # 1/32 + max_ = float.fromhex("0x1.Fp3") # 15.5 max normal + + obj = object.__new__(np.finfo) + obj.dtype = _float8_e3m4_dtype + obj.bits = 8 + obj.eps = float8_e3m4(eps) + obj.epsneg = float8_e3m4(epsneg) + obj.machep = -4 + obj.negep = -5 + obj.max = float8_e3m4(max_) + obj.min = float8_e3m4(-max_) + obj.nexp = 3 + obj.nmant = 4 + obj.iexp = obj.nexp + obj.maxexp = 4 + obj.minexp = -2 + obj.precision = 1 + obj.resolution = float8_e3m4(resolution) + # pylint: disable=protected-access + obj._machar = _Float8E3m4MachArLike() + if not hasattr(obj, "tiny"): + obj.tiny = float8_e3m4(tiny) + if not hasattr(obj, "smallest_normal"): + obj.smallest_normal = obj._machar.smallest_normal + obj.smallest_subnormal = obj._machar.smallest_subnormal + + obj._str_tiny = float_to_str(tiny) + obj._str_smallest_normal = float_to_str(tiny) + obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal) + obj._str_max = float_to_str(max_) + obj._str_epsneg = float_to_str(epsneg) + obj._str_eps = float_to_str(eps) + obj._str_resolution = float_to_str(resolution) + # pylint: enable=protected-access + return obj + @staticmethod def _float8_e4m3_finfo(): def float_to_str(f): @@ -425,6 +481,14 @@ def __new__(cls, dtype): if _bfloat16_dtype not in cls._finfo_cache: cls._finfo_cache[_bfloat16_dtype] = cls._bfloat16_finfo() return cls._finfo_cache[_bfloat16_dtype] + if ( + isinstance(dtype, str) + and dtype == "float8_e3m4" + or dtype == _float8_e3m4_dtype + ): + if _float8_e3m4_dtype not in cls._finfo_cache: + cls._finfo_cache[_float8_e3m4_dtype] = cls._float8_e3m4_finfo() + return cls._finfo_cache[_float8_e3m4_dtype] if ( isinstance(dtype, str) and dtype == "float8_e4m3" diff --git a/ml_dtypes/_src/dtypes.cc b/ml_dtypes/_src/dtypes.cc index 87f7578f..84914704 100644 --- a/ml_dtypes/_src/dtypes.cc +++ b/ml_dtypes/_src/dtypes.cc @@ -60,6 +60,20 @@ struct TypeDescriptor : CustomFloatType { static constexpr char kNpyDescrByteorder = '='; }; +template <> +struct TypeDescriptor : CustomFloatType { + typedef float8_e3m4 T; + static constexpr bool is_floating = true; + static constexpr bool is_integral = false; + static constexpr const char* kTypeName = "float8_e3m4"; + static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_e3m4"; + static constexpr const char* kTpDoc = "float8_e3m4 floating-point values"; + // Set e3m4 kind as Void since kind=f (float) with itemsize=1 is used by e5m2 + static constexpr char kNpyDescrKind = 'V'; // Void + static constexpr char kNpyDescrType = '3'; + static constexpr char kNpyDescrByteorder = '='; // Native byte order +}; + template <> struct TypeDescriptor : CustomFloatType { typedef float8_e4m3 T; @@ -283,6 +297,9 @@ bool Initialize() { if (!RegisterFloatDtype(numpy.get())) { return false; } + if (!RegisterFloatDtype(numpy.get())) { + return false; + } if (!RegisterFloatDtype(numpy.get())) { return false; } @@ -342,6 +359,13 @@ bool Initialize() { success &= RegisterTwoWayCustomCast(); success &= RegisterTwoWayCustomCast(); success &= RegisterTwoWayCustomCast(); + success &= RegisterTwoWayCustomCast(); + success &= RegisterTwoWayCustomCast(); + success &= RegisterTwoWayCustomCast(); + success &= RegisterTwoWayCustomCast(); + success &= RegisterTwoWayCustomCast(); + success &= RegisterTwoWayCustomCast(); + success &= RegisterTwoWayCustomCast(); success &= RegisterOneWayCustomCast(); success &= RegisterOneWayCustomCast(); return success; @@ -372,6 +396,11 @@ extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_ext() { return nullptr; } + if (PyObject_SetAttrString(m.get(), "float8_e3m4", + reinterpret_cast( + TypeDescriptor::type_ptr)) < 0) { + return nullptr; + } if (PyObject_SetAttrString(m.get(), "float8_e4m3", reinterpret_cast( TypeDescriptor::type_ptr)) < 0) { diff --git a/ml_dtypes/include/float8.h b/ml_dtypes/include/float8.h index 93aa0da4..36a3e3a9 100644 --- a/ml_dtypes/include/float8.h +++ b/ml_dtypes/include/float8.h @@ -43,6 +43,7 @@ namespace ml_dtypes { namespace float8_internal { // Forward-declarations of classes. +class float8_e3m4; class float8_e4m3; class float8_e4m3fn; class float8_e4m3fnuz; @@ -244,6 +245,20 @@ template using RequiresIsDerivedFromFloat8Base = std::enable_if_t, T>, int>; +class float8_e3m4 : public float8_base { + // Exponent: 3, Mantissa: 4, bias: 3. + // IEEE 754. + private: + using Base = float8_base; + friend class float8_base; + using Base::Base; + + public: + template = 0> + explicit EIGEN_DEVICE_FUNC float8_e3m4(T f8) + : float8_e3m4(ConvertFrom(f8)) {} +}; + class float8_e4m3 : public float8_base { // Exponent: 4, Mantissa: 3, bias: 7. // IEEE 754. @@ -386,6 +401,8 @@ class float8_e5m2fnuz : public float8_base { : float8_e5m2fnuz(ConvertFrom(f8)) {} explicit EIGEN_DEVICE_FUNC float8_e5m2fnuz(const float8_e4m3& f8) : float8_e5m2fnuz(ConvertFrom(f8)) {} + explicit EIGEN_DEVICE_FUNC float8_e5m2fnuz(const float8_e3m4& f8) + : float8_e5m2fnuz(ConvertFrom(f8)) {} explicit EIGEN_DEVICE_FUNC float8_e5m2fnuz(const float8_e4m3b11fnuz& f8) : float8_e5m2fnuz(ConvertFrom(f8)) {} explicit EIGEN_DEVICE_FUNC float8_e5m2fnuz(const float8_e4m3fn& f8) @@ -463,6 +480,8 @@ constexpr int MaxExponent10FromMaxExponentAndDigits(int max_exponent, -0.057991946977686754, // log10(1 - 2**-4) -0.028028723600243537, + // log10(1 - 2**-5) + -0.013788284485633295, }; return static_cast(ConstexprFloor(kLog10OfOnePredecessor[digits - 3] + max_exponent * kLog10Of2)); @@ -490,6 +509,70 @@ struct numeric_limits_float8_base { // NOLINTEND }; +struct numeric_limits_float8_e3m4 : public numeric_limits_float8_base { + private: + static inline constexpr const int kExponentBias = 3; + static inline constexpr const int kMantissaBits = 4; + + public: + // NOLINTBEGIN: these names must match std::numeric_limits. + static inline constexpr const int digits = kMantissaBits + 1; + static inline constexpr const int digits10 = Digits10FromDigits(digits); + static inline constexpr const int max_digits10 = + MaxDigits10FromDigits(digits); + static inline constexpr const int min_exponent = (1 - kExponentBias) + 1; + static inline constexpr const int min_exponent10 = + MinExponent10FromMinExponent(min_exponent); + static inline constexpr const int max_exponent = 0b111 - kExponentBias; + static inline constexpr const int max_exponent10 = + MaxExponent10FromMaxExponentAndDigits(max_exponent, digits); + static inline constexpr const bool is_iec559 = true; + static inline constexpr const bool has_infinity = true; + static inline constexpr const bool has_signaling_NaN = true; + // NOLINTEND + + // 1.0 * 2^(0b001 - 3) = 1.0 * 2^-2 = 1/4 (min normal) + static constexpr float8_e3m4 min() { + return float8_e3m4::FromRep(1 << kMantissaBits); + } + // -(1 + 0b1111 * 2^-2) * 2^(0b110 - 3) = -(1 + 15/16) * 2^3 = -15.5 + static constexpr float8_e3m4 lowest() { + return float8_e3m4::FromRep(0b1'110'1111); + } + // (1 + 0b1111 * 2^-2) * 2^(0b110 - 3) = (1 + 15/16) * 2^3 = 15.5 + static constexpr float8_e3m4 max() { + return float8_e3m4::FromRep(0b0'110'1111); + } + // (1 + 1/16) * 2^0 - 1.0 = 1.0 + 1/16 - 1.0 = 1/16 + // Encoded as denormal number 2^-2 * 1/4 + static constexpr float8_e3m4 epsilon() { + return float8_e3m4::FromRep(0b0'000'0100); + } + // 1.0 * 2^-1 = 0.5 + static constexpr float8_e3m4 round_error() { + return float8_e3m4::FromRep((-1 + kExponentBias) << kMantissaBits); + } + static constexpr float8_e3m4 infinity() { + return float8_e3m4::FromRep(0b0'111'0000); + } + static constexpr float8_e3m4 quiet_NaN() { + // IEEE 754-2019 6.2.1: "All binary NaN bit strings have the sign bit S set + // to 0 or 1 and all the bits of the biased exponent field E set to 1 + // (see 3.4). A quiet NaN bit string should be encoded with the first bit + // (d1) of the trailing significand field T being 1." + return float8_e3m4::FromRep(0b0'111'1000); + } + static constexpr float8_e3m4 signaling_NaN() { + // IEEE 754-2019 6.2.1: "A signaling NaN bit string should be encoded with + // the first bit of the trailing significand field being 0." + return float8_e3m4::FromRep(0b0'111'0100); + } + // 2^(-2) * 2^(-4) = 2^-6 = 1/64 (min denormal) + static constexpr float8_e3m4 denorm_min() { + return float8_e3m4::FromRep(0b0'000'0001); + } +}; + struct numeric_limits_float8_e4m3 : public numeric_limits_float8_base { private: static inline constexpr const int kExponentBias = 7; @@ -850,6 +933,10 @@ struct numeric_limits_float8_e5m2fnuz : public numeric_limits_float8_base { namespace std { // Standard-library overrides. Note that these are picked up by Eigen as well. +template <> +struct numeric_limits + : public ml_dtypes::float8_internal::numeric_limits_float8_e3m4 {}; + template <> struct numeric_limits : public ml_dtypes::float8_internal::numeric_limits_float8_e4m3 {}; @@ -878,6 +965,14 @@ struct numeric_limits namespace ml_dtypes { namespace float8_internal { +constexpr inline float8_e3m4 abs(const float8_e3m4& a) { + return float8_e3m4::FromRep(a.rep() & 0b0'111'1111); +} + +constexpr inline bool(isnan)(const float8_e3m4& a) { + return abs(a).rep() > std::numeric_limits::infinity().rep(); +} + constexpr inline float8_e4m3 abs(const float8_e4m3& a) { return float8_e4m3::FromRep(a.rep() & 0b0'1111'111); } @@ -1371,6 +1466,7 @@ EIGEN_DEVICE_FUNC To float8_base::ConvertTo(Derived from) { } // namespace float8_internal // Exported types. +using float8_e3m4 = float8_internal::float8_e3m4; using float8_e4m3 = float8_internal::float8_e4m3; using float8_e4m3fn = float8_internal::float8_e4m3fn; using float8_e4m3fnuz = float8_internal::float8_e4m3fnuz; @@ -1384,6 +1480,18 @@ using float8_e5m2fnuz = float8_internal::float8_e5m2fnuz; namespace Eigen { namespace numext { +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC ml_dtypes::float8_e3m4 +bit_cast(const uint8_t &src) { + return ml_dtypes::float8_e3m4::FromRep(src); +} + +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint8_t +bit_cast(const ml_dtypes::float8_e3m4 &src) { + return src.rep(); +} + template <> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC ml_dtypes::float8_e4m3 bit_cast(const uint8_t &src) { @@ -1425,6 +1533,12 @@ bit_cast(const ml_dtypes::float8_e5m2& src) { // Work-around for isinf/isnan/isfinite issue on aarch64. namespace internal { +template <> +EIGEN_DEVICE_FUNC inline bool isinf_impl( + const ml_dtypes::float8_e3m4& x) { + return ml_dtypes::float8_internal::isinf(x); +} + template <> EIGEN_DEVICE_FUNC inline bool isinf_impl( const ml_dtypes::float8_e4m3& x) { @@ -1461,6 +1575,12 @@ EIGEN_DEVICE_FUNC inline bool isinf_impl( return ml_dtypes::float8_internal::isinf(x); } +template <> +EIGEN_DEVICE_FUNC inline bool isnan_impl( + const ml_dtypes::float8_e3m4& x) { + return ml_dtypes::float8_internal::isnan(x); +} + template <> EIGEN_DEVICE_FUNC inline bool isnan_impl( const ml_dtypes::float8_e4m3& x) { @@ -1497,6 +1617,12 @@ EIGEN_DEVICE_FUNC inline bool isnan_impl( return ml_dtypes::float8_internal::isnan(x); } +template <> +EIGEN_DEVICE_FUNC inline bool isfinite_impl( + const ml_dtypes::float8_e3m4& x) { + return ml_dtypes::float8_internal::isfinite(x); +} + template <> EIGEN_DEVICE_FUNC inline bool isfinite_impl( const ml_dtypes::float8_e4m3& x) { diff --git a/ml_dtypes/tests/custom_float_test.py b/ml_dtypes/tests/custom_float_test.py index b94dc0da..54a37d01 100644 --- a/ml_dtypes/tests/custom_float_test.py +++ b/ml_dtypes/tests/custom_float_test.py @@ -30,6 +30,7 @@ import numpy as np bfloat16 = ml_dtypes.bfloat16 +float8_e3m4 = ml_dtypes.float8_e3m4 float8_e4m3 = ml_dtypes.float8_e4m3 float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fn = ml_dtypes.float8_e4m3fn @@ -109,6 +110,7 @@ def dtype_has_inf(dtype): FLOAT_DTYPES = [ bfloat16, + float8_e3m4, float8_e4m3, float8_e4m3b11fnuz, float8_e4m3fn, @@ -148,6 +150,11 @@ def dtype_has_inf(dtype): # Values that should round trip exactly to integer and back. INT_VALUES = { bfloat16: [0, 1, 2, 10, 34, 47, 128, 255, 256, 512], + float8_e3m4: list( + itertools.chain.from_iterable( + range(1 << n, 2 << n, 1 << max(0, n - 3)) for n in range(4) + ) + ), float8_e4m3: list( itertools.chain.from_iterable( range(1 << n, 2 << n, 1 << max(0, n - 3)) for n in range(8) @@ -178,6 +185,7 @@ def dtype_has_inf(dtype): BITS_TYPE = { bfloat16: np.uint16, + float8_e3m4: np.uint8, float8_e4m3: np.uint8, float8_e4m3b11fnuz: np.uint8, float8_e4m3fn: np.uint8, @@ -631,9 +639,12 @@ def testArray(self, float_type): self.assertTrue((x == x).all()) def testComparisons(self, float_type): - x = np.array([30, 7, -30], dtype=np.float32) + x0, x1, y0 = 30, 7, 17 + if float_type == ml_dtypes.float8_e3m4: + x0, x1, y0 = 15, 3, 9 + x = np.array([x0, x1, -x0], dtype=np.float32) bx = x.astype(float_type) - y = np.array([17, 7, 0], dtype=np.float32) + y = np.array([y0, x1, 0], dtype=np.float32) by = y.astype(float_type) np.testing.assert_equal(x == y, bx == by) np.testing.assert_equal(x != y, bx != by) @@ -749,9 +760,12 @@ def testArange(self, float_type): np.arange(-0.0, -2.0, -0.25, dtype=np.float32).astype(float_type), np.arange(-0.0, -2.0, -0.25, dtype=float_type), ) + m = 16 + if float_type == ml_dtypes.float8_e3m4: + m = 14 np.testing.assert_equal( - np.arange(-16.0, 16.0, 2.0, dtype=np.float32).astype(float_type), - np.arange(-16.0, 16.0, 2.0, dtype=float_type), + np.arange(-m, m, 2.0, dtype=np.float32).astype(float_type), + np.arange(-m, m, 2.0, dtype=float_type), ) @ignore_warning(category=RuntimeWarning, message="invalid value encountered") diff --git a/ml_dtypes/tests/finfo_test.py b/ml_dtypes/tests/finfo_test.py index 3999476b..ab92ea2f 100644 --- a/ml_dtypes/tests/finfo_test.py +++ b/ml_dtypes/tests/finfo_test.py @@ -19,6 +19,7 @@ ALL_DTYPES = [ ml_dtypes.bfloat16, + ml_dtypes.float8_e3m4, ml_dtypes.float8_e4m3, ml_dtypes.float8_e4m3b11fnuz, ml_dtypes.float8_e4m3fn, diff --git a/ml_dtypes/tests/float8_test.cc b/ml_dtypes/tests/float8_test.cc index c3a4841e..829c5eb9 100644 --- a/ml_dtypes/tests/float8_test.cc +++ b/ml_dtypes/tests/float8_test.cc @@ -40,6 +40,8 @@ struct Float8TestParamNames { return "float8_e4m3fn"; } else if constexpr (std::is_same_v) { return "float8_e4m3b11fnuz"; + } else if constexpr (std::is_same_v) { + return "float8_e3m4"; } else if constexpr (std::is_same_v) { return "float8_e4m3"; } else if constexpr (std::is_same_v) { @@ -54,10 +56,41 @@ struct Float8TestParamNames { }; using Float8Types = - ::testing::Types; TYPED_TEST_SUITE(Float8Test, Float8Types, Float8TestParamNames); +TEST(Float8E3m4Test, NumericLimits) { + EXPECT_TRUE( + Eigen::numext::isnan(std::numeric_limits::quiet_NaN())); + EXPECT_TRUE( + Eigen::numext::isnan(std::numeric_limits::signaling_NaN())); + EXPECT_EQ(static_cast(std::numeric_limits::min()), + 0.25); + EXPECT_EQ(static_cast(std::numeric_limits::max()), 15.5); + EXPECT_EQ(static_cast(std::numeric_limits::lowest()), + -15.5); + EXPECT_EQ(static_cast(std::numeric_limits::epsilon()), + 0.0625); + EXPECT_EQ(static_cast(std::numeric_limits::round_error()), + 0.5); + EXPECT_TRUE( + Eigen::numext::isinf(std::numeric_limits::infinity())); + EXPECT_EQ(static_cast(std::numeric_limits::denorm_min()), + std::exp2(-6)); + EXPECT_EQ(std::numeric_limits::digits, 5); + EXPECT_EQ(std::numeric_limits::digits10, 1); + EXPECT_EQ(std::numeric_limits::max_digits10, 3); + EXPECT_EQ(std::numeric_limits::min_exponent, -1); + EXPECT_EQ(std::numeric_limits::min_exponent10, 0); + EXPECT_EQ(std::numeric_limits::max_exponent, 4); + EXPECT_EQ(std::numeric_limits::max_exponent10, 1); + EXPECT_EQ(std::numeric_limits::is_iec559, true); + EXPECT_EQ(std::numeric_limits::has_infinity, true); + EXPECT_EQ(std::numeric_limits::has_quiet_NaN, true); + EXPECT_EQ(std::numeric_limits::has_signaling_NaN, true); +} + TEST(Float8E4m3Test, NumericLimits) { EXPECT_TRUE( Eigen::numext::isnan(std::numeric_limits::quiet_NaN())); @@ -610,6 +643,20 @@ TEST(Float8Test, Float8E4m3b11fnuz_To_Float8E4m3fn) { } } +TEST(Float8Test, Float8E3m4_To_Float8E5m2) { + // Truncation and rounding of a number ever-so-slightly less than 2. + float8_e3m4 less_than_two = float8_e3m4::FromRep(0x3F); + float8_e5m2 truncated = + float8_e5m2::template ConvertFrom(less_than_two); + EXPECT_LT(static_cast(truncated), 2); + + float8_e5m2 rounded = + float8_e5m2::template ConvertFrom(less_than_two); + EXPECT_EQ(static_cast(rounded), 2); +} + TEST(Float8Test, Float8E4m3_To_Float8E5m2) { // Truncation and rounding of a number ever-so-slightly less than 2. float8_e4m3 less_than_two = float8_e4m3::FromRep(0x3F); @@ -638,6 +685,67 @@ TEST(Float8Test, Float8E4m3fn_To_Float8E5m2) { EXPECT_EQ(static_cast(rounded), 2); } +TEST(Float8Test, Half_To_Float8E3m4) { + // Special values, NaN. + Eigen::half inf = + Eigen::numext::bit_cast(static_cast(0x7C00)); + EXPECT_EQ(static_cast(inf).rep(), 0x70); + Eigen::half ninf = + Eigen::numext::bit_cast(static_cast(0xFC00)); + EXPECT_EQ(static_cast(ninf).rep(), 0xF0); + + Eigen::half nan = + Eigen::numext::bit_cast(static_cast(0x7C01)); + EXPECT_EQ(static_cast(nan).rep(), 0x78); + Eigen::half nnan = + Eigen::numext::bit_cast(static_cast(0xFC01)); + EXPECT_EQ(static_cast(nnan).rep(), 0xF8); + + // Rounding vs truncation. + Eigen::half less_than_two = + Eigen::numext::bit_cast(static_cast(0x3FFF)); + EXPECT_EQ((float8_e3m4::ConvertFrom(less_than_two) + .rep()), + 0x40); + EXPECT_EQ((float8_e3m4::ConvertFrom(less_than_two) + .rep()), + 0x3F); + EXPECT_EQ((float8_e3m4::ConvertFrom(-less_than_two) + .rep()), + 0xC0); + EXPECT_EQ((float8_e3m4::ConvertFrom(-less_than_two) + .rep()), + 0xBF); + + // Saturation. + // f8e3m4=0.110.1111 0x1.Fp+3 f16=0.10010.1111000000 uint16=0x4BC0 + // f8e3m4=0.111.0000 0x1.0p+4 f16=0.10011.0000000000 uint16=0x4C00 + for (uint16_t i = 0x4BC0; i < 0x4C00; ++i) { + Eigen::half big_half = Eigen::numext::bit_cast(i); + float big_float = static_cast(big_half); + EXPECT_EQ( + (float8_e3m4::ConvertFrom( + big_half) + .rep()), + (float8_e3m4::ConvertFrom( + big_float) + .rep())) + << i; + EXPECT_EQ( + (float8_e3m4::ConvertFrom( + -big_half) + .rep()), + (float8_e3m4::ConvertFrom( + -big_float) + .rep())) + << i; + } +} + TEST(Float8Test, Half_To_Float8E4m3) { // Special values, NaN. Eigen::half inf = @@ -831,22 +939,31 @@ TYPED_TEST(Float8Test, CallTheConstOperator) { } } +TEST(Float8E3m4Test, SmallCastToDenormal) { + // Special edge-case where rounding to a normalized value would + // normally round down, but rounding to a subnormal rounds up. + float x = 0x0.8Ap-2; // btw denormals + float8_e3m4 y = static_cast(x); + float z = static_cast(y); + EXPECT_EQ(z, 0x0.9p-2); // rounded up to the next denormal +} + TEST(Float8E4m3Test, SmallCastToDenormal) { // Special edge-case where rounding to a normalized value would // normally round down, but rounding to a subnormal rounds up. - float x = std::ldexp(1.3125, -8); + float x = 0x0.94p-6; // btw denormals float8_e4m3 y = static_cast(x); float z = static_cast(y); - EXPECT_EQ(z, std::ldexp(1.5, -8)); + EXPECT_EQ(z, 0x0.Ap-6); // rounded up to the next denormal } TEST(Float8E5m2Test, SmallCastToDenormal) { // Special edge-case where rounding to a normalized value would // normally round down, but rounding to a subnormal rounds up. - float x = std::ldexp(1.3125, -15); + float x = 0x0.A8p-14; // btw denormals float8_e5m2 y = static_cast(x); float z = static_cast(y); - EXPECT_EQ(z, std::ldexp(1.5, -15)); + EXPECT_EQ(z, 0x0.Cp-14); // rounded up to the next denormal } // Helper utility for prettier test names. @@ -872,16 +989,17 @@ struct Float8CastTestParamNames { GEN_LONG_DOUBLE_PAIR(Type) \ std::pair, std::pair, \ std::pair, std::pair, \ - std::pair, \ + std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair -#define GEN_TYPE_PAIRS() \ - GEN_DEST_TYPES(float8_e4m3fn), GEN_DEST_TYPES(float8_e4m3b11fnuz), \ - GEN_DEST_TYPES(float8_e5m2), GEN_DEST_TYPES(float8_e4m3fnuz), \ - GEN_DEST_TYPES(float8_e5m2fnuz), GEN_DEST_TYPES(float8_e4m3) +#define GEN_TYPE_PAIRS() \ + GEN_DEST_TYPES(float8_e3m4), GEN_DEST_TYPES(float8_e4m3), \ + GEN_DEST_TYPES(float8_e4m3fn), GEN_DEST_TYPES(float8_e4m3b11fnuz), \ + GEN_DEST_TYPES(float8_e5m2), GEN_DEST_TYPES(float8_e4m3fnuz), \ + GEN_DEST_TYPES(float8_e5m2fnuz) using Float8CastTypePairs = ::testing::Types;