diff --git a/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/README.md b/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/README.md new file mode 100644 index 00000000000..d6297bda5f2 --- /dev/null +++ b/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/README.md @@ -0,0 +1,141 @@ +# LoRA Fine-Tuning on ChatGLM3-6B with IPEX-LLM + +This example ports [ChatGLM3-6B lora_finetune](https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo/lora_finetune.ipynb) demo to IPEX-LLM on [Intel Arc GPU](../../README.md). + +### 1. Install + +```bash +conda create -n llm python=3.11 +conda activate llm +pip install "jieba>=0.42.1" +pip install "ruamel_yaml>=0.18.6" +pip install "rouge_chinese>=1.0.3" +pip install "jupyter>=1.0.0" +pip install "datasets>=2.18.0" +pip install "peft>=0.10.0" +pip install "typer" +pip install "sentencepiece" +pip install "nltk" +pip install "numpy<2.0.0" +# below command will install intel_extension_for_pytorch==2.1.10+xpu as default +pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ +pip install oneccl_bind_pt==2.1.100 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ +``` + +### 2. Configures OneAPI Environment Variables +```bash +source /opt/intel/oneapi/setvars.sh +``` + +### 3. LoRA Fine-Tune on ChatGLM3-6B + +First, download the dataset: we use `AdvertiseGen` to finetune ChatGLM3-6B in the following, and please now get it from [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1), and unzip it in the current directory. Then, process the dataset with the below script: + +```bash +python process_advertise_gen_dataset.py +``` + +Then, './AdvertiseGen' will be converted to './AdvertiseGen_fix'. Now, we have prepared the dataset, and are going to start LoRA fine-tuning on ChatGLM3-6B. + +#### 3.1. Fine-Tune with a Single Arc Card + +Start the fine-tuning by: + + +```bash +bash lora_finetuning_on_chatglm3_6b_with_1_arc_card.sh +``` + +Then, you will get output are as below: + +```bash +2024-06-27 13:47:02,680 - root - INFO - intel_extension_for_pytorch auto imported +Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00, 6.47it/s] +2024-06-27 13:47:03,794 - ipex_llm.transformers.utils - INFO - Converting the current model to bf16 format...... +[2024-06-27 13:47:04,105] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to xpu (auto detect) +trainable params: 487,424 || all params: 6,244,071,424 || trainable%: 0.0078 +PeftModelForCausalLM( + (base_model): LoraModel( + (model): ChatGLMForConditionalGeneration( + (transformer): ChatGLMModel( + (embedding): Embedding( + (word_embeddings): Embedding(65024, 4096) + ) + (rotary_pos_emb): RotaryEmbedding() + (encoder): GLMTransformer( + (layers): ModuleList( + (0-27): 28 x GLMBlock( + (input_layernorm): RMSNorm() + (self_attention): SelfAttention( + (query_key_value): LoraLowBitLinear( + (base_layer): BF16Linear(in_features=4096, out_features=4608, bias=True) + (lora_dropout): ModuleDict( + (default): Dropout(p=0.1, inplace=False) + ) + (lora_A): ModuleDict( + (default): Linear(in_features=4096, out_features=2, bias=False) + ) + (lora_B): ModuleDict( + (default): Linear(in_features=2, out_features=4608, bias=False) + ) + (lora_embedding_A): ParameterDict() + (lora_embedding_B): ParameterDict() + (qa_pool): Identity() + ) + (core_attention): CoreAttention( + (attention_dropout): Dropout(p=0.0, inplace=False) + ) + (dense): BF16Linear(in_features=4096, out_features=4096, bias=False) + ) + (post_attention_layernorm): RMSNorm() + (mlp): MLP( + (dense_h_to_4h): BF16Linear(in_features=4096, out_features=27392, bias=False) + (dense_4h_to_h): BF16Linear(in_features=13696, out_features=4096, bias=False) + ) + ) + ) + (final_layernorm): RMSNorm() + ) + (output_layer): BF16Linear(in_features=4096, out_features=65024, bias=False) + ) + ) + ) +) +--> Model + +--> model has 0.487424M params + +train_dataset: Dataset({ + features: ['input_ids', 'labels'], + num_rows: 114599 +}) +val_dataset: Dataset({ + features: ['input_ids', 'output_ids'], + num_rows: 1070 +}) +test_dataset: Dataset({ + features: ['input_ids', 'output_ids'], + num_rows: 1070 +}) +--> Sanity check + '[gMASK]': 64790 -> -100 + 'sop': 64792 -> -100 + '<|user|>': 64795 -> -100 + '': 30910 -> -100 + '\n': 13 -> -100 +...... + +# Here it takes time to finish the whole fine-tuning + +...... + +Training completed. Do not forget to share your model on huggingface.co/models =) + + +{'train_runtime': xxxx.xxxx, 'train_samples_per_second': x.xxx, 'train_steps_per_second': x.xxx, 'train_loss': xx.xx, 'epoch': x.xx} +100%|████████████████████████████████████████████████████████████████████████████████████████████| 3000/3000 [xx:xx<00:00, x.xxit/s] +***** Running Prediction ***** + Num examples = 1070 + Batch size = 4 +100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [xx:xx<00:00, x.xxs/it] +``` diff --git a/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_config.yaml b/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_config.yaml new file mode 100644 index 00000000000..107ba91e54a --- /dev/null +++ b/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_config.yaml @@ -0,0 +1,47 @@ +# This is ported from https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo/configs/lora.yaml +data_config: + train_file: train.json + val_file: dev.json + test_file: dev.json + num_proc: 16 +max_input_length: 128 +max_output_length: 128 +training_args: + # see `transformers.Seq2SeqTrainingArguments` + output_dir: ./output + max_steps: 3000 + # needed to be fit for the dataset + learning_rate: 5e-5 + # settings for data loading + per_device_train_batch_size: 1 + dataloader_num_workers: 16 + remove_unused_columns: false + # settings for saving checkpoints + save_strategy: steps + save_steps: 500 + # settings for logging + log_level: info + logging_strategy: steps + logging_steps: 10 + # settings for evaluation + per_device_eval_batch_size: 4 + evaluation_strategy: steps + eval_steps: 1000 + # settings for optimizer + # adam_epsilon: 1e-6 + # uncomment the following line to detect nan or inf values + # debug: underflow_overflow + predict_with_generate: true + # see `transformers.GenerationConfig` + generation_config: + max_new_tokens: 128 + # set your absolute deepspeed path here + #deepspeed: ds_zero_2.json + # set to true if train with cpu. + use_cpu: false +peft_config: + peft_type: LORA + task_type: CAUSAL_LM + r: 2 + lora_alpha: 8 + lora_dropout: 0.1 diff --git a/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_finetune_chatglm.py b/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_finetune_chatglm.py new file mode 100644 index 00000000000..2e99b8e4406 --- /dev/null +++ b/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_finetune_chatglm.py @@ -0,0 +1,601 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/tloen/alpaca-lora/blob/main/finetune.py +# +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This example is ported from https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo/finetune_hf.py + + +import os +import jieba +import dataclasses as dc +import functools +from collections.abc import Callable, Mapping, Sequence +from pathlib import Path +from typing import Annotated, Any, Optional, Union +import numpy as np +import ruamel.yaml as yaml +import torch +import typer +from datasets import Dataset, DatasetDict, NamedSplit, Split, load_dataset +from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu +from peft import ( + PeftConfig, + PeftModelForCausalLM, + get_peft_config +) +from rouge_chinese import Rouge +from torch import nn +from transformers import ( + AutoTokenizer, + EvalPrediction, + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizer, + PreTrainedTokenizerFast, + Seq2SeqTrainingArguments, AutoConfig, +) +from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq + +from transformers import Seq2SeqTrainer as _Seq2SeqTrainer + +ModelType = Union[PreTrainedModel, PeftModelForCausalLM] +TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] +app = typer.Typer(pretty_exceptions_show_locals=False) + + +class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq): + def __call__(self, features, return_tensors=None): + output_ids = ( + [feature['output_ids'] for feature in features] + if 'output_ids' in features[0].keys() + else None + ) + if output_ids is not None: + max_output_length = max(len(out) for out in output_ids) + if self.pad_to_multiple_of is not None: + max_output_length = ( + ( + max_output_length + self.pad_to_multiple_of - 1) // + self.pad_to_multiple_of * self.pad_to_multiple_of + ) + for feature in features: + remainder = [self.tokenizer.pad_token_id] * ( + max_output_length - len(feature['output_ids']) + ) + if isinstance(feature['output_ids'], list): + feature['output_ids'] = feature['output_ids'] + remainder + else: + feature['output_ids'] = np.concatenate( + [feature['output_ids'], remainder] + ).astype(np.int64) + return super().__call__(features, return_tensors) + + +class Seq2SeqTrainer(_Seq2SeqTrainer): + def prediction_step( + self, + model: nn.Module, + inputs: dict[str, Any], + prediction_loss_only: bool, + ignore_keys=None, + **gen_kwargs, + ) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + if self.args.predict_with_generate: + output_ids = inputs.pop('output_ids') + input_ids = inputs['input_ids'] + loss, generated_tokens, labels = super().prediction_step( + model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs + ) + generated_tokens = generated_tokens[:, input_ids.size()[1]:] + if self.args.predict_with_generate: + labels = output_ids + return loss, generated_tokens, labels + + +def _resolve_path(path: Union[str, Path]) -> Path: + return Path(path).expanduser().resolve() + + +def _sanity_check( + input_ids: Sequence[int], + output_ids: Sequence[int], + tokenizer: PreTrainedTokenizer, +): + print('--> Sanity check') + for in_id, out_id in zip(input_ids, output_ids): + if in_id == 0: + continue + if in_id in tokenizer.tokenizer.index_special_tokens: + in_text = tokenizer.tokenizer.index_special_tokens[in_id] + else: + in_text = tokenizer.decode([in_id]) + print(f'{repr(in_text):>20}: {in_id} -> {out_id}') + + +@functools.cache +def _get_yaml_parser() -> yaml.YAML: + parser = yaml.YAML(typ='safe', pure=True) + parser.indent(mapping=2, offset=2, sequence=4) + parser.default_flow_style = False + return parser + + +@dc.dataclass +class DataConfig(object): + train_file: str + val_file: Optional[str] = None + test_file: Optional[str] = None + + num_proc: Optional[int] = None + + @property + def data_format(self) -> str: + return Path(self.train_file).suffix + + @property + def data_files(self) -> dict[NamedSplit, str]: + return { + split: data_file + for split, data_file in zip( + [Split.TRAIN, Split.VALIDATION, Split.TEST], + [self.train_file, self.val_file, self.test_file], + ) + if data_file is not None + } + + +@dc.dataclass +class FinetuningConfig(object): + data_config: DataConfig + + max_input_length: int + max_output_length: int + + training_args: Seq2SeqTrainingArguments = dc.field( + default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output') + ) + peft_config: Optional[PeftConfig] = None + + def __post_init__(self): + if not self.training_args.do_eval or self.data_config.val_file is None: + # skips the evaluation stage when `do_eval` or `eval_file` is not provided + self.training_args.do_eval = False + self.training_args.evaluation_strategy = 'no' + self.data_config.val_file = None + else: + self.training_args.per_device_eval_batch_size = ( + self.training_args.per_device_eval_batch_size + or self.training_args.per_device_train_batch_size + ) + + @classmethod + def from_dict(cls, **kwargs) -> 'FinetuningConfig': + training_args = kwargs.get('training_args', None) + if training_args is not None and not isinstance( + training_args, Seq2SeqTrainingArguments + ): + gen_config = training_args.get('generation_config') + # TODO: a bit hacky + if not isinstance(gen_config, GenerationConfig): + training_args['generation_config'] = GenerationConfig( + **gen_config + ) + kwargs['training_args'] = Seq2SeqTrainingArguments(**training_args) + + data_config = kwargs.get('data_config') + if not isinstance(data_config, DataConfig): + kwargs['data_config'] = DataConfig(**data_config) + + peft_config = kwargs.get('peft_config', None) + if peft_config is not None and not isinstance(peft_config, PeftConfig): + kwargs['peft_config'] = get_peft_config(peft_config) + return cls(**kwargs) + + @classmethod + def from_file(cls, path: Union[str, Path]) -> 'FinetuningConfig': + path = _resolve_path(path) + kwargs = _get_yaml_parser().load(path) + return cls.from_dict(**kwargs) + + +def _load_datasets( + data_dir: Path, + data_format: str, + data_files: dict[NamedSplit, str], + num_proc: Optional[int], +) -> DatasetDict: + if data_format in ('.csv', '.json', '.jsonl'): + dataset_dct = load_dataset( + data_format[1:], + data_dir=data_dir, + data_files=data_files, + num_proc=num_proc, + ) + else: + err_msg = f"Cannot load dataset in the '{data_format}' format." + raise NotImplementedError(err_msg) + + return dataset_dct + + +class DataManager(object): + def __init__(self, data_dir: str, data_config: DataConfig): + self._num_proc = data_config.num_proc + + self._dataset_dct = _load_datasets( + _resolve_path(data_dir), + data_config.data_format, + data_config.data_files, + self._num_proc, + ) + + def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]: + return self._dataset_dct.get(split, None) + + def get_dataset( + self, + split: NamedSplit, + process_fn: Callable[[dict[str, Any]], dict[str, Any]], + batched: bool = True, + remove_orig_columns: bool = True, + ) -> Optional[Dataset]: + orig_dataset = self._get_dataset(split) + if orig_dataset is None: + return + + if remove_orig_columns: + remove_columns = orig_dataset.column_names + else: + remove_columns = None + return orig_dataset.map( + process_fn, + batched=batched, + remove_columns=remove_columns, + num_proc=self._num_proc, + ) + + +def print_model_size(model: PreTrainedModel): + print("--> Model") + total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"\n--> model has {total_params / 1e6}M params\n") + + +def process_batch( + batch: Mapping[str, Sequence], + tokenizer: PreTrainedTokenizer, + max_input_length: int, + max_output_length: int, +) -> dict[str, list]: + batched_tools = batch.get('tools', None) + batched_conv = batch['conversations'] + batched_input_ids = [] + batched_labels = [] + + if batched_tools is None: + batched_tools = [None] * len(batched_conv) + + for tools, conv in zip(batched_tools, batched_conv): + input_ids, loss_masks = [ + tokenizer.get_command('[gMASK]'), + tokenizer.get_command('sop'), + ], [False, False] + + if tools is not None: + raise NotImplementedError() + + for message in conv: + if message['role'] in ('system', 'user'): + loss_mask_val = False + else: + loss_mask_val = True + + if message['role'] == 'tool': + raise NotImplementedError() + else: + new_input_ids = tokenizer.build_single_message( + message['role'], '', message['content'] + ) + new_loss_masks = [loss_mask_val] * len(new_input_ids) + + input_ids += new_input_ids + loss_masks += new_loss_masks + + input_ids.append(tokenizer.eos_token_id) + loss_masks = [False, *loss_masks] + labels = [] + for input_id, mask in zip(input_ids, loss_masks): + if mask: + labels.append(input_id) + else: + labels.append(-100) + max_length = max_input_length + max_output_length + 1 + batched_input_ids.append(input_ids[:max_length]) + batched_labels.append(labels[:max_length]) + return {'input_ids': batched_input_ids, 'labels': batched_labels} + + +def process_batch_eval( + batch: Mapping[str, Sequence], + tokenizer: PreTrainedTokenizer, + max_input_length: int, + max_output_length: int, +) -> dict[str, list]: + batched_tools = batch.get('tools', None) + batched_conv = batch['conversations'] + batched_input_ids = [] + # To avoid computing loss, we do not provide the `labels` field in the input dictionary. + batched_output_ids = [] + + if batched_tools is None: + batched_tools = [None] * len(batched_conv) + + for tools, conv in zip(batched_tools, batched_conv): + input_ids = [ + tokenizer.get_command('[gMASK]'), + tokenizer.get_command('sop'), + ] + + if tools is not None: + raise NotImplementedError() + + for message in conv: + if len(input_ids) >= max_input_length: + break + if message['role'] == 'tool': + raise NotImplementedError() + else: + new_input_ids = tokenizer.build_single_message( + message['role'], '', message['content'] + ) + if message['role'] == 'assistant': + output_prompt, output_ids = ( + new_input_ids[:1], + new_input_ids[1:], + ) + output_ids.append(tokenizer.eos_token_id) + batched_input_ids.append( + input_ids[:max_input_length] + output_prompt[:1] + ) + batched_output_ids.append(output_ids[:max_output_length]) + input_ids += new_input_ids + return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids} + + +# Not sure if this is necessary, can set it to half. +# If train with cpu, cast all params to fp32 instead of trainable ones. +def _prepare_model_for_training(model: nn.Module, use_cpu: bool): + for param in model.parameters(): + if param.requires_grad or use_cpu: + param.data = param.data.to(torch.float32) + + +def load_tokenizer_and_model( + model_dir: str, + peft_config: Optional[PeftConfig] = None, +) -> tuple[PreTrainedTokenizer, nn.Module]: + tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) + if peft_config is not None: + if peft_config.peft_type.name == "PREFIX_TUNING": + config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) + config.pre_seq_len = peft_config.num_virtual_tokens + config.use_cache = False + model = AutoModelForCausalLM.from_pretrained( + model_dir, + trust_remote_code=True, + config=config, + ) + if peft_config.peft_type.name == "LORA": + from ipex_llm.transformers import AutoModelForCausalLM + from ipex_llm.transformers.qlora import get_peft_model + import os + os.environ["ACCELERATE_USE_XPU"] = "true" + model = AutoModelForCausalLM.from_pretrained( + model_dir, + trust_remote_code=True, + load_in_low_bit="bf16", + optimize_model=False, + empty_init=False, + use_cache=False, + torch_dtype=torch.bfloat16 + ) + + model = get_peft_model(model, peft_config) + model.print_trainable_parameters() + print(model) + else: + model = AutoModelForCausalLM.from_pretrained( + model_dir, + trust_remote_code=True, + empty_init=False, + use_cache=False + ) + print_model_size(model) + return tokenizer, model + + +def compute_metrics(eval_preds: EvalPrediction, tokenizer: PreTrainedTokenizer): + batched_pred_ids, batched_label_ids = eval_preds + + metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []} + for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids): + pred_txt = tokenizer.decode(pred_ids).strip() + label_txt = tokenizer.decode(label_ids).strip() + pred_tokens = list(jieba.cut(pred_txt)) + label_tokens = list(jieba.cut(label_txt)) + rouge = Rouge() + scores = rouge.get_scores(' '.join(pred_tokens), ' '.join(label_tokens)) + for k, v in scores[0].items(): + metrics_dct[k].append(round(v['f'] * 100, 4)) + metrics_dct['bleu-4'].append( + sentence_bleu( + [label_tokens], + pred_tokens, + smoothing_function=SmoothingFunction().method3, + ) + ) + return {k: np.mean(v) for k, v in metrics_dct.items()} + + +@app.command() +def main( + data_dir: Annotated[str, typer.Argument(help='')], + model_dir: Annotated[ + str, + typer.Argument( + help='A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file.' + ), + ], + config_file: Annotated[str, typer.Argument(help='')], + auto_resume_from_checkpoint: str = typer.Argument( + default='', + help='If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training' + ), + +): + ft_config = FinetuningConfig.from_file(config_file) + tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config) + data_manager = DataManager(data_dir, ft_config.data_config) + + train_dataset = data_manager.get_dataset( + Split.TRAIN, + functools.partial( + process_batch, + tokenizer=tokenizer, + max_input_length=ft_config.max_input_length, + max_output_length=ft_config.max_output_length, + ), + batched=True, + ) + print('train_dataset:', train_dataset) + val_dataset = data_manager.get_dataset( + Split.VALIDATION, + functools.partial( + process_batch_eval, + tokenizer=tokenizer, + max_input_length=ft_config.max_input_length, + max_output_length=ft_config.max_output_length, + ), + batched=True, + ) + if val_dataset is not None: + print('val_dataset:', val_dataset) + test_dataset = data_manager.get_dataset( + Split.TEST, + functools.partial( + process_batch_eval, + tokenizer=tokenizer, + max_input_length=ft_config.max_input_length, + max_output_length=ft_config.max_output_length, + ), + batched=True, + ) + if test_dataset is not None: + print('test_dataset:', test_dataset) + + # checks encoded dataset + _sanity_check( + train_dataset[0]["input_ids"], train_dataset[0]["labels"], tokenizer + ) + + # turn model to fp32 + _prepare_model_for_training(model, ft_config.training_args.use_cpu) + + ft_config.training_args.generation_config.pad_token_id = ( + tokenizer.pad_token_id + ) + ft_config.training_args.generation_config.eos_token_id = [ + tokenizer.eos_token_id, + tokenizer.get_command('<|user|>'), + tokenizer.get_command('<|observation|>'), + ] + model.gradient_checkpointing_enable() + model.enable_input_require_grads() + + use_tokenizer = True + if ft_config.peft_config is not None: + use_tokenizer = False if ft_config.peft_config.peft_type == "LORA" else True + + trainer = Seq2SeqTrainer( + model=model, + args=ft_config.training_args, + data_collator=DataCollatorForSeq2Seq( + tokenizer=tokenizer, + padding='longest', + return_tensors='pt', + ), + train_dataset=train_dataset, + eval_dataset=val_dataset.select(list(range(50))), + tokenizer=tokenizer if use_tokenizer else None, # LORA does not need tokenizer + compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer), + ) + + if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None: + trainer.train() + else: + output_dir = ft_config.training_args.output_dir + dirlist = os.listdir(output_dir) + checkpoint_sn = 0 + for checkpoint_str in dirlist: + if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1: + checkpoint = int(checkpoint_str.replace("checkpoint-", "")) + if checkpoint > checkpoint_sn: + checkpoint_sn = checkpoint + if auto_resume_from_checkpoint.upper() == "YES": + if checkpoint_sn > 0: + model.gradient_checkpointing_enable() + model.enable_input_require_grads() + checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn)) + print("resume checkpoint from checkpoint-" + str(checkpoint_sn)) + trainer.train(resume_from_checkpoint=checkpoint_directory) + else: + trainer.train() + else: + if auto_resume_from_checkpoint.isdigit(): + if int(auto_resume_from_checkpoint) > 0: + checkpoint_sn = int(auto_resume_from_checkpoint) + model.gradient_checkpointing_enable() + model.enable_input_require_grads() + checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn)) + print("resume checkpoint from checkpoint-" + str(checkpoint_sn)) + trainer.train(resume_from_checkpoint=checkpoint_directory) + else: + print(auto_resume_from_checkpoint, + "The specified checkpoint sn(" + + auto_resume_from_checkpoint + + ") has not been saved. Please search for the correct chkeckpoint in the model output directory") + + # test stage + if test_dataset is not None: + trainer.predict(test_dataset) + + +if __name__ == '__main__': + app() diff --git a/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_finetuning_on_chatglm3_6b_with_1_arc_card.sh b/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_finetuning_on_chatglm3_6b_with_1_arc_card.sh new file mode 100644 index 00000000000..a4fc762b17b --- /dev/null +++ b/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_finetuning_on_chatglm3_6b_with_1_arc_card.sh @@ -0,0 +1,21 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# You can also set the remote model repository to a local model path +python lora_finetune_chatglm.py \ + ./AdvertiseGen_fix \ + THUDM/chatglm3-6b \ + ./lora_config.yaml diff --git a/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/process_advertise_gen_dataset.py b/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/process_advertise_gen_dataset.py new file mode 100644 index 00000000000..18f1040ef01 --- /dev/null +++ b/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/process_advertise_gen_dataset.py @@ -0,0 +1,59 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This is ported from https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo/lora_finetune.ipynb + + +import json +from typing import Union +from pathlib import Path + + +def _resolve_path(path: Union[str, Path]) -> Path: + return Path(path).expanduser().resolve() + + +def _mkdir(dir_name: Union[str, Path]): + dir_name = _resolve_path(dir_name) + if not dir_name.is_dir(): + dir_name.mkdir(parents=True, exist_ok=False) + + +def convert_adgen(data_dir: Union[str, Path], save_dir: Union[str, Path]): + def _convert(in_file: Path, out_file: Path): + _mkdir(out_file.parent) + with open(in_file, encoding='utf-8') as fin: + with open(out_file, 'wt', encoding='utf-8') as fout: + for line in fin: + dct = json.loads(line) + sample = {'conversations': [{'role': 'user', 'content': dct['content']}, + {'role': 'assistant', 'content': dct['summary']}]} + fout.write(json.dumps(sample, ensure_ascii=False) + '\n') + + data_dir = _resolve_path(data_dir) + save_dir = _resolve_path(save_dir) + + train_file = data_dir / 'train.json' + if train_file.is_file(): + out_file = save_dir / train_file.relative_to(data_dir) + _convert(train_file, out_file) + + dev_file = data_dir / 'dev.json' + if dev_file.is_file(): + out_file = save_dir / dev_file.relative_to(data_dir) + _convert(dev_file, out_file) + + +convert_adgen('./AdvertiseGen', './AdvertiseGen_fix')