Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
context
_keys
are stored in the pytree flatten specs. All other arguments/parameters would be derived accordingly.stride
(int) of a KJT, which represents thebatch_size
, is computed by_maybe_compute_stride_kjt
:batch_size
is static, however, this no longer holds true in a variable batch size scenario, where thestride_per_key_per_rank
is notNone
.dedup_ebc
, where the actual batch_size is variable (depending on the dedup data), but the output of the ebc should always be the truestride
(static).kjt.stride()
function, which would be incorrect if the pytree specs only contains thekeys
.stride
into the KJT pytree flatten/unflatten functions so that a fakified KJT would have the correct stride value.Differential Revision: D66400821