Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix model kwargs #35875

Merged
merged 45 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
d3c618e
Save state
muellerzr Jan 23, 2025
c489527
Make a failing test
muellerzr Jan 24, 2025
8a58190
Better test
muellerzr Jan 24, 2025
4348e36
mpt -> done, many more to go
muellerzr Jan 24, 2025
3b3dfd2
Rm extranious
muellerzr Jan 24, 2025
2bf5390
Bamba
muellerzr Jan 24, 2025
34f9060
Bert
muellerzr Jan 24, 2025
3960502
big_bird
muellerzr Jan 24, 2025
a87ed15
biogpt
muellerzr Jan 24, 2025
2705ae6
bloom
muellerzr Jan 24, 2025
33e718b
codegen
muellerzr Jan 24, 2025
e215848
ctrl
muellerzr Jan 24, 2025
72459fa
data2vec
muellerzr Jan 24, 2025
212ee51
dbrx
muellerzr Jan 24, 2025
8159793
Through up to Dbrx
muellerzr Jan 24, 2025
f5cf781
electra
muellerzr Jan 24, 2025
96e26f6
ernie
muellerzr Jan 24, 2025
1ac07d3
falcon
muellerzr Jan 24, 2025
9666691
Fuyu/persimmon
muellerzr Jan 24, 2025
d2d8f8e
Include noop kwargs to base models
muellerzr Jan 24, 2025
bf112ca
Rebase
muellerzr Feb 5, 2025
308b91d
Skip musigen
muellerzr Jan 30, 2025
ad5e487
Refactor/skip mllama
muellerzr Jan 30, 2025
14c121d
Revert makefile
muellerzr Jan 30, 2025
fcf896c
Rm file
muellerzr Jan 30, 2025
24b59bf
Fix PT failing, need to modify rest of loss funcs to not resize
muellerzr Feb 3, 2025
6320ab4
Propagate some
muellerzr Feb 3, 2025
44530b6
Continue
muellerzr Feb 3, 2025
978dbbe
More
muellerzr Feb 3, 2025
ea4484e
More options
muellerzr Feb 3, 2025
12627ef
Mostly fixed
muellerzr Feb 3, 2025
dc42e65
Proved that it's the same
muellerzr Feb 5, 2025
9f23ae7
Bloom is good
muellerzr Feb 5, 2025
12c00f6
Make ability to override loss func possible
muellerzr Feb 5, 2025
b6fb606
Fixup
muellerzr Feb 5, 2025
cfb3bcf
Clean
muellerzr Feb 5, 2025
f7eda3b
Fix xglm
muellerzr Feb 5, 2025
6d34419
Quality tests
muellerzr Feb 5, 2025
c103851
Skip OCR2
muellerzr Feb 5, 2025
bde0bef
Make specific loss for xglm
muellerzr Feb 6, 2025
2f951dd
Make order the same/line up 1:1
muellerzr Feb 6, 2025
5204b53
xglm
muellerzr Feb 6, 2025
038dc55
Skip fx output loss bloom model
muellerzr Feb 6, 2025
6033db8
Didn't pass in pad_token_id
muellerzr Feb 6, 2025
ff06a1d
Fix quality
muellerzr Feb 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5198,6 +5198,9 @@ def tplize(mod: torch.nn.Module) -> None:

@property
def loss_function(self):
if hasattr(self, "_loss_function"):
return self._loss_function

Comment on lines +5201 to +5203
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker this was needed to be added for a few models that don't need everything the loss func was up to. Case was xglm

loss_type = getattr(self, "loss_type", None)

if loss_type is None or loss_type not in LOSS_MAPPING:
Expand All @@ -5208,6 +5211,10 @@ def loss_function(self):
loss_type = "ForCausalLM"
return LOSS_MAPPING[loss_type]

@loss_function.setter
def loss_function(self, value):
self._loss_function = value

