Skip to content

Commit

Permalink
fix: cross entropy for transformers>4.45 (#123)
Browse files Browse the repository at this point in the history
* trigger-only pattern for custom loss

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* add cross ent fix for llama, mistral, mixtral

Signed-off-by: Anh Uong <[email protected]>

* fix linting errors

Signed-off-by: Anh Uong <[email protected]>

* run formatter

Signed-off-by: Anh Uong <[email protected]>

* fix misspelling and error test

Signed-off-by: Anh Uong <[email protected]>

* fix import error with later transformers

Signed-off-by: Anh Uong <[email protected]>

* add benchmarks

Signed-off-by: Anh Uong <[email protected]>

* fix import order

Signed-off-by: Anh Uong <[email protected]>

* replace benchmark and requirements

Signed-off-by: Anh Uong <[email protected]>

---------

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Anh Uong <[email protected]>
Co-authored-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
anhuong and fabianlim authored Feb 7, 2025
1 parent 8787ca1 commit 24bdadb
Show file tree
Hide file tree
Showing 14 changed files with 465 additions and 224 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,14 @@
PreTrainedModel,
)
from transformers.modeling_utils import (
dtype_byte_size,
is_local_dist_rank_0,
no_init_weights,
shard_checkpoint,
)
from transformers.pytorch_utils import id_tensor_storage
from transformers.utils import WEIGHTS_NAME
from transformers.utils.generic import ContextManagers
from transformers.utils.hub import convert_file_size_to_int
import accelerate
import torch
import torch.nn as nn
Expand Down Expand Up @@ -688,7 +691,7 @@ def save_quantized(
torch.save(model.state_dict(), join(save_dir, model_save_name))
else:
# Shard checkpoint
shards, index = shard_checkpoint(
shards, index = self.shard_checkpoint(
state_dict, max_shard_size=max_shard_size, weights_name=model_save_name
)

Expand Down Expand Up @@ -766,6 +769,106 @@ def save_quantized(
quantize_config.model_file_base_name = model_base_name
quantize_config.save_pretrained(save_dir)

# added by [email protected]
# adapted from transformers.modeling_utils.shard_checkpoint
# from transformers v4.46, removed in later versions
# TODO: split_torch_state_dict_into_shards from huggingface_hub library
def shard_checkpoint(
self,
state_dict: Dict[str, torch.Tensor],
max_shard_size: Union[int, str] = "10GB",
weights_name: str = WEIGHTS_NAME,
):
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no
optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the
limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB],
[6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
<Tip warning={true}>
If one of the model's weight is bigger than `max_shard_size`, it will end up in its own sub-checkpoint which will
have a size greater than `max_shard_size`.
</Tip>
Args:
state_dict (`Dict[str, torch.Tensor]`): The state dictionary of a model to save.
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
(like `"5MB"`).
weights_name (`str`, *optional*, defaults to `"pytorch_model.bin"`):
The name of the model save file.
"""
logger.warning(
"Note that `shard_checkpoint` is deprecated and will be removed in v4.44. We recommend you using "
"split_torch_state_dict_into_shards from huggingface_hub library"
)
max_shard_size = convert_file_size_to_int(max_shard_size)

sharded_state_dicts = [{}]
last_block_size = 0
total_size = 0
storage_id_to_block = {}

for key, weight in state_dict.items():
# when bnb serialization is used the weights in the state dict can be strings
# check: https://github.com/huggingface/transformers/pull/24416 for more details
if isinstance(weight, str):
continue
else:
storage_id = id_tensor_storage(weight)

# If a `weight` shares the same underlying storage as another tensor, we put `weight` in the same `block`
if storage_id in storage_id_to_block and weight.device != torch.device(
"meta"
):
block_id = storage_id_to_block[storage_id]
sharded_state_dicts[block_id][key] = weight
continue

weight_size = weight.numel() * dtype_byte_size(weight.dtype)
# If this weight is going to tip up over the maximal size, we split, but only if we have put at least one
# weight in the current shard.
if (
last_block_size + weight_size > max_shard_size
and len(sharded_state_dicts[-1]) > 0
):
sharded_state_dicts.append({})
last_block_size = 0

sharded_state_dicts[-1][key] = weight
last_block_size += weight_size
total_size += weight_size
storage_id_to_block[storage_id] = len(sharded_state_dicts) - 1

# If we only have one shard, we return it
if len(sharded_state_dicts) == 1:
return {weights_name: sharded_state_dicts[0]}, None

# Otherwise, let's build the index
weight_map = {}
shards = {}
for idx, shard in enumerate(sharded_state_dicts):
shard_file = weights_name.replace(
".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin"
)
shard_file = shard_file.replace(
".safetensors",
f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors",
)
shards[shard_file] = shard
for key in shard.keys():
weight_map[key] = shard_file

# Add the metadata
metadata = {"total_size": total_size}
index = {"metadata": metadata, "weight_map": weight_map}
return shards, index

def save_pretrained(
self,
save_dir: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
taken from https://github.com/imoneoi/multipack_sampler with some modifications
taken from https://github.com/imoneoi/multipack_sampler with some modifications
taken from https://github.com/instructlab/training/blob/main/src/instructlab/training/multipack_sampler.py
"""

Expand Down
14 changes: 9 additions & 5 deletions plugins/framework/src/fms_acceleration/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,11 @@ def __post_init__(self):
self.import_and_maybe_reload is not None,
]
)
!= 1
> 1
):
raise ValueError(
f"Rule '{self.rule_id}' must only have only one of forward, "
"foward builder, or import_and_maybe_reload, specified."
f"Rule '{self.rule_id}' must only have at most one of forward, "
"forward builder, or import_and_maybe_reload, specified."
)

