Skip to content

Commit

Permalink
Make inputs actually contiguously laid out in memory (pytorch#7072)
Browse files Browse the repository at this point in the history
  • Loading branch information
dvorjackz authored Nov 26, 2024
1 parent 97a8a89 commit 2a292c3
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions examples/models/llama3_2_vision/text_decoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,22 @@ def get_example_inputs(self):
def get_example_kwarg_inputs(self):
# For export we must use the prefill versions of the
# causal mask and input_pos.

# Make input_pos and mask contiguous in memory.
input_pos = self.input_pos[None, : self.n_tokens]
mask = self.causal_mask[None, : self.n_tokens]
contiguous_input_pos = torch.empty_like(
input_pos, memory_format=torch.contiguous_format
)
contiguous_input_pos.data.copy_(input_pos.data)
contiguous_mask = torch.empty_like(mask, memory_format=torch.contiguous_format)
contiguous_mask.data.copy_(mask.data)

# Hardcoding # of tiles to be 2. image tokens per tile is 1601.
if self.use_kv_cache:
return {
"input_pos": self.input_pos[None, : self.n_tokens],
"mask": self.causal_mask[None, : self.n_tokens],
"input_pos": contiguous_input_pos,
"mask": contiguous_mask,
"encoder_input": torch.randn(
1, self.encoder_max_seq_len, self.model_.dim, dtype=self.dtype
),
Expand Down

0 comments on commit 2a292c3

Please sign in to comment.