Skip to content

Commit

Permalink
do not call permute on empty tensor (#3705)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#787

Pull Request resolved: #3705

short circuit if pooled embedding(s) is empty, otherwise we run into invalid CUDA kernel launch args

this change is necessary to support zero batch size KeyedJaggedTensor embedding lookups

conceptually, we cannot permute an empty tensor (i.e. `tensor.numel()==0`) so return the empty tensor

Reviewed By: sryap

Differential Revision: D69635568

fbshipit-source-id: 4a8f710072a6de6d83c86a56c01e5434366d80dc
  • Loading branch information
sarckk authored and facebook-github-bot committed Feb 19, 2025
1 parent 607861a commit 95099a6
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ Tensor permute_pooled_embs_gpu_impl(
const Tensor& inv_offset_dim_list,
const Tensor& inv_permute_list,
const bool& allow_duplicates = false) {
if (pooled_embs.numel() == 0) {
return pooled_embs;
}

// inv_permute_list is not being used so it's not checked here.
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
pooled_embs, offset_dim_list, permute_list, inv_offset_dim_list);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ Tensor permute_pooled_embs_cpu_impl(
const Tensor& inv_offset_dim_list,
const Tensor& inv_permute_list,
const bool& allow_duplicates) {
if (pooled_embs.numel() == 0) {
return pooled_embs;
}
TORCH_CHECK(
offset_dim_list.scalar_type() == at::ScalarType::Long,
"offset_dim_list needs to have long/int64 type")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ Tensor permute_pooled_embs_split_gpu_impl(
const Tensor& inv_offset_dim_list,
const Tensor& inv_permute_list,
const bool& allow_duplicates) {
if (pooled_embs.numel() == 0) {
return pooled_embs;
}
// inv_permute_list is not being used so it's not checked here.
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
pooled_embs, offset_dim_list, permute_list, inv_offset_dim_list);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ Tensor permute_pooled_embs_split_cpu_impl(
const Tensor& inv_offset_dim_list,
const Tensor& inv_permute_list,
const bool& allow_duplicates) {
if (pooled_embs.numel() == 0) {
return pooled_embs;
}
TORCH_CHECK(
offset_dim_list.scalar_type() == at::ScalarType::Long,
"offset_dim_list needs to have long/int64 type")
Expand Down

0 comments on commit 95099a6

Please sign in to comment.