def get_compiled_call(self, compile_config: CompileConfig):
"""Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between
non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # NOOP kwargs, for now
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/bamba/modular_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,6 +949,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # NOOP kwargs, for now
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
Expand Down Expand Up @@ -734,6 +733,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs, # NOOP kwargs, for now
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Expand Down Expand Up @@ -901,6 +901,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Expand Down Expand Up @@ -963,18 +964,20 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)

lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
lm_loss = self.loss_function(
prediction_scores,
labels,
vocab_size=self.config.vocab_size,
**kwargs,
)

if not return_dict:
output = (prediction_scores,) + outputs[1:]
Expand Down
14 changes: 9 additions & 5 deletions src/transformers/models/big_bird/modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -1983,6 +1983,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs, # NOOP kwargs, for now
) -> Union[BaseModelOutputWithPoolingAndCrossAttentions, Tuple[torch.FloatTensor]]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Expand Down Expand Up @@ -2540,6 +2541,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[CausalLMOutputWithCrossAttentions, Tuple[torch.FloatTensor]]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Expand Down Expand Up @@ -2580,18 +2582,20 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)

lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
lm_loss = self.loss_function(
prediction_scores,
labels,
vocab_size=self.config.vocab_size,
**kwargs,
)

if not return_dict:
output = (prediction_scores,) + outputs[2:]
Expand Down
13 changes: 8 additions & 5 deletions src/transformers/models/biogpt/modeling_biogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs, # NOOP kwargs, for now
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -757,6 +758,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand All @@ -783,11 +785,12 @@ def forward(

lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
lm_loss = self.loss_function(
prediction_scores,
labels,
vocab_size=self.config.vocab_size,
**kwargs,
)

if not return_dict:
output = (prediction_scores,) + outputs[1:]
Expand Down
14 changes: 7 additions & 7 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,8 @@ def forward(
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
# Bloom has deprecated kwargs, so we need to pop num_items_in_batch explicitly
num_items_in_batch = deprecated_arguments.pop("num_items_in_batch", None)
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings.warn(
Expand Down Expand Up @@ -999,14 +1001,12 @@ def forward(
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
loss = self.loss_function(
lm_logits,
labels,
vocab_size=self.config.vocab_size,
num_items_in_batch=num_items_in_batch,
)

if not return_dict:
Expand Down
12 changes: 7 additions & 5 deletions src/transformers/models/camembert/modeling_camembert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1584,6 +1584,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Expand Down Expand Up @@ -1655,11 +1656,12 @@ def forward(
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(prediction_scores.device)
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
lm_loss = self.loss_function(
prediction_scores,
labels,
vocab_size=self.config.vocab_size,
**kwargs,
)

if not return_dict:
output = (prediction_scores,) + outputs[2:]
Expand Down
14 changes: 8 additions & 6 deletions src/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
Expand Down Expand Up @@ -450,6 +449,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # NOOP kwargs, for now
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -741,6 +741,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down Expand Up @@ -775,12 +776,13 @@ def forward(
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
loss = self.loss_function(
lm_logits,
labels,
vocab_size=self.config.vocab_size,
**kwargs,
)

loss = loss.to(hidden_states.dtype)

Expand Down
14 changes: 8 additions & 6 deletions src/transformers/models/ctrl/modeling_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs, # NOOP kwargs, for now
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
r"""
Returns:
Expand Down Expand Up @@ -537,6 +538,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down Expand Up @@ -593,12 +595,12 @@ def forward(

loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
loss = self.loss_function(
lm_logits,
labels,
vocab_size=self.config.vocab_size,
**kwargs,
)

if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
Expand Down
14 changes: 7 additions & 7 deletions src/transformers/models/data2vec/modeling_data2vec_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Expand Down Expand Up @@ -975,13 +976,12 @@ def forward(

lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()

labels = labels.to(shifted_prediction_scores.device)
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
lm_loss = self.loss_function(
prediction_scores,
labels,
vocab_size=self.config.vocab_size,
**kwargs,
)

if not return_dict:
output = (prediction_scores,) + outputs[2:]
Expand Down
18 changes: 8 additions & 10 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,7 @@ def forward(
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # NOOP kwargs, for now
) -> Union[Tuple, MoeModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -1278,6 +1279,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""Forward function for causal language modeling.

Expand Down Expand Up @@ -1344,16 +1346,12 @@ def forward(

loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
loss = self.loss_function(
logits,
labels,
vocab_size=self.config.vocab_size,
**kwargs,
)

aux_loss = None
if output_router_logits:
Expand Down
12 changes: 7 additions & 5 deletions src/transformers/models/electra/modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,6 +1564,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Expand Down Expand Up @@ -1633,11 +1634,12 @@ def forward(

lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
lm_loss = self.loss_function(
prediction_scores,
labels,
vocab_size=self.config.vocab_size,
**kwargs,
)

if not return_dict:
output = (prediction_scores,) + outputs[1:]
Expand Down
Loading