Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU memory allocated increase during finetuning #792

Open
1 of 2 tasks
rong-hash opened this issue Nov 17, 2024 · 5 comments
Open
1 of 2 tasks

GPU memory allocated increase during finetuning #792

rong-hash opened this issue Nov 17, 2024 · 5 comments
Labels

Comments

@rong-hash
Copy link

rong-hash commented Nov 17, 2024

System Info

Pytorch version: 2.4.1+cu124
Cuda version: 12.7
GPU: A100 80G * 1

Information

  • The official example scripts
  • My own modified scripts

🐛 Describe the bug

My script is :

python -m llama_recipes.finetuning \
--model_name /model/Meta-Llama-3.1-8B-Instruct \
--output_dir /model/llama3_ft \
--use_wandb \
--dataset alpaca_dataset \
--data_path /data/<dataset_path>.json  \
--project <my_project_name>\
--batching_strategy padding \
--context_length 2048 \
--num_epochs 3 \
--lr 3e-4 \
--gradient_accumulation_steps 4 \
--use_fast_kernels \
--quantization 8bit \
--use_peft \
--peft_method lora 

Everything goes well except the memory allocated. I found that the allocated memory increased in some specific steps, which is very abnormal. The memory allocated picture is shown below.

image

And finally, it will cause OOM error.

Error logs

 1 Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.08s/it]
 2 --> Model /home/chenzhirong/model/Meta-Llama-3.1-8B-Instruct
 3 
 4 --> /home/chenzhirong/model/Meta-Llama-3.1-8B-Instruct has 1050.939392 Million params
 5 
 6 trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.0424
 7 --> Training Set Length = 14507
 8 --> Validation Set Length = 763
 9 length of dataset_train 14507
10 --> Num of Training Set Batches loaded = 14507
11 --> Num of Validation Set Batches loaded = 763
12 --> Num of Validation Set Batches loaded = 763
13 Starting epoch 0/3
14 train_config.max_train_step: 0
15 /home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
16   warnings.warn(
17 Training Epoch: 1:   0%|                                                                                                          | 0/3626 [00:00<?, ?it/s]/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/bitsandbytes/autograd/_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
18   warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
19 Training Epoch: 1/3, step 6609/14507 completed (loss: 0.007389023434370756):  46%|███████████████                  | 1652/3626 [1:32:27<1:46:50,  3.25s/it]Traceback (most recent call last):
20   File "<frozen runpy>", line 198, in _run_module_as_main
21   File "<frozen runpy>", line 88, in _run_code
22   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/llama_recipes/finetuning.py", line 332, in <module>
23     fire.Fire(main)
24   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/fire/core.py", line 135, in Fire
25     component_trace = _Fire(component, args, parsed_flag_args, context, name)
26                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
27   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/fire/core.py", line 468, in _Fire
28     component, remaining_args = _CallAndUpdateTrace(
29                                 ^^^^^^^^^^^^^^^^^^^^
30   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/fire/core.py", line 684, in _CallAndUpdateTrace
31     component = fn(*varargs, **kwargs)
32                 ^^^^^^^^^^^^^^^^^^^^^^
33   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/llama_recipes/finetuning.py", line 311, in main
34     results = train(
35               ^^^^^^
36   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/llama_recipes/utils/train_utils.py", line 153, in train
37     loss = model(**batch).loss
38            ^^^^^^^^^^^^^^
39   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
40     return self._call_impl(*args, **kwargs)
41            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
42   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
43     return forward_call(*args, **kwargs)
44            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
45   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/peft/peft_model.py", line 1644, in forward
46     return self.base_model(
47            ^^^^^^^^^^^^^^^^
48   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
49     return self._call_impl(*args, **kwargs)
50            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
51   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
52     return forward_call(*args, **kwargs)
53            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
54   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/peft/tuners/tuners_utils.py", line 197, in forward
55     return self.model.forward(*args, **kwargs)
56            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
57   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/accelerate/hooks.py", line 170, in new_forward
58     output = module._old_forward(*args, **kwargs)
59              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
60   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 1210, in forward
61     logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
62              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
63   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
64     return self._call_impl(*args, **kwargs)
65            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
66   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
67     return forward_call(*args, **kwargs)
68            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
69   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/accelerate/hooks.py", line 170, in new_forward
70     output = module._old_forward(*args, **kwargs)
71              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
72   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 117, in forward
73     return F.linear(input, self.weight, self.bias)
74            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
75 torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.75 GiB. GPU 0 has a total capacity of 79.25 GiB of which 1.21 GiB is free. Including non-PyTorch memory, this process has 78.02 GiB memory in use. Of the allocated memory 70.12 GiB is allocated by PyTorch, and 7.39 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Expected behavior

I believe the normal behavior is that the GPU memory allocated is stable during the training.

@HamidShojanazeri
Copy link
Contributor

HamidShojanazeri commented Nov 18, 2024

qq @rong-hash would changing the "gradient_accumulation_steps 4 " change the behavior for you? cc: @mreso

@mreso
Copy link
Contributor

mreso commented Nov 20, 2024

Hi @rong-hash you're using padding which means that samples are first bucketed together with respect to their length. This is to minimize excessive padding when short and long sequences would be batched together. Due to the different sequence lengths of the batches I suspect you're seeing jumps in memory usage whenever a batches with longer sequences are processed. The OOM then occurs when an even longer sequence length is processed. Have you tried reducing the batch size?

@rong-hash
Copy link
Author

rong-hash commented Nov 20, 2024

Hi @rong-hash you're using padding which means that samples are first bucketed together with respect to their length. This is to minimize excessive padding when short and long sequences would be batched together. Due to the different sequence lengths of the batches I suspect you're seeing jumps in memory usage whenever a batches with longer sequences are processed. The OOM then occurs when an even longer sequence length is processed. Have you tried reducing the batch size?

Hi @mreso @HamidShojanazeri , I tried your solutions, but still get the same problem. I set argument --batch_size_training to be 1, and deleted --gradient_accumulation_steps argument, the allocated memory still jumps at some certain steps.

image

@mreso
Copy link
Contributor

mreso commented Nov 20, 2024

The jumps will come from the different sequence lengths. The first time a sequence is longer than all the others before there will be more memory allocated to fit the intermediate tensors. Memory that's allocated will usually not be freed unless you explicitly tell PyTorch to do so even if subsequent samples are shorter and require less memory. Are you still seeing OOMs with bs=1?

@rong-hash
Copy link
Author

rong-hash commented Nov 22, 2024

The jumps will come from the different sequence lengths. The first time a sequence is longer than all the others before there will be more memory allocated to fit the intermediate tensors. Memory that's allocated will usually not be freed unless you explicitly tell PyTorch to do so even if subsequent samples are shorter and require less memory. Are you still seeing OOMs with bs=1?

@mreso Yes, OOM still occurs with bs=1, even at the same step number.

 1 Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:19<00:00,  4.80s/it]
 2 --> Model /home/chenzhirong/model/Meta-Llama-3.1-8B-Instruct
 3 
 4 --> /home/chenzhirong/model/Meta-Llama-3.1-8B-Instruct has 1050.939392 Million params
 5 
 6 trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.0424
 7 --> Training Set Length = 14507
 8 --> Validation Set Length = 763
 9 length of dataset_train 14507
10 --> Num of Training Set Batches loaded = 14507
11 --> Num of Validation Set Batches loaded = 763
12 --> Num of Validation Set Batches loaded = 763
13 Starting epoch 0/5
14 train_config.max_train_step: 0
15 /home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
16   warnings.warn(
17 Training Epoch: 1:   0%|                                                                                                                       | 0/3626 [00:00<?, ?it/s]/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/bitsandbytes/autograd/_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
18   warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
19 Training Epoch: 1/5, step 6609/14507 completed (loss: 0.005838111508637667):  46%|████████████████████▉                         | 1652/3626 [1:28:21<1:45:47,  3.22s/it]Traceback (most recent call last):
20   File "<frozen runpy>", line 198, in _run_module_as_main
21   File "<frozen runpy>", line 88, in _run_code
22   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/llama_recipes/finetuning.py", line 332, in <module>
23     fire.Fire(main)
24   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/fire/core.py", line 135, in Fire
25     component_trace = _Fire(component, args, parsed_flag_args, context, name)
26                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
27   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/fire/core.py", line 468, in _Fire
28     component, remaining_args = _CallAndUpdateTrace(
29                                 ^^^^^^^^^^^^^^^^^^^^
30   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/fire/core.py", line 684, in _CallAndUpdateTrace
31     component = fn(*varargs, **kwargs)
32                 ^^^^^^^^^^^^^^^^^^^^^^
33   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/llama_recipes/finetuning.py", line 311, in main
34     results = train(
35               ^^^^^^
36   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/llama_recipes/utils/train_utils.py", line 153, in train
37     loss = model(**batch).loss
38            ^^^^^^^^^^^^^^
39   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
40     return self._call_impl(*args, **kwargs)
41            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
42   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
43     return forward_call(*args, **kwargs)
44            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
45   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/peft/peft_model.py", line 1644, in forward
46     return self.base_model(
47            ^^^^^^^^^^^^^^^^
48   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
49     return self._call_impl(*args, **kwargs)
50            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
51   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
52     return forward_call(*args, **kwargs)
53            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
54   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/peft/tuners/tuners_utils.py", line 197, in forward
55     return self.model.forward(*args, **kwargs)
56            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
57   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/accelerate/hooks.py", line 170, in new_forward
58     output = module._old_forward(*args, **kwargs)
59              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
60   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 1210, in forward
61     logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
62              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
63   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
64     return self._call_impl(*args, **kwargs)
65            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
66   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
67     return forward_call(*args, **kwargs)
68            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
69   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/accelerate/hooks.py", line 170, in new_forward
70     output = module._old_forward(*args, **kwargs)
71              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
72   File "/home/chenzhirong/anaconda3/envs/powerinfer/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 117, in forward
73     return F.linear(input, self.weight, self.bias)
74            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
75 torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.75 GiB. GPU 0 has a total capacity of 79.25 GiB of which 1.17 GiB is free. Including non-PyTorch memory, this process has 78.06 GiB memory in use. Of the allocated memory 70.11 GiB is allocated by PyTorch, and 7.43 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
image

Here's the latest script:

python -m llama_recipes.finetuning \
--model_name /home/chenzhirong/model/Meta-Llama-3.1-8B-Instruct \
--output_dir /home/chenzhirong/model/llama3_ft \
--use_wandb \
--dataset alpaca_dataset \
--data_path <dataset_path>  \
--project chipgptmm \
--batching_strategy padding \
--batch_size_training 1 \
--context_length 2048 \
--num_epochs 5 \
--lr 3e-4 \
--gradient_accumulation_steps 4 \
--use_fast_kernels \
--quantization 8bit \
--use_peft \
--peft_method lora \
--one_gpu

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants