Skip to content

Commit

Permalink
Merge branch 'develop' into qdq-packed-out
Browse files Browse the repository at this point in the history
  • Loading branch information
causten authored Nov 13, 2024
2 parents c8ab8e2 + 1cfd6c2 commit f2e2f90
Show file tree
Hide file tree
Showing 17 changed files with 474 additions and 202 deletions.
2 changes: 1 addition & 1 deletion docs/install/installing_with_package.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ ROCm must be installed before installing MIGraphX. See `ROCm installation for Li

Installing MIGraphX using the package installer is sufficient for users who want to use the MIGraphX API.

If you want to develop for MIGraphX and contribute to the source code, see `Building MIGraphX <https://rocm.docs.amd.com/projects/AMDMIGraphX/en/latest/install/docs/install/building_migraphx.html>`_ and `Developing for MIGraphX <https://rocm.docs.amd.com/projects/AMDMIGraphX/en/latest/dev/contributing-to-migraphx.html>`_
If you want to develop for MIGraphX and contribute to the source code, see :doc:`Building MIGraphX </install/building_migraphx>` and :doc:`Developing for MIGraphX <../dev/contributing-to-migraphx>`.

The package installer will install all the prerequisites needed for MIGraphX.

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build
msgpack/[email protected] -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCm/composable_kernel@57cdd70b7cb14e5e3b60cd9a5f96ba8dc343763e -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCm/rocMLIR@e454b5d06fc2f099f7de3ee43450e7a6b1efe015 -DBUILD_FAT_LIBROCKCOMPILER=On
ROCm/rocMLIR@99fc9d24714ee7eae75ef8e414df4f2dacd3af16 -DBUILD_FAT_LIBROCKCOMPILER=On
2 changes: 1 addition & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ add_library(migraphx
insert_pad.cpp
instruction.cpp
json.cpp
layout_nhwc.cpp
layout_convolution.cpp
lexing.cpp
load_save.cpp
make_op.cpp
Expand Down
108 changes: 65 additions & 43 deletions src/include/migraphx/float8.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include <migraphx/config.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/float8_impl.hpp>
#include <migraphx/generic_float.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down Expand Up @@ -379,52 +380,73 @@ class numeric_limits<fp8e5m2>

// =================================================================================================
// define numeric limits for the new data type
// NOLINTBEGIN
// NOLINTBEGIN(cert-dcl58-cpp)
namespace std {
#define MIGRAPHX_FP8_STD_OVERLOADS(T) \
inline bool isfinite(T x) { return not x.is_inf() and not x.is_nan(); } \
inline bool isnan(T x) { return x.is_nan(); } \
template <> \
class numeric_limits<T> : public migraphx::fp8::numeric_limits<T> \
{ \
}; \
template <class U> \
struct common_type<T, U> : std::common_type<float, U> \
{ \
}; \
template <class U> \
struct common_type<U, T> : std::common_type<U, float> \
{ \
}; \
template <> \
struct common_type<T, T> \
{ \
using type = T; \
};

MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fn)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2fnuz)

// needed to resolve between multiple ambiguous definition from previous templates
#define MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(T, U) \
template <> \
struct common_type<T, U> : std::common_type<float, float> \
{ \
}; \
template <> \
struct common_type<U, T> : std::common_type<float, float> \
{ \
};
template <migraphx::fp8::f8_type T, bool FNUZ>
inline bool isfinite(migraphx::fp8::float8<T, FNUZ> x)
{
return not x.is_inf() and not x.is_nan();
}

template <migraphx::fp8::f8_type T, bool FNUZ>
inline bool isnan(migraphx::fp8::float8<T, FNUZ> x)
{
return x.is_nan();
}

template <migraphx::fp8::f8_type T, bool FNUZ>
class numeric_limits<migraphx::fp8::float8<T, FNUZ>>
: public migraphx::fp8::numeric_limits<migraphx::fp8::float8<T, FNUZ>>
{
};
template <migraphx::fp8::f8_type T, bool FNUZ, class U>
struct common_type<migraphx::fp8::float8<T, FNUZ>, U> : std::common_type<float, U>
{
};
template <migraphx::fp8::f8_type T, bool FNUZ, class U>
struct common_type<U, migraphx::fp8::float8<T, FNUZ>> : std::common_type<U, float>
{
};
template <migraphx::fp8::f8_type T, bool FNUZ>
struct common_type<migraphx::fp8::float8<T, FNUZ>, migraphx::fp8::float8<T, FNUZ>>
{
using type = migraphx::fp8::float8<T, FNUZ>;
};

