Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: cross entropy for transformers>4.45 #123

Merged
merged 9 commits into from
Feb 7, 2025

Conversation

anhuong
Copy link
Collaborator

@anhuong anhuong commented Feb 5, 2025

  • Checks for transformers version and creates new custom loss function for llama, granite, mistral, and mixtral models.
  • Adds shard_checkpoint function from transformers as it is missing in later versions

Tested and saw cross-entropy switching based on transformers version correctly. Benchmark singular images posted below here is a single image comparing the previous benchmark to the new one. Overall the train_loss is on par, memory use is on par/a little higher, and train_tokens_per_second is larger/slower.

Screenshot 2025-02-07 at 12 04 35 PM

closes: #98

@anhuong anhuong requested a review from fabianlim as a code owner February 5, 2025 04:40
Copy link
Collaborator Author

@anhuong anhuong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had some questions when going through the code, thanks so much Fabian. Also once the new benchmark is complete I will add the results to the scripts/benchmarks/refs with the CSV and the requirements that shows the updated deps.

shift_labels = shift_labels.to(shift_logits.device)

reduction = "sum" if num_items_in_batch is not None else "mean"
assert ignore_index == -100, "FastForCausalLMLoss currently supports only hardcoded ignore index -100."
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is -100 ignore_index, I see that ignore_index is the target value that is ignored and does not contribute to the input gradient, but for CausalLMLoss what is at index -100?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-100 is used extensively throughout HF, while they provide some means for user to change it, almostly nobody will bother to change it

It is the label that is at -100. For a label with that value, we will ignore that token's contribution to the loss


reduction = "sum" if num_items_in_batch is not None else "mean"
assert ignore_index == -100, "FastForCausalLMLoss currently supports only hardcoded ignore index -100."
loss = Fast_CrossEntropyLoss.apply(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you describe the difference between Fast_CrossEntropyLoss and FastCrossEntropyLoss?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Fast_CrossEntropyLoss is the autograd function. we inherit this from unsloth
  • FastCrossEntropyLoss is a specialization of torch.nn.CrossEntropyLoss that serves as a convinienced, implemted using Fast_CrossEntropyLoss

# added by [email protected]

# adapted from transformers.loss.loss_utils.ForCausalLMLoss
def FastForCausalLMLoss(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we need to create a similar FastForCausalLMLoss for liger kernel cross entropy?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I think we will have a new function for liger cross entropy with the same API. then its a plug and play. But it should be used only if the transformer versioin is advanced

Comment on lines 128 to 131
rule_id="granite-custom-loss",
trigger=ModelPatcherTrigger(
check=replace_custom_loss_when_triggered(
GraniteForCausalLM, custom_loss_type="granite-custom-loss"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts on calling this granite-custom-crossent-loss instead to be specific that the custom loss is for cross entropy

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel custom-loss is ok, because it mostly refers to that we are using the new custom loss feature.

Signed-off-by: Anh Uong <[email protected]>
@anhuong
Copy link
Collaborator Author

anhuong commented Feb 5, 2025

I also noticed in the new transformers version there is a lot of slowness after loading on the checkpoint on log line:

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`

Is this something we want to do? Is it expected that loading the embeddings will take a long time to run?

@fabianlim
Copy link
Contributor

@anhuong the slownesss i feel its due to the recent changes in fms-hf-tuning, and is there because now we will resize the embedding layer if there is special tokens. Previously we didnt use to do that.

Signed-off-by: Anh Uong <[email protected]>
@anhuong
Copy link
Collaborator Author

anhuong commented Feb 5, 2025

Able to verify enabling the correct cross-entropy is triggeed based on transformers version for granite, llama, mistral, and mixtral models.

With transformers==4.48

***************** Module Forwards Patching *************
Rule: llama-custom-loss Module:                           Class: LlamaForCausalLM Num:  1
INFO:framework.py:Rule: llama-custom-loss Module:                           Class: LlamaForCausalLM Num:  1
Rule: llama-rms       Module: input_layernorm           Class: LlamaRMSNorm    Num: 32
INFO:framework.py:Rule: llama-rms       Module: input_layernorm           Class: LlamaRMSNorm    Num: 32
Rule: llama-rms       Module: model                     Class: LlamaRMSNorm    Num:  1
INFO:framework.py:Rule: llama-rms       Module: model                     Class: LlamaRMSNorm    Num:  1
Rule: llama-rms       Module: post_attention_layernorm  Class: LlamaRMSNorm    Num: 32
INFO:framework.py:Rule: llama-rms       Module: post_attention_layernorm  Class: LlamaRMSNorm    Num: 32
Rule: llama-rope      Module:                           Class: LlamaForCausalLM Num:  1
INFO:framework.py:Rule: llama-rope      Module:                           Class: LlamaForCausalLM Num:  1
***************** Accelerator Patching *************

With transformers=4.45

***************** Module Forwards Patching *************
Rule: llama-cross-ent Module:                           Class: LlamaForCausalLM Num:  1
INFO:framework.py:Rule: llama-cross-ent Module:                           Class: LlamaForCausalLM Num:  1
Rule: llama-rms       Module: input_layernorm           Class: LlamaRMSNorm    Num: 32
INFO:framework.py:Rule: llama-rms       Module: input_layernorm           Class: LlamaRMSNorm    Num: 32
Rule: llama-rms       Module: model                     Class: LlamaRMSNorm    Num:  1
INFO:framework.py:Rule: llama-rms       Module: model                     Class: LlamaRMSNorm    Num:  1
Rule: llama-rms       Module: post_attention_layernorm  Class: LlamaRMSNorm    Num: 32
INFO:framework.py:Rule: llama-rms       Module: post_attention_layernorm  Class: LlamaRMSNorm    Num: 32
Rule: llama-rope      Module:                           Class: LlamaForCausalLM Num:  1
INFO:framework.py:Rule: llama-rope      Module:                           Class: LlamaForCausalLM Num:  1
***************** Accelerator Patching *************

Comment on lines +772 to +776
# added by [email protected]
# adapted from transformers.modeling_utils.shard_checkpoint
# from transformers v4.46, removed in later versions
# TODO: split_torch_state_dict_into_shards from huggingface_hub library
def shard_checkpoint(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After transformers v4.46, this method no longer exists in in transformers so I copied it in here to start. The new method to migrate to as per the warning message in the original function says to migrate to split_torch_state_dict_into_shards as noted in the TODO item here. This method was similar but requires more investigation on the difference - https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/serialization/_torch.py#L302

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok this is fine for now

@@ -0,0 +1,89 @@
bf16,epoch,fp16,framework_config,learning_rate,lora_alpha,lora_dropout,mem_nvidia_mem_reserved,mem_peak_torch_mem_alloc_in_bytes,mem_torch_mem_alloc_in_bytes,model_name_or_path,num_gpus,peft_method,per_device_train_batch_size,r,target_modules,torch_dtype,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second
,0.07,,none,2.00E-05,,,15116,11267745280,6770300416,bigcode/gpt_bigcode-santacoder,1,,4,,,bfloat16,2.33703125,47.6604,8.393,2.098,17188.262
Copy link
Collaborator Author

@anhuong anhuong Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran the benchmarks and found the failure in auto-gptq that I commented on the fix above. Rerunning benchmarks for auto-gptq. I also ran benchmarks for granite3.1 model but for comparison I had to run against the granite-gptcode model. Do we want to update to update to the granite3.1 model? It did run successfully with it as well.

Here are the charts showing the comparison against a100_80gb.csv without including the auto-gptq failed runs shows the train_loss is on par, memory use is on par/a little higher, and train_tokens_per_second is larger/slower
compare-mem_peak_torch_mem_alloc_in_bytes-crossent-full-match-more
compare-mem_torch_mem_alloc_in_bytes-crossent-full-match-more
compare-train_loss-crossent-full-match-more
compare-train_tokens_per_second-crossent-full-match0more

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these look quite decent

Signed-off-by: Anh Uong <[email protected]>
Copy link
Contributor

@fabianlim fabianlim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM the benches look good.

One question is that the new benches you ran. rather than have it in a seperate file, it is better to replace the previous a100_80gb and update the requirements, but did you run all the cases or only a subset?

@anhuong
Copy link
Collaborator Author

anhuong commented Feb 7, 2025

I will rename the benchmarks and requirements, I had run all of the benchmarks except for auto-gptq due to the error that came up so I ran those separately after the fix. I also did not run the baseline-bnb but running now separately, what is the purpose of this benchmark? updated the image in the description and will include individual images here. I will add all of them to the benchmark. I updated the above description with the summary results that include auto-gptq, Here are the individual images. Overall they continue to look good, the only outlier identified was

framework_config peft_method model_name_or_path num_gpus per_device_train_batch_size reference metric new
accelerated-peft-bnb lora bigcode/gpt_bigcode-santacoder 2 2 6840.355 train_tokens_per_second 8548.922
accelerated-peft-bnb-foak lora bigcode/gpt_bigcode-santacoder 2 2 10345.932 train_tokens_per_second 11994.044

compare-mem_peak_torch_mem_alloc_in_bytes
compare-mem_torch_mem_alloc_in_bytes
compare-train_loss
compare-train_tokens_per_second

@anhuong
Copy link
Collaborator Author

anhuong commented Feb 7, 2025

I have replaced the benchmark and requirements with my full runs

@fabianlim
Copy link
Contributor

I also did not run the baseline-bnb but running now separately, what is the purpose of this benchmark? updated the image in the description and will include individual images here.
the purpose of this is to have a baseline so we can compare the accelerations. The baseline could change also due to different transformer versions

@anhuong
Copy link
Collaborator Author

anhuong commented Feb 7, 2025

Makes sense, the benchmark I added has the complete benchmark including the baseline. It matches the original on the number of runs. With this, I will merge in this change

@anhuong anhuong merged commit 24bdadb into foundation-model-stack:main Feb 7, 2025
7 checks passed
@fabianlim
Copy link
Contributor

sounds good!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

FOAK Cross Entropy Loss Will Not Work with New Loss Functions After Transformers 4.46
2 participants