if self.import_and_maybe_reload is not None and self.trigger is not None:
Expand Down Expand Up @@ -425,7 +425,7 @@ def _patch_forwards(
# otherwise triggered
if rule.forward is not None:
forward = rule.forward
else:
elif rule.forward_builder is not None:
fba = {}
if rule.forward_builder_args is not None:
fba = {
Expand All @@ -434,6 +434,9 @@ def _patch_forwards(
if rule.forward_builder_args
}
forward = rule.forward_builder(mod, **fba)
else:
# trigger-only case
forward = None

if isinstance(forward, list):
# this will be list of tuples case
Expand Down Expand Up @@ -468,7 +471,8 @@ def _patch_forwards(
continue

# otherwise
mod.forward = MethodType(forward, mod)
if forward is not None:
mod.forward = MethodType(forward, mod)
ModelPatcher.history.append(
ModelPatcherHistory(
instance=mod_id,
Expand Down
6 changes: 3 additions & 3 deletions plugins/framework/tests/test_model_patcher_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,12 @@ def test_combine_mp_triggers_produces_correct_output(


def test_mp_rule_raises_error_when_arguments_incorrectly_configured():
"Ensure MP rule is throws appropriate error when wrong argument combinations are passed"
"Ensure MP rule throws appropriate error when wrong argument combinations are passed"
# Test mp rule construction raises with multiple arguments
with pytest.raises(
ValueError,
match="must only have only one of forward, "
"foward builder, or import_and_maybe_reload, specified.",
match="must only have at most one of forward, "
"forward builder, or import_and_maybe_reload, specified.",
):
ModelPatcherRule(
rule_id=DUMMY_RULE_ID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def register_foak_model_patch_rules(
FILTER_MAP = {
"fused_lora": {"qkvo", "mlp"},
"fast_loss": {
True: "cross-ent",
True: {"cross-ent", "custom-loss"},
"fused_ce_liger": "fused-lce",
},
"fast_rms_layernorm": "rms",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import triton.language as tl
import torch
from .utils import calculate_settings, MAX_FUSED_SIZE
from typing import Type


@triton.jit
Expand Down Expand Up @@ -290,3 +291,55 @@ def forward(self, input, target):
)
n_items = torch.count_nonzero(target != -100)
return loss.sum() / n_items


# added by [email protected]

# adapted from transformers.loss.loss_utils.ForCausalLMLoss
def FastForCausalLMLoss(
logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()

# Flatten the tokens
shift_logits = shift_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)

reduction = "sum" if num_items_in_batch is not None else "mean"
assert ignore_index == -100, "FastForCausalLMLoss currently supports only hardcoded ignore index -100."
loss = Fast_CrossEntropyLoss.apply(
shift_logits, shift_labels
)
if reduction == "sum":
n_items = num_items_in_batch
else:
n_items = torch.count_nonzero(shift_labels != -100)
return loss.sum() / n_items


def replace_custom_loss_when_triggered(
module_cls: Type,
custom_loss_type: str,
):

# this is a special trigger that will perform the replacement
def _trigger(mod):
if isinstance (mod, module_cls) and hasattr(mod, "loss_function"):
# guarded
from transformers.loss.loss_utils import LOSS_MAPPING
LOSS_MAPPING[custom_loss_type] = FastForCausalLMLoss
mod.loss_type = custom_loss_type
return True

return False

return _trigger


Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@

# Local
from ..fused_ops.liger_ce.fused_linear_cross_entropy_loss import lce_forward
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.cross_entropy_loss import (
FastCrossEntropyLoss,
replace_custom_loss_when_triggered,
)
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from ..utils import filter_mp_rules
Expand All @@ -37,6 +40,7 @@
KEY_QKV,
build_lora_fused_ops,
get_hidden_activation_fn_key,
get_transformers_version,
trigger_fused_ops,
)

Expand Down Expand Up @@ -122,16 +126,27 @@ def get_mp_rules(base_type: str, config: PretrainedConfig = None):
base_type=base_type,
),
),
# TODO: have a generic version of this rule
# - get the module_name and reload on that
ModelPatcherRule(
rule_id="granite-cross-ent",
import_and_maybe_reload=(
"torch.nn.CrossEntropyLoss",
FastCrossEntropyLoss,
"transformers.models.granite.modeling_granite",
),
),
*[
(
ModelPatcherRule(
rule_id="granite-custom-loss",
trigger=ModelPatcherTrigger(
check=replace_custom_loss_when_triggered(
GraniteForCausalLM, custom_loss_type="granite-custom-loss"
)
),
)
if get_transformers_version() >= "4.46"
else ModelPatcherRule(
rule_id="granite-cross-ent",
import_and_maybe_reload=(
"torch.nn.CrossEntropyLoss",
FastCrossEntropyLoss,
"transformers.models.granite.modeling_granite",
),
)
)
],
ModelPatcherRule(
rule_id="granite-fused-lce",
trigger=ModelPatcherTrigger(check=GraniteForCausalLM),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@

# Local
from ..fused_ops.liger_ce.fused_linear_cross_entropy_loss import lce_forward
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.cross_entropy_loss import (
FastCrossEntropyLoss,
replace_custom_loss_when_triggered,
)
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from ..utils import filter_mp_rules
Expand All @@ -43,6 +46,7 @@
KEY_QKV,
build_lora_fused_ops,
get_hidden_activation_fn_key,
get_transformers_version,
trigger_fused_ops,
)

Expand Down Expand Up @@ -122,14 +126,27 @@ def get_mp_rules(base_type: str, config: PretrainedConfig = None):
trigger=ModelPatcherTrigger(check=LlamaForCausalLM),
forward=lce_forward,
),
ModelPatcherRule(
rule_id="llama-cross-ent",
import_and_maybe_reload=(
"torch.nn.CrossEntropyLoss",
FastCrossEntropyLoss,
"transformers.models.llama.modeling_llama",
),
),
*[
(
ModelPatcherRule(
rule_id="llama-custom-loss",
trigger=ModelPatcherTrigger(
check=replace_custom_loss_when_triggered(
LlamaForCausalLM, custom_loss_type="llama-custom-loss"
)
),
)
if get_transformers_version() >= "4.46"
else ModelPatcherRule(
rule_id="llama-cross-ent",
import_and_maybe_reload=(
"torch.nn.CrossEntropyLoss",
FastCrossEntropyLoss,
"transformers.models.llama.modeling_llama",
),
)
)
],
# TODO: have a generic version of this rule
# - get the module name
# - check if "apply_rotary_pos_emb" exists
Expand Down
Loading

0 comments on commit 24bdadb

Please sign in to comment.