Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partial revert of D66986498 #3620

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -528,320 +528,3 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row
////////////////////////////////////////////////////////////////////////////////

{%- endif %}

{%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and not dense and not is_gwd_kernel and not vbe and not ssd %}
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include "fbgemm_gpu/rocm/split_embeddings_common.h"
#include "gen_embedding_backward_split_{{ desc_suffix }}{{ ndesc }}_device_kernel_hip.hip"

template <
typename emb_t,
typename grad_t,
typename cache_t,
typename index_t,
int32_t kFixedMaxVecsPerThread,
int32_t kThreadGroupSize,
bool kUseVecBlocking,
int32_t embedding_dim,
int32_t weight_decay_mode_v>
__global__ void
hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1(
const pta::PackedTensorAccessor64<grad_t, {{ "1" if is_index_select else "2" }}, at::RestrictPtrTraits> grad_output,
{%- if optimizer != "none" %}
pta::PackedTensorAccessor64<emb_t, 1, at::RestrictPtrTraits> dev_weights,
{%- if not dense %}
pta::PackedTensorAccessor64<emb_t, 1, at::RestrictPtrTraits> uvm_weights,
pta::PackedTensorAccessor64<cache_t, 2, at::RestrictPtrTraits> lxu_cache_weights,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> weights_placements,
{%- endif %}
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> weights_offsets,
{%- if not nobag or is_index_select %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
{%- else %}
int64_t D,
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
{%- if not nobag %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_infos,
{%- else %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_infos,
{%- endif %}
{%- if not dense %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_lxu_cache_locations,
const bool use_uniq_cache_locations,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> table_unique_indices_offsets,
{%- endif %}
{%- if weighted %}
const pta::PackedTensorAccessor32<at::acc_type<cache_t, true>, 1, at::RestrictPtrTraits> sorted_indice_weights,
{%- endif %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_num_runs,
int32_t max_segment_length_per_warp,
{%- if not dense and optimizer != "none" %}
bool stochastic_rounding,
at::PhiloxCudaState stochastic_rounding_philox_args,
{%- else %}
pta::PackedTensorAccessor64<emb_t, 1, at::RestrictPtrTraits> grad_dev_weights,
{%- endif %} // if not dense and optimizer != "none"
{%- if not nobag and vbe %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> B_offsets,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> row_output_offsets,
{%- endif %}
{%- if not nobag %}
const int32_t info_B_num_bits,
const uint32_t info_B_mask,
{%- endif %}
const int32_t max_D,
const int32_t max_vecs_per_thread,
{%- if is_index_select %}
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_offsets,
const bool permute_output_dim_0_1
{%- else %}
{{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }}
{%- endif %}
) {
{%- if not nobag %}
int32_t T = D_offsets.size(0) - 1;
{%- else %}
int32_t T = weights_offsets.size(0);
{%- endif %}

auto p_output_grad = grad_output.data();
auto p_emb_table = dev_weights.data();
auto p_hash_size_cumsum = hash_size_cumsum.data();
auto p_sorted_linear_indices_run = sorted_linear_indices_run.data();
auto p_sorted_linear_indices_cumulative_run_lengths = sorted_linear_indices_cumulative_run_lengths.data();
auto p_sorted_linear_indices_num_runs = sorted_linear_indices_num_runs.data();
auto p_sorted_infos = sorted_infos.data();
{%- if weighted %}
auto p_indice_weights_sorted = sorted_indice_weights.data();
{%- endif %}
auto emb_dim = embedding_dim;
constexpr int32_t segment_prefetch = 2;
constexpr int32_t segment_unroll = 8;
constexpr int32_t segment_split = 0;
auto batch = grad_output.size(0);
auto num_rows = dev_weights.size(0) / T / max_D;
{%- if weighted %}
constexpr bool is_weighted = true;
{%- else %}
constexpr bool is_weighted = false;
{%- endif %}
rocm::{{optimizer}}_kernel_arg_t opt_karg;
opt_karg.p_momentum = momentum1_dev.data();
opt_karg.eps = eps;
opt_karg.learning_rate = learning_rate;
// weight_decay(_mode) is supplied as args.split_function_args_no_defaults
opt_karg.weight_decay_mode = weight_decay_mode_v;
opt_karg.weight_decay = weight_decay;
auto batch_mdiv = [](uint32_t d) -> rocm::magic_div_u32_t {
assert(d >= 1 && d <= INT32_MAX);
uint8_t shift;
for(shift = 0; shift < 32; shift++)
if((1U << shift) >= d)
break;

uint64_t one = 1;
uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1;
assert(magic <= 0xffffffffUL);

rocm::magic_div_u32_t result;
result.magic = magic;
result.shift = shift;
return result;
}(batch);
rocm::split_tbe_backward_hip_kernel_{{kdesc}}<
rocm::{{optimizer}}_optimizer_t<cache_t, emb_t, embedding_dim, weight_decay_mode_v>,
rocm::{{optimizer}}_kernel_arg_t,
emb_t,
cache_t,
grad_t,
index_t,
BLOCK_SIZE,
embedding_dim,
segment_prefetch,
segment_unroll,
segment_split,
is_weighted>(p_output_grad,
p_emb_table,
p_hash_size_cumsum,
p_sorted_linear_indices_run,
p_sorted_linear_indices_cumulative_run_lengths,
p_sorted_linear_indices_num_runs,
{%- if not nobag %}
info_B_num_bits,
info_B_mask,
{%- endif %}
p_sorted_infos,
batch_mdiv,
max_segment_length_per_warp,
emb_dim,
batch,
num_rows,
T,
opt_karg
{%- if weighted %}
, p_indice_weights_sorted
{%- endif %});
}

{%- macro hip_template_instantiation(
emb_type,
grad_type,
cache_type,
index_type,
kFixedMaxVecsPerThread,
kThreadGroupSize,
kUseVecBlocking,
kEmbeddingDim,
kWeighDecayMode
)
%}
template __global__ __launch_bounds__(kBackwardMaxThreads) void
hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1
< {{ emb_type }},
{{ grad_type }},
{{ cache_type }},
{{ index_type }},
{{ kFixedMaxVecsPerThread }},
{{ kThreadGroupSize }},
{{ kUseVecBlocking }},
{{ kEmbeddingDim }},
{{ kWeighDecayMode }}
> (
const pta::PackedTensorAccessor64<{{ grad_type }}, {{ "1" if is_index_select else "2" }}, at::RestrictPtrTraits> grad_output,
{%- if optimizer != "none" %}
pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> dev_weights,
{%- if not dense %}
pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> uvm_weights,
pta::PackedTensorAccessor64<{{ cache_type }}, 2, at::RestrictPtrTraits> lxu_cache_weights,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> weights_placements,
{%- endif %}
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> weights_offsets,
{%- if not nobag or is_index_select %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
{%- else %}
int64_t D,
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
{%- if not nobag %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_infos,
{%- else %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_infos,
{%- endif %}
{%- if not dense %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_lxu_cache_locations,
const bool use_uniq_cache_locations,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> table_unique_indices_offsets,
{%- endif %}
{%- if weighted %}
const pta::PackedTensorAccessor32<at::acc_type<{{ cache_type }}, true>, 1, at::RestrictPtrTraits> sorted_indice_weights,
{%- endif %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_num_runs,
int32_t max_segment_length_per_warp,
{%- if not dense and optimizer != "none" %}
bool stochastic_rounding,
at::PhiloxCudaState stochastic_rounding_philox_args,
{%- else %}
pta::PackedTensorAccessor64< {{ emb_type }}, 1, at::RestrictPtrTraits> grad_dev_weights,
{%- endif %} // if not dense and optimizer != "none"
{%- if not nobag and vbe %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> B_offsets,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> row_output_offsets,
{%- endif %}
{%- if not nobag %}
const int32_t info_B_num_bits,
const uint32_t info_B_mask,
{%- endif %}
const int32_t max_D,
const int32_t max_vecs_per_thread,
{%- if is_index_select %}
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_offsets,
const bool permute_output_dim_0_1
{%- else %}
{{ args.split_kernel_args_no_defaults | replace_pta_namespace() | join(",\n ") | replace("cache_t", cache_type) }}
{%- endif %}
);
{%- endmacro %}

{%- macro hip_bulk_template_instantiations(kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking) %}
{%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %}
{%- for emb_type in ['float', 'at::Half'] %}
{%- for cache_type in ['float', 'at::Half'] %}
{%- for index_type in ['int32_t', 'int64_t'] %}
{%- for kEmbeddingDim in [64, 128, 160, 192, 256] %}
{%- for kWeighDecayMode in [0, 1, 2] %}
{{ hip_template_instantiation(
emb_type,
grad_type,
cache_type,
index_type,
kFixedMaxVecsPerThread,
kThreadGroupSize,
kUseVecBlocking,
kEmbeddingDim,
kWeighDecayMode
)
}}
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endmacro %}

{%- macro hip_instantiate_templates(use_subwarp_shuffle) %}
{%- for (kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking)
in get_max_vecs_template_configs(
items_per_warp,
fixed_max_vecs_per_thread["backward"],
use_subwarp_shuffle,
use_vec_blocking=True,
)
%}
{{
hip_bulk_template_instantiations(
kFixedMaxVecsPerThread,
kThreadGroupSize,
kUseVecBlocking,
)
}}
{%- endfor %}
{%- endmacro %}

////////////////////////////////////////////////////////////////////////////////
#ifdef FBGEMM_USE_SUBWARP_SHUFFLE
////////////////////////////////////////////////////////////////////////////////

{#- /*
Explicitly instantiate kernels for the FBGEMM_USE_SUBWARP_SHUFFLE case
Please see get_max_vecs_template_configs in
codegen/embedding_common_code_generator.py for more details
*/ #}

{{ hip_instantiate_templates(use_subwarp_shuffle=True) }}

////////////////////////////////////////////////////////////////////////////////
#else
////////////////////////////////////////////////////////////////////////////////

{#- /*
Explicitly instantiate kernels for the non-FBGEMM_USE_SUBWARP_SHUFFLE case
Please see get_max_vecs_template_configs in
codegen/embedding_common_code_generator.py for more details
*/ #}

{{ hip_instantiate_templates(use_subwarp_shuffle=False) }}

////////////////////////////////////////////////////////////////////////////////
#endif
////////////////////////////////////////////////////////////////////////////////
{%- endif %}
// clang-format on
Loading
Loading