diff --git a/python/llm/src/ipex_llm/transformers/training_patch.py b/python/llm/src/ipex_llm/transformers/training_patch.py index 0c1addd6ce2..7db17d5baaa 100644 --- a/python/llm/src/ipex_llm/transformers/training_patch.py +++ b/python/llm/src/ipex_llm/transformers/training_patch.py @@ -47,6 +47,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/huggingface/peft/blob/v0.5.0/src/peft/config.py +# Copyright [yyyy] [name of copyright owner] + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. def patch_prepare_ipex(self, *args): @@ -58,6 +74,7 @@ def patch_prepare_ipex(self, *args): is_sagemaker_mp_enabled, is_accelerate_available, is_torch_xpu_available, + is_peft_available, is_sagemaker_dp_enabled, is_torch_tpu_available, is_torch_npu_available) @@ -69,6 +86,8 @@ def patch_prepare_ipex(self, *args): import os import warnings from datetime import timedelta +from huggingface_hub import hf_hub_download +from ipex_llm.utils.common import invalidInputError if is_accelerate_available(): from accelerate.state import AcceleratorState, PartialState @@ -196,3 +215,50 @@ def _setup_devices(self) -> "torch.device": # patch transformer for xpu DDP traing from transformers import TrainingArguments TrainingArguments._setup_devices = _setup_devices + +CONFIG_NAME = "adapter_config.json" + + +@classmethod +def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None, **kwargs): + # Avoid circular dependency .. TODO: fix this with a larger refactor + from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING + + path = ( + os.path.join(pretrained_model_name_or_path, subfolder) + if subfolder is not None + else pretrained_model_name_or_path + ) + + hf_hub_download_kwargs, class_kwargs, _ = cls._split_kwargs(kwargs) + + if os.path.isfile(os.path.join(path, CONFIG_NAME)): + config_file = os.path.join(path, CONFIG_NAME) + else: + try: + config_file = hf_hub_download(pretrained_model_name_or_path, CONFIG_NAME, + subfolder=subfolder, **hf_hub_download_kwargs) + except Exception: + invalidInputError(False, + f"Can't find '{CONFIG_NAME}' at '{pretrained_model_name_or_path}'") + + loaded_attributes = cls.from_json_file(config_file) + + if "peft_type" in loaded_attributes: + peft_type = loaded_attributes["peft_type"] + config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_type] + else: + config_cls = cls + + config = config_cls(**class_kwargs) + + for key, value in loaded_attributes.items(): + if hasattr(config, key): + setattr(config, key, value) + + return config + +# patch peft for merging adapter into the original model +if is_peft_available(): + from peft.config import PeftConfigMixin + PeftConfigMixin.from_pretrained = from_pretrained