Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
firefighter-eric committed Sep 17, 2023
1 parent 6f1c066 commit 513f644
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 37 deletions.
5 changes: 2 additions & 3 deletions inference/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@

# %%
parser = ArgumentParser()
parser.add_argument('--model_path', '-m', type=str, default='/mnt/h/models/llama-7b')
parser.add_argument('--model_path', '-m', type=str, default='')
parser.add_argument('--lora_path', '-l', type=str, default='')
# parser.add_argument('--precision', '-p', type=str, default='bf16')
parser.add_argument('--prompt', '-t', type=str, default='你好')

args = parser.parse_args()
Expand All @@ -27,7 +26,7 @@
print(output)
while True:
prompt = input('Human: ')
output = pipeline(prompt, max_new_tokens=32, do_sample=False, top_p=1, num_return_sequences=1, return_full_text=False)
output = pipeline(prompt, max_new_tokens=32, num_beams=10, num_return_sequences=5, return_full_text=False)
for o in output:
print('Bot:', o['generated_text'])
print('-' * 100)
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@ pandas
scipy

wandb
tensorboard
tensorboard

opencc
26 changes: 12 additions & 14 deletions training/config/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ do_eval: true
do_predict: false
do_train: true

# lora
lora: true
load_in_4bit: true
load_in_8bit: false
target_modules: [ q_proj, v_proj, k_proj, o_proj, gate_proj, down_proj, up_proj ]

# hyperparameters
per_device_eval_batch_size: 8
per_device_train_batch_size: 8
Expand All @@ -31,9 +37,11 @@ optim: adamw_torch
gradient_accumulation_steps: 1
gradient_checkpointing: true

# backend
torch_compile: true
#torch_compile_backend: None
#torch_compile_mode: None
deepspeed: None

# precision
bf16: true
Expand All @@ -45,13 +53,12 @@ fp16_opt_level: O1
tf32: true

# data
seed: 42
data_seed: None
dataloader_drop_last: false
dataloader_num_workers: 0
dataloader_pin_memory: true

deepspeed: None
disable_tqdm: false
remove_unused_columns: true

# eval
eval_delay: 0
Expand All @@ -62,9 +69,6 @@ prediction_loss_only: True
#greater_is_better: None
#metric_for_best_model: None

label_smoothing_factor: 0.0


# log
log_level: passive
log_level_replica: warning
Expand All @@ -75,10 +79,7 @@ logging_steps: 10
logging_strategy: steps
report_to:
- wandb

remove_unused_columns: true

#resume_from_checkpoint: None
disable_tqdm: false

# save
save_safetensors: false
Expand All @@ -87,9 +88,6 @@ save_strategy: steps
save_total_limit: 2
load_best_model_at_end: true

seed: 42
skip_memory_metrics: true


#resume_from_checkpoint: None


34 changes: 19 additions & 15 deletions training/train_gpt_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,14 @@ def __post_init__(self):
class MyTrainingArguments(TrainingArguments):
project_name: str = ''
# lora
lora: bool = False
target_modules: list[str] = None

def __post_init__(self):
super().__post_init__()
# wandb
os.environ['WANDB_PROJECT'] = self.project_name
self.run_name += f'-{time.time()}'
self.run_name += f'-{int(time.time())}'


# %% config
Expand Down Expand Up @@ -88,18 +89,20 @@ def __post_init__(self):

# %% model
model = load_model(**model_args.__dict__)
model.enable_input_require_grads()
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.05,
# target_modules=training_args.target_modules
# bias="all"
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

if training_args.lora:
model.enable_input_require_grads()
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.05,
# target_modules=training_args.target_modules
# bias="all"
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# %% optimizer
# if data_args.optim == 'adam_bf16':
Expand All @@ -122,6 +125,7 @@ def __post_init__(self):
tokenizer.save_pretrained(f'{training_args.output_dir}/best')

"""
CONFIG=''
torchrun --nnodes 1 --nproc-per-node 1 training/train_gpt_hf.py -c $CONFIG
config=''
python training/train_gpt_hf.py -c $config
torchrun --nnodes 1 --nproc-per-node 1 training/train_gpt_hf.py -c $config
"""
6 changes: 3 additions & 3 deletions utils/data/load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def load_data(path: str, tokenizer: PreTrainedTokenizer, max_length: int = 1024,
data = processor.load_json(path=path)
# else:
# data = processor.load_dataset_dict(path=path)
print(d)
print(d['train'][0]['input_ids'])
print(tokenizer.decode(d['train'][0]['input_ids']))
print(data)
print(data['train'][0]['input_ids'])
print(tokenizer.decode(data['train'][0]['input_ids']))
return data


Expand Down
2 changes: 1 addition & 1 deletion utils/model/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def load_model(model_name_or_path, model_type, **kwargs) -> PreTrainedModel:
args = {'device_map': 'auto'}
args = {'device_map': 'auto', 'torch_dtype': 'auto'}
if kwargs.get('load_in_8bit'):
args['load_in_8bit'] = True
quantization_config = BitsAndBytesConfig(
Expand Down

0 comments on commit 513f644

Please sign in to comment.