Skip to content

Commit

Permalink
add backward compatible reference for _get_unflattened_lengths (#2541)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #2541

Reviewed By: PaulZhang12

Differential Revision: D65490058

fbshipit-source-id: 7fe44cc56bf5b72abe20e10911c0c288905c2dd1
  • Loading branch information
seanx92 authored and facebook-github-bot committed Nov 5, 2024
1 parent 0512183 commit 7e867ad
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions torchrec/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,17 @@ def _slice_1d_tensor(tensor: torch.Tensor, start: int, end: int) -> torch.Tensor
return tensor[start:end]


# PLEASE DO NOT USE THIS FUNCTION, THIS FUNCTION IS FOR BACKWARD COMPATIBILITY ONLY
# USE THE ONE IN torchrec/quant/embedding_modules.py
# TODO(@shuaoxiong): remove this function after we make sure all models switch to the new reference
@torch.fx.wrap
def _get_unflattened_lengths(lengths: torch.Tensor, num_features: int) -> torch.Tensor:
"""
Unflatten lengths tensor from [F * B] to [F, B].
"""
return lengths.view(num_features, -1)


def extract_module_or_tensor_callable(
module_or_callable: Union[
Callable[[], torch.nn.Module],
Expand Down

0 comments on commit 7e867ad

Please sign in to comment.