template <migraphx::fp8::f8_type T1, bool FNUZ1, migraphx::fp8::f8_type T2, bool FNUZ2>
struct common_type<migraphx::fp8::float8<T1, FNUZ1>, migraphx::fp8::float8<T2, FNUZ2>>
{
using type = float;
};

template <unsigned int E, unsigned int M, unsigned int F, bool FNUZ>
struct common_type<migraphx::generic_float<E, M, F>,
migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, FNUZ>>
{
using type = float;
};

template <unsigned int E, unsigned int M, unsigned int F, bool FNUZ>
struct common_type<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, FNUZ>,
migraphx::generic_float<E, M, F>>
{
using type = float;
};

template <unsigned int E, unsigned int M, unsigned int F, migraphx::fp8::f8_type T, bool FNUZ>
struct common_type<migraphx::generic_float<E, M, F>, migraphx::fp8::float8<T, FNUZ>>
: std::common_type<float, float>
{
};

template <unsigned int E, unsigned int M, unsigned int F, migraphx::fp8::f8_type T, bool FNUZ>
struct common_type<migraphx::fp8::float8<T, FNUZ>, migraphx::generic_float<E, M, F>>
: std::common_type<float, float>
{
};

MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(migraphx::fp8::fp8e4m3fn, migraphx::fp8::fp8e5m2)
MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(migraphx::fp8::fp8e4m3fn, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(migraphx::fp8::fp8e4m3fn, migraphx::fp8::fp8e5m2fnuz)
MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(migraphx::fp8::fp8e5m2, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(migraphx::fp8::fp8e5m2, migraphx::fp8::fp8e5m2fnuz)
MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(migraphx::fp8::fp8e4m3fnuz, migraphx::fp8::fp8e5m2fnuz)
} // namespace std
// NOLINTEND
// NOLINTEND(cert-dcl58-cpp)
// =================================================================================================
#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
118 changes: 79 additions & 39 deletions src/include/migraphx/generic_float.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
* THE SOFTWARE.
*/

#ifndef MIGRAPHX_GUARD_MIGRAPHX_GENERIC_FLOAT_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_GENERIC_FLOAT_HPP

#include <migraphx/config.hpp>
#include <migraphx/bit_cast.hpp>
#include <algorithm>
Expand All @@ -47,6 +50,54 @@ constexpr int countl_zero(T value)
return 8 * sizeof(value) - r;
}

constexpr std::size_t bit_ceil(std::size_t v)
{
if(v <= 1)
return 1;
v--;
v |= v >> 1u;
v |= v >> 2u;
v |= v >> 4u;
v |= v >> 8u;
v |= v >> 16u;
v |= v >> 32u;
return v + 1;
}

constexpr std::size_t integer_divide_ceil(std::size_t x, std::size_t y)
{
return (x + y - std::size_t{1}) / y;
}

template <unsigned int Bytes>
struct unsigned_type
{
};

template <>
struct unsigned_type<1>
{
using type = std::uint8_t;
};

template <>
struct unsigned_type<2>
{
using type = std::uint16_t;
};

template <>
struct unsigned_type<4>
{
using type = std::uint32_t;
};

template <>
struct unsigned_type<8>
{
using type = std::uint64_t;
};

struct float32_parts
{
unsigned int mantissa : 23;
Expand All @@ -70,9 +121,12 @@ constexpr float32_parts get_parts(float f) { return migraphx::bit_cast<float32_p
template <unsigned int MantissaSize, unsigned int ExponentSize, unsigned int Flags = 0>
struct __attribute__((packed, may_alias)) generic_float
{
unsigned int mantissa : MantissaSize;
unsigned int exponent : ExponentSize;
unsigned int sign : 1;
using type = typename unsigned_type<bit_ceil(
integer_divide_ceil(MantissaSize + ExponentSize + 1, 8))>::type;

type mantissa : MantissaSize;
type exponent : ExponentSize;
type sign : 1;

static constexpr int exponent_bias() { return all_ones<ExponentSize - 1>(); }

Expand Down Expand Up @@ -108,13 +162,13 @@ struct __attribute__((packed, may_alias)) generic_float
}
else
{
unsigned int shift = 0;
type shift = 0;
f.mantissa = mantissa;

if(MantissaSize < float32_parts::mantissa_width())
{
shift = MantissaSize - ((sizeof(unsigned int) * 8) - countl_zero(mantissa));
f.mantissa <<= (shift + 1);
shift = MantissaSize - ((sizeof(type) * 8) - countl_zero(mantissa));
f.mantissa <<= (shift + 1u);
}

f.exponent = float32_parts::exponent_bias() - exponent_bias() - shift;
Expand Down Expand Up @@ -184,7 +238,7 @@ struct __attribute__((packed, may_alias)) generic_float
}
}

exponent = std::min(exponent, all_ones<ExponentSize>());
exponent = std::min<type>(exponent, all_ones<ExponentSize>());
}

