From 0945080b6172a6b08578dcf08322e09820b1c123 Mon Sep 17 00:00:00 2001 From: kong13661 Date: Sat, 30 Sep 2023 02:52:28 +0800 Subject: [PATCH] pipeline delete _internal_config_dict_converted convert init_subclass to wrapper trainer_configmixin sha256 and pretrained. fix wrapper download from obs remove test code image_classification_trainer append sys add dict mindspore dataset fix bug findFromPreConfig fix path bug fix path bug pathlib class check add document checkfiles autopipeline documentation doc init --- docs/en/source/pipeline/overview.md | 196 +++++ tinyms/pipeline/__init__.py | 4 + tinyms/pipeline/auto_pipeline.py | 5 + tinyms/pipeline/configmixin.py | 695 ++++++++++++++++++ tinyms/pipeline/download.py | 204 +++++ .../pipeline/image_classification_trainer.py | 339 +++++++++ tinyms/pipeline/trainer_configmixin.py | 287 ++++++++ 7 files changed, 1730 insertions(+) create mode 100644 docs/en/source/pipeline/overview.md create mode 100644 tinyms/pipeline/__init__.py create mode 100644 tinyms/pipeline/auto_pipeline.py create mode 100644 tinyms/pipeline/configmixin.py create mode 100644 tinyms/pipeline/download.py create mode 100644 tinyms/pipeline/image_classification_trainer.py create mode 100644 tinyms/pipeline/trainer_configmixin.py diff --git a/docs/en/source/pipeline/overview.md b/docs/en/source/pipeline/overview.md new file mode 100644 index 00000000..2be8141c --- /dev/null +++ b/docs/en/source/pipeline/overview.md @@ -0,0 +1,196 @@ +# Pipeline + +## Introduction + +To enable fast downloading of pre-trained models and perform model training and inference with minimal code, tinyms provides pipeline API. With the pipeline, you can simply use a few lines of code to download the model from the cloud and perform inference locally. Besides, you can also use just a few lines of code to train and fine-tune the model. + +## Examples + +### Load a pretrained model + +``` python +>>> from tinyms.pipeline import AutoModelPileline +>>> model = AutoModelPileline.from_pretrained("model_cache_path", "model_repo") +>>> pred = model(your_input) +``` + +### Load a trainer + +```python +>>> from tinyms.pipeline import AutoTrainerPipeline +>>> trainer = AutoTrainerPipeline.from_pretrained("trainer_cache_path", "trainer_repo") +>>> trainer.init_model(model) +>>> trainer.train() +``` + +You can also pass some arguments such as epoch to the `train` method. The passed arguments will override the default value. + +After training, you can save your model by invoking `save_pretrained` method. + +```python +>>> model.save_pretrained("path_to_save") +``` + +This will save config and checkpoint to `path_to_save`. + + + + +## For Develeper + +## details for config + +### structure of folder + +A the folder structure using `from_pretrained` to load is displayed below. + + + model_cache_path + ├── demo_model + ├── model + │ ├── config.json + │ ├── filelist.txt + │ └── weight.ckpt + ├── model_code + │ ├── filelist.txt + │ ├── networks_test.py + └── trainer + ├── config.json + ├── filelist.txt + └── train_config + └── config.json + +`model_cache_path` is the path to save the downloaded repo. + +`demo_model` is the name of downloaded repo. + +`model` is the folder to save config and checkpoint. + +`model_code` is the folder to save the extra-code. If no extra-code, there is no this folder. + +`trainer` is the folder to save config about training the model. + +The config folder can be generated by invoking `save_pretrained`. + +### structure of config.json + +Blocks below is an example of the `config.json` in `trainer`. + +``` json +{ + "__module__": "tinyms.pipeline.image_classification_trainer.Trainer", + "__version__": "0.3.2", + "build_config": null, + "eval_config": null, + "fit_config": null, + "loss": { + "loss": "SoftmaxCrossEntropyWithLogits", + "params": { + "sparse": true + } + }, + "metrics": [ + "accuracy" + ], + "optim": { + "optimizer": "Momentum", + "params": { + "learning_rate": 0.1, + "momentum": 0.9 + } + }, + "predict_config": null, + "train_config": { + "__module__": "tinyms.pipeline.image_classification_trainer.TrainConfig", + "__subfolder__": null + } +} +``` + +This is an example of `config.json`. `save_pretrained` will record the `__module__` needed and the tinyms version. Other keys such as `loss` is used to instantiates module `tinyms.pipeline.image_classification_trainer.Trainer`. If a dictionary has key `__subfolder__`, the details of this arguments will be saved into a subfoleder with name of this dictionary. + + +## Define new model + +To define new model class, you need define a new class inherited from `ConfigMixin`. + +```python +>>> from tinyms.pipeline import ConfigMixin, save_config, Ignore, SubFolder +>>> class Model(ConfigMixin): +... @save_config +... def __init__( +... self, a: Ignore=1, b: SubFolder=2, c: Union[Ignore, int]=3, d: Union[SubFolder: int]=4): +>>> model = Model() +>>> model.save_pretrained("config_path") +``` + +`ConfigMixin` is a base class for pipeline mixin. This class provides methods forsaving and loading model config. A class that inherits from this class can apply `@save_config` to `__init__` method to record the config of the class. + +If you wrap `__init__` with `@save_config`, the argument of Ignore type will not be saved into the config. The SubFolder type will be saved into a sub folder. + +Set `__prefix__` to change the name of the folder to save config and checkpoint. +Set `__weight__` to change the name of the checkpoint file. + + +tinyms also provide a function `wrap_config_mixin` to define a new model. + +This function will add ConfigMixin to the base class of the class and wrap `__init__` method with `@save_config`. + +```python +>>> class Model: +... ... +>>> +>>> Model = wrap_config_mixin(Model) +>>> model = Model() +>>> model.save_pretrained("config_path") +``` + +### Extra-code + +If the defined model is not in tinyms, you need save the code containing the defination of the model class. The code should save to `model_code` folder in the repo folder. When calling `from_pretrained`, the path of `model_code` will append to `sys.path`. + + +## Define a new trainer + +To define new trainer class, you need define a new class inherited from `TrainerConfigMixin`. + +`TrainerConfigMixin` is a base class for trainer pipeline mixin. This class provides methods for saving and loading model config. A class that is inherited from this class can apply `@save_config` to `__init__` method to record the config of the class. + +If you wrap `__init__` with `@save_config`, the argument of Ignore type will not be saved into the config. The SubFolder type will be saved into a sub folder. + +For a trainer, you may want to implement some methods like `train`, `eval`, `predict`. You can use `@set_from_config` to set the arguments from the config. The `FromConfig` type arguments having the following property. + +1. The arguments that are not in the config will be set to default value. +2. The arguments set in running time will override the arguments in the config. + +Once you wrap a method with `@set_from_config`, you can use `BaseArgsFromConfig` to generate the arguments class. To use `BaseArgsFromConfig`, you should wrap the `__init__` method with `@copy_signature(Trainer.method)`. The arguments of `__init__` method should be `__init__(self, *args, **kwargs)`. + +You should define the arguments class in `__init__` method. The default name of the arguments class is `{method_name}_config`. You can change the name by passing the name to `@set_from_config(name)`. + +```python +>>> class Trainer(TrainerConfigMixin): +... @save_config +... def __init__(self, train_args=None): +... self.train_args = train_args +... +... @set_from_config +... def train(self, epoch: FromConfig): +... ... +>>> +>>> class TrainConfig(BaseArgsFromConfig): +... @copy_signature(Trainer.train) +... def __init__(self, *args, **kwargs): +... super().__init__(*args, **kwargs) +>>> +>>> train_config = TrainConfig(2) +>>> trainer = Trainer(train_config=train_config) +>>> +>>> trainer.save_pretrained('model_config') +>>> new_trainer = trainer.from_pretrained('model_config') +>>> +>>> new_trainer.train() +>>> +>>> new_trainer.train(4) + +``` + diff --git a/tinyms/pipeline/__init__.py b/tinyms/pipeline/__init__.py new file mode 100644 index 00000000..e5cac21b --- /dev/null +++ b/tinyms/pipeline/__init__.py @@ -0,0 +1,4 @@ +from .auto_pipeline import AutoModelPipeline, AutoTrainerPipeline +from .configmixin import ConfigMixin, wrap_config_mixin, save_config +from .trainer_configmixin import TrainerConfigMixin +from .download import make_filelist diff --git a/tinyms/pipeline/auto_pipeline.py b/tinyms/pipeline/auto_pipeline.py new file mode 100644 index 00000000..0fd68279 --- /dev/null +++ b/tinyms/pipeline/auto_pipeline.py @@ -0,0 +1,5 @@ +from .configmixin import ConfigMixin +from .trainer_configmixin import TrainerConfigMixin + +AutoModelPipeline = ConfigMixin +AutoTrainerPipeline = TrainerConfigMixin diff --git a/tinyms/pipeline/configmixin.py b/tinyms/pipeline/configmixin.py new file mode 100644 index 00000000..8917a2e7 --- /dev/null +++ b/tinyms/pipeline/configmixin.py @@ -0,0 +1,695 @@ +import copy + +import functools +import inspect +from typing import Dict, Callable, Union +import json +import pathlib +import logging +from typing import NewType, Any, get_origin, get_args, TypeVar +from importlib.metadata import version +from mindspore.train.serialization import load_checkpoint, save_checkpoint +from .download import RepoDownloaderWithCode, make_filelist +import shutil + +logger = logging.getLogger(__name__) + + +Ignore = NewType('Ignore', Any) +SubFolder = NewType('SubFolder', Any) + +_T = TypeVar('_T') + + +def _key_var_keyword(func_signature): + name = None + for key, v in func_signature.parameters.items(): + if v.kind == inspect.Parameter.VAR_KEYWORD: + name = key + break + return name + + +def _key_var_position(func_signature): + name = None + for key, v in func_signature.parameters.items(): + if v.kind == inspect.Parameter.VAR_POSITIONAL: + name = key + break + return name + + +def wrap_config_mixin(cls: _T) -> _T: + """ + Wrap a class with ConfigMixin. This function will add ConfigMixin to the base class of + the class and wrap `__init__` method with `@save_config`. + + Args: + cls (object): a class to be wrapped. + + Returns: + return a wrapped class. This class is a subclass of ConfigMixin and the `__init__` + method is wrapped with `@save_config`. + + Examples: + >>> class Model: + ... ... + >>> + >>> Model = wrap_config_mixin(Model) + >>> model = Model() + >>> model.save_pretrained("config_path") + """ + class Wrapped(cls, ConfigMixin): + ... + Wrapped.__module__ = cls.__module__ + Wrapped.__name__ = cls.__name__ + Wrapped.__init__ = save_config(cls.__init__) + return Wrapped + + +def _walk(obj, func: Callable, walk_path=None, replace=False, invoke_func_on_first=False): + if walk_path is None: + walk_path = [] + elif not isinstance(walk_path, (list, tuple)): + walk_path = [walk_path] + + having_walk_path = 'walk_path' in inspect.signature(func).parameters + + if isinstance(obj, dict): + new_dict = [] + for k, v in obj.items(): + kwargs = {'walk_path': walk_path + [k]} if having_walk_path else {} + new_dict.append((func(k, **kwargs), func(_walk(v, func, replace=replace, **kwargs), **kwargs))) + if replace: + obj = dict(new_dict) + + elif isinstance(obj, list): + new_list = [] + for idx, v in enumerate(obj): + kwargs = {'walk_path': walk_path + [idx]} if having_walk_path else {} + new_list.append(func(_walk(v, func, replace=replace, **kwargs), **kwargs)) + if replace: + obj = new_list + + elif isinstance(obj, tuple): + new_tuple = [] + for idx, v in enumerate(obj): + kwargs = {'walk_path': walk_path + [idx]} if having_walk_path else {} + new_tuple.append(func(_walk(v, func, replace=replace, **kwargs), **kwargs)) + if replace: + obj = tuple(new_tuple) + + elif isinstance(obj, PRIMITIVE_TYPE + CONFIG_TYPE) or obj is inspect.Parameter.empty: + kwargs = {'walk_path': walk_path[:]} if having_walk_path else {} + r = func(obj, **kwargs) + if replace: + obj = r + + else: + raise TypeError(f'walk obj type {type(obj)} is not supported') + + if invoke_func_on_first: + kwargs = {'walk_path': walk_path[:]} if having_walk_path else {} + return func(obj, **kwargs) + return obj + + +def _func_args_dict_with_default_value(func, args: list, kwargs: dict): + init_args = inspect.signature(func) + kwargs = copy.copy(kwargs) + + init_args_keys = list(init_args.parameters.keys()) + + args_kwargs = {} + for idx, v in enumerate(args): + args_kwargs[init_args.parameters[init_args_keys[idx]].name] = v + kwargs.update(args_kwargs) + + params = dict(init_args.parameters) + for k in kwargs: + if k in params: + del params[k] + for k in params: + kwargs[k] = params[k].default + return kwargs + + +def _attrtype_args_list_factory(Type): + def _args_list(func): + init_args = inspect.signature(func) + args_with_type = [] + for k, v in init_args.parameters.items(): + if v.annotation is Type: + args_with_type.append(k) + elif get_origin(v.annotation) is Union: + type_list = get_args(v.annotation) + if Type in type_list: + args_with_type.append(k) + return args_with_type + return _args_list + + +_ignore_args_list = _attrtype_args_list_factory(Ignore) +_subfolder_args_list = _attrtype_args_list_factory(SubFolder) + + +class TypeConverter: + + @classmethod + def obj_to_config(cls, obj): + if isinstance(obj, CONFIG_TYPE): + module_path = cls._module_path(obj) + _internal_config = obj._config + return cls._config_class_to_dict(module_path, _internal_config) + return obj + + @classmethod + def obj_to_config_and_seperate_save(cls, obj, walk_path): + if not isinstance(walk_path, (list, tuple)): + walk_path = [walk_path] + if isinstance(obj, CONFIG_TYPE): + module_path = cls._module_path(obj) + path = pathlib.Path() + for p in walk_path: + path = path / str(p) + _internal_config = obj._internal_config_dict_converted_seperate_save(path, module_path) + if obj.__subfolder_save__: + _internal_config = {'__subfolder__': None} + return cls._config_class_to_dict(module_path, _internal_config) + return obj + + @classmethod + def _load_config(cls, config, walk_path): + if isinstance(config, dict) and '__module__' in config: + path = pathlib.Path() + for p in walk_path: + path = path / str(p) + module = cls._load_module(config) + if '__subfolder__' in config: + config = module._load_config(path) + return config + return config + + @staticmethod + def _module_path(obj): + return f"{obj.__module__}.{obj.__class__.__name__}" + + @classmethod + def _config_class_to_dict(cls, module_path, _internal_config): + cls.attribution_check(_internal_config) + config = {'__module__': module_path} + config.update(_internal_config) + return config + + @staticmethod + def _add_module_to_dict(config, module): + config = copy.copy(config) + config['__module__'] = module + return config + + @classmethod + def config_to_obj(cls, config): + if isinstance(config, Dict): + if '__module__' in config: + module = cls._load_module(config) + config = cls._remove_internal_argument(config) + if issubclass(module, CONFIG_TYPE): + config = module._from_config_pre_process(config) + return module(**config) + return config + + @classmethod + def _load_module(cls, config): + cls._version_check(config) + loaded_cls = _locate(config['__module__']) + if not isinstance(loaded_cls, ConfigMixin): + loaded_cls = wrap_config_mixin(loaded_cls) + return loaded_cls + + @classmethod + def _load_check(cls, config): + if '__module__' not in config: + raise KeyError(f'__module__ is not found in config: {config}') + + @staticmethod + def is_valid_type(obj): + if isinstance(obj, VALID_TYPE) or obj is inspect.Parameter.empty: + return True + raise TypeError(f'obj type {type(obj)} is not supported') + + @staticmethod + def attribution_check(config): + if '__module__' in config: + raise KeyError('__module__ is not allowed as a argument in __init__ method') + if '__version__' in config: + raise KeyError('__version__ is not allowed as a argument in __init__ method') + + @staticmethod + def _version_check(config): + pass + + @staticmethod + def _remove_internal_argument(config): + config = copy.copy(config) + if '__module__' in config: + del config['__module__'] + + if '__version__' in config: + del config['__version__'] + return config + + @staticmethod + def add_version(config): + config = copy.copy(config) + config['__version__'] = version('tinyms') + return config + + @classmethod + def check_module(cls, config_dict, module): + if '__module__' not in config_dict: + raise KeyError(f'__module__ is not found in {module}') + # if config_dict['__module__'] != cls._module_path(module): + # raise KeyError(f'__module__ is not {module}') + + +def _locate(path: str): + """ + Locate an object by name or dotted path, importing as necessary. + This is similar to the pydoc function `locate`, except that it checks for + the module from the given path from back to front. + """ + if path == "": + raise ImportError("Empty path") + from importlib import import_module + from types import ModuleType + + parts = [part for part in path.split(".")] + for part in parts: + if not len(part): + raise ValueError( + f"Error loading '{path}': invalid dotstring." + + "\nRelative imports are not supported." + ) + assert len(parts) > 0 + part0 = parts[0] + try: + obj = import_module(part0) + except Exception as exc_import: + raise ImportError( + f"Error loading '{path}':\n{repr(exc_import)}" + + f"\nAre you sure that module '{part0}' is installed?" + ) from exc_import + for m in range(1, len(parts)): + part = parts[m] + + parent_dotpath = ".".join(parts[:m]) + if isinstance(obj, ModuleType): + mod = ".".join(parts[: m + 1]) + try: + obj = import_module(mod) + continue + except ModuleNotFoundError as exc_import: + try: + obj = getattr(obj, part) + except AttributeError: + raise ImportError( + f"Error loading '{path}':\n{repr(exc_import)}" + + f"\nAre you sure that '{part}' is an attribute of '{parent_dotpath}'?" + ) from exc_import + except Exception as exc_import: + raise ImportError( + f"Error loading '{path}':\n{repr(exc_import)}" + ) from exc_import + return obj + + +def save_config(func): + @functools.wraps(func) + def wrapped(*args, **kwargs): + self = args[0] + if not isinstance(self, ConfigMixin): + raise RuntimeError( + f"`@save_config` was applied to {self.__class__.__name__} init method, but this class does " + "not inherit from `ConfigMixin`." + ) + if func.__name__ != '__init__': + raise RuntimeError( + f"`@save_config` was applied to {self.__class__.__name__} init method, but this mathod is " + f"`{func.__name__}`." + ) + + config = _func_args_dict_with_default_value(func, args, kwargs) + config = self._filter_ignore_keys(config, func) + config = self._set_subfolder_save(config, func) + self.__type_converter__.attribution_check(config) + + self._config_value_type_check(config) + self._internal_config = config + func(*args, **kwargs) + + return wrapped + + +class ConfigMixin: + """ + A base class for pipeline mixin. This class provides methods for + saving and loading model config. A class that inherits from this class + can apply `@save_config` to `__init__` method to record the + config of the class. + If you wrap `__init__` with `@save_config`, the argument of Ignore type + will not be saved into the config. The SubFolder type will be saved into a sub folder. + + Set `__prefix__` to change the name of the folder to save config and checkpoint. + Set `__weight__` to change the name of the checkpoint file. + + Examples: + >>> class Model(ConfigMixin): + ... @save_config + ... def __init__( + ... self, a: Ignore=1, b: SubFolder=2, c: Union[Ignore, int]=3, d: Union[SubFolder, int]=4): + ... ... + >>> + >>> model = Model() + >>> model.save_pretrained('model_config') + >>> new_model = Model.from_pretrained('model_config') + + >>> config = model.config + >>> new_model = Model.from_config(config) + + """ + __save_name__ = None + __type_converter__ = TypeConverter() + __subfolder_save__ = None + _internal_config: Dict + __prefix__ = "model" + __weight__ = 'weight.ckpt' + __cache_path_before__: pathlib.Path + __repo__: str + + @classmethod + def from_config(cls, config: Dict): + """Instantiates a Model from config dictionary. + + Args: + config (Dict): A dictionary specifying the config of the model. This can be + obtained by calling `model.config` or `model.load_config('config_path'))`. + + Examples: + >>> class Model(ConfigMixin): + ... @save_config + ... def __init__( + ... self, a: Ignore=1, b: SubFolder=2, c: Union[Ignore, int]=3, d: Union[SubFolder, int]=4): + ... ... + >>> + >>> model = Model() + >>> config = model.config + >>> new_model = Model.from_config(config) + + Returns: + A instance of a subclass of ConfigMixin. + """ + if cls is not ConfigMixin: + cls.__type_converter__.check_module(config, cls) + return _walk(config, cls.__type_converter__.config_to_obj, replace=True, invoke_func_on_first=True) + + def save_config(self, path: Union[str, pathlib.Path]): + """ + Save model config to a json file. + + Args: + path (Union[str, pathlib.Path]): The path to save the config. + + Examples: + >>> class Model(ConfigMixin): + ... @save_config + ... def __init__(self): + ... ... + >>> + >>> model = Model() + >>> model.save_config("config_path") + """ + if not hasattr(self, '_internal_config'): + raise AttributeError(f'`{self.__class__.__name__}` object has no atrribute `_internal_config`.' + 'You can apply @save_config to init method to record `_internal_config` automatically.') + + path = pathlib.Path(path) + path = path / self.__prefix__ + path.mkdir(parents=True, exist_ok=True) + path = path / self._save_name() + if path.exists(): + raise FileExistsError(f'{path} already exists') + + config = self.__type_converter__.obj_to_config_and_seperate_save(self, path.parent) + config = self.__type_converter__.add_version(config) + with path.open("w") as f: + json.dump(config, f, indent=2, sort_keys=True) + + logger.info(f"Model config saved in directory {path.parent}") + + @property + def config(self): + return self.__type_converter__.obj_to_config(self) + + @property + def _config(self): + return _walk(self._internal_config, self.__type_converter__.obj_to_config, replace=True) + + def _internal_config_dict_converted_seperate_save(self, path, module_name=None): + path = pathlib.Path(path) + + config = _walk( + self._internal_config, + self.__type_converter__.obj_to_config_and_seperate_save, + replace=True, + walk_path=path,) + + if self.__subfolder_save__: + path.mkdir(parents=True, exist_ok=True) + with (path / self._save_name()).open("w") as f: + if module_name is not None: + _config = self.__type_converter__._add_module_to_dict(config, module_name) + else: + _config = config + json.dump(_config, f, indent=2, sort_keys=True) + + return config + + @property + def _internal_config_dict(self): + return copy.copy(self._internal_config) + + @classmethod + def _from_config_pre_process(cls, config): + cls._config_value_type_check(config) + cls._class_attribution_check(config) + config = cls._filter_unexpected_keys(config) + return config + + @classmethod + def load_config(cls, path: Union[str, pathlib.Path]): + """ + Load model config from a path. + + Args: + path (Union[str, pathlib.Path]): Path to load the config. + + Returns: + A instance of a subclass of ConfigMixin. + """ + path = pathlib.Path(path) + path = path / cls.__prefix__ + return cls._load_config(path) + + @classmethod + def _load_config(cls, path): + path = pathlib.Path(path) + path = path / cls._save_name() + if not path.exists(): + raise FileNotFoundError(f'{path} not found') + with open(path, 'r') as f: + config = json.load(f) + cls.__type_converter__._load_check(config) + return _walk(config, cls.__type_converter__._load_config, replace=True, walk_path=path.parent) + + @classmethod + def _config_value_type_check(cls, config): + for k, v in config.items(): + try: + _walk(v, cls.__type_converter__.is_valid_type) + except TypeError: + raise TypeError(f'The value of {k}={v} is not valid type') + + @classmethod + def _class_attribution_check(cls, config): + config_self = _func_args_dict_with_default_value(cls.__init__, [], {}) + config_self = cls._filter_ignore_keys(config_self, cls.__init__) + config_self_keys = set(config_self.keys()) + config_keys = set(config.keys()) + missing_keys = config_self_keys - config_keys + unexpected_keys = config_keys - config_self_keys + + if len(missing_keys) > 0: + non_default_key = [k for k in missing_keys if config_self[k] is inspect.Parameter.empty] + if len(non_default_key) > 0: + raise AttributeError(f'missing keys {list(non_default_key)} to instantiate {cls.__name__}') + missing_keys = tuple(missing_keys) + if len(missing_keys) == 1: + missing_keys = missing_keys[0] + logger.warning(f'missing keys {missing_keys} to instantiate {cls.__name__}, using default value') + + if len(unexpected_keys) > 0: + unexpected_keys = tuple(unexpected_keys) + if len(unexpected_keys) == 1: + unexpected_keys = unexpected_keys[0] + logger.warning(f'unexpected keys {unexpected_keys} to instantiate {cls.__name__}, ignore them') + + @classmethod + def _filter_unexpected_keys(cls, config): + config_self = _func_args_dict_with_default_value(cls.__init__, [], {}) + config_self = cls._filter_ignore_keys(config_self, cls.__init__) + config_self_keys = set(config_self.keys()) + config_keys = set(config.keys()) + unexpected_keys = config_keys - config_self_keys + + config = copy.copy(config) + for k in unexpected_keys: + del config[k] + return config + + @classmethod + def _filter_ignore_keys(cls, config, func): + config = copy.copy(config) + del config['self'] + ignore_keys = _ignore_args_list(func) + sig = inspect.signature(func) + var_position = _key_var_position(sig) + var_keyword = _key_var_keyword(sig) + if var_position is not None: + ignore_keys.append(var_position) + if var_keyword is not None: + ignore_keys.append(var_keyword) + if ignore_keys: + for k in ignore_keys: + if k in config: + del config[k] + return config + + def _save_checkpoint(self, path): + path = pathlib.Path(path) + path = path / self.__prefix__ / self.__weight__ + save_checkpoint(self, str(path)) + + def _load_checkpoint(self, path): + path = pathlib.Path(path) + path = path / self.__prefix__ / self.__weight__ + load_checkpoint(str(path), self) + + def save_pretrained(self, path: Union[str, pathlib.Path]): + """ + Save model config and checkpoint to a path. + + Args: + path (Union[str, pathlib.Path]): The path to save the config and checkpoint. + + Examples: + >>> class Model(ConfigMixin): + ... @save_config + ... def __init__(self): + ... ... + >>> + >>> model = Model() + >>> model.save_pretrained("model_path") + """ + path = pathlib.Path(path) + if hasattr(self, "__repo__"): + code_path = self.__cache_path_before__ / f"{self.__repo__}_code" + code_new_path = path / f"{path.name}_code" + if code_new_path.exists(): + code_new_path.unlink() + shutil.copytree(code_path, code_new_path) + + self.save_config(path) + self._save_checkpoint(path) + path = path / self.__prefix__ + make_filelist(path) + + @classmethod + def from_pretrained( + cls, path: Union[str, pathlib.Path], + repo, checkfiles: bool = True, download: bool = True): + """ + Load model config and checkpoint from a path or a repo. + + Args: + path (Union[str, pathlib.Path]): Path to save repo. Defaults to None. + + repo (str): The repo name. + + checkfiles (bool, optional): If this is set to False, this method will not check + the files in `path`, and will not download from `repo`. Defaults to True. + + download (bool, optional): If this is set to False, this method will not download. + If download is true, this method will first check whether the repo has been + downloaded or completed, if not, it will download the wrong or missing files. + If download is false, this method will check the whether the local repo in `path` + is completed, if not, it will raise an Error. + + Returns: + A instance of a subclass of ConfigMixin. + """ + if not download: + path = pathlib.Path(path) / repo + repo = None + assert path is not None or repo is not None, \ + 'You have to set path or repo to load pretrained model' + path = pathlib.Path(path) + _path = path + if repo is None: + _path = pathlib.Path(path) / cls.__prefix__ + downloader = RepoDownloaderWithCode(_path, repo, checkfiles) + downloader.download() + + __repo__ = path.name if repo is None else repo + __path__ = path.parent if repo is None else path + + if repo is not None: + path = path / (repo).split('/')[0] + loaded_cls = cls.load_config(path) + loaded_cls = cls.from_config(loaded_cls) + + cls._check_cls_loaded(loaded_cls) + loaded_cls._load_checkpoint(path) + loaded_cls.__cache_path_before__ = __path__ + loaded_cls.__repo__ = __repo__ + return loaded_cls + + @classmethod + def _check_cls_loaded(cls, loaded_cls): + if cls is not ConfigMixin: + assert loaded_cls is cls, f"The repo loaded is not the same as the \ + class {cls.__name__}, it is {loaded_cls.__class__.__name__}" + + @classmethod + def _set_subfolder_save(cls, config, func): + subfolder_keys = _subfolder_args_list(func) + for k in subfolder_keys: + if k in config: + if config[k] is not None: + if not isinstance(config[k], CONFIG_TYPE): + raise TypeError(f'{k} must be a ConfigMixin') + config[k].__subfolder_save__ = True + return config + + @classmethod + def _save_name(cls): + save_name = cls.__save_name__ or 'config.json' + if not save_name.endswith('.json'): + save_name += '.json' + return save_name + + +ITERALE_TYPE = (list, tuple, set, dict) +PRIMITIVE_TYPE = (str, int, float, bool, type(None)) +CONFIG_TYPE = (ConfigMixin,) + +VALID_TYPE = ITERALE_TYPE + PRIMITIVE_TYPE + CONFIG_TYPE diff --git a/tinyms/pipeline/download.py b/tinyms/pipeline/download.py new file mode 100644 index 00000000..8c060a36 --- /dev/null +++ b/tinyms/pipeline/download.py @@ -0,0 +1,204 @@ +from ..hub.utils.download import url_exist +import hashlib +import urllib +from urllib.request import urlretrieve, HTTPError, URLError +import pathlib +import tqdm + + +def sha256sum(file_name): + fp = open(file_name, 'rb') + content = fp.read() + fp.close() + m = hashlib.sha256() + m.update(content) + sha256 = m.hexdigest() + return sha256 + + +def is_directory_empty(path): + path = pathlib.Path(path) + if not path.exists(): + return True + if any(path.iterdir()): + return False + return True + + +def glob_files(path): + path = pathlib.Path(path) + all_files = [str(i.relative_to(path)) for i in path.rglob('*') if i.is_file()] + sha256 = [sha256sum(str(path / i)) for i in all_files] + return list(zip(all_files, sha256)) + + +def make_filelist(path): + path = pathlib.Path(path) + file_sha256 = glob_files(path) + file_sha256 = [' '.join(i) for i in file_sha256] + with (path / 'filelist.txt').open('w') as f: + f.write('\n'.join(file_sha256)) + + +class RepoDownloader: + __prefix__ = "https://kaiyuanzhixia.obs.cn-east-3.myhuaweicloud.com/" + __filelist__ = "filelist.txt" + + def __init__(self, path, repo="lenet5/model", force_download=False, checkfiles=True): + self.repo = repo + self.path = pathlib.Path(path) + if repo is not None: + self.path = self.path / repo + self.force_download = force_download + self.checkfiles = checkfiles + + def append_sys_path(self): + import sys + sys.path.append(str(self.path)) + + def download(self): + if not self.checkfiles: + return + if not self.validate_repo(): + if self.repo is None: + _info = self.path + else: + _info = self.repo + raise RuntimeError(f"Invalid repo: {_info}") + + if not is_directory_empty(self.path): + if not self.filelist.exists(): + if not self.force_download: + raise RuntimeError(f"repo cached is broken: {self.path}") + else: + return + else: + if not self.force_download: + return + + download_file_from_url( + self.get_file_url(self.__filelist__), save_path=self.filelist) + + self.filter_files() + self._download() + + def validate_repo(self): + if not self.checkfiles: + return True + if self.repo is None: + return self.filelist.exists() + + return url_exist(self.get_file_url(self.__filelist__)) + + @property + def repo_url(self): + return self.__prefix__ + f"{self.repo}" + + def get_file_url(self, filename): + if self.repo is None: + return None + return self.repo_url + "/" + filename + + def get_file_save_path(self, filename): + return self.path / filename + + def _download(self): + file_sha256 = self.all_files_sha256 + for f, s in tqdm.tqdm(file_sha256): + download_file_from_url( + self.get_file_url(f), hash_sha256=s, save_path=self.get_file_save_path(f)) + + @property + def filelist(self): + return self.path / self.__filelist__ + + @property + def all_files_sha256(self): + all_files = [] + with self.filelist.open() as f: + for line in f.readlines(): + file_sha256 = line.strip().split() + assert len(file_sha256) <= 2 + + if len(file_sha256) == 1: + file_sha256.append(None) + all_files.append(file_sha256) + + return all_files + + def filter_files(self): + file_sha256 = glob_files(self.path) + file_sha256 = dict(file_sha256) + all_files_sha256 = dict(self.all_files_sha256) + exclude_files = [] + exclude_files.append(str(self.filelist.relative_to(self.path))) + + for f in file_sha256: + if f not in exclude_files: + if f not in all_files_sha256: + (self.path / f).unlink() + else: + if file_sha256[f] != all_files_sha256[f]: + (self.path / f).unlink() + + +class RepoDownloaderWithCode: + def __init__(self, path, repo="lenet5/model", force_download=False, checkfiles=True): + self.repo = repo + self.path = pathlib.Path(path) + self.force_download = force_download + self.checkfiles = checkfiles + + self.model_repo_downloader = RepoDownloader( + self.path, repo=self.repo, force_download=self.force_download) + assert self.model_repo_downloader.validate_repo(), f"Invalid repo: {repo}" + + if repo is None: + self.code_repo_downloader = RepoDownloader( + str(self.path) + "_code", repo=self.repo, force_download=self.force_download) + else: + self.code_repo_downloader = RepoDownloader( + self.path, repo=self.repo + "_code", force_download=self.force_download) + + def download(self): + self.model_repo_downloader.download() + + if self.code_repo_downloader.validate_repo(): + self.code_repo_downloader.download() + self.code_repo_downloader.append_sys_path() + + +def download_file_from_url(url, hash_sha256=None, save_path='.'): + def reporthook(a, b, c): + percent = a * b * 100.0 / c + percent = 100 if percent > 100 else percent + if c > 0: + print("\rDownloading...%5.1f%%" % percent, end="") + + save_path = pathlib.Path(save_path) + if not save_path.parent.exists(): + save_path.parent.mkdir(parents=True) + if not save_path.exists(): + if url is None: + raise RuntimeError("A valid repo is not given.") + try: + opener = urllib.request.build_opener() + opener.addheaders = [('User-Agent', 'Mozilla/5.0')] + urllib.request.install_opener(opener) + urlretrieve(url, str(save_path), reporthook=reporthook) + except HTTPError as e: + raise Exception(e.code, e.msg, url) + except URLError as e: + raise Exception(e.errno, e.reason, url) + + # Check file integrity + if hash_sha256: + result = sha256sum(save_path) + result = result == hash_sha256 + if not result: + raise Exception('INTEGRITY ERROR: File: {} is not integral'.format(save_path)) + + +if __name__ == '__main__': + a = RepoDownloaderWithCode('cache_dir', force_download=True) + a.download() diff --git a/tinyms/pipeline/image_classification_trainer.py b/tinyms/pipeline/image_classification_trainer.py new file mode 100644 index 00000000..bf6cee4f --- /dev/null +++ b/tinyms/pipeline/image_classification_trainer.py @@ -0,0 +1,339 @@ +from .trainer_configmixin import TrainerConfigMixin, copy_signature, \ + FromConfig, set_from_config, BaseArgsFromConfig +from ..model import Model +from ..losses import SoftmaxCrossEntropyWithLogits +from ..optimizers import Momentum +from .configmixin import SubFolder, save_config, Ignore + + +class ImageClassificationTrainer(Model, TrainerConfigMixin): + @save_config + def __init__(self, + model: Ignore = None, + optim: dict = {'optimizer': 'Momentum', + 'params': {'learning_rate': 0.1, 'momentum': 0.9}}, + loss: dict = {'loss': 'SoftmaxCrossEntropyWithLogits', 'params': {'sparse': True}}, + metrics=['accuracy'], + train_config: SubFolder = None, + fit_config: SubFolder = None, + build_config: SubFolder = None, + eval_config: SubFolder = None, + predict_config: SubFolder = None): + self._model = model + self.optim = self._make_params(optim) + self.loss = self._make_params(loss) + self.metrics = metrics + self.train_config = train_config + self.fit_config = fit_config + self.build_config = build_config + self.eval_config = eval_config + self.predict_config = predict_config + + def _make_params(self, config): + if 'params' not in config: + config['params'] = {} + return config + + def _compile(self): + if self.optim['optimizer'] == 'Momentum': + optimizer = Momentum(params=self._network.trainable_params(), **self.optim['params']) + if self.loss['loss'] == 'SoftmaxCrossEntropyWithLogits': + loss_fn = SoftmaxCrossEntropyWithLogits(**self.loss['params']) + if isinstance(self.metrics, (tuple, list)): + self.metrics = set(self.metrics) + return super().compile(loss_fn, optimizer, self.metrics) + + def init_model(self, model): + """ + Send the model to trainer. + """ + if model is None: + model = self._model + super().__init__(model) + self._compile() + + @set_from_config + def train(self, + epoch: FromConfig, + train_dataset: FromConfig, + callbacks: FromConfig = None, + dataset_sink_mode: FromConfig = False, + sink_size: FromConfig = -1, + initial_epoch: FromConfig = 0, + **kwargs: FromConfig): + """ + Training API. + + When setting pynative mode or CPU, the training process will be performed with dataset not sink. + + Note: + If dataset_sink_mode is True, data will be sent to device. If the device is Ascend, features + of data will be transferred one by one. The limitation of data transmission per time is 256M. + + When dataset_sink_mode is True, the `step_end` method of the instance of Callback will be called at the end + of epoch. + + If dataset_sink_mode is True, dataset will be bound to this model and cannot be used by other models. + + If sink_size > 0, each epoch of the dataset can be traversed unlimited times until you get sink_size + elements of the dataset. The next epoch continues to traverse from the end position of the previous + traversal. + + The interface builds the computational graphs and then executes the computational graphs. However, when + the `Model.build` is executed first, it only performs the graphs execution. + + Args: + epoch (int): Total training epochs. Generally, train network will be trained on complete dataset per epoch. + If `dataset_sink_mode` is set to True and `sink_size` is greater than 0, each epoch will + train `sink_size` steps instead of total steps of dataset. + If `epoch` used with `initial_epoch`, it is to be understood as "final epoch". + train_dataset (Dataset): A training dataset iterator. If `loss_fn` is defined, the data and label will be + passed to the `network` and the `loss_fn` respectively, so a tuple (data, label) + should be returned from dataset. If there is multiple data or labels, set `loss_fn` + to None and implement calculation of loss in `network`, + then a tuple (data1, data2, data3, ...) with all data returned from dataset will be + passed to the `network`. + callbacks (Optional[list[Callback], Callback]): List of callback objects or callback object, + which should be executed while training. + Default: None. + dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. + Configure pynative mode or CPU, the training process will be performed with + dataset not sink. Default: True. + sink_size (int): Control the amount of data in each sink. `sink_size` is invalid if `dataset_sink_mode` + is False. + If sink_size = -1, sink the complete dataset for each epoch. + If sink_size > 0, sink sink_size data for each epoch. + Default: -1. + initial_epoch (int): Epoch at which to start train, it used for resuming a previous training run. + Default: 0. + + Examples: + >>> import mindspore as ms + >>> from mindspore import nn + >>> + >>> # For details about how to build the dataset, please refer to the tutorial + >>> # document on the official website. + >>> dataset = create_custom_dataset() + >>> net = Net() + >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> loss_scale_manager = ms.FixedLossScaleManager() + >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> model = ms.Model(net, loss_fn=loss, optimizer=optim, metrics=None, + ... loss_scale_manager=loss_scale_manager) + >>> model.train(2, dataset) + """ + train_dataset = train_dataset.run() + super().train(epoch, train_dataset, callbacks, dataset_sink_mode, sink_size, initial_epoch) + + @set_from_config + def fit(self, + epoch: FromConfig, + train_dataset: FromConfig, + valid_dataset: FromConfig = None, + valid_frequency: FromConfig = 1, + callbacks: FromConfig = None, + dataset_sink_mode: FromConfig = False, + valid_dataset_sink_mode: FromConfig = False, + sink_size: FromConfig = -1, + initial_epoch: FromConfig = 0): + """ + Fit API. + + Evaluation process will be performed during training process if `valid_dataset` is provided. + + More details please refer to `mindspore.Model.train` and `mindspore.Model.eval`. + + Args: + epoch (int): Total training epochs. Generally, train network will be trained on complete dataset per epoch. + If `dataset_sink_mode` is set to True and `sink_size` is greater than 0, each epoch will + train `sink_size` steps instead of total steps of dataset. + If `epoch` used with `initial_epoch`, it is to be understood as "final epoch". + train_dataset (Dataset): A training dataset iterator. If `loss_fn` is defined, the data and label will be + passed to the `network` and the `loss_fn` respectively, so a tuple (data, label) + should be returned from dataset. If there is multiple data or labels, set `loss_fn` + to None and implement calculation of loss in `network`, + then a tuple (data1, data2, data3, ...) with all data returned from dataset + will be passed to the `network`. + valid_dataset (Dataset): Dataset to evaluate the model. If `valid_dataset` is provided, evaluation process + will be performed on the end of training process. Default: None. + valid_frequency (int, list): Only relevant if `valid_dataset` is provided. If an integer, specifies + how many training epochs to run before a new validation run is performed, + e.g. `valid_frequency=2` runs validation every 2 epochs. + If a list, specifies the epochs on which to run validation, + e.g. `valid_frequency=[1, 5]` runs validation at the end of the 1st, 5th epochs. + Default: 1 + callbacks (Optional[list[Callback], Callback]): List of callback objects or callback object, + which should be executed while training. + Default: None. + dataset_sink_mode (bool): Determines whether to pass the train data through dataset channel. + Configure pynative mode or CPU, the training process will be performed with + dataset not sink. Default: True. + valid_dataset_sink_mode (bool): Determines whether to pass the validation data through dataset channel. + Default: True. + sink_size (int): Control the amount of data in each sink. `sink_size` is invalid if `dataset_sink_mode` + is False. + If sink_size = -1, sink the complete dataset for each epoch. + If sink_size > 0, sink sink_size data for each epoch. + Default: -1. + initial_epoch (int): Epoch at which to start train, it useful for resuming a previous training run. + Default: 0. + + Examples: + >>> import mindspore as ms + >>> from mindspore import nn + >>> + >>> # For details about how to build the dataset, please refer to the tutorial + >>> # document on the official website. + >>> train_dataset = create_custom_dataset() + >>> valid_dataset = create_custom_dataset() + >>> net = Net() + >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> model = ms.Model(net, loss_fn=loss, optimizer=optim, metrics={"accuracy"}) + >>> model.fit(2, train_dataset, valid_dataset) + """ + super().fit(epoch, train_dataset, valid_dataset, valid_frequency, callbacks, dataset_sink_mode, + valid_dataset_sink_mode, sink_size, initial_epoch) + + @set_from_config + def build(self, + train_dataset: FromConfig = None, + valid_dataset: FromConfig = None, + sink_size: FromConfig = -1, + epoch: FromConfig = 1): + """ + Build computational graphs and data graphs with the sink mode. + + .. warning:: + This is an experimental prototype that is subject to change or deletion. + + Note: + The interface builds the computational graphs, when the interface is executed first, 'Model.train' only + performs the graphs execution. Pre-build process only supports `GRAPH_MODE` and `Ascend` target currently. + It only supports dataset sink mode. + + Args: + train_dataset (Dataset): A training dataset iterator. If `train_dataset` is defined, training graphs will be + built. Default: None. + valid_dataset (Dataset): An evaluating dataset iterator. If `valid_dataset` is defined, evaluation graphs + will be built, and `metrics` in `Model` can not be None. Default: None. + sink_size (int): Control the amount of data in each sink. Default: -1. + epoch (int): Control the training epochs. Default: 1. + + Examples: + >>> import mindspore as ms + >>> from mindspore import nn + >>> + >>> # For details about how to build the dataset, please refer to the tutorial + >>> # document on the official website. + >>> dataset = create_custom_dataset() + >>> net = Net() + >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> loss_scale_manager = ms.FixedLossScaleManager() + >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> model = ms.Model(net, loss_fn=loss, optimizer=optim, metrics=None, + ... loss_scale_manager=loss_scale_manager) + >>> model.build(dataset, epoch=2) + >>> model.train(2, dataset) + """ + super().build(train_dataset, valid_dataset, sink_size, epoch) + + @set_from_config + def eval(self, + valid_dataset: FromConfig, + callbacks: FromConfig = None, + dataset_sink_mode: FromConfig = False): + """ + Evaluation API. + + Configure to pynative mode or CPU, the evaluating process will be performed with dataset non-sink mode. + + Note: + If dataset_sink_mode is True, data will be sent to device. At this point, the dataset will be bound to this + model, so the dataset cannot be used by other models. If the device is Ascend, features + of data will be transferred one by one. The limitation of data transmission per time is 256M. + + The interface builds the computational graphs and then executes the computational graphs. However, when + the `Model.build` is executed first, it only performs the graphs execution. + + Args: + valid_dataset (Dataset): Dataset to evaluate the model. + callbacks (Optional[list(Callback), Callback]): List of callback objects or callback object, + which should be executed while evaluation. + Default: None. + dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. + Default: True. + + Returns: + Dict, the key is the metric name defined by users and the value is the metrics value for + the model in the test mode. + + Examples: + >>> import mindspore as ms + >>> from mindspore import nn + >>> + >>> # For details about how to build the dataset, please refer to the tutorial + >>> # document on the official website. + >>> dataset = create_custom_dataset() + >>> net = Net() + >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> model = ms.Model(net, loss_fn=loss, optimizer=None, metrics={'acc'}) + >>> acc = model.eval(dataset, dataset_sink_mode=False) + """ + super().eval(valid_dataset, callbacks, dataset_sink_mode) + + @set_from_config + def predict(self, + *predict_data: FromConfig, + backend: FromConfig = None): + """ + Generate output predictions for the input samples. + + Args: + predict_data (Union[Tensor, list[Tensor], tuple[Tensor]], optional): + The predict data, can be a single tensor, + a list of tensor, or a tuple of tensor. + + Returns: + Tensor, array(s) of predictions. + + Examples: + >>> import numpy as np + >>> import mindspore as ms + >>> from mindspore import Tensor + >>> + >>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), ms.float32) + >>> model = ms.Model(Net()) + >>> result = model.predict(input_data) + """ + super().predict(*predict_data, backend) + + +class ImageClassificationTrainConfig(BaseArgsFromConfig): + @copy_signature(ImageClassificationTrainer.train) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class ImageClassificationFitConfig(BaseArgsFromConfig): + @copy_signature(ImageClassificationTrainer.fit) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class ImageClassificationBuildConfig(BaseArgsFromConfig): + @copy_signature(ImageClassificationTrainer.build) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class ImageClassificationEvalConfig(BaseArgsFromConfig): + @copy_signature(ImageClassificationTrainer.eval) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class ImageClassificationPredictConfig(BaseArgsFromConfig): + @copy_signature(ImageClassificationTrainer.predict) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) diff --git a/tinyms/pipeline/trainer_configmixin.py b/tinyms/pipeline/trainer_configmixin.py new file mode 100644 index 00000000..7cc71c70 --- /dev/null +++ b/tinyms/pipeline/trainer_configmixin.py @@ -0,0 +1,287 @@ +from .configmixin import ConfigMixin, _key_var_keyword, _key_var_position +from typing import NewType, Any, Union +from functools import partial, wraps +import inspect +from typing import TypeVar, Callable, Generic +import logging +from typing import get_origin, get_args +import pathlib + +logger = logging.getLogger(__name__) +FromConfig = NewType('FromConfig', Any) + + +def _is_from_pre_config(annotation): + if annotation is FromConfig: + return True + elif get_origin(annotation) is Union: + type_list = get_args(annotation) + if FromConfig in type_list: + return True + return False + + +_F = TypeVar('_F', bound=Callable[..., Any]) + + +class TrainerConfigMixin(ConfigMixin): + """ + A base class for trainer pipeline mixin. This class provides methods for + saving and loading model config. A class that inherits from this class + can apply `@save_config` to `__init__` method to record the + config of the class. + + If you wrap `__init__` with `@save_config`, the argument of Ignore type + will not be saved into the config. The SubFolder type will be saved into a sub folder. + + For a trainer, you may want to implement some methods like `train`, `eval`, `predict`. + You can use `@set_from_config` to set the arguments from the config. The FromConfig + type arguments having the following property. + + The arguments that are not in the config will be set to default value. The arguments set + in running time will override the arguments in the config. + + If you wrap a method with `@set_from_config`, you can use `BaseArgsFromConfig` + to generate the arguments class. To use `BaseArgsFromConfig`, you should wrap the `__init__` + method with `@copy_signature(Trainer.method)`. The arguments of `__init__` method should + be `__init__(self, *args, **kwargs)`. + + You should define the arguments class in `__init__` + method. The default name of the arguments class is `{method_name}_config`. You can + change the name by passing the name to `@set_from_config(name)`. + + + Examples: + >>> class Trainer(TrainerConfigMixin): + ... @save_config + ... def __init__(self, train_args=None): + ... self.train_args = train_args + ... + ... @set_from_config + ... def train(self, epoch: FromConfig): + ... ... + >>> + >>> class TrainConfig(BaseArgsFromConfig): + ... @copy_signature(Trainer.train) + ... def __init__(self, *args, **kwargs): + ... super().__init__(*args, **kwargs) + >>> + >>> train_config = TrainConfig(2) + >>> trainer = Trainer(train_config=train_config) + >>> + >>> trainer.save_pretrained('model_config') + >>> new_trainer = trainer.from_pretrained('model_config') + >>> + >>> new_trainer.train() + >>> + >>> new_trainer.train(4) + + """ + __prefix__ = "trainer" + + def _save_checkpoint(self, path): + raise NotImplementedError + + def _load_checkpoint(self, path): + raise NotImplementedError + + @classmethod + def from_pretrained(cls, path: Union[str, pathlib.Path], repo: str, checkfiles: bool = True, download: bool = True) -> 'TrainerConfigMixin': + """ + path (Union[str, pathlib.Path]): Path to save repo. Defaults to None. + + repo (str): The repo name. + + checkfiles (bool, optional): If this is set to False, this method will not check + the files in `path`, and will not download from `repo`. Defaults to True. + + download (bool, optional): If this is set to False, this method will not download. + If download is true, this method will first check whether the repo has been + downloaded or completed, if not, it will download the wrong or missing files. + If download is false, this method will check the whether the local repo in `path` + is completed, if not, it will raise an Error. + + Returns: + A instance of a subclass of TrainerConfigMixin. + """ + return super().from_pretrained(path, repo, checkfiles) + + def init_model(self, model=None): + ... + + def _compile(self): + ... + + @classmethod + def _check_cls_loaded(cls, loaded_cls): + if cls is not TrainerConfigMixin: + assert loaded_cls is cls, f"The repo loaded is not the same as the \ + class {cls.__name__}, it is {loaded_cls.__class__.__name__}" + + +class copy_signature(Generic[_F]): + """ + A decorator to copy the signature of a function to another function. + Using with `BaseArgsFromConfig` to generate the arguments class. + """ + + def __init__(self, target: _F) -> None: + self.target_signature = inspect.signature(target) + + def __call__(self, wrapped: Callable[..., Any]) -> _F: + def wrapped_with_signature(self_, *args, **kwargs): + self_.func_signature = self.target_signature + return wrapped(self_, *args, **kwargs) + wrapped_with_signature.__signature__ = self.target_signature + wrapped_with_signature.__name__ = wrapped.__name__ + return wrapped_with_signature + + +def _is_var_keyword_from_pre_config(func_signature): + if _key_var_keyword(func_signature) is not None: + return _is_from_pre_config( + func_signature.parameters[_key_var_keyword(func_signature)].annotation) + return False + + +def _func_keys_from_pre_config(func_signature): + preconfg = [] + flag = False + for v in func_signature.parameters.values(): + if not flag: + flag = True + continue + if _is_from_pre_config(v.annotation): + if v.name == _key_var_keyword(func_signature): + preconfg.append(f"**{v.name}") + elif v.name == _key_var_position(func_signature): + preconfg.append(f"*{v.name}") + else: + preconfg.append(v.name) + return tuple(preconfg) + + +def _func_keys_add_star_to_var(func_signature): + func_keys = [] + flag = False + for k in func_signature.parameters: + if not flag: + flag = True + continue + if k == _key_var_keyword(func_signature): + func_keys.append(f"**{k}") + elif k == _key_var_position(func_signature): + func_keys.append(f"*{k}") + else: + func_keys.append(k) + return tuple(func_keys) + + +class BaseArgsFromConfig(ConfigMixin): + """ + The base class to generate arguments class for `@set_from_config`. + """ + + def __init__(self, *args, **kwargs): + self.func_signature: inspect.Signature + if _key_var_position(self.func_signature) is not None: + raise AttributeError(f"Var position args {_key_var_position(self.func_signature)} is not supported") + + self.func_signature.bind(self, *args, **kwargs) + + not_in_preconfg_warning = [] + default_value_warning = [] + + self._internal_config = {} + allocated_kwargs = [] + func_keys = _func_keys_add_star_to_var(self.func_signature) + func_from_pre_config = _func_keys_from_pre_config(self.func_signature) + is_var_keyword_from_pre_config = _is_var_keyword_from_pre_config(self.func_signature) + + for idx, v in enumerate(args): + kwargs[func_keys[idx]] = v + + for k, v in kwargs.items(): + if k in func_keys: + if k in func_from_pre_config: + self._internal_config[k] = v + else: + not_in_preconfg_warning.append(k) + else: + if is_var_keyword_from_pre_config: + self._internal_config[k] = v + else: + not_in_preconfg_warning.append(k) + allocated_kwargs.append(k) + + for k in set(func_from_pre_config) - set(allocated_kwargs): + default_value_warning.append(k) + + if len(not_in_preconfg_warning) > 0: + logger.warning(f"Parameters {not_in_preconfg_warning} will be ignored in pre config") + if len(default_value_warning) > 0: + logger.warning(f"Parameters {default_value_warning} will use default value") + + def __getattr__(self, item): + return self._internal_config[item] + + def __getitem__(self, item): + return self._internal_config[item] + + def keys(self): + return self._internal_config.keys() + + def __contains__(self, key): + return key in self._internal_config + + +def set_from_config(config_name_or_callable): + """ + A decorator to set the arguments from the config. Using with `BaseArgsFromConfig`. + """ + def _set_from_config(func: Callable[..., Any], config_name=None): + if config_name is None: + config_name = f"{func.__name__}_config" + sig = inspect.signature(func) + keys = _func_keys_add_star_to_var(sig) + pre_config_keys = _func_keys_from_pre_config(sig) + if config_name is None: + config_name = f"{func.__name__}_config" + + @wraps(func) + def wrapper(self, *args, **kwargs): + if _key_var_position(sig) is not None: + raise AttributeError(f"Var position args {_key_var_position(sig)} is not supported") + + if not hasattr(self, config_name): + func(self, *args, **kwargs) + + allocated_keys = [] + for idx in range(len(args)): + if idx < len(keys): + k = keys[idx] + if k in sig.parameters: + allocated_keys.append(k) + + allocated_keys += list(kwargs.keys()) + pre_config = getattr(self, config_name) + + for k in set(pre_config_keys) - set(allocated_keys): + if k[0] != "*": + if k in pre_config: + kwargs[k] = pre_config[k] + allocated_keys.append(k) + + if _is_var_keyword_from_pre_config(sig) is not None: + for k in set(getattr(self, config_name).keys()) - set(allocated_keys): + kwargs[k] = getattr(self, config_name)[k] + allocated_keys.append(k) + + return func(self, *args, **kwargs) + + return wrapper + + if callable(config_name_or_callable): + return _set_from_config(config_name_or_callable, None) + + return partial(_set_from_config, config_name=config_name_or_callable)