-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_llm.py
71 lines (65 loc) · 2.2 KB
/
train_llm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from unsloth import FastLanguageModel
from unsloth import is_bfloat16_supported
from trl import SFTTrainer
from transformers import TrainingArguments
from datasets import load_dataset
from general import relative_path
MAX_SEQ_LENGTH: int = 2048 # maximum number of tokens the model can process
DTYPE = None
LOAD_IN_4BIT: bool = True
BASE_MODEL_NAME = 'unsloth/Qwen2-0.5B-Instruct-bnb-4bit'
def main() -> None:
"""
Train the language model
:return: None
"""
# https://github.com/unsloth/unsloth/
dataset = load_dataset(
'json',
data_files=relative_path('data/ingredients/synthetic/train.jsonl'),
split='train'
)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=BASE_MODEL_NAME,
max_seq_length=MAX_SEQ_LENGTH,
dtype=DTYPE,
load_in_4bit=LOAD_IN_4BIT,
)
# Do model patching and add fast LoRA weights
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj',
'gate_proj', 'up_proj', 'down_proj', ],
lora_alpha=16,
lora_dropout=0, # Supports any, but = 0 is optimized
bias='none', # Supports any, but = 'none' is optimized
use_gradient_checkpointing=False, # True or 'unsloth' for very long context
random_state=3407,
max_seq_length=MAX_SEQ_LENGTH,
use_rslora=False, # We support rank stabilized LoRA
loftq_config=None, # And LoftQ
)
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
dataset_text_field='text',
max_seq_length=MAX_SEQ_LENGTH,
tokenizer=tokenizer,
args=TrainingArguments(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_steps=16,
max_steps=256,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=1,
output_dir='llm_models',
optim='adamw_8bit',
seed=3407,
),
)
trainer.train()
model.save_pretrained_gguf('./llm_models', tokenizer, quantization_method='q4_k_m')
if __name__ == '__main__':
main()