Skip to content

Commit

Permalink
Remove using splitkv from fmha-fwd training path
Browse files Browse the repository at this point in the history
  • Loading branch information
qianfengz committed Oct 16, 2024
1 parent d4437ad commit f94fdfd
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 52 deletions.
32 changes: 6 additions & 26 deletions xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#pragma once

#include "ck_tiled_fmha_batched_forward_dispatch.h"
#include "ck_tiled_fmha_batched_forward_splitkv_dispatch.h"

template <
typename ScalarType,
Expand All @@ -18,29 +17,10 @@ template <
void run_batched_forward_causalmask_bias_dropout_dispatch(
BatchedForwardParams& param,
hipStream_t stream) {
// currently split-kv implementation does not support dropout
if constexpr (!kHasDropout) {
#ifndef FMHA_FWD_SPLITKV_NOT_USED
if (param.use_split_kv)
batched_forward_splitkv_causalmask_bias_dropout_dispatch<
ScalarType,
kHasCausalMask,
kHasBias,
MaxK>::Run(param, stream);
else
#endif
batched_forward_causalmask_bias_dropout_dispatch<
ScalarType,
kHasCausalMask,
kHasBias,
kHasDropout,
MaxK>::Run(param, stream);
} else {
batched_forward_causalmask_bias_dropout_dispatch<
ScalarType,
kHasCausalMask,
kHasBias,
kHasDropout,
MaxK>::Run(param, stream);
}
batched_forward_causalmask_bias_dropout_dispatch<
ScalarType,
kHasCausalMask,
kHasBias,
kHasDropout,
MaxK>::Run(param, stream);
};
32 changes: 6 additions & 26 deletions xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#pragma once

#include "ck_tiled_fmha_grouped_forward_dispatch.h"
#include "ck_tiled_fmha_grouped_forward_splitkv_dispatch.h"

template <
typename ScalarType,
Expand All @@ -18,29 +17,10 @@ template <
void run_grouped_forward_causalmask_bias_dropout_dispatch(
GroupedForwardParams& param,
hipStream_t stream) {
// currently split-kv implementation does not support dropout
if constexpr (!kHasDropout) {
#ifndef FMHA_FWD_SPLITKV_NOT_USED
if (param.use_split_kv)
grouped_forward_splitkv_causalmask_bias_dropout_dispatch<
ScalarType,
kHasCausalMask,
kHasBias,
MaxK>::Run(param, stream);
else
#endif
grouped_forward_causalmask_bias_dropout_dispatch<
ScalarType,
kHasCausalMask,
kHasBias,
kHasDropout,
MaxK>::Run(param, stream);
} else {
grouped_forward_causalmask_bias_dropout_dispatch<
ScalarType,
kHasCausalMask,
kHasBias,
kHasDropout,
MaxK>::Run(param, stream);
}
grouped_forward_causalmask_bias_dropout_dispatch<
ScalarType,
kHasCausalMask,
kHasBias,
kHasDropout,
MaxK>::Run(param, stream);
};

0 comments on commit f94fdfd

Please sign in to comment.