Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bf16 gpu support #3630

Merged
merged 92 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
c51c1ce
first pass at integrating generic float
richagadgil Oct 10, 2024
134b408
fix namespaces
richagadgil Oct 10, 2024
d4fa6eb
fix mantissa
richagadgil Oct 10, 2024
0b60841
refactor
richagadgil Oct 11, 2024
7a646f1
refactor
richagadgil Oct 11, 2024
ebe819b
add fp
richagadgil Oct 11, 2024
379a77a
fixed generic float class
richagadgil Oct 14, 2024
174384c
add fp32 test
richagadgil Oct 14, 2024
787b651
remove import
richagadgil Oct 14, 2024
1d1fa1c
update tests
richagadgil Oct 15, 2024
1791092
fp16 tests that work
richagadgil Oct 17, 2024
a2eb005
update tests
richagadgil Oct 18, 2024
ff8ffc7
updated fp16 and fp32 tests
richagadgil Oct 18, 2024
e36fd65
half tests
richagadgil Oct 22, 2024
9ac4e2a
underflow and overflow tests
richagadgil Oct 22, 2024
f05fd31
generate map
richagadgil Oct 22, 2024
cb4d92d
add more tests
richagadgil Oct 22, 2024
0cc1946
fix names
richagadgil Oct 22, 2024
85a761b
update tests
richagadgil Oct 23, 2024
65cf9ae
remove and
richagadgil Oct 24, 2024
fbabf54
disable warning
richagadgil Oct 24, 2024
549f5e6
fix tidy warning
richagadgil Oct 24, 2024
d302e5d
migraphx py fix
richagadgil Oct 25, 2024
8d475e3
add increments
richagadgil Oct 25, 2024
a0fd055
fix warnings
richagadgil Oct 25, 2024
41379fe
disable duplicate branch warning
richagadgil Oct 25, 2024
0c29c7b
add countzero_std
richagadgil Oct 28, 2024
4b012a8
ci error
richagadgil Oct 28, 2024
dbaa3a8
simplify countl
richagadgil Oct 28, 2024
b2bd2a0
fix ci
richagadgil Oct 28, 2024
6f328f0
src
richagadgil Oct 29, 2024
e6d9763
remove flag
richagadgil Oct 29, 2024
6538050
hide abi warning
richagadgil Oct 29, 2024
4e96d4d
revert changes
richagadgil Oct 29, 2024
ef11f1f
Merge branch 'develop' into generic_float
richagadgil Oct 29, 2024
e4a25bd
change half in tests
richagadgil Oct 29, 2024
3354c6e
Update generic_float.hpp
richagadgil Oct 29, 2024
6de079b
format
richagadgil Oct 29, 2024
7750874
Merge branch 'develop' into generic_float
richagadgil Oct 29, 2024
801f485
Merge branch 'develop' into generic_float
causten Oct 30, 2024
33e2c8d
fix bug
richagadgil Oct 30, 2024
9bb7198
Merge branch 'generic_float' of github.com:ROCm/AMDMIGraphX into gene…
richagadgil Oct 30, 2024
b3c345d
fix err
richagadgil Oct 30, 2024
03df6f9
edits
richagadgil Oct 31, 2024
ad817b2
tidy and format
richagadgil Oct 31, 2024
898417b
tidy etc
richagadgil Oct 31, 2024
aa5b9c9
gf
richagadgil Oct 31, 2024
6f72370
fix tidy errs
richagadgil Nov 1, 2024
0aab1a0
bf16 changes
richagadgil Nov 4, 2024
7b965c0
add flag to trace quantization passes (#3571)
shivadbhavsar Oct 30, 2024
5f5f13d
bf16
richagadgil Oct 30, 2024
d64b124
Update bf16.cpp
richagadgil Nov 1, 2024
a064eaa
Update bf16.hpp
richagadgil Nov 2, 2024
befbd9e
Update bf16.hpp
richagadgil Nov 2, 2024
08b9511
update files with working version
richagadgil Nov 4, 2024
b9d204e
Update bf16.cpp
richagadgil Nov 4, 2024
fb6df2d
Update generic_float.hpp
richagadgil Nov 4, 2024
bb78138
Merge branch 'develop' into bf16
richagadgil Nov 8, 2024
8e1f99e
add extra common type
richagadgil Nov 8, 2024
6192970
tidy
richagadgil Nov 8, 2024
c0d6bc4
Update bf16.hpp
richagadgil Nov 11, 2024
7bfc407
Update generic_float.hpp
richagadgil Nov 11, 2024
4cb96ad
Merge branch 'develop' into bf16
richagadgil Nov 11, 2024
ffd4ba2
remove imports
richagadgil Nov 12, 2024
8a10da3
Merge branch 'develop' into bf16
richagadgil Nov 12, 2024
1565a0e
ref tests
richagadgil Nov 13, 2024
e6d1155
migraphx_py fix
richagadgil Nov 13, 2024
867e960
fix test cae by index
richagadgil Nov 13, 2024
9852da5
add rocblas type
richagadgil Nov 13, 2024
bf50653
fix tgts err
richagadgil Nov 13, 2024
0ebd220
address changes
richagadgil Nov 18, 2024
043e322
Merge branch 'develop' into bf16
richagadgil Nov 18, 2024
a3ca184
bf16 gpu support
richagadgil Nov 19, 2024
490d326
add vector types
richagadgil Nov 19, 2024
a63ac1e
rocblas
richagadgil Nov 19, 2024
94990bb
bf16 gpu testing
shivadbhavsar Nov 19, 2024
8aaae90
mlir bf16
shivadbhavsar Nov 19, 2024
208232e
fix type
richagadgil Nov 19, 2024
d4866d5
fix type
richagadgil Nov 19, 2024
59eec66
add type
richagadgil Nov 19, 2024
79c0bfb
Merge branch 'develop' into bf16_gpu_support
richagadgil Dec 6, 2024
e52b95e
Update hip.hpp
richagadgil Dec 6, 2024
79e9809
Merge branch 'develop' into bf16_gpu_support
richagadgil Dec 7, 2024
6dff501
add bf16 support
richagadgil Dec 9, 2024
5d6eeba
working float equals
richagadgil Dec 11, 2024
6772242
update verify tolerance for bf16
shivadbhavsar Dec 11, 2024
dca3048
Merge branch 'develop' into bf16_gpu_support
richagadgil Dec 11, 2024
4ca5ba1
add supported type filter to miopen pooling
shivadbhavsar Dec 11, 2024
0726f8b
remove repetitions
richagadgil Dec 12, 2024
d10aadf
Merge branch 'bf16_gpu_support' of github.com:ROCm/AMDMIGraphX into b…
richagadgil Dec 12, 2024
d258b6d
format
richagadgil Dec 12, 2024
62fec32
ci
richagadgil Dec 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/driver/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ struct compiler
compiler_target ct;
compile_options co;
bool to_fp16 = false;
bool to_bf16 = false;
bool to_fp8 = false;
bool to_int8 = false;
bool to_int4 = false;
Expand All @@ -506,6 +507,7 @@ struct compiler
ap.help("Exhastively search for best tuning parameters for kernels"),
ap.set_value(true));
ap(to_fp16, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(true));
ap(to_bf16, {"--bf16"}, ap.help("Quantize for bf16"), ap.set_value(true));
ap(to_int8, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(true));
ap(to_fp8, {"--fp8"}, ap.help("Quantize for fp8"), ap.set_value(true));
ap(to_int4, {"--int4-weights"}, ap.help("Quantize weights for int4"), ap.set_value(true));
Expand Down Expand Up @@ -555,6 +557,10 @@ struct compiler
{
quantize_fp16(p);
}
if(to_bf16)
{
quantize_bf16(p);
}
if(to_int8)
{
quantize_int8(p, t, {host_params(p)});
Expand Down Expand Up @@ -639,6 +645,10 @@ struct verify : command<verify>
{
vo.quantize = precision::fp16;
}
if(c.to_bf16)
{
vo.quantize = precision::bf16;
}
if(c.to_int8)
{
vo.quantize = precision::int8;
Expand Down
1 change: 1 addition & 0 deletions src/driver/precision.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ enum class precision
{
fp32,
fp16,
bf16,
int8
};

Expand Down
13 changes: 10 additions & 3 deletions src/driver/verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,14 @@ verify::tolerance get_tolerances(const program& p,
std::optional<double> atol,
std::optional<double> rtol)
{
bool has_fp16 = any_of(p.get_modules(), [](auto&& m) {
return any_of(*m, [](auto&& ins) { return (ins.get_shape().type() == shape::half_type); });
bool has_16bit = any_of(p.get_modules(), [](auto&& m) {
return any_of(*m, [](auto&& ins) {
return (ins.get_shape().type() == shape::half_type or
ins.get_shape().type() == shape::bf16_type);
});
});
migraphx::verify::tolerance result{};
if(has_fp16 or vo.quantize == precision::fp16)
if(has_16bit or vo.quantize == precision::fp16 or vo.quantize == precision::bf16)
{
result.rms_tol = 8e-2;
result.atol = 4e-2;
Expand Down Expand Up @@ -100,6 +103,10 @@ std::vector<argument> run_target(program p,
{
quantize_fp16(p);
}
if(vo.quantize == precision::bf16)
{
quantize_bf16(p);
}
p.compile(t, options);

parameter_map m;
Expand Down
3 changes: 3 additions & 0 deletions src/include/migraphx/quantization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ quantize_fp8(program& prog, const target& t, const std::vector<parameter_map>& c

MIGRAPHX_EXPORT void quantize_int4_weights(program& prog);

MIGRAPHX_EXPORT void quantize_bf16(program& prog,
const std::vector<std::string>& ins_names = {"all"});

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

Expand Down
4 changes: 4 additions & 0 deletions src/py/migraphx_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,10 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
},
"Auto-convert FP8 parameters and return values to Float for MIGraphX Program",
py::arg("prog"));
m.def("quantize_bf16",
&migraphx::quantize_bf16,
py::arg("prog"),
py::arg("ins_names") = std::vector<std::string>{"all"});

#ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
Expand Down
10 changes: 10 additions & 0 deletions src/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@
quant_tracer());
}

void quantize_bf16(program& prog, const std::vector<std::string>& ins_names)

Check warning on line 77 in src/quantization.cpp

View check run for this annotation

Codecov / codecov/patch

src/quantization.cpp#L77

Added line #L77 was not covered by tests
{
run_passes(prog,
{normalize_ops{},
optimize_module{{"quantizelinear", "dequantizelinear"}},
truncate_float_pass{ins_names, shape::bf16_type},
optimize_module{{"quantizelinear", "dequantizelinear"}}},

Check warning on line 83 in src/quantization.cpp

View check run for this annotation

Codecov / codecov/patch

src/quantization.cpp#L79-L83

Added lines #L79 - L83 were not covered by tests
quant_tracer());
}

Check warning on line 85 in src/quantization.cpp

View check run for this annotation

Codecov / codecov/patch

src/quantization.cpp#L85

Added line #L85 was not covered by tests

void quantize_8bits(program& prog,
const target& t,
shape::type_t precision,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ __device__ bool float_equal_device(T x, T y)
std::nextafter(x, std::numeric_limits<T>::max()) >= y;
}

template <>
__device__ bool float_equal_device(__bf16 x, __bf16 y) // NOLINT(misc-definitions-in-headers)
{
float xf = x;
float yf = y;
return std::isfinite(xf) and std::isfinite(yf) and
std::nextafter(xf, std::numeric_limits<float>::lowest()) <= yf and
std::nextafter(xf, std::numeric_limits<float>::max()) >= yf;
}

template <class T, MIGRAPHX_REQUIRES(not is_floating_point<T>{})>
__device__ bool float_equal_device(T x, T y)
{
Expand Down
74 changes: 59 additions & 15 deletions src/targets/gpu/device/include/migraphx/gpu/device/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <hip/hip_runtime.h>
#include <migraphx/half.hpp>
#include <migraphx/bf16.hpp>
#include <migraphx/config.hpp>
#include <migraphx/tensor_view.hpp>

Expand Down Expand Up @@ -67,6 +68,7 @@ auto pack_vec(Ts... xs)
}

using gpu_half = __fp16;
using gpu_bf16 = __bf16;

namespace detail {
template <class T>
Expand All @@ -87,6 +89,12 @@ struct device_type<half>
using type = gpu_half;
};

template <>
struct device_type<bf16>
{
using type = gpu_bf16;
};

template <class T>
struct host_type
{
Expand All @@ -99,6 +107,12 @@ struct host_type<gpu_half>
using type = half;
};

template <>
struct host_type<gpu_bf16>
{
using type = bf16;
};

} // namespace detail

template <class T>
Expand Down Expand Up @@ -143,23 +157,53 @@ __device__ __host__ T to_hip_type(T x)
return x;
}

// Hip doens't support __fp16
// Hip doens't support __fp16 and __bf16
inline __device__ __host__ float to_hip_type(gpu_half x) { return x; }
inline __device__ __host__ float to_hip_type(gpu_bf16 x) { return x; }

template <class X>
struct is_floating_point : std::is_floating_point<X>
{
};

template <>
struct is_floating_point<__fp16> : std::true_type
{
};

template <class X>
struct is_signed : std::is_signed<X>
{
};

template <>
struct is_signed<__fp16> : std::true_type
{
};

template <class X>
struct is_arithmetic : std::is_arithmetic<X>
{
};

template <>
struct is_arithmetic<__fp16> : std::true_type
{
};

#define MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \
struct trait<T> : std::true_type \
{ \
};

MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, __fp16)
MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_signed, __fp16)
MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, __fp16)
// Redo for __bf16
template <>
struct is_floating_point<__bf16> : std::true_type
{
};
template <>
struct is_signed<__bf16> : std::true_type
{
};
template <>
struct is_arithmetic<__bf16> : std::true_type
{
};

} // namespace device
} // namespace gpu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ template <>
struct is_hip_type<std::int32_t> : std::true_type
{
};
template <>
struct is_hip_type<bf16> : std::true_type
{
};

template <class T, class V, MIGRAPHX_REQUIRES(is_hip_type<typename T::type>{})>
void hip_visitor_invoke(T as, V&& v)
Expand Down
8 changes: 7 additions & 1 deletion src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
const auto& name = i.name();
const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::bf16_type,
type_t::half_type,
type_t::fp8e4m3fnuz_type,
type_t::fp8e5m2fnuz_type,
Expand Down Expand Up @@ -439,6 +440,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
};
std::set<shape::type_t> float_types = {type_t::float_type,
type_t::half_type,
type_t::bf16_type,
type_t::fp8e4m3fnuz_type,
type_t::fp8e5m2fnuz_type,
type_t::fp8e4m3fn_type,
Expand All @@ -459,7 +461,8 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
return false;
} // else
return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) {
return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type());
return contains({type_t::float_type, type_t::half_type, type_t::bf16_type},
arg->get_shape().type());
});
}
return false;
Expand All @@ -472,10 +475,12 @@ bool is_reduce_op_supported_by_mlir(const instruction& i)
const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::half_type,
type_t::bf16_type,
type_t::fp8e4m3fnuz_type,
type_t::fp8e5m2fnuz_type,
type_t::fp8e4m3fn_type,
type_t::fp8e5m2_type};

// Preliminary type check.
if(not contains(allowed_types, result_type))
{
Expand Down Expand Up @@ -732,6 +737,7 @@ struct find_mlir_standalone_op
if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) {
return not contains({shape::type_t::float_type,
shape::type_t::half_type,
shape::type_t::bf16_type,
shape::type_t::int8_type,
shape::type_t::fp8e4m3fnuz_type,
shape::type_t::fp8e5m2fnuz_type,
Expand Down
2 changes: 1 addition & 1 deletion src/targets/gpu/gemm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ struct gemm_impl
compute_type = rb_compute_type{output_type};
if(compute_fp32)
{
if(arg_type == rocblas_datatype_f16_r)
if(arg_type == rocblas_datatype_f16_r or arg_type == rocblas_datatype_bf16_r)
compute_type = rocblas_datatype_f32_r;
}
if(arg_type == rocblas_datatype_f8_r)
Expand Down
2 changes: 2 additions & 0 deletions src/targets/gpu/include/migraphx/gpu/miopen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ inline tensor_descriptor make_tensor(const migraphx::shape& os)
d = miopenInt32;
else if(s.type() == shape::int8_type)
d = miopenInt8;
else if(s.type() == shape::bf16_type)
d = miopenBFloat16;
else
MIGRAPHX_THROW("MAKE_TENSOR: unsupported type");
miopenSetTensorDescriptor(t.get(), d, s.lens().size(), lens.data(), strides.data());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ using vec = T __attribute__((ext_vector_type(N)));

using half = _Float16;
using half2 = migraphx::vec<half, 2>;
using bf16 = __bf16;

} // namespace migraphx

Expand Down
3 changes: 2 additions & 1 deletion src/targets/gpu/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,8 @@ struct miopen_apply

static bool use_miopen_pooling(instruction_ref ins)
{
if(enabled(MIGRAPHX_DISABLE_MIOPEN_POOLING{}))
if(enabled(MIGRAPHX_DISABLE_MIOPEN_POOLING{}) or
not contains({shape::float_type, shape::half_type}, ins->get_shape().type()))
return false;
auto&& op = ins->get_operator();
auto op_val = op.to_value();
Expand Down
24 changes: 13 additions & 11 deletions src/targets/gpu/mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ struct mlir_program
result = mlirF32TypeGet(ctx.get());
else if(as.type_enum() == shape::half_type)
result = mlirF16TypeGet(ctx.get());
else if(as.type_enum() == shape::bf16_type)
result = mlirBF16TypeGet(ctx.get());
else if(as.type_enum() == shape::fp8e4m3fnuz_type)
result = mlirFloat8E4M3FNUZTypeGet(ctx.get());
else if(as.type_enum() == shape::fp8e5m2fnuz_type)
Expand Down Expand Up @@ -444,15 +446,15 @@ struct mlir_program
}

using attribute_t = std::variant<std::nullptr_t,
std::uint64_t,
unsigned char,
bool,
double,
std::string,
value,
std::vector<value>,
MlirType,
MlirAttribute>;
std::uint64_t,
unsigned char,
bool,
double,
std::string,
value,
std::vector<value>,
MlirType,
MlirAttribute>;
using named_attribute_t = std::pair<std::string_view, attribute_t>;

MlirNamedAttribute name_attribute(const named_attribute_t& na) const
Expand Down Expand Up @@ -1155,7 +1157,7 @@ mlir_code_object compile_mlir(const context& migraphx_ctx,
const std::lock_guard<std::mutex> lock(mutex);
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
}
auto co = mp.compile(solution);
auto co = mp.compile(solution);

co.expected_inputs = in_shapes;
auto out_shapes = m.get_output_shapes();
Expand Down Expand Up @@ -1248,7 +1250,7 @@ void dump_mlir_to_mxr(module m,
sizes.insert(sizes.end(), ins->inputs().begin(), ins->inputs().end());
}
auto name = compute_dump_name(m, ".mxr");
auto f = location / name;
auto f = location / name;
std::cout << "Dumping MXR file to: " << f << std::endl;
save(program{std::move(m)}, f.string());
}
Expand Down
1 change: 1 addition & 0 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types.erase(shape::type_t::uint8_type);
unsupported_types.erase(shape::type_t::int32_type);
unsupported_types.erase(shape::type_t::tuple_type);
unsupported_types.erase(shape::type_t::bf16_type);

// whiltelist supported Ops for the FP8 types
// different between fp8e4m3fnuz and OCP types because rocBLAS only has
Expand Down
Loading