constexpr bool is_normal() const noexcept
Expand Down Expand Up @@ -343,10 +397,11 @@ struct __attribute__((packed, may_alias)) generic_float
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

// NOLINTBEGIN(cert-dcl58-cpp)
namespace std {

template <unsigned int E, unsigned int M, unsigned int F>
class numeric_limits<migraphx::generic_float<E, M, F>> // NOLINT(cert-dcl58-cpp)
class numeric_limits<migraphx::generic_float<E, M, F>>
{
public:
static constexpr bool has_infinity = true;
Expand Down Expand Up @@ -392,48 +447,33 @@ class numeric_limits<migraphx::generic_float<E, M, F>> // NOLINT(cert-dcl58-cpp)
};

template <unsigned int E, unsigned int M, unsigned int F, class T>
struct common_type<migraphx::generic_float<E, M, F>, T> // NOLINT(cert-dcl58-cpp)
: std::common_type<float, T>
struct common_type<migraphx::generic_float<E, M, F>, T> : std::common_type<float, T>
{
};

template <unsigned int E, unsigned int M, unsigned int F, class T>
struct common_type<T, migraphx::generic_float<E, M, F>> // NOLINT(cert-dcl58-cpp)
: std::common_type<float, T>
struct common_type<T, migraphx::generic_float<E, M, F>> : std::common_type<float, T>
{
};

// template<unsigned int E, unsigned int M, unsigned int F, bool FNUZ>
// struct common_type<migraphx::generic_float<E, M, F>,
// migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, FNUZ>> : std::common_type<float, float>
// {};

// template<unsigned int E, unsigned int M, unsigned int F, bool FNUZ>
// struct common_type<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, FNUZ>,
// migraphx::generic_float<E, M, F>> : std::common_type<float, float>
// {};

// template<unsigned int E, unsigned int M, unsigned int F, migraphx::fp8::f8_type T, bool FNUZ>
// struct common_type<migraphx::generic_float<E, M, F>, migraphx::fp8::float8<T, FNUZ>> :
// std::common_type<float, float>
// {};

// template<unsigned int E, unsigned int M, unsigned int F, migraphx::fp8::f8_type T, bool FNUZ>
// struct common_type<migraphx::fp8::float8<T, FNUZ>, migraphx::generic_float<E, M, F>> :
// std::common_type<float, float>
// {};

template <unsigned int E, unsigned int M, unsigned int F>
struct common_type<migraphx::generic_float<E, M, F>, // NOLINT(cert-dcl58-cpp)
migraphx::generic_float<E, M, F>>
struct common_type<migraphx::generic_float<E, M, F>, migraphx::generic_float<E, M, F>>
{
using type = migraphx::generic_float<E, M, F>;
};

// template<unsigned int E, unsigned int M, unsigned int F, unsigned int E1, .....>
// struct common_type<migraphx::generic_float<E, M, F>, migraphx::generic_float<E1, M1, F1>>
// {
// using type = float;
// };
template <unsigned int E1,
unsigned int M1,
unsigned int F1,
unsigned int E2,
unsigned int M2,
unsigned int F2>
struct common_type<migraphx::generic_float<E1, M1, F1>, migraphx::generic_float<E2, M2, F2>>
{
using type = float;
};

} // namespace std
// NOLINTEND(cert-dcl58-cpp)

#endif // MIGRAPHX_GUARD_MIGRAPHX_GENERIC_FLOAT_HPP
56 changes: 0 additions & 56 deletions src/include/migraphx/half.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,60 +48,4 @@ using deduce = typename detail::deduce<T>::type;
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

namespace std {

template <class T>
struct common_type<migraphx::half, T> : std::common_type<float, T> // NOLINT
{
};

template <class T>
struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT
{
};

template <>
struct common_type<migraphx::fp8::fp8e4m3fnuz, migraphx::half>
{
using type = float;
};

template <>
struct common_type<migraphx::half, migraphx::fp8::fp8e4m3fnuz>
{
using type = float;
};

template <>
struct common_type<migraphx::fp8::fp8e4m3fn, migraphx::half>
{
using type = float;
};

template <>
struct common_type<migraphx::half, migraphx::fp8::fp8e4m3fn>
{
using type = float;
};

template <>
struct common_type<migraphx::fp8::fp8e5m2, migraphx::half>
{
using type = float;
};

template <>
struct common_type<migraphx::half, migraphx::fp8::fp8e5m2>
{
using type = float;
};

template <>
struct common_type<migraphx::half, migraphx::half>
{
using type = migraphx::half;
};

} // namespace std

#endif
Loading

0 comments on commit f2e2f90

Please sign in to comment.