-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
assert len(input_id) == len(target) AssertionError #3
Comments
能打印下它们具体有哪些区别么。 |
我保存了下中间结果,文件命名左边的是input_id的长度,右边是target的长度。反复运行都是这样。 186 185这个文件中: input_id:[128000, 128006, 128000, 9125, 128007, 128000, 198, 128000, 2675, 527, 264, 55066, 6369, 6465, 889, 2744, 31680, 304, 55066, 6604, 0, 128009, 128006, 128000, 882, 128007, 128000, 198, 128000, 106161, 100815, 107015, 9554, 106246, 3922, 117805, 115532, 32943, 9554, 127944, 117633, 113333, 96455, 124671, 118402, 3922, 120605, 127944, 9554, 103572, 82317, 75863, 102654, 127198, 9554, 104654, 124662, 124778, 102924, 118742, 1811, 92672, 3922, 17792, 49792, 118034, 110745, 20675, 5486, 103458, 112352, 50667, 61826, 42399, 104696, 9554, 109589, 104587, 18184, 102208, 23039, 102208, 42052, 103652, 117661, 117724, 122705, 33748, 110053, 107644, 103624, 122943, 124858, 17297, 3922, 83687, 33208, 47770, 25287, 104724, 112743, 105231, 112157, 28190, 33764, 125648, 1811, 34226, 53901, 30590, 51611, 33764, 83687, 33208, 47770, 25287, 9554, 114099, 102778, 3922, 112026, 121915, 9554, 118556, 34208, 106246, 105000, 86206, 105231, 123882, 103229, 108199, 104696, 109189, 34208, 122705, 106556, 38741, 60843, 37985, 17905, 28190, 44388, 38574, 17161, 22656, 9554, 123092, 127555, 106015, 1811, 128009, 128006, 128000, 78191, 128007, 128000, 198, 128000, 107015, 5486, 127944, 5486, 17792, 49792, 118034, 110745, 20675, 5486, 103458, 112352, 5486, 83687, 33208, 47770, 25287, 5486, 105231, 5486, 104696, 109189, 5486, 122705, 106556, 1811, 128009] target:[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 128000, 107015, 5486, 127944, 5486, 17792, 49792, 118034, 110745, 20675, 5486, 103458, 112352, 5486, 83687, 33208, 47770, 25287, 5486, 105231, 5486, 104696, 109189, 5486, 122705, 106556, 1811, 128009] |
那里应该改成ignore token乘以n ltoken的长度 没考虑到nl token长度为2 |
就是_target最后一个,tokenizer(value)前面的那个是吧? |
是的 |
好的,应该是可以了,非常感谢 |
您好,我在执行qlora微调复现时遇到这个问题,报错信息是:
Traceback (most recent call last):
File "../finetune_llama3.py", line 452, in
train()
File "../finetune_llama3.py", line 445, in train
trainer.train()
File "/home/nlpir/miniconda3/envs/cjy_llama/lib/python3.8/site-packages/transformers/trainer.py", line 1624, in train
return inner_training_loop(
File "/home/nlpir/miniconda3/envs/cjy_llama/lib/python3.8/site-packages/transformers/trainer.py", line 1928, in _inner_training_loop
for step, inputs in enumerate(epoch_iterator):
File "/home/nlpir/miniconda3/envs/cjy_llama/lib/python3.8/site-packages/accelerate/data_loader.py", line 452, in iter
current_batch = next(dataloader_iter)
File "/home/nlpir/miniconda3/envs/cjy_llama/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 633, in next
data = self._next_data()
File "/home/nlpir/miniconda3/envs/cjy_llama/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 677, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/home/nlpir/miniconda3/envs/cjy_llama/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/nlpir/miniconda3/envs/cjy_llama/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 51, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "../finetune_llama3.py", line 255, in getitem
ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer, self.max_len)
File "../finetune_llama3.py", line 192, in preprocess
assert len(input_id) == len(target)
AssertionError
实际情况是input_id一直比target长度大1.
我在6块1080ti运行的,shell脚本内容如下:
NCCL_P2P_DISABLE=1
NCCL_IB_DISABLE=1
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5
torchrun
--nproc_per_node 6
--nnodes 1
--node_rank 0
--master_addr localhost
--master_port 6601
../finetune_llama3.py
--model_name_or_path "../model_hub/LLM-Research/Meta-Llama-3-8B-Instruct/"
--data_path "../data/Belle_sampled_qwen.json"
--fp16 True
--output_dir "../output/llama3_8B_qlora"
--num_train_epochs 100
--per_device_train_batch_size 1
--per_device_eval_batch_size 1
--gradient_accumulation_steps 16
--evaluation_strategy "no"
--save_strategy "steps"
--save_steps 5
--save_total_limit 1
--learning_rate 1e-5
--weight_decay 0.1
--adam_beta2 0.95
--warmup_ratio 0.01
--lr_scheduler_type "cosine"
--logging_steps 1
--report_to "none"
--model_max_length 4096
--gradient_checkpointing True
--lazy_preprocess True
--deepspeed "../config/ds_config_zero2.json"
--use_lora
--load_in_4bit
--q_lora
只是改了下CUDA_VISIBLE_DEVICES和nproc_per_node ,并且把bf16改为fp16.
The text was updated successfully, but these errors were encountered: