Skip to content

Commit

Permalink
Fix error during merging adapter (#11145)
Browse files Browse the repository at this point in the history
  • Loading branch information
plusbang authored May 27, 2024
1 parent daf7b1c commit c9168b8
Showing 1 changed file with 66 additions and 0 deletions.
66 changes: 66 additions & 0 deletions python/llm/src/ipex_llm/transformers/training_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/peft/blob/v0.5.0/src/peft/config.py
# Copyright [yyyy] [name of copyright owner]

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def patch_prepare_ipex(self, *args):
Expand All @@ -58,6 +74,7 @@ def patch_prepare_ipex(self, *args):
is_sagemaker_mp_enabled,
is_accelerate_available,
is_torch_xpu_available,
is_peft_available,
is_sagemaker_dp_enabled,
is_torch_tpu_available,
is_torch_npu_available)
Expand All @@ -69,6 +86,8 @@ def patch_prepare_ipex(self, *args):
import os
import warnings
from datetime import timedelta
from huggingface_hub import hf_hub_download
from ipex_llm.utils.common import invalidInputError

if is_accelerate_available():
from accelerate.state import AcceleratorState, PartialState
Expand Down Expand Up @@ -196,3 +215,50 @@ def _setup_devices(self) -> "torch.device":
# patch transformer for xpu DDP traing
from transformers import TrainingArguments
TrainingArguments._setup_devices = _setup_devices

CONFIG_NAME = "adapter_config.json"


@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None, **kwargs):
# Avoid circular dependency .. TODO: fix this with a larger refactor
from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING

path = (
os.path.join(pretrained_model_name_or_path, subfolder)
if subfolder is not None
else pretrained_model_name_or_path
)

hf_hub_download_kwargs, class_kwargs, _ = cls._split_kwargs(kwargs)

if os.path.isfile(os.path.join(path, CONFIG_NAME)):
config_file = os.path.join(path, CONFIG_NAME)
else:
try:
config_file = hf_hub_download(pretrained_model_name_or_path, CONFIG_NAME,
subfolder=subfolder, **hf_hub_download_kwargs)
except Exception:
invalidInputError(False,
f"Can't find '{CONFIG_NAME}' at '{pretrained_model_name_or_path}'")

loaded_attributes = cls.from_json_file(config_file)

if "peft_type" in loaded_attributes:
peft_type = loaded_attributes["peft_type"]
config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_type]
else:
config_cls = cls

config = config_cls(**class_kwargs)

for key, value in loaded_attributes.items():
if hasattr(config, key):
setattr(config, key, value)

return config

# patch peft for merging adapter into the original model
if is_peft_available():
from peft.config import PeftConfigMixin
PeftConfigMixin.from_pretrained = from_pretrained

0 comments on commit c9168b8

Please sign in to comment.