-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: cross entropy for transformers>4.45 (#123)
* trigger-only pattern for custom loss Signed-off-by: Yu Chin Fabian Lim <[email protected]> * add cross ent fix for llama, mistral, mixtral Signed-off-by: Anh Uong <[email protected]> * fix linting errors Signed-off-by: Anh Uong <[email protected]> * run formatter Signed-off-by: Anh Uong <[email protected]> * fix misspelling and error test Signed-off-by: Anh Uong <[email protected]> * fix import error with later transformers Signed-off-by: Anh Uong <[email protected]> * add benchmarks Signed-off-by: Anh Uong <[email protected]> * fix import order Signed-off-by: Anh Uong <[email protected]> * replace benchmark and requirements Signed-off-by: Anh Uong <[email protected]> --------- Signed-off-by: Yu Chin Fabian Lim <[email protected]> Signed-off-by: Anh Uong <[email protected]> Co-authored-by: Yu Chin Fabian Lim <[email protected]>
- Loading branch information
Showing
14 changed files
with
465 additions
and
224 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,11 +33,14 @@ | |
PreTrainedModel, | ||
) | ||
from transformers.modeling_utils import ( | ||
dtype_byte_size, | ||
is_local_dist_rank_0, | ||
no_init_weights, | ||
shard_checkpoint, | ||
) | ||
from transformers.pytorch_utils import id_tensor_storage | ||
from transformers.utils import WEIGHTS_NAME | ||
from transformers.utils.generic import ContextManagers | ||
from transformers.utils.hub import convert_file_size_to_int | ||
import accelerate | ||
import torch | ||
import torch.nn as nn | ||
|
@@ -688,7 +691,7 @@ def save_quantized( | |
torch.save(model.state_dict(), join(save_dir, model_save_name)) | ||
else: | ||
# Shard checkpoint | ||
shards, index = shard_checkpoint( | ||
shards, index = self.shard_checkpoint( | ||
state_dict, max_shard_size=max_shard_size, weights_name=model_save_name | ||
) | ||
|
||
|
@@ -766,6 +769,106 @@ def save_quantized( | |
quantize_config.model_file_base_name = model_base_name | ||
quantize_config.save_pretrained(save_dir) | ||
|
||
# added by [email protected] | ||
# adapted from transformers.modeling_utils.shard_checkpoint | ||
# from transformers v4.46, removed in later versions | ||
# TODO: split_torch_state_dict_into_shards from huggingface_hub library | ||
def shard_checkpoint( | ||
self, | ||
state_dict: Dict[str, torch.Tensor], | ||
max_shard_size: Union[int, str] = "10GB", | ||
weights_name: str = WEIGHTS_NAME, | ||
): | ||
""" | ||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a | ||
given size. | ||
The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no | ||
optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the | ||
limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], | ||
[6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB]. | ||
<Tip warning={true}> | ||
If one of the model's weight is bigger than `max_shard_size`, it will end up in its own sub-checkpoint which will | ||
have a size greater than `max_shard_size`. | ||
</Tip> | ||
Args: | ||
state_dict (`Dict[str, torch.Tensor]`): The state dictionary of a model to save. | ||
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): | ||
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit | ||
(like `"5MB"`). | ||
weights_name (`str`, *optional*, defaults to `"pytorch_model.bin"`): | ||
The name of the model save file. | ||
""" | ||
logger.warning( | ||
"Note that `shard_checkpoint` is deprecated and will be removed in v4.44. We recommend you using " | ||
"split_torch_state_dict_into_shards from huggingface_hub library" | ||
) | ||
max_shard_size = convert_file_size_to_int(max_shard_size) | ||
|
||
sharded_state_dicts = [{}] | ||
last_block_size = 0 | ||
total_size = 0 | ||
storage_id_to_block = {} | ||
|
||
for key, weight in state_dict.items(): | ||
# when bnb serialization is used the weights in the state dict can be strings | ||
# check: https://github.com/huggingface/transformers/pull/24416 for more details | ||
if isinstance(weight, str): | ||
continue | ||
else: | ||
storage_id = id_tensor_storage(weight) | ||
|
||
# If a `weight` shares the same underlying storage as another tensor, we put `weight` in the same `block` | ||
if storage_id in storage_id_to_block and weight.device != torch.device( | ||
"meta" | ||
): | ||
block_id = storage_id_to_block[storage_id] | ||
sharded_state_dicts[block_id][key] = weight | ||
continue | ||
|
||
weight_size = weight.numel() * dtype_byte_size(weight.dtype) | ||
# If this weight is going to tip up over the maximal size, we split, but only if we have put at least one | ||
# weight in the current shard. | ||
if ( | ||
last_block_size + weight_size > max_shard_size | ||
and len(sharded_state_dicts[-1]) > 0 | ||
): | ||
sharded_state_dicts.append({}) | ||
last_block_size = 0 | ||
|
||
sharded_state_dicts[-1][key] = weight | ||
last_block_size += weight_size | ||
total_size += weight_size | ||
storage_id_to_block[storage_id] = len(sharded_state_dicts) - 1 | ||
|
||
# If we only have one shard, we return it | ||
if len(sharded_state_dicts) == 1: | ||
return {weights_name: sharded_state_dicts[0]}, None | ||
|
||
# Otherwise, let's build the index | ||
weight_map = {} | ||
shards = {} | ||
for idx, shard in enumerate(sharded_state_dicts): | ||
shard_file = weights_name.replace( | ||
".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin" | ||
) | ||
shard_file = shard_file.replace( | ||
".safetensors", | ||
f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors", | ||
) | ||
shards[shard_file] = shard | ||
for key in shard.keys(): | ||
weight_map[key] = shard_file | ||
|
||
# Add the metadata | ||
metadata = {"total_size": total_size} | ||
index = {"metadata": metadata, "weight_map": weight_map} | ||
return shards, index | ||
|
||
def save_pretrained( | ||
self, | ||
save_dir: str, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
import triton.language as tl | ||
import torch | ||
from .utils import calculate_settings, MAX_FUSED_SIZE | ||
from typing import Type | ||
|
||
|
||
@triton.jit | ||
|
@@ -290,3 +291,55 @@ def forward(self, input, target): | |
) | ||
n_items = torch.count_nonzero(target != -100) | ||
return loss.sum() / n_items | ||
|
||
|
||
# added by [email protected] | ||
|
||
# adapted from transformers.loss.loss_utils.ForCausalLMLoss | ||
def FastForCausalLMLoss( | ||
logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs | ||
): | ||
# Upcast to float if we need to compute the loss to avoid potential precision issues | ||
logits = logits.float() | ||
labels = labels.to(logits.device) | ||
# Shift so that tokens < n predict n | ||
shift_logits = logits[..., :-1, :].contiguous() | ||
shift_labels = labels[..., 1:].contiguous() | ||
|
||
# Flatten the tokens | ||
shift_logits = shift_logits.view(-1, vocab_size) | ||
shift_labels = shift_labels.view(-1) | ||
# Enable model parallelism | ||
shift_labels = shift_labels.to(shift_logits.device) | ||
|
||
reduction = "sum" if num_items_in_batch is not None else "mean" | ||
assert ignore_index == -100, "FastForCausalLMLoss currently supports only hardcoded ignore index -100." | ||
loss = Fast_CrossEntropyLoss.apply( | ||
shift_logits, shift_labels | ||
) | ||
if reduction == "sum": | ||
n_items = num_items_in_batch | ||
else: | ||
n_items = torch.count_nonzero(shift_labels != -100) | ||
return loss.sum() / n_items | ||
|
||
|
||
def replace_custom_loss_when_triggered( | ||
module_cls: Type, | ||
custom_loss_type: str, | ||
): | ||
|
||
# this is a special trigger that will perform the replacement | ||
def _trigger(mod): | ||
if isinstance (mod, module_cls) and hasattr(mod, "loss_function"): | ||
# guarded | ||
from transformers.loss.loss_utils import LOSS_MAPPING | ||
LOSS_MAPPING[custom_loss_type] = FastForCausalLMLoss | ||
mod.loss_type = custom_loss_type | ||
return True | ||
|
||
return False | ||
|
||
return _trigger | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.