Skip to content

Commit

Permalink
Adaptive batch size supports XPU (#3256)
Browse files Browse the repository at this point in the history
* draft implementation

* draft implemenetation

* success to execute adaptive bs w/ XPU

* refactor code

* align with pre-comimt

* add meta argument

* make CustomOTXHead pickeable

* fix bug

* draft implementation for nncf auto bs

* refactor a little bit

* import nncf patch according to task

* support nncf task

* minor bugfix

* refactor code

* align with pre-commit

* update unit test

* update unit test

* change bs search algo

* fix a bug occured when length of data is smaller than class number in class incremental case

* apply sync batch norm after adaptive bs

* update unit test

* align with pre-commit

* fix wrong fixture

* add comments
  • Loading branch information
eunwoosh authored Apr 3, 2024
1 parent 620dfd1 commit 6e33590
Show file tree
Hide file tree
Showing 24 changed files with 693 additions and 473 deletions.
10 changes: 5 additions & 5 deletions src/otx/algorithms/action/adapters/mmaction/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import time
from contextlib import nullcontext
from copy import deepcopy
from functools import partial
from typing import Dict, Optional, Union

import torch
Expand Down Expand Up @@ -324,12 +323,13 @@ def _train_model(
validate = bool(cfg.data.get("val", None))

if self._hyperparams.learning_parameters.auto_adapt_batch_size != BatchSizeAdaptType.NONE:
train_func = partial(train_model, meta=deepcopy(meta), model=deepcopy(model), distributed=False)
adapt_batch_size(
train_func,
cfg,
train_model,
model,
datasets,
validate,
cfg,
cfg.distributed,
meta=meta,
not_increase=(self._hyperparams.learning_parameters.auto_adapt_batch_size == BatchSizeAdaptType.SAFE),
)

Expand Down
19 changes: 11 additions & 8 deletions src/otx/algorithms/classification/adapters/mmcls/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import time
from contextlib import nullcontext
from copy import deepcopy
from functools import partial
from typing import Any, Dict, Optional, Type, Union

import torch
Expand Down Expand Up @@ -380,9 +379,6 @@ def _train_model(
htcore.hpu.ModuleCacher(max_graphs=10)(model=model.backbone, inplace=True)
htcore.hpu.ModuleCacher(max_graphs=10)(model=model.head, inplace=True)

if cfg.distributed:
convert_sync_batchnorm(model)

validate = bool(cfg.data.get("val", None))
if validate:
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
Expand Down Expand Up @@ -410,15 +406,22 @@ def _train_model(
)

if self._hyperparams.learning_parameters.auto_adapt_batch_size != BatchSizeAdaptType.NONE:
train_func = partial(train_model, meta=deepcopy(meta), model=deepcopy(model), distributed=False)
is_nncf = isinstance(self, NNCFBaseTask)
adapt_batch_size(
train_func,
cfg,
train_model,
model,
datasets,
isinstance(self, NNCFBaseTask), # nncf needs eval hooks
cfg,
cfg.distributed,
is_nncf,
meta=meta,
not_increase=(self._hyperparams.learning_parameters.auto_adapt_batch_size == BatchSizeAdaptType.SAFE),
model_builder=getattr(self, "model_builder") if is_nncf else None,
)

if cfg.distributed:
convert_sync_batchnorm(model)

train_model(
model,
datasets,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@

from otx.algorithms.classification.adapters.mmcls.utils.builder import build_classifier
from otx.algorithms.common.adapters.mmcv.tasks.exporter import Exporter
from otx.algorithms.common.adapters.mmdeploy.utils.utils import (
sync_batchnorm_2_batchnorm,
)
from otx.algorithms.common.adapters.torch.utils import sync_batchnorm_2_batchnorm
from otx.utils.logger import get_logger

logger = get_logger()
Expand Down
264 changes: 214 additions & 50 deletions src/otx/algorithms/common/adapters/mmcv/utils/automatic_bs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,26 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from copy import deepcopy
import inspect
from copy import copy
from importlib import import_module
from math import sqrt
from typing import Callable, Dict, List
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import torch
from mmcv import Config
from mmcv.runner import wrap_fp16_model
from torch import distributed as dist
from torch.cuda import is_available as cuda_available
from torch.utils.data import Dataset

from otx.algorithms.common.adapters.torch.utils import BsSearchAlgo
from otx.algorithms.common.adapters.mmcv.utils.config_utils import OTXConfig
from otx.algorithms.common.adapters.torch.utils import BsSearchAlgo, sync_batchnorm_2_batchnorm
from otx.algorithms.common.utils import is_xpu_available
from otx.core.data import caching
from otx.utils.logger import get_logger

logger = get_logger()
Expand Down Expand Up @@ -38,7 +50,142 @@ def _set_value_at_dict_in_dict(target: Dict, key_path: str, value):
target[keys[-1]] = value


def adapt_batch_size(train_func: Callable, cfg, datasets: List, validate: bool = False, not_increase: bool = True):
def _build_model(model_builder: Callable, cfg: Config) -> torch.nn.Module:
model = model_builder(cfg)
if cfg.get("fp16", False):
wrap_fp16_model(model)
return model


NNCF_PATCH_MODULE = {
"mmcls": "otx.algorithms.classification.adapters.mmcls.nncf.patches",
"mmdet": "otx.algorithms.detection.adapters.mmdet.nncf.patches",
"mmseg": "otx.algorithms.segmentation.adapters.mmseg.nncf.patches",
}


def _train_func_single_iter(
batch_size: int,
train_func: Callable,
datasets: List[Dataset],
cfg: OTXConfig,
is_nncf: bool = False,
meta: Optional[Dict[str, Any]] = None,
model: Optional[torch.nn.Module] = None,
model_builder: Optional[Callable] = None,
) -> None:
caching.MemCacheHandlerSingleton.create("null", 0) # initialize mem cache
_set_batch_size(cfg, batch_size)
_set_max_epoch(cfg, 1) # setup for training a single iter to save time

new_dataset = [SubDataset(datasets[0], batch_size)]

validate = is_nncf # nncf needs eval hooks
if is_nncf:
pkg_name = inspect.getmodule(train_func).__package__
for framework in ["mmcls", "mmdet", "mmseg"]:
if framework in pkg_name:
import_module(NNCF_PATCH_MODULE[framework])
break
else:
framework = None

if framework == "mmcls":
validate = False # classification task has own custom eval hook

if model is None:
model = _build_model(model_builder, cfg)

if is_nncf:
model.nncf._uncompressed_model_accuracy = 0

sync_batchnorm_2_batchnorm(model)

train_func(
model=model,
dataset=new_dataset,
cfg=cfg,
distributed=False,
validate=validate,
meta=meta,
)


def _save_nncf_model_weight(model: torch.nn.Module, cfg: OTXConfig, save_path: Path) -> str:
"""Save nncf model weight after nncf finishes to build a model.
NNCF analyzes and get some statistics when buliding a model, which is time consuming.
To skip this part, nncf model weight is saved and load it on new process.
"""
from otx.algorithms.common.adapters.nncf.compression import NNCFMetaState

file_path = save_path / "nncf_model.pth"
for custom_hook in cfg.custom_hooks:
if custom_hook["type"] == "CompressionHook":
compression_ctrl = custom_hook["compression_ctrl"].get_compression_state()
break
else:
msg = "CompressionHook doesn't exist in custom hooks."
raise RuntimeError(msg)

torch.save(
{
"state_dict": model.state_dict(),
"meta": {
"nncf_meta": NNCFMetaState(
state_to_build=cfg.runner.nncf_meta.state_to_build,
data_to_build=cfg.runner.nncf_meta.data_to_build,
compression_ctrl=compression_ctrl,
),
"nncf_enable_compression": True,
},
},
file_path,
)

return str(file_path)


def _organize_custom_hooks(custom_hooks: List, is_nncf: bool = False) -> None:
# Remove hooks due to reasons below
# for nncf task
# OTXProgressHook and CompressionHook are added when building a model. Need to remove them to avoid duplication.
# for normal task
# OTXProgressHook => prevent progress bar from being 0 and 100 repeatably
# CancelInterfaceHook => avoid segmentation fault
# earlystoppinghook => if eval hook is excluded, this hook makes an error due to absence of score history
# CustomEvalHook => exclude validation in classification task

if is_nncf:
hooks_to_remove = ["OTXProgressHook", "CompressionHook"]
else:
hooks_to_remove = ["OTXProgressHook", "earlystoppinghook", "CustomEvalHook", "CancelInterfaceHook"]

idx_hooks_to_remove = []
for i, hook in enumerate(custom_hooks):
if not is_nncf and hook["type"] == "AdaptiveTrainSchedulingHook":
hook["enable_eval_before_run"] = False
for hook_to_remove in hooks_to_remove:
if hook_to_remove.lower() in hook["type"].lower():
idx_hooks_to_remove.append(i)

if idx_hooks_to_remove:
idx_hooks_to_remove.sort()
for i in reversed(idx_hooks_to_remove):
custom_hooks.pop(i)


def adapt_batch_size(
train_func: Callable,
model: torch.nn.Module,
datasets: List[Dataset],
cfg: OTXConfig,
distributed: bool = False,
is_nncf: bool = False,
meta: Optional[Dict[str, Any]] = None,
not_increase: bool = True,
model_builder: Optional[Callable] = None,
) -> None:
"""Decrease batch size if default batch size isn't fit to current GPU device.
This function just setup for single iteration training to reduce time for adapting.
Expand All @@ -47,59 +194,69 @@ def adapt_batch_size(train_func: Callable, cfg, datasets: List, validate: bool =
Args:
train_func (Callable): The function to train a model.
Only cfg, dataset and meta are passed to the function when invoking it.
cfg: Configuration of a training.
meta (Dict): A dict records some meta information of a training.
model (torch.nn.Module): Model to train.
datasets (List): List of datasets.
validate (bool): Whether do vlidation or not.
cfg (OTXConfig): Configuration of a training.
distributed (bool): whether distributed training or not.
is_nncf (bool): Whether nncf or not.
meta (Optional[Dict[str, Any]]): meta information.
not_increase (bool) : Whether adapting batch size to larger value than default value or not.
model_builder (Optional[Callable]):
Function for building a model. If it exsits, a model build from model_builder is used instead of the model
in the argument. It's required for nncf because nncf changes model , which prevent model from pickling.
"""

if not cuda_available():
logger.warning("Skip Auto-adaptive batch size: CUDA should be available, but it isn't.")
if not (cuda_available() or is_xpu_available()):
logger.warning("Skip Auto-adaptive batch size: Adaptive batch size supports CUDA or XPU.")
return

def train_func_single_iter(batch_size):
copied_cfg = deepcopy(cfg)
_set_batch_size(copied_cfg, batch_size)
_set_max_epoch(copied_cfg, 1) # setup for training a single iter to reduce time

# Remove hooks due to reasons below
# OTXProgressHook => prevent progress bar from being 0 and 100 repeatably
# earlystoppinghook => if eval hook is excluded, this hook makes an error due to absence of score history
# CustomEvalHook => exclude validation in classification task
idx_hooks_to_remove = []
hooks_to_remove = ["OTXProgressHook", "earlystoppinghook", "CustomEvalHook"]
for i, hook in enumerate(copied_cfg.custom_hooks):
if not validate and hook["type"] == "AdaptiveTrainSchedulingHook":
hook["enable_eval_before_run"] = False
for hook_to_remove in hooks_to_remove:
if hook_to_remove.lower() in hook["type"].lower():
idx_hooks_to_remove.append(i)

if idx_hooks_to_remove:
idx_hooks_to_remove.sort()
for i in reversed(idx_hooks_to_remove):
del copied_cfg.custom_hooks[i]

new_datasets = [SubDataset(datasets[0], batch_size)]

train_func(
dataset=new_datasets,
cfg=copied_cfg,
validate=validate,
)
copied_cfg = copy(cfg)
copied_cfg.custom_hooks = copy(cfg.custom_hooks)
copied_cfg.pop("algo_backend", None)

if is_nncf:
if model_builder is None:
msg = "model_builder should be possed for building a nncf model."
raise RuntimeError(msg)
temp_dir = TemporaryDirectory("adaptive-bs")
copied_cfg.load_from = _save_nncf_model_weight(model, cfg, Path(temp_dir.name))

_organize_custom_hooks(copied_cfg.custom_hooks, is_nncf)

default_bs = _get_batch_size(cfg)
bs_search_algo = BsSearchAlgo(
train_func=train_func_single_iter,
default_bs=default_bs,
max_bs=len(datasets[0]),
)
if not_increase:
new_batch_size = bs_search_algo.auto_decrease_batch_size()
else:
drop_last = cfg.data.get("train_dataloader", {}).get("drop_last", False)
new_batch_size = bs_search_algo.find_big_enough_batch_size(drop_last)
if not distributed or (rank := dist.get_rank()) == 0:
train_func_kwargs = {
"train_func": train_func,
"datasets": datasets,
"cfg": copied_cfg,
"is_nncf": is_nncf,
"meta": meta,
}
if model_builder is None:
train_func_kwargs["model"] = model
else:
train_func_kwargs["model_builder"] = model_builder

bs_search_algo = BsSearchAlgo(
train_func=_train_func_single_iter,
train_func_kwargs=train_func_kwargs,
default_bs=default_bs,
max_bs=len(datasets[0]),
)
if not_increase:
new_batch_size = bs_search_algo.auto_decrease_batch_size()
else:
drop_last = cfg.data.get("train_dataloader", {}).get("drop_last", False)
new_batch_size = bs_search_algo.find_big_enough_batch_size(drop_last)

if distributed:
if rank == 0:
total_try_result = torch.tensor([new_batch_size], dtype=torch.int)
else:
total_try_result = torch.empty(1, dtype=torch.int)
total_try_result = total_try_result.cuda() if torch.cuda.is_available() else total_try_result.xpu()
dist.broadcast(total_try_result, src=0)
new_batch_size = total_try_result[0].item()

if default_bs != new_batch_size:
_set_batch_size(cfg, new_batch_size)
Expand Down Expand Up @@ -158,11 +315,18 @@ def __init__(self, fullset, num_samples: int):

self.fullset = fullset
self.num_samples = num_samples
self.img_indices = { # for class incremental case
self._img_indices = { # for class incremental case
"old": [i for i in range(num_samples // 2)],
"new": [i for i in range(num_samples // 2, num_samples)],
}

@property
def img_indices(self):
"""img_indices getter."""
img_indices = copy(getattr(self.fullset, "img_indices", {}))
img_indices.update(self._img_indices)
return img_indices

def __len__(self) -> int:
"""Get length of subset."""
return self.num_samples
Expand Down
Loading

0 comments on commit 6e33590

Please sign in to comment.