Skip to content
This repository has been archived by the owner on Jul 18, 2024. It is now read-only.

Commit

Permalink
simplify API
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouyu5 committed Oct 30, 2023
1 parent 12bcfd9 commit 484cad4
Showing 1 changed file with 41 additions and 9 deletions.
50 changes: 41 additions & 9 deletions e2eAIOK/deltatuner/deltatuner/deltatuner_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .tuner import DeltaLoraModel, DeltaLoraSearchSpace, DeltaSSFModel, DeltaSSFSearchSpace
from .search import SearchEngineFactory, Timer
from .search.utils import network_latency
from .utils import DeltaTunerType, get_deltatuner_model_state_dict, set_deltatuner_model_state_dict
from .utils import DeltaTunerType, get_deltatuner_model_state_dict, set_deltatuner_model_state_dict, BEST_MODEL_STRUCTURE_NAME
from typing import Any, Dict, List, Optional, Union

DELTATUNNER_TO_MODEL_MAPPING = {
Expand Down Expand Up @@ -218,14 +218,7 @@ def from_pretrained(
):
from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING
from .mapping import DELTATUNER_TYPE_TO_CONFIG_MAPPING, MODEL_TYPE_TO_DELTATUNER_MODEL_MAPPING
denas_config = kwargs.pop("denas_config", None)

best_structure_file = os.path.join(model_id, "best_model_structure.txt")
if os.path.isfile(best_structure_file):
denas_config.denas = True
denas_config.best_model_structure = best_structure_file
else:
denas_config.denas = False
denas_config = self._get_denas_config(model_id, **kwargs)

# load the config
if config is None:
Expand Down Expand Up @@ -262,6 +255,45 @@ def from_pretrained(
model.load_adapter(model_id, adapter_name, is_trainable=is_trainable, **kwargs)
return model

def _get_denas_config(self, model_id: str, **kwargs: Any):
hf_hub_download_kwargs, kwargs = self._split_kwargs(kwargs)

# load weights if any
path = (
os.path.join(model_id, hf_hub_download_kwargs["subfolder"])
if hf_hub_download_kwargs.get("subfolder", None) is not None
else model_id
)

if os.path.exists(os.path.join(path, BEST_MODEL_STRUCTURE_NAME)):
filename = os.path.join(path, BEST_MODEL_STRUCTURE_NAME)
else:
has_remote_structure_file = hub_file_exists(
model_id,
BEST_MODEL_STRUCTURE_NAME,
revision=hf_hub_download_kwargs.get("revision", None),
repo_type=hf_hub_download_kwargs.get("repo_type", None),
)

if has_remote_structure_file:
filename = hf_hub_download(
model_id,
BEST_MODEL_STRUCTURE_NAME,
**hf_hub_download_kwargs,
)
else:
raise ValueError(
f"Can't find structure for {model_id} in {model_id} or in the Hugging Face Hub. "
f"Please check that the file {BEST_MODEL_STRUCTURE_NAME} is present at {model_id}."
)

denas_config = DeltaTunerArguments()
denas_config.denas = True
denas_config.best_model_structure = filename

return denas_config


def load_adapter(self, model_id: str, adapter_name: str, is_trainable: bool = False, **kwargs: Any):
from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING
from .mapping import DELTATUNER_TYPE_TO_CONFIG_MAPPING
Expand Down

0 comments on commit 484cad4

Please sign in to comment.