Skip to content

Commit

Permalink
Cleaner .to()
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed Jan 29, 2025
1 parent 8ccde63 commit 8c69579
Show file tree
Hide file tree
Showing 32 changed files with 32 additions and 32 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,7 +1126,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/ctrl/modeling_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/diffllama/modeling_diffllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,7 +1220,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,7 +1363,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/glm/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,7 +1400,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,7 +1247,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/helium/modeling_helium.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1683,7 +1683,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/jetmoe/modeling_jetmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1459,7 +1459,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/nemotron/modeling_nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,7 +1198,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/openai/modeling_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/opt/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,7 +1302,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/persimmon/modeling_persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,7 +1061,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/phimoe/modeling_phimoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,7 +1601,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1442,7 +1442,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/stablelm/modeling_stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,7 +1269,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/starcoder2/modeling_starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/zamba/modeling_zamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,7 +1436,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/zamba2/modeling_zamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1862,7 +1862,7 @@ def forward(
else:
if input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
else:
Expand Down

0 comments on commit 8c69579

Please sign in to comment.