-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: cross entropy for transformers>4.45 #123
fix: cross entropy for transformers>4.45 #123
Conversation
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Anh Uong <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had some questions when going through the code, thanks so much Fabian. Also once the new benchmark is complete I will add the results to the scripts/benchmarks/refs
with the CSV and the requirements that shows the updated deps.
shift_labels = shift_labels.to(shift_logits.device) | ||
|
||
reduction = "sum" if num_items_in_batch is not None else "mean" | ||
assert ignore_index == -100, "FastForCausalLMLoss currently supports only hardcoded ignore index -100." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is -100 ignore_index, I see that ignore_index is the target value that is ignored and does not contribute to the input gradient, but for CausalLMLoss what is at index -100?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-100
is used extensively throughout HF, while they provide some means for user to change it, almostly nobody will bother to change it
It is the label
that is at -100. For a label with that value, we will ignore that token's contribution to the loss
|
||
reduction = "sum" if num_items_in_batch is not None else "mean" | ||
assert ignore_index == -100, "FastForCausalLMLoss currently supports only hardcoded ignore index -100." | ||
loss = Fast_CrossEntropyLoss.apply( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you describe the difference between Fast_CrossEntropyLoss
and FastCrossEntropyLoss
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fast_CrossEntropyLoss
is the autograd function. we inherit this from unslothFastCrossEntropyLoss
is a specialization oftorch.nn.CrossEntropyLoss
that serves as a convinienced, implemted usingFast_CrossEntropyLoss
# added by [email protected] | ||
|
||
# adapted from transformers.loss.loss_utils.ForCausalLMLoss | ||
def FastForCausalLMLoss( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would we need to create a similar FastForCausalLMLoss for liger kernel cross entropy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes I think we will have a new function for liger cross entropy with the same API. then its a plug and play. But it should be used only if the transformer versioin is advanced
rule_id="granite-custom-loss", | ||
trigger=ModelPatcherTrigger( | ||
check=replace_custom_loss_when_triggered( | ||
GraniteForCausalLM, custom_loss_type="granite-custom-loss" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thoughts on calling this granite-custom-crossent-loss
instead to be specific that the custom loss is for cross entropy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel custom-loss
is ok, because it mostly refers to that we are using the new custom loss feature.
Signed-off-by: Anh Uong <[email protected]>
I also noticed in the new transformers version there is a lot of slowness after loading on the checkpoint on log line:
Is this something we want to do? Is it expected that loading the embeddings will take a long time to run? |
@anhuong the slownesss i feel its due to the recent changes in |
Signed-off-by: Anh Uong <[email protected]>
Signed-off-by: Anh Uong <[email protected]>
Able to verify enabling the correct cross-entropy is triggeed based on transformers version for granite, llama, mistral, and mixtral models. With transformers==4.48 ***************** Module Forwards Patching *************
Rule: llama-custom-loss Module: Class: LlamaForCausalLM Num: 1
INFO:framework.py:Rule: llama-custom-loss Module: Class: LlamaForCausalLM Num: 1
Rule: llama-rms Module: input_layernorm Class: LlamaRMSNorm Num: 32
INFO:framework.py:Rule: llama-rms Module: input_layernorm Class: LlamaRMSNorm Num: 32
Rule: llama-rms Module: model Class: LlamaRMSNorm Num: 1
INFO:framework.py:Rule: llama-rms Module: model Class: LlamaRMSNorm Num: 1
Rule: llama-rms Module: post_attention_layernorm Class: LlamaRMSNorm Num: 32
INFO:framework.py:Rule: llama-rms Module: post_attention_layernorm Class: LlamaRMSNorm Num: 32
Rule: llama-rope Module: Class: LlamaForCausalLM Num: 1
INFO:framework.py:Rule: llama-rope Module: Class: LlamaForCausalLM Num: 1
***************** Accelerator Patching ************* With transformers=4.45 ***************** Module Forwards Patching *************
Rule: llama-cross-ent Module: Class: LlamaForCausalLM Num: 1
INFO:framework.py:Rule: llama-cross-ent Module: Class: LlamaForCausalLM Num: 1
Rule: llama-rms Module: input_layernorm Class: LlamaRMSNorm Num: 32
INFO:framework.py:Rule: llama-rms Module: input_layernorm Class: LlamaRMSNorm Num: 32
Rule: llama-rms Module: model Class: LlamaRMSNorm Num: 1
INFO:framework.py:Rule: llama-rms Module: model Class: LlamaRMSNorm Num: 1
Rule: llama-rms Module: post_attention_layernorm Class: LlamaRMSNorm Num: 32
INFO:framework.py:Rule: llama-rms Module: post_attention_layernorm Class: LlamaRMSNorm Num: 32
Rule: llama-rope Module: Class: LlamaForCausalLM Num: 1
INFO:framework.py:Rule: llama-rope Module: Class: LlamaForCausalLM Num: 1
***************** Accelerator Patching ************* |
Signed-off-by: Anh Uong <[email protected]>
Signed-off-by: Anh Uong <[email protected]>
# added by [email protected] | ||
# adapted from transformers.modeling_utils.shard_checkpoint | ||
# from transformers v4.46, removed in later versions | ||
# TODO: split_torch_state_dict_into_shards from huggingface_hub library | ||
def shard_checkpoint( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After transformers v4.46, this method no longer exists in in transformers so I copied it in here to start. The new method to migrate to as per the warning message in the original function says to migrate to split_torch_state_dict_into_shards
as noted in the TODO item here. This method was similar but requires more investigation on the difference - https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/serialization/_torch.py#L302
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok this is fine for now
@@ -0,0 +1,89 @@ | |||
bf16,epoch,fp16,framework_config,learning_rate,lora_alpha,lora_dropout,mem_nvidia_mem_reserved,mem_peak_torch_mem_alloc_in_bytes,mem_torch_mem_alloc_in_bytes,model_name_or_path,num_gpus,peft_method,per_device_train_batch_size,r,target_modules,torch_dtype,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second | |||
,0.07,,none,2.00E-05,,,15116,11267745280,6770300416,bigcode/gpt_bigcode-santacoder,1,,4,,,bfloat16,2.33703125,47.6604,8.393,2.098,17188.262 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran the benchmarks and found the failure in auto-gptq that I commented on the fix above. Rerunning benchmarks for auto-gptq. I also ran benchmarks for granite3.1 model but for comparison I had to run against the granite-gptcode model. Do we want to update to update to the granite3.1 model? It did run successfully with it as well.
Here are the charts showing the comparison against a100_80gb.csv without including the auto-gptq failed runs shows the train_loss is on par, memory use is on par/a little higher, and train_tokens_per_second is larger/slower
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these look quite decent
Signed-off-by: Anh Uong <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM the benches look good.
One question is that the new benches you ran. rather than have it in a seperate file, it is better to replace the previous a100_80gb
and update the requirements, but did you run all the cases or only a subset?
I will rename the benchmarks and requirements, I had run all of the benchmarks except for auto-gptq due to the error that came up so I ran those separately after the fix. I also did not run the baseline-bnb but running now separately, what is the purpose of this benchmark? updated the image in the description and will include individual images here. I will add all of them to the benchmark. I updated the above description with the summary results that include auto-gptq, Here are the individual images. Overall they continue to look good, the only outlier identified was
|
Signed-off-by: Anh Uong <[email protected]>
I have replaced the benchmark and requirements with my full runs |
|
Makes sense, the benchmark I added has the complete benchmark including the baseline. It matches the original on the number of runs. With this, I will merge in this change |
sounds good! |
Tested and saw cross-entropy switching based on transformers version correctly. Benchmark singular images posted below here is a single image comparing the previous benchmark to the new one. Overall the train_loss is on par, memory use is on par/a little higher, and train_tokens_per_second is larger/slower.
closes: #98