Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

assert len(input_id) == len(target) AssertionError #3

Open
CarlChang39 opened this issue May 15, 2024 · 9 comments
Open

assert len(input_id) == len(target) AssertionError #3

CarlChang39 opened this issue May 15, 2024 · 9 comments

Comments

@CarlChang39
Copy link

您好,我在执行qlora微调复现时遇到这个问题,报错信息是:
Traceback (most recent call last):
File "../finetune_llama3.py", line 452, in
train()
File "../finetune_llama3.py", line 445, in train
trainer.train()
File "/home/nlpir/miniconda3/envs/cjy_llama/lib/python3.8/site-packages/transformers/trainer.py", line 1624, in train
return inner_training_loop(
File "/home/nlpir/miniconda3/envs/cjy_llama/lib/python3.8/site-packages/transformers/trainer.py", line 1928, in _inner_training_loop
for step, inputs in enumerate(epoch_iterator):
File "/home/nlpir/miniconda3/envs/cjy_llama/lib/python3.8/site-packages/accelerate/data_loader.py", line 452, in iter
current_batch = next(dataloader_iter)
File "/home/nlpir/miniconda3/envs/cjy_llama/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 633, in next
data = self._next_data()
File "/home/nlpir/miniconda3/envs/cjy_llama/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 677, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/home/nlpir/miniconda3/envs/cjy_llama/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/nlpir/miniconda3/envs/cjy_llama/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 51, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "../finetune_llama3.py", line 255, in getitem
ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer, self.max_len)
File "../finetune_llama3.py", line 192, in preprocess
assert len(input_id) == len(target)
AssertionError
实际情况是input_id一直比target长度大1.
我在6块1080ti运行的,shell脚本内容如下:
NCCL_P2P_DISABLE=1
NCCL_IB_DISABLE=1
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5
torchrun
--nproc_per_node 6
--nnodes 1
--node_rank 0
--master_addr localhost
--master_port 6601
../finetune_llama3.py
--model_name_or_path "../model_hub/LLM-Research/Meta-Llama-3-8B-Instruct/"
--data_path "../data/Belle_sampled_qwen.json"
--fp16 True
--output_dir "../output/llama3_8B_qlora"
--num_train_epochs 100
--per_device_train_batch_size 1
--per_device_eval_batch_size 1
--gradient_accumulation_steps 16
--evaluation_strategy "no"
--save_strategy "steps"
--save_steps 5
--save_total_limit 1
--learning_rate 1e-5
--weight_decay 0.1
--adam_beta2 0.95
--warmup_ratio 0.01
--lr_scheduler_type "cosine"
--logging_steps 1
--report_to "none"
--model_max_length 4096
--gradient_checkpointing True
--lazy_preprocess True
--deepspeed "../config/ds_config_zero2.json"
--use_lora
--load_in_4bit
--q_lora
只是改了下CUDA_VISIBLE_DEVICES和nproc_per_node ,并且把bf16改为fp16.

@taishan1994
Copy link
Owner

能打印下它们具体有哪些区别么。

@taishan1994
Copy link
Owner

image
没啥问题呀。

@CarlChang39
Copy link
Author

能打印下它们具体有哪些区别么。

我保存了下中间结果,文件命名左边的是input_id的长度,右边是target的长度。反复运行都是这样。

image

186 185这个文件中:

input_id:[128000, 128006, 128000, 9125, 128007, 128000, 198, 128000, 2675, 527, 264, 55066, 6369, 6465, 889, 2744, 31680, 304, 55066, 6604, 0, 128009, 128006, 128000, 882, 128007, 128000, 198, 128000, 106161, 100815, 107015, 9554, 106246, 3922, 117805, 115532, 32943, 9554, 127944, 117633, 113333, 96455, 124671, 118402, 3922, 120605, 127944, 9554, 103572, 82317, 75863, 102654, 127198, 9554, 104654, 124662, 124778, 102924, 118742, 1811, 92672, 3922, 17792, 49792, 118034, 110745, 20675, 5486, 103458, 112352, 50667, 61826, 42399, 104696, 9554, 109589, 104587, 18184, 102208, 23039, 102208, 42052, 103652, 117661, 117724, 122705, 33748, 110053, 107644, 103624, 122943, 124858, 17297, 3922, 83687, 33208, 47770, 25287, 104724, 112743, 105231, 112157, 28190, 33764, 125648, 1811, 34226, 53901, 30590, 51611, 33764, 83687, 33208, 47770, 25287, 9554, 114099, 102778, 3922, 112026, 121915, 9554, 118556, 34208, 106246, 105000, 86206, 105231, 123882, 103229, 108199, 104696, 109189, 34208, 122705, 106556, 38741, 60843, 37985, 17905, 28190, 44388, 38574, 17161, 22656, 9554, 123092, 127555, 106015, 1811, 128009, 128006, 128000, 78191, 128007, 128000, 198, 128000, 107015, 5486, 127944, 5486, 17792, 49792, 118034, 110745, 20675, 5486, 103458, 112352, 5486, 83687, 33208, 47770, 25287, 5486, 105231, 5486, 104696, 109189, 5486, 122705, 106556, 1811, 128009]

target:[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 128000, 107015, 5486, 127944, 5486, 17792, 49792, 118034, 110745, 20675, 5486, 103458, 112352, 5486, 83687, 33208, 47770, 25287, 5486, 105231, 5486, 104696, 109189, 5486, 122705, 106556, 1811, 128009]

@CarlChang39
Copy link
Author

image 没啥问题呀。

image
又运行了一下,出问题的都是assistant。

@CarlChang39
Copy link
Author

我看了下您的代码
image
_input_id和_target的长度差值应该是len(nl_tokens)-1?但我打印nl_tokens = [128000, 198],长度是2,那是否说明_input_id的长度一定会比_target大1呢?所以最后导致input_id比target大1。

@taishan1994
Copy link
Owner

那里应该改成ignore token乘以n ltoken的长度 没考虑到nl token长度为2

@CarlChang39
Copy link
Author

那里应该改成ignore token乘以n ltoken的长度 没考虑到nl token长度为2

就是_target最后一个,tokenizer(value)前面的那个是吧?

@taishan1994
Copy link
Owner

是的

@CarlChang39
Copy link
Author

是的

好的,应该是可以了,非常感谢

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants