From 1f7ddf257108087cd8bed3bcb25d9c5f13b32172 Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Wed, 20 Nov 2024 06:27:10 +0000 Subject: [PATCH] upgrade inference model --- paddlex/inference/__init__.py | 7 +- paddlex/inference/new_models/__init__.py | 108 ++++ paddlex/inference/new_models/base/__init__.py | 19 + .../new_models/base/batch_sampler/__init__.py | 16 + .../base/batch_sampler/batch_data.py | 38 ++ .../base/batch_sampler/batch_sampler.py | 48 ++ .../inference/new_models/base/component.py | 127 ++++ .../new_models/base/predictor/__init__.py | 16 + .../base/predictor/base_predictor.py | 71 +++ .../predictor/basic_predictor/__init__.py | 15 + .../predictor/basic_predictor/predictor.py | 124 ++++ .../basic_predictor/result_packager.py | 37 ++ .../basic_predictor/transformer_engine.py | 45 ++ .../new_models/base/result/__init__.py | 16 + .../inference/new_models/base/result/mixin.py | 215 +++++++ .../new_models/base/result/result.py | 47 ++ .../inference/new_models/base/transformer.py | 34 ++ .../inference/new_models/common/__init__.py | 13 + .../inference/new_models/common/batchable.py | 34 ++ .../common/cv_components/__init__.py | 16 + .../common/cv_components/batch_sampler.py | 73 +++ .../cv_components/transformers/__init__.py | 15 + .../cv_components/transformers/funcs.py | 58 ++ .../transformers/transformers.py | 544 ++++++++++++++++++ .../common/paddle_predictor/__init__.py | 18 + .../common/paddle_predictor/predictor.py | 253 ++++++++ .../image_classification/__init__.py | 15 + .../image_classification/predictor.py | 116 ++++ .../new_models/image_classification/result.py | 82 +++ .../image_classification/transformers.py | 64 +++ .../new_models/text_detection/__init__.py | 15 + .../new_models/text_detection/predictor.py | 121 ++++ .../new_models/text_detection/result.py | 35 ++ .../new_models/text_detection/transformers.py | 416 ++++++++++++++ .../new_models/text_recognition/__init__.py | 15 + .../new_models/text_recognition/predictor.py | 90 +++ .../new_models/text_recognition/result.py | 65 +++ .../text_recognition/transformers.py | 196 +++++++ paddlex/inference/utils/benchmark.py | 9 +- paddlex/utils/flags.py | 3 + paddlex/utils/logging.py | 3 +- 41 files changed, 3246 insertions(+), 6 deletions(-) create mode 100644 paddlex/inference/new_models/__init__.py create mode 100644 paddlex/inference/new_models/base/__init__.py create mode 100644 paddlex/inference/new_models/base/batch_sampler/__init__.py create mode 100644 paddlex/inference/new_models/base/batch_sampler/batch_data.py create mode 100644 paddlex/inference/new_models/base/batch_sampler/batch_sampler.py create mode 100644 paddlex/inference/new_models/base/component.py create mode 100644 paddlex/inference/new_models/base/predictor/__init__.py create mode 100644 paddlex/inference/new_models/base/predictor/base_predictor.py create mode 100644 paddlex/inference/new_models/base/predictor/basic_predictor/__init__.py create mode 100644 paddlex/inference/new_models/base/predictor/basic_predictor/predictor.py create mode 100644 paddlex/inference/new_models/base/predictor/basic_predictor/result_packager.py create mode 100644 paddlex/inference/new_models/base/predictor/basic_predictor/transformer_engine.py create mode 100644 paddlex/inference/new_models/base/result/__init__.py create mode 100644 paddlex/inference/new_models/base/result/mixin.py create mode 100644 paddlex/inference/new_models/base/result/result.py create mode 100644 paddlex/inference/new_models/base/transformer.py create mode 100644 paddlex/inference/new_models/common/__init__.py create mode 100644 paddlex/inference/new_models/common/batchable.py create mode 100644 paddlex/inference/new_models/common/cv_components/__init__.py create mode 100644 paddlex/inference/new_models/common/cv_components/batch_sampler.py create mode 100644 paddlex/inference/new_models/common/cv_components/transformers/__init__.py create mode 100644 paddlex/inference/new_models/common/cv_components/transformers/funcs.py create mode 100644 paddlex/inference/new_models/common/cv_components/transformers/transformers.py create mode 100644 paddlex/inference/new_models/common/paddle_predictor/__init__.py create mode 100644 paddlex/inference/new_models/common/paddle_predictor/predictor.py create mode 100644 paddlex/inference/new_models/image_classification/__init__.py create mode 100644 paddlex/inference/new_models/image_classification/predictor.py create mode 100644 paddlex/inference/new_models/image_classification/result.py create mode 100644 paddlex/inference/new_models/image_classification/transformers.py create mode 100644 paddlex/inference/new_models/text_detection/__init__.py create mode 100644 paddlex/inference/new_models/text_detection/predictor.py create mode 100644 paddlex/inference/new_models/text_detection/result.py create mode 100644 paddlex/inference/new_models/text_detection/transformers.py create mode 100644 paddlex/inference/new_models/text_recognition/__init__.py create mode 100644 paddlex/inference/new_models/text_recognition/predictor.py create mode 100644 paddlex/inference/new_models/text_recognition/result.py create mode 100644 paddlex/inference/new_models/text_recognition/transformers.py diff --git a/paddlex/inference/__init__.py b/paddlex/inference/__init__.py index e02c9eebd..7233bdbcf 100644 --- a/paddlex/inference/__init__.py +++ b/paddlex/inference/__init__.py @@ -12,6 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .models import create_predictor +from ..utils.flags import NEW_PREDICTOR + +if NEW_PREDICTOR: + from .new_models import create_predictor +else: + from .models import create_predictor from .pipelines import create_pipeline from .utils.pp_option import PaddlePredictorOption diff --git a/paddlex/inference/new_models/__init__.py b/paddlex/inference/new_models/__init__.py new file mode 100644 index 000000000..40ca79273 --- /dev/null +++ b/paddlex/inference/new_models/__init__.py @@ -0,0 +1,108 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pathlib import Path +from typing import Any, Dict, Optional + +from ...utils import errors +from ..utils.official_models import official_models +from .base import BasePredictor, BasicPredictor + +from .image_classification import ClasPredictor +from .text_detection import TextDetPredictor +from .text_recognition import TextRecPredictor + +# from .table_recognition import TablePredictor +# from .object_detection import DetPredictor +# from .instance_segmentation import InstanceSegPredictor +# from .semantic_segmentation import SegPredictor +# from .general_recognition import ShiTuRecPredictor +# from .ts_fc import TSFcPredictor +# from .ts_ad import TSAdPredictor +# from .ts_cls import TSClsPredictor +# from .image_unwarping import WarpPredictor +# from .multilabel_classification import MLClasPredictor +# from .anomaly_detection import UadPredictor +# from .formula_recognition import LaTeXOCRPredictor +# from .face_recognition import FaceRecPredictor + + +def _create_hp_predictor( + model_name, model_dir, device, config, hpi_params, *args, **kwargs +): + try: + from paddlex_hpi.models import HPPredictor + except ModuleNotFoundError: + raise RuntimeError( + "The PaddleX HPI plugin is not properly installed, and the high-performance model inference features are not available." + ) from None + try: + predictor = HPPredictor.get(model_name)( + model_dir=model_dir, + config=config, + device=device, + *args, + hpi_params=hpi_params, + **kwargs, + ) + except errors.others.ClassNotFoundException: + raise ValueError( + f"{model_name} is not supported by the PaddleX HPI plugin." + ) from None + return predictor + + +def create_predictor( + model: str, + device=None, + pp_option=None, + use_hpip: bool = False, + hpi_params: Optional[Dict[str, Any]] = None, + *args, + **kwargs, +) -> BasePredictor: + model_dir = check_model(model) + config = BasePredictor.load_config(model_dir) + model_name = config["Global"]["model_name"] + if use_hpip: + return _create_hp_predictor( + model_name=model_name, + model_dir=model_dir, + config=config, + hpi_params=hpi_params, + device=device, + *args, + **kwargs, + ) + else: + return BasicPredictor.get(model_name)( + model_dir=model_dir, + config=config, + device=device, + pp_option=pp_option, + *args, + **kwargs, + ) + + +def check_model(model): + if Path(model).exists(): + return Path(model) + elif model in official_models: + return official_models[model] + else: + raise Exception( + f"The model ({model}) is no exists! Please using directory of local model files or model name supported by PaddleX!" + ) diff --git a/paddlex/inference/new_models/base/__init__.py b/paddlex/inference/new_models/base/__init__.py new file mode 100644 index 000000000..791b2d80b --- /dev/null +++ b/paddlex/inference/new_models/base/__init__.py @@ -0,0 +1,19 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .component import BaseComponent +from .transformer import BaseTransformer +from .predictor import BasePredictor, BasicPredictor +from .result import BaseResult, CVResult +from .batch_sampler import BaseBatchSampler, BatchData diff --git a/paddlex/inference/new_models/base/batch_sampler/__init__.py b/paddlex/inference/new_models/base/batch_sampler/__init__.py new file mode 100644 index 000000000..fcdcc65c2 --- /dev/null +++ b/paddlex/inference/new_models/base/batch_sampler/__init__.py @@ -0,0 +1,16 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .batch_sampler import BaseBatchSampler +from .batch_data import BatchData diff --git a/paddlex/inference/new_models/base/batch_sampler/batch_data.py b/paddlex/inference/new_models/base/batch_sampler/batch_data.py new file mode 100644 index 000000000..64975b56a --- /dev/null +++ b/paddlex/inference/new_models/base/batch_sampler/batch_data.py @@ -0,0 +1,38 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class BatchData(object): + + def __init__(self, data, num): + self._data = data + self._num = num + + @property + def num(self): + return self._num + + def get_by_key(self, key): + assert key in self._data, f"{key}, {list(self._data.keys())}" + return self._data[key] + + def get_by_idx(self, idx): + assert idx <= self.num + return {k: v[idx] for k, v in self._data.items()} + + def update_by_key(self, output): + assert isinstance(output, dict) + for k, v in output.items(): + assert isinstance(v, list) and len(v) == self.num + self._data[k] = v diff --git a/paddlex/inference/new_models/base/batch_sampler/batch_sampler.py b/paddlex/inference/new_models/base/batch_sampler/batch_sampler.py new file mode 100644 index 000000000..ed58626af --- /dev/null +++ b/paddlex/inference/new_models/base/batch_sampler/batch_sampler.py @@ -0,0 +1,48 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +from ..component import BaseComponent +from .batch_data import BatchData + + +class BaseBatchSampler(BaseComponent): + + def __init__(self, batch_size=1): + self._batch_size = batch_size + super().__init__() + + @property + def batch_size(self): + return self._batch_size + + @batch_size.setter + def batch_size(self, bs): + assert bs > 0 + self._batch_size = bs + + def __call__(self, *args, **kwargs): + for batch, num in self.apply(*args, **kwargs): + yield BatchData( + {f"{self.name}.{k}": batch[0] for k in self.OUTPUT_KEYS}, num + ) + + @abstractmethod + def apply(self, *args, **kwargs): + raise NotImplementedError + + # def set_outputs(self, outputs): + # assert isinstance(outputs, dict) + # self.outputs = outputs diff --git a/paddlex/inference/new_models/base/component.py b/paddlex/inference/new_models/base/component.py new file mode 100644 index 000000000..b33878f00 --- /dev/null +++ b/paddlex/inference/new_models/base/component.py @@ -0,0 +1,127 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from abc import ABC, abstractmethod +from types import GeneratorType + +from ....utils.flags import INFER_BENCHMARK +from ....utils import logging +from ...utils.benchmark import Timer + + +class Param: + def __init__(self, name, cmpt): + self.name = name + self.cmpt = cmpt + + def __repr__(self): + return f"{self.cmpt.name}.{self.name}" + + +class OutParam(Param): + pass + + +class InParam(Param): + def fetch(self, param: OutParam): + self.cmpt.set_dep(self, param) + + +class InOuts: + def __getattr__(self, key): + if key in self._keys: + return self._keys.get(key) + raise AttributeError( + f"'{self._cmpt.name}.{self.__class__.__name__}' object has no attribute '{key}'" + ) + + def __repr__(self): + _str = "" + for key in self._keys: + param = self._keys[key] + _str += f"{param.cmpt.name}.{param.name}" + return _str + + +class Outputs(InOuts): + def __init__(self, cmpt): + self._cmpt = cmpt + self._keys = {} + if cmpt.OUTPUT_KEYS: + for key in cmpt.OUTPUT_KEYS: + self._keys[key] = OutParam(key, cmpt) + + +class Inputs(InOuts): + def __init__(self, cmpt): + self._cmpt = cmpt + self._keys = {} + if cmpt.INPUT_KEYS: + for key in cmpt.INPUT_KEYS: + self._keys[key] = InParam(key, cmpt) + + def __iter__(self): + for in_param, out_param in self._cmpt.dependencies: + out_param_str = f"{out_param.cmpt.name}.{out_param.name}" + yield f"{in_param.name}", f"{out_param_str}" + + # def __repr__(self): + # _str = "" + # for in_param, out_param in self._dependencies: + # out_param_str = f"{out_param.cmpt.name}.{out_param.name}" + # _str += f"{in_param.cmpt.name}.{in_param.name}: {out_param_str}\t" + # return _str + + +# class Dependencies: +# def __init__(self): +# pass + +# def add(self, ) + + +class BaseComponent(ABC): + + INPUT_KEYS = None + OUTPUT_KEYS = None + + def __init__(self): + self.name = getattr(self, "NAME", self.__class__.__name__) + self.inputs = Inputs(self) + self.outputs = Outputs(self) + self.dependencies = [] + + if INFER_BENCHMARK: + self.timer = Timer() + self.apply = self.timer.watch_func(self.apply) + + def set_dep(self, in_param, out_param): + self.dependencies.append((in_param, out_param)) + + @classmethod + def get_input_keys(cls) -> list: + return cls.INPUT_KEYS + + @classmethod + def get_output_keys(cls) -> list: + return cls.OUTPUT_KEYS + + @abstractmethod + def __call__(self, batch_data): + raise NotImplementedError + + @abstractmethod + def apply(self, input): + raise NotImplementedError diff --git a/paddlex/inference/new_models/base/predictor/__init__.py b/paddlex/inference/new_models/base/predictor/__init__.py new file mode 100644 index 000000000..ffb313c09 --- /dev/null +++ b/paddlex/inference/new_models/base/predictor/__init__.py @@ -0,0 +1,16 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base_predictor import BasePredictor +from .basic_predictor import BasicPredictor diff --git a/paddlex/inference/new_models/base/predictor/base_predictor.py b/paddlex/inference/new_models/base/predictor/base_predictor.py new file mode 100644 index 000000000..7e026121f --- /dev/null +++ b/paddlex/inference/new_models/base/predictor/base_predictor.py @@ -0,0 +1,71 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from abc import abstractmethod, ABC + +from ....utils.io import YAMLReader + + +class BasePredictor(ABC): + + MODEL_FILE_PREFIX = "inference" + + def __init__(self, model_dir, config=None): + super().__init__() + self.model_dir = Path(model_dir) + self.config = config if config else self.load_config(self.model_dir) + + # alias predict() to the __call__() + self.predict = self.__call__ + self.pkg_res = True + self.benchmark = None + + @property + def config_path(self): + return self.get_config_path(self.model_dir) + + @property + def model_name(self) -> str: + return self.config["Global"]["model_name"] + + @classmethod + def get_config_path(cls, model_dir): + return model_dir / f"{cls.MODEL_FILE_PREFIX}.yml" + + @classmethod + def load_config(cls, model_dir): + yaml_reader = YAMLReader() + return yaml_reader.read(cls.get_config_path(model_dir)) + + @property + def package_result(self): + return self._pkg_res + + @package_result.setter + def package_result(self, pkg_res): + assert isinstance(pkg_res, bool) + self._pkg_res = pkg_res + + @abstractmethod + def __call__(self, input, **kwargs): + raise NotImplementedError + + @abstractmethod + def apply(self, input): + raise NotImplementedError + + @abstractmethod + def set_predictor(self): + raise NotImplementedError diff --git a/paddlex/inference/new_models/base/predictor/basic_predictor/__init__.py b/paddlex/inference/new_models/base/predictor/basic_predictor/__init__.py new file mode 100644 index 000000000..6556b65a4 --- /dev/null +++ b/paddlex/inference/new_models/base/predictor/basic_predictor/__init__.py @@ -0,0 +1,15 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .predictor import BasicPredictor diff --git a/paddlex/inference/new_models/base/predictor/basic_predictor/predictor.py b/paddlex/inference/new_models/base/predictor/basic_predictor/predictor.py new file mode 100644 index 000000000..303e25244 --- /dev/null +++ b/paddlex/inference/new_models/base/predictor/basic_predictor/predictor.py @@ -0,0 +1,124 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod + +from ......utils.subclass_register import AutoRegisterABCMetaClass +from ......utils.flags import ( + INFER_BENCHMARK, + INFER_BENCHMARK_WARMUP, +) +from ......utils import logging +from .....utils.pp_option import PaddlePredictorOption +from .....utils.benchmark import Benchmark +from ..base_predictor import BasePredictor +from .transformer_engine import TransformerEngine +from .result_packager import ResultPackager + + +class BasicPredictor( + BasePredictor, + metaclass=AutoRegisterABCMetaClass, +): + + __is_base = True + + def __init__(self, model_dir, config=None, device=None, pp_option=None): + super().__init__(model_dir=model_dir, config=config) + if not pp_option: + pp_option = PaddlePredictorOption(model_name=self.model_name) + if device: + pp_option.device = device + self.pp_option = pp_option + + self.batch_sampler = self._build_batch_sampler() + self.result_packager = self._build_result_packager() + self.transformers = {} + self._build_transformers() + self._set_dataflow() + self.engine = TransformerEngine(self.transformers) + self.rtn_res = True + logging.debug(f"{self.__class__.__name__}: {self.model_dir}") + + if INFER_BENCHMARK: + self.benchmark = Benchmark(self.transformers) + + def __call__(self, input, **kwargs): + self.set_predictor(**kwargs) + if self.benchmark: + self.benchmark.start() + if INFER_BENCHMARK_WARMUP > 0: + output = self.apply(input) + warmup_num = 0 + for _ in range(INFER_BENCHMARK_WARMUP): + try: + next(output) + warmup_num += 1 + except StopIteration: + logging.warning( + f"There are only {warmup_num} batches in input data, but `INFER_BENCHMARK_WARMUP` has been set to {INFER_BENCHMARK_WARMUP}." + ) + break + self.benchmark.warmup_stop(warmup_num) + output = list(self.apply(input)) + self.benchmark.collect(len(output)) + else: + yield from self.apply(input) + + def apply(self, input): + """predict""" + for batch in self.batch_sampler(input): + batch_data = self.engine(batch) + if self.rtn_res: + yield from self.result_packager(batch_data) + else: + yield batch_data + + def set_predictor(self, batch_size=None, device=None, pp_option=None): + if batch_size: + self.batch_sampler.batch_size = batch_size + self.pp_option.batch_size = batch_size + if device and device != self.pp_option.device: + self.pp_option.device = device + if pp_option and pp_option != self.pp_option: + self.pp_option = pp_option + + def _add_transformer(self, cmp): + self.transformers[cmp.name] = cmp + + def __getattr__(self, cmp): + if cmp in self.transformers: + return self.transformers.get(cmp) + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{cmp}'" + ) + + def _build_result_packager(self): + return ResultPackager(self._get_result_class()) + + @abstractmethod + def _build_batch_sampler(self): + raise NotImplementedError + + @abstractmethod + def _get_result_class(self): + raise NotImplementedError + + @abstractmethod + def _build_transformers(self): + raise NotImplementedError + + @abstractmethod + def _set_dataflow(self): + raise NotImplementedError diff --git a/paddlex/inference/new_models/base/predictor/basic_predictor/result_packager.py b/paddlex/inference/new_models/base/predictor/basic_predictor/result_packager.py new file mode 100644 index 000000000..e7de66052 --- /dev/null +++ b/paddlex/inference/new_models/base/predictor/basic_predictor/result_packager.py @@ -0,0 +1,37 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...component import BaseComponent + + +class ResultPackager(BaseComponent): + + OUTPUT_KEYS = None + + def __init__(self, result_class): + self._result_class = result_class + self.INPUT_KEYS = result_class.INPUT_KEYS + super().__init__() + + def __call__(self, batch_data): + yield from self.apply(batch_data) + + def apply(self, batch_data): + for idx in range(batch_data.num): + single = batch_data.get_by_idx(idx) + yield self._result_class({k: single[v] for k, v in self.inputs}) + + def set_inputs(self, inputs): + assert isinstance(inputs, dict) + self.inputs = inputs diff --git a/paddlex/inference/new_models/base/predictor/basic_predictor/transformer_engine.py b/paddlex/inference/new_models/base/predictor/basic_predictor/transformer_engine.py new file mode 100644 index 000000000..30abbaee1 --- /dev/null +++ b/paddlex/inference/new_models/base/predictor/basic_predictor/transformer_engine.py @@ -0,0 +1,45 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import networkx as nx + +from ......utils import logging +from ...batch_sampler import BaseBatchSampler + + +class TransformerEngine(object): + def __init__(self, cmpts): + self._cmpts = cmpts + self.keys = list(cmpts.keys()) + # graph = self._build_graph() + + def _build_graph(self): + graph = nx.DiGraph() + for name in self._cmpts: + cmpt = self._cmpts[name] + graph.add_node(name, transform=cmpt) + for in_param, out_param in cmpt.dependencies: + logging.debug(f"{in_param} <-- {out_param}") + graph.add_edge(out_param.cmpt.name, in_param.cmpt.name) + + execution_order = list(nx.topological_sort(graph)) + logging.debug(f"Execution Order: {execution_order}") + return graph + + def __call__(self, data, i=0): + data = self._cmpts[self.keys[i]](data) + if i + 1 < len(self._cmpts): + return self.__call__(data, i + 1) + else: + return data diff --git a/paddlex/inference/new_models/base/result/__init__.py b/paddlex/inference/new_models/base/result/__init__.py new file mode 100644 index 000000000..78e44efbf --- /dev/null +++ b/paddlex/inference/new_models/base/result/__init__.py @@ -0,0 +1,16 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .result import BaseResult, CVResult +from .mixin import * diff --git a/paddlex/inference/new_models/base/result/mixin.py b/paddlex/inference/new_models/base/result/mixin.py new file mode 100644 index 000000000..b53a1e87b --- /dev/null +++ b/paddlex/inference/new_models/base/result/mixin.py @@ -0,0 +1,215 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +import json +from pathlib import Path +import numpy as np +from PIL import Image +import pandas as pd + +from .....utils import logging +from ....utils.io import ( + JsonWriter, + ImageReader, + ImageWriter, + CSVWriter, + HtmlWriter, + XlsxWriter, + TextWriter, +) + + +__all__ = [ + "StrMixin", + "JsonMixin", + "Base64Mixin", + "ImgMixin", + "CSVMixin", + "HtmlMixin", + "XlsxMixin", +] + + +def _save_list_data(save_func, save_path, data, *args, **kwargs): + save_path = Path(save_path) + if data is None: + return + if isinstance(data, list): + for idx, single in enumerate(data): + save_func( + ( + save_path.parent / f"{save_path.stem}_{idx}{save_path.suffix}" + ).as_posix(), + single, + *args, + **kwargs, + ) + save_func(save_path.as_posix(), data, *args, **kwargs) + logging.info(f"The result has been saved in {save_path}.") + + +class StrMixin: + @property + def str(self): + return self._to_str() + + def _to_str(self, data, json_format=False, indent=4, ensure_ascii=False): + if json_format: + return json.dumps(data.json, indent=indent, ensure_ascii=ensure_ascii) + else: + return str(data) + + def print(self, json_format=False, indent=4, ensure_ascii=False): + str_ = self._to_str( + self, json_format=json_format, indent=indent, ensure_ascii=ensure_ascii + ) + logging.info(str_) + + +class JsonMixin: + def __init__(self): + self._json_writer = JsonWriter() + self._show_funcs.append(self.save_to_json) + + def _to_json(self): + def _format_data(obj): + if isinstance(obj, np.float32): + return float(obj) + elif isinstance(obj, np.ndarray): + return [_format_data(item) for item in obj.tolist()] + elif isinstance(obj, pd.DataFrame): + return obj.to_json(orient="records", force_ascii=False) + elif isinstance(obj, Path): + return obj.as_posix() + elif isinstance(obj, dict): + return type(obj)({k: _format_data(v) for k, v in obj.items()}) + elif isinstance(obj, (list, tuple)): + return [_format_data(i) for i in obj] + else: + return obj + + return _format_data(self) + + @property + def json(self): + return self._to_json() + + def save_to_json(self, save_path, indent=4, ensure_ascii=False, *args, **kwargs): + if not str(save_path).endswith(".json"): + save_path = Path(save_path) / f"{Path(self['input_path']).stem}.json" + _save_list_data( + self._json_writer.write, + save_path, + self.json, + indent=indent, + ensure_ascii=ensure_ascii, + *args, + **kwargs, + ) + + +class Base64Mixin: + def __init__(self, *args, **kwargs): + self._base64_writer = TextWriter(*args, **kwargs) + self._show_funcs.append(self.save_to_base64) + + @abstractmethod + def _to_base64(self): + raise NotImplementedError + + @property + def base64(self): + return self._to_base64() + + def save_to_base64(self, save_path, *args, **kwargs): + if not str(save_path).lower().endswith((".b64")): + fp = Path(self["input_path"]) + save_path = Path(save_path) / f"{fp.stem}{fp.suffix}" + _save_list_data( + self._base64_writer.write, save_path, self.base64, *args, **kwargs + ) + + +class ImgMixin: + def __init__(self, backend="pillow", *args, **kwargs): + self._img_writer = ImageWriter(backend=backend, *args, **kwargs) + self._show_funcs.append(self.save_to_img) + + @abstractmethod + def _to_img(self): + raise NotImplementedError + + @property + def img(self): + image = self._to_img() + # The img must be a PIL.Image obj + if isinstance(image, np.ndarray): + return Image.fromarray(image) + return image + + def save_to_img(self, save_path, *args, **kwargs): + if not str(save_path).lower().endswith((".jpg", ".png")): + fp = Path(self["input_path"]) + save_path = Path(save_path) / f"{fp.stem}{fp.suffix}" + _save_list_data(self._img_writer.write, save_path, self.img, *args, **kwargs) + + +class CSVMixin: + def __init__(self, backend="pandas", *args, **kwargs): + self._csv_writer = CSVWriter(backend=backend, *args, **kwargs) + self._show_funcs.append(self.save_to_csv) + + @abstractmethod + def _to_csv(self): + raise NotImplementedError + + def save_to_csv(self, save_path, *args, **kwargs): + if not str(save_path).endswith(".csv"): + save_path = Path(save_path) / f"{Path(self['input_path']).stem}.csv" + _save_list_data( + self._csv_writer.write, save_path, self._to_csv(), *args, **kwargs + ) + + +class HtmlMixin: + def __init__(self, *args, **kwargs): + self._html_writer = HtmlWriter(*args, **kwargs) + self._show_funcs.append(self.save_to_html) + + @property + def html(self): + return self._to_html() + + def _to_html(self): + return self["html"] + + def save_to_html(self, save_path, *args, **kwargs): + if not str(save_path).endswith(".html"): + save_path = Path(save_path) / f"{Path(self['input_path']).stem}.html" + _save_list_data(self._html_writer.write, save_path, self.html, *args, **kwargs) + + +class XlsxMixin: + def __init__(self, *args, **kwargs): + self._xlsx_writer = XlsxWriter(*args, **kwargs) + self._show_funcs.append(self.save_to_xlsx) + + def _to_xlsx(self): + return self["html"] + + def save_to_xlsx(self, save_path, *args, **kwargs): + if not str(save_path).endswith(".xlsx"): + save_path = Path(save_path) / f"{Path(self['input_path']).stem}.xlsx" + _save_list_data(self._xlsx_writer.write, save_path, self.html, *args, **kwargs) diff --git a/paddlex/inference/new_models/base/result/result.py b/paddlex/inference/new_models/base/result/result.py new file mode 100644 index 000000000..fa9e3a343 --- /dev/null +++ b/paddlex/inference/new_models/base/result/result.py @@ -0,0 +1,47 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +from .....utils.func_register import FuncRegister +from ....utils.io import ImageReader, ImageWriter +from .mixin import JsonMixin, ImgMixin, StrMixin + + +class BaseResult(dict, StrMixin, JsonMixin): + def __init__(self, data): + super().__init__(data) + self._show_funcs = [] + StrMixin.__init__(self) + JsonMixin.__init__(self) + + def save_all(self, save_path): + for func in self._show_funcs: + signature = inspect.signature(func) + if "save_path" in signature.parameters: + func(save_path=save_path) + else: + func() + + +class CVResult(BaseResult, ImgMixin): + + INPUT_KEYS = ["input_img"] + + def __init__(self, data): + assert set(CVResult.INPUT_KEYS).issubset(self.INPUT_KEYS) + self._input_img = data.pop["input_img"] + super().__init__(data) + ImgMixin.__init__(self, "pillow") + self._img_writer = ImageWriter(backend="pillow") diff --git a/paddlex/inference/new_models/base/transformer.py b/paddlex/inference/new_models/base/transformer.py new file mode 100644 index 000000000..e4438c565 --- /dev/null +++ b/paddlex/inference/new_models/base/transformer.py @@ -0,0 +1,34 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod + +from ....utils import logging +from .component import BaseComponent + + +class BaseTransformer(BaseComponent): + + def __call__(self, batch_data): + logging.debug(f"Call apply() func...") + kwargs = {k: batch_data.get_by_key(v) for k, v in self.inputs} + output = self.apply(**kwargs) + if not output: + return batch_data + batch_data.update_by_key({f"{self.name}.{key}": output[key] for key in output}) + return batch_data + + @abstractmethod + def apply(self, input): + raise NotImplementedError diff --git a/paddlex/inference/new_models/common/__init__.py b/paddlex/inference/new_models/common/__init__.py new file mode 100644 index 000000000..59372f937 --- /dev/null +++ b/paddlex/inference/new_models/common/__init__.py @@ -0,0 +1,13 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlex/inference/new_models/common/batchable.py b/paddlex/inference/new_models/common/batchable.py new file mode 100644 index 000000000..b9a9250de --- /dev/null +++ b/paddlex/inference/new_models/common/batchable.py @@ -0,0 +1,34 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools + + +def batchable(func): + @functools.wraps(func) + def wrap(self, **batch_kwargs): + outputs = {} + keys = list(batch_kwargs.keys()) + single_kwargs = [ + dict(zip(keys, values)) for values in zip(*batch_kwargs.values()) + ] + for kwargs in single_kwargs: + single_output = func(self, **kwargs) + for k, v in single_output.items(): + if k not in outputs: + outputs[k] = [] + outputs[k].append(v) + return outputs + + return wrap diff --git a/paddlex/inference/new_models/common/cv_components/__init__.py b/paddlex/inference/new_models/common/cv_components/__init__.py new file mode 100644 index 000000000..83f78d21e --- /dev/null +++ b/paddlex/inference/new_models/common/cv_components/__init__.py @@ -0,0 +1,16 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .transformers import * +from .batch_sampler import ImageBatchSampler diff --git a/paddlex/inference/new_models/common/cv_components/batch_sampler.py b/paddlex/inference/new_models/common/cv_components/batch_sampler.py new file mode 100644 index 000000000..d72fbd4db --- /dev/null +++ b/paddlex/inference/new_models/common/cv_components/batch_sampler.py @@ -0,0 +1,73 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pathlib import Path +import numpy as np + +from .....utils.download import download +from .....utils.cache import CACHE_DIR +from ...base.batch_sampler import BaseBatchSampler + + +class ImageBatchSampler(BaseBatchSampler): + + INPUT_KEYS = None + OUTPUT_KEYS = ["img"] + + SUFFIX = ["jpg", "png", "jpeg", "JPEG", "JPG", "bmp", "PDF", "pdf"] + + # XXX: auto download for url + def _download_from_url(self, in_path): + if in_path.startswith("http"): + file_name = Path(in_path).name + save_path = Path(CACHE_DIR) / "predict_input" / file_name + download(in_path, save_path, overwrite=True) + return save_path.as_posix() + return in_path + + def _get_files_list(self, fp): + file_list = [] + if fp is None or not os.path.exists(fp): + raise Exception(f"Not found any img file in path: {fp}") + + if os.path.isfile(fp) and fp.split(".")[-1] in self.SUFFIX: + file_list.append(fp) + elif os.path.isdir(fp): + for root, dirs, files in os.walk(fp): + for single_file in files: + if single_file.split(".")[-1] in self.SUFFIX: + file_list.append(os.path.join(root, single_file)) + if len(file_list) == 0: + raise Exception("Not found any file in {}".format(fp)) + file_list = sorted(file_list) + return file_list + + def apply(self, inputs): + if not isinstance(inputs, list): + inputs = [inputs] + for input in inputs: + if isinstance(input, np.ndarray): + yield [[input]], 1 + elif isinstance(input, str): + file_path = self._download_from_url(input) + file_list = self._get_files_list(file_path) + batch = [] + for file_path in file_list: + batch.append(file_path) + if len(batch) == self.batch_size: + yield [batch], len(batch) + batch = [] + if len(batch) > 0: + yield [batch], len(batch) diff --git a/paddlex/inference/new_models/common/cv_components/transformers/__init__.py b/paddlex/inference/new_models/common/cv_components/transformers/__init__.py new file mode 100644 index 000000000..4f786cd6f --- /dev/null +++ b/paddlex/inference/new_models/common/cv_components/transformers/__init__.py @@ -0,0 +1,15 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .transformers import * diff --git a/paddlex/inference/new_models/common/cv_components/transformers/funcs.py b/paddlex/inference/new_models/common/cv_components/transformers/funcs.py new file mode 100644 index 000000000..aaed3877d --- /dev/null +++ b/paddlex/inference/new_models/common/cv_components/transformers/funcs.py @@ -0,0 +1,58 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 + + +def resize(im, target_size, interp): + """resize image to target size""" + w, h = target_size + im = cv2.resize(im, (w, h), interpolation=interp) + return im + + +def flip_h(im): + """flip image horizontally""" + if len(im.shape) == 3: + im = im[:, ::-1, :] + elif len(im.shape) == 2: + im = im[:, ::-1] + return im + + +def flip_v(im): + """flip image vertically""" + if len(im.shape) == 3: + im = im[::-1, :, :] + elif len(im.shape) == 2: + im = im[::-1, :] + return im + + +def slice(im, coords): + """slice the image""" + x1, y1, x2, y2 = coords + im = im[y1:y2, x1:x2, ...] + return im + + +def pad(im, pad, val): + """padding image by value""" + if isinstance(pad, int): + pad = [pad] * 4 + if len(pad) != 4: + raise ValueError + chns = 1 if im.ndim == 2 else im.shape[2] + im = cv2.copyMakeBorder(im, *pad, cv2.BORDER_CONSTANT, value=(val,) * chns) + return im diff --git a/paddlex/inference/new_models/common/cv_components/transformers/transformers.py b/paddlex/inference/new_models/common/cv_components/transformers/transformers.py new file mode 100644 index 000000000..945691b64 --- /dev/null +++ b/paddlex/inference/new_models/common/cv_components/transformers/transformers.py @@ -0,0 +1,544 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import ast +import math +from pathlib import Path +from copy import deepcopy + +import numpy as np +import cv2 + +from ......utils.flags import ( + INFER_BENCHMARK, + INFER_BENCHMARK_ITER, + INFER_BENCHMARK_DATA_SIZE, +) +from .....utils.io import ImageReader, PDFReader +from ....base import BaseTransformer +from ...batchable import batchable +from . import funcs as F + + +__all__ = [ + "ReadImage", + "Flip", + "Crop", + "Resize", + "ResizeByLong", + "ResizeByShort", + "Pad", + "Normalize", + "ToCHWImage", + "PadStride", +] + + +def _check_image_size(input_): + """check image size""" + if not ( + isinstance(input_, (list, tuple)) + and len(input_) == 2 + and isinstance(input_[0], int) + and isinstance(input_[1], int) + ): + raise TypeError(f"{input_} cannot represent a valid image size.") + + +class ReadImage(BaseTransformer): + """Load image from the file.""" + + INPUT_KEYS = ["img"] + OUTPUT_KEYS = ["input_path", "img", "img_size"] + + _FLAGS_DICT = { + "BGR": cv2.IMREAD_COLOR, + "RGB": cv2.IMREAD_COLOR, + "GRAY": cv2.IMREAD_GRAYSCALE, + } + + def __init__(self, format="BGR"): + """ + Initialize the instance. + + Args: + format (str, optional): Target color format to convert the image to. + Choices are 'BGR', 'RGB', and 'GRAY'. Default: 'BGR'. + """ + super().__init__() + self.format = format + flags = self._FLAGS_DICT[self.format] + self._img_reader = ImageReader(backend="opencv", flags=flags) + + @batchable + def apply(self, img): + """apply""" + + def rand_data(): + def parse_size(s): + res = ast.literal_eval(s) + if isinstance(res, int): + return (res, res) + else: + assert isinstance(res, (tuple, list)) + assert len(res) == 2 + assert all(isinstance(item, int) for item in res) + return res + + size = parse_size(INFER_BENCHMARK_DATA_SIZE) + return np.random.randint(0, 256, (*size, 3), dtype=np.uint8) + + def process_ndarray(img): + if self.format == "RGB": + img = img[:, :, ::-1] + return { + "input_path": "", + "img": img, + "img_size": [img.shape[1], img.shape[0]], + } + + # if INFER_BENCHMARK and img is None: + # for _ in range(INFER_BENCHMARK_ITER): + # yield [process_ndarray(rand_data()) for _ in range(self.batch_size)] + + if isinstance(img, np.ndarray): + return process_ndarray(img) + elif isinstance(img, str): + return self._read_img(img) + else: + raise TypeError( + f"ReadImage only supports the following types:\n" + f"1. str, indicating a image file path or a directory containing image files.\n" + f"2. numpy.ndarray.\n" + f"However, got type: {type(img).__name__}." + ) + + def _read_img(self, img_path): + blob = self._img_reader.read(img_path) + if blob is None: + raise Exception("Image read Error") + + if self.format == "RGB": + if blob.ndim != 3: + raise RuntimeError("Array is not 3-dimensional.") + # BGR to RGB + blob = blob[..., ::-1] + return { + "input_path": img_path, + "img": blob, + "img_size": [blob.shape[1], blob.shape[0]], + } + + +class GetImageInfo(BaseTransformer): + """Get Image Info""" + + INPUT_KEYS = "img" + OUTPUT_KEYS = "img_size" + DEAULT_INPUTS = {"img": "img"} + DEAULT_OUTPUTS = {"img_size": "img_size"} + + def __init__(self): + super().__init__() + + def apply(self, img): + """apply""" + return {"img_size": [img.shape[1], img.shape[0]]} + + +class Flip(BaseTransformer): + """Flip the image vertically or horizontally.""" + + INPUT_KEYS = "img" + OUTPUT_KEYS = "img" + DEAULT_INPUTS = {"img": "img"} + DEAULT_OUTPUTS = {"img": "img"} + + def __init__(self, mode="H"): + """ + Initialize the instance. + + Args: + mode (str, optional): 'H' for horizontal flipping and 'V' for vertical + flipping. Default: 'H'. + """ + super().__init__() + if mode not in ("H", "V"): + raise ValueError("`mode` should be 'H' or 'V'.") + self.mode = mode + + def apply(self, img): + """apply""" + if self.mode == "H": + img = F.flip_h(img) + elif self.mode == "V": + img = F.flip_v(img) + return {"img": img} + + +class Crop(BaseTransformer): + """Crop region from the image.""" + + INPUT_KEYS = ["img"] + OUTPUT_KEYS = ["img", "img_size"] + + def __init__(self, crop_size, mode="C"): + """ + Initialize the instance. + + Args: + crop_size (list|tuple|int): Width and height of the region to crop. + mode (str, optional): 'C' for cropping the center part and 'TL' for + cropping the top left part. Default: 'C'. + """ + super().__init__() + if isinstance(crop_size, int): + crop_size = [crop_size, crop_size] + _check_image_size(crop_size) + + self.crop_size = crop_size + + if mode not in ("C", "TL"): + raise ValueError("Unsupported interpolation method") + self.mode = mode + + @batchable + def apply(self, img): + """apply""" + h, w = img.shape[:2] + cw, ch = self.crop_size + if self.mode == "C": + x1 = max(0, (w - cw) // 2) + y1 = max(0, (h - ch) // 2) + elif self.mode == "TL": + x1, y1 = 0, 0 + x2 = min(w, x1 + cw) + y2 = min(h, y1 + ch) + coords = (x1, y1, x2, y2) + if coords == (0, 0, w, h): + raise ValueError( + f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})." + ) + img = F.slice(img, coords=coords) + return {"img": img, "img_size": [img.shape[1], img.shape[0]]} + + +class _BaseResize(BaseTransformer): + _INTERP_DICT = { + "NEAREST": cv2.INTER_NEAREST, + "LINEAR": cv2.INTER_LINEAR, + "CUBIC": cv2.INTER_CUBIC, + "AREA": cv2.INTER_AREA, + "LANCZOS4": cv2.INTER_LANCZOS4, + } + + def __init__(self, size_divisor, interp): + super().__init__() + + if size_divisor is not None: + assert isinstance( + size_divisor, int + ), "`size_divisor` should be None or int." + self.size_divisor = size_divisor + + try: + interp = self._INTERP_DICT[interp] + except KeyError: + raise ValueError( + "`interp` should be one of {}.".format(self._INTERP_DICT.keys()) + ) + self.interp = interp + + @staticmethod + def _rescale_size(img_size, target_size): + """rescale size""" + scale = min(max(target_size) / max(img_size), min(target_size) / min(img_size)) + rescaled_size = [round(i * scale) for i in img_size] + return rescaled_size, scale + + +class Resize(_BaseResize): + """Resize the image.""" + + INPUT_KEYS = ["img"] + OUTPUT_KEYS = ["img", "img_size", "scale_factors"] + DEAULT_INPUTS = {"img": "img"} + DEAULT_OUTPUTS = { + "img": "img", + "img_size": "img_size", + "scale_factors": "scale_factors", + } + + def __init__( + self, target_size, keep_ratio=False, size_divisor=None, interp="LINEAR" + ): + """ + Initialize the instance. + + Args: + target_size (list|tuple|int): Target width and height. + keep_ratio (bool, optional): Whether to keep the aspect ratio of resized + image. Default: False. + size_divisor (int|None, optional): Divisor of resized image size. + Default: None. + interp (str, optional): Interpolation method. Choices are 'NEAREST', + 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'. + """ + super().__init__(size_divisor=size_divisor, interp=interp) + + if isinstance(target_size, int): + target_size = [target_size, target_size] + _check_image_size(target_size) + self.target_size = target_size + + self.keep_ratio = keep_ratio + + def apply(self, img): + """apply""" + target_size = self.target_size + original_size = img.shape[:2][::-1] + + if self.keep_ratio: + h, w = img.shape[0:2] + target_size, _ = self._rescale_size((w, h), self.target_size) + + if self.size_divisor: + target_size = [ + math.ceil(i / self.size_divisor) * self.size_divisor + for i in target_size + ] + + img_scale_w, img_scale_h = [ + target_size[0] / original_size[0], + target_size[1] / original_size[1], + ] + img = F.resize(img, target_size, interp=self.interp) + return { + "img": img, + "img_size": [img.shape[1], img.shape[0]], + "scale_factors": [img_scale_w, img_scale_h], + } + + +class ResizeByLong(_BaseResize): + """ + Proportionally resize the image by specifying the target length of the + longest side. + """ + + INPUT_KEYS = ["img"] + OUTPUT_KEYS = ["img", "img_size"] + DEAULT_INPUTS = {"img": "img"} + DEAULT_OUTPUTS = {"img": "img", "img_size": "img_size"} + + def __init__(self, target_long_edge, size_divisor=None, interp="LINEAR"): + """ + Initialize the instance. + + Args: + target_long_edge (int): Target length of the longest side of image. + size_divisor (int|None, optional): Divisor of resized image size. + Default: None. + interp (str, optional): Interpolation method. Choices are 'NEAREST', + 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'. + """ + super().__init__(size_divisor=size_divisor, interp=interp) + self.target_long_edge = target_long_edge + + def apply(self, img): + """apply""" + h, w = img.shape[:2] + scale = self.target_long_edge / max(h, w) + h_resize = round(h * scale) + w_resize = round(w * scale) + if self.size_divisor is not None: + h_resize = math.ceil(h_resize / self.size_divisor) * self.size_divisor + w_resize = math.ceil(w_resize / self.size_divisor) * self.size_divisor + + img = F.resize(img, (w_resize, h_resize), interp=self.interp) + return {"img": img, "img_size": [img.shape[1], img.shape[0]]} + + +class ResizeByShort(_BaseResize): + """ + Proportionally resize the image by specifying the target length of the + shortest side. + """ + + INPUT_KEYS = ["img"] + OUTPUT_KEYS = ["img", "img_size"] + + def __init__(self, target_short_edge, size_divisor=None, interp="LINEAR"): + """ + Initialize the instance. + + Args: + target_short_edge (int): Target length of the shortest side of image. + size_divisor (int|None, optional): Divisor of resized image size. + Default: None. + interp (str, optional): Interpolation method. Choices are 'NEAREST', + 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'. + """ + super().__init__(size_divisor=size_divisor, interp=interp) + self.target_short_edge = target_short_edge + + @batchable + def apply(self, img): + """apply""" + h, w = img.shape[:2] + scale = self.target_short_edge / min(h, w) + h_resize = round(h * scale) + w_resize = round(w * scale) + if self.size_divisor is not None: + h_resize = math.ceil(h_resize / self.size_divisor) * self.size_divisor + w_resize = math.ceil(w_resize / self.size_divisor) * self.size_divisor + + img = F.resize(img, (w_resize, h_resize), interp=self.interp) + return {"img": img, "img_size": [img.shape[1], img.shape[0]]} + + +class Pad(BaseTransformer): + """Pad the image.""" + + INPUT_KEYS = "img" + OUTPUT_KEYS = ["img", "img_size"] + DEAULT_INPUTS = {"img": "img"} + DEAULT_OUTPUTS = {"img": "img", "img_size": "img_size"} + + def __init__(self, target_size, val=127.5): + """ + Initialize the instance. + + Args: + target_size (list|tuple|int): Target width and height of the image after + padding. + val (float, optional): Value to fill the padded area. Default: 127.5. + """ + super().__init__() + + if isinstance(target_size, int): + target_size = [target_size, target_size] + _check_image_size(target_size) + self.target_size = target_size + + self.val = val + + def apply(self, img): + """apply""" + h, w = img.shape[:2] + tw, th = self.target_size + ph = th - h + pw = tw - w + + if ph < 0 or pw < 0: + raise ValueError( + f"Input image ({w}, {h}) smaller than the target size ({tw}, {th})." + ) + else: + img = F.pad(img, pad=(0, ph, 0, pw), val=self.val) + return {"img": img, "img_size": [img.shape[1], img.shape[0]]} + + +class PadStride(BaseTransformer): + """padding image for model with FPN , instead PadBatch(pad_to_stride, pad_gt) in original config + Args: + stride (bool): model with FPN need image shape % stride == 0 + """ + + INPUT_KEYS = "img" + OUTPUT_KEYS = "img" + DEAULT_INPUTS = {"img": "img"} + DEAULT_OUTPUTS = {"img": "img"} + + def __init__(self, stride=0): + super().__init__() + self.coarsest_stride = stride + + def apply(self, img): + """ + Args: + im (np.ndarray): image (np.ndarray) + Returns: + im (np.ndarray): processed image (np.ndarray) + """ + im = img + coarsest_stride = self.coarsest_stride + if coarsest_stride <= 0: + return {"img": im} + im_c, im_h, im_w = im.shape + pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride) + pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride) + padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32) + padding_im[:, :im_h, :im_w] = im + return {"img": padding_im} + + +class Normalize(BaseTransformer): + """Normalize the image.""" + + INPUT_KEYS = ["img"] + OUTPUT_KEYS = ["img"] + + def __init__(self, scale=1.0 / 255, mean=0.5, std=0.5, preserve_dtype=False): + """ + Initialize the instance. + + Args: + scale (float, optional): Scaling factor to apply to the image before + applying normalization. Default: 1/255. + mean (float|tuple|list, optional): Means for each channel of the image. + Default: 0.5. + std (float|tuple|list, optional): Standard deviations for each channel + of the image. Default: 0.5. + preserve_dtype (bool, optional): Whether to preserve the original dtype + of the image. + """ + super().__init__() + + self.scale = np.float32(scale) + if isinstance(mean, float): + mean = [mean] + self.mean = np.asarray(mean).astype("float32") + if isinstance(std, float): + std = [std] + self.std = np.asarray(std).astype("float32") + self.preserve_dtype = preserve_dtype + + @batchable + def apply(self, img): + """apply""" + old_type = img.dtype + # XXX: If `old_type` has higher precision than float32, + # we will lose some precision. + img = img.astype("float32", copy=False) + img *= self.scale + img -= self.mean + img /= self.std + if self.preserve_dtype: + img = img.astype(old_type, copy=False) + return {"img": img} + + +class ToCHWImage(BaseTransformer): + """Reorder the dimensions of the image from HWC to CHW.""" + + INPUT_KEYS = ["img"] + OUTPUT_KEYS = ["img"] + + @batchable + def apply(self, img): + """apply""" + img = img.transpose((2, 0, 1)) + return {"img": img} diff --git a/paddlex/inference/new_models/common/paddle_predictor/__init__.py b/paddlex/inference/new_models/common/paddle_predictor/__init__.py new file mode 100644 index 000000000..2c3d6b108 --- /dev/null +++ b/paddlex/inference/new_models/common/paddle_predictor/__init__.py @@ -0,0 +1,18 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .predictor import ( + BasePaddlePredictor, + ImagePredictor, +) diff --git a/paddlex/inference/new_models/common/paddle_predictor/predictor.py b/paddlex/inference/new_models/common/paddle_predictor/predictor.py new file mode 100644 index 000000000..315482b28 --- /dev/null +++ b/paddlex/inference/new_models/common/paddle_predictor/predictor.py @@ -0,0 +1,253 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import inspect +from abc import abstractmethod +import lazy_paddle as paddle +import numpy as np + +from .....utils.flags import FLAGS_json_format_model +from .....utils import logging +from ....utils.pp_option import PaddlePredictorOption +from ...base import BaseTransformer + + +class Copy2GPU(BaseTransformer): + INPUT_KEYS = None + OUTPUT_KEYS = None + + def __init__(self, input_handlers): + super().__init__() + self.input_handlers = input_handlers + + def apply(self, x): + for idx in range(len(x)): + self.input_handlers[idx].reshape(x[idx].shape) + self.input_handlers[idx].copy_from_cpu(x[idx]) + + +class Copy2CPU(BaseTransformer): + + INPUT_KEYS = None + OUTPUT_KEYS = None + + def __init__(self, output_handlers): + super().__init__() + self.output_handlers = output_handlers + + def apply(self): + output = [] + for out_tensor in self.output_handlers: + batch = out_tensor.copy_to_cpu() + output.append(batch) + return output + + +class Infer(BaseTransformer): + INPUT_KEYS = None + OUTPUT_KEYS = None + + def __init__(self, predictor): + super().__init__() + self.predictor = predictor + + def apply(self): + self.predictor.run() + + +class BasePaddlePredictor(BaseTransformer): + """Predictor based on Paddle Inference""" + + def __init__(self, model_dir, model_prefix, option): + super().__init__() + self.model_dir = model_dir + self.model_prefix = model_prefix + self._update_option(option) + + def _update_option(self, option): + if option: + if self.option and option == self.option: + return + self._option = option + self._reset() + + @property + def option(self): + return self._option if hasattr(self, "_option") else None + + @option.setter + def option(self, option): + self._update_option(option) + + def _reset(self): + if not self.option: + self.option = PaddlePredictorOption() + logging.debug(f"Env: {self.option}") + ( + predictor, + input_handlers, + output_handlers, + ) = self._create() + self.copy2gpu = Copy2GPU(input_handlers) + self.copy2cpu = Copy2CPU(output_handlers) + self.infer = Infer(predictor) + self.option.changed = False + + def _create(self): + """_create""" + from lazy_paddle.inference import Config, create_predictor + + model_postfix = ".json" if FLAGS_json_format_model else ".pdmodel" + model_file = (self.model_dir / f"{self.model_prefix}{model_postfix}").as_posix() + params_file = (self.model_dir / f"{self.model_prefix}.pdiparams").as_posix() + config = Config(model_file, params_file) + + config.enable_memory_optim() + if self.option.device in ("gpu", "dcu"): + if self.option.device == "gpu": + config.exp_disable_mixed_precision_ops({"feed", "fetch"}) + config.enable_use_gpu(100, self.option.device_id) + if self.option.device == "gpu": + # NOTE: The pptrt settings are not aligned with those of FD. + precision_map = { + "trt_int8": Config.Precision.Int8, + "trt_fp32": Config.Precision.Float32, + "trt_fp16": Config.Precision.Half, + } + if self.option.run_mode in precision_map.keys(): + config.enable_tensorrt_engine( + workspace_size=(1 << 25) * self.option.batch_size, + max_batch_size=self.option.batch_size, + min_subgraph_size=self.option.min_subgraph_size, + precision_mode=precision_map[self.option.run_mode], + use_static=self.option.trt_use_static, + use_calib_mode=self.option.trt_calib_mode, + ) + + if self.option.shape_info_filename is not None: + if not os.path.exists(self.option.shape_info_filename): + config.collect_shape_range_info( + self.option.shape_info_filename + ) + logging.info( + f"Dynamic shape info is collected into: {self.option.shape_info_filename}" + ) + else: + logging.info( + f"A dynamic shape info file ( {self.option.shape_info_filename} ) already exists. \ + No need to generate again." + ) + config.enable_tuned_tensorrt_dynamic_shape( + self.option.shape_info_filename, True + ) + elif self.option.device == "npu": + config.enable_custom_device("npu") + elif self.option.device == "xpu": + pass + elif self.option.device == "mlu": + config.enable_custom_device("mlu") + else: + assert self.option.device == "cpu" + config.disable_gpu() + if "mkldnn" in self.option.run_mode: + try: + config.enable_mkldnn() + if "bf16" in self.option.run_mode: + config.enable_mkldnn_bfloat16() + except Exception as e: + logging.warning( + "MKL-DNN is not available. We will disable MKL-DNN." + ) + config.set_mkldnn_cache_capacity(-1) + else: + if hasattr(config, "disable_mkldnn"): + config.disable_mkldnn() + + # Disable paddle inference logging + config.disable_glog_info() + + config.set_cpu_math_library_num_threads(self.option.cpu_threads) + + if self.option.device in ("cpu", "gpu"): + if not ( + self.option.device == "gpu" and self.option.run_mode.startswith("trt") + ): + if hasattr(config, "enable_new_ir"): + config.enable_new_ir(self.option.enable_new_ir) + if hasattr(config, "enable_new_executor"): + config.enable_new_executor() + config.set_optimization_level(3) + + for del_p in self.option.delete_pass: + config.delete_pass(del_p) + + if self.option.device in ("gpu", "dcu"): + if paddle.is_compiled_with_rocm(): + # Delete unsupported passes in dcu + config.delete_pass("conv2d_add_act_fuse_pass") + config.delete_pass("conv2d_add_fuse_pass") + + predictor = create_predictor(config) + + # Get input and output handlers + input_names = predictor.get_input_names() + input_names.sort() + input_handlers = [] + output_handlers = [] + for input_name in input_names: + input_handler = predictor.get_input_handle(input_name) + input_handlers.append(input_handler) + output_names = predictor.get_output_names() + for output_name in output_names: + output_handler = predictor.get_output_handle(output_name) + output_handlers.append(output_handler) + return predictor, input_handlers, output_handlers + + def apply(self, **kwargs): + if self.option.changed: + self._reset() + batches = self.to_batch(**kwargs) + self.copy2gpu.apply(batches) + self.infer.apply() + pred = self.copy2cpu.apply() + return self.format_output(pred) + + @property + def sub_cmps(self): + return { + "Copy2GPU": self.copy2gpu, + "Infer": self.infer, + "Copy2CPU": self.copy2cpu, + } + + @abstractmethod + def to_batch(self): + raise NotImplementedError + + @abstractmethod + def format_output(self, pred): + raise NotImplementedError + + +class ImagePredictor(BasePaddlePredictor): + + INPUT_KEYS = ["img"] + OUTPUT_KEYS = ["pred"] + + def to_batch(self, img): + return [np.stack(img, axis=0).astype(dtype=np.float32, copy=False)] + + def format_output(self, pred): + return {"pred": pred} diff --git a/paddlex/inference/new_models/image_classification/__init__.py b/paddlex/inference/new_models/image_classification/__init__.py new file mode 100644 index 000000000..d11379783 --- /dev/null +++ b/paddlex/inference/new_models/image_classification/__init__.py @@ -0,0 +1,15 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .predictor import ClasPredictor diff --git a/paddlex/inference/new_models/image_classification/predictor.py b/paddlex/inference/new_models/image_classification/predictor.py new file mode 100644 index 000000000..8596e17a6 --- /dev/null +++ b/paddlex/inference/new_models/image_classification/predictor.py @@ -0,0 +1,116 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ....utils.func_register import FuncRegister +from ....modules.image_classification.model_list import MODELS +from ..base import BasicPredictor +from ..common.cv_components import * +from ..common.paddle_predictor import ImagePredictor +from .transformers import * +from .result import TopkResult + + +class ClasPredictor(BasicPredictor): + + entities = MODELS + + _FUNC_MAP = {} + register = FuncRegister(_FUNC_MAP) + + def _build_batch_sampler(self): + return ImageBatchSampler() + + def _get_result_class(self): + return TopkResult + + def _build_transformers(self): + self._add_transformer(ReadImage(format="RGB")) + for cfg in self.config["PreProcess"]["transform_ops"]: + tf_key = list(cfg.keys())[0] + func = self._FUNC_MAP[tf_key] + args = cfg.get(tf_key, {}) + op = func(self, **args) if args else func(self) + self._add_transformer(op) + + predictor = ImagePredictor( + model_dir=self.model_dir, + model_prefix=self.MODEL_FILE_PREFIX, + option=self.pp_option, + ) + self._add_transformer(predictor) + + post_processes = self.config["PostProcess"] + for key in post_processes: + func = self._FUNC_MAP.get(key) + args = post_processes.get(key, {}) + op = func(self, **args) if args else func(self) + self._add_transformer(op) + + def _set_dataflow(self): + self.ReadImage.inputs.img.fetch(self.batch_sampler.outputs.img) + self.Resize.inputs.img.fetch(self.ReadImage.outputs.img) + self.Crop.inputs.img.fetch(self.Resize.outputs.img) + self.Normalize.inputs.img.fetch(self.Crop.outputs.img) + self.ToCHWImage.inputs.img.fetch(self.Normalize.outputs.img) + self.ImagePredictor.inputs.img.fetch(self.ToCHWImage.outputs.img) + self.Topk.inputs.pred.fetch(self.ImagePredictor.outputs.pred) + self.result_packager.inputs.input_path.fetch(self.batch_sampler.outputs.img) + self.result_packager.inputs.class_ids.fetch(self.Topk.outputs.class_ids) + self.result_packager.inputs.scores.fetch(self.Topk.outputs.scores) + if self.Topk.class_id_map is not None: + self.result_packager.inputs.label_names.fetch(self.Topk.outputs.label_names) + + @register("ResizeImage") + # TODO(gaotingquan): backend & interpolation + def build_resize( + self, resize_short=None, size=None, backend="cv2", interpolation="LINEAR" + ): + assert resize_short or size + if resize_short: + op = ResizeByShort( + target_short_edge=resize_short, size_divisor=None, interp="LINEAR" + ) + else: + op = Resize(target_size=size) + op.name = "Resize" + return op + + @register("CropImage") + def build_crop(self, size=224): + return Crop(crop_size=size) + + @register("NormalizeImage") + def build_normalize( + self, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + scale=1 / 255, + order="", + channel_num=3, + ): + assert channel_num == 3 + assert order == "" + return Normalize(scale=scale, mean=mean, std=std) + + @register("ToCHWImage") + def build_to_chw(self): + return ToCHWImage() + + @register("Topk") + def build_topk(self, topk, label_list=None): + return Topk(topk=int(topk), class_ids=label_list) + + @register("MultiLabelThreshOutput") + def build_threshoutput(self, threshold, label_list=None): + return MultiLabelThreshOutput(threshold=float(threshold), class_ids=label_list) diff --git a/paddlex/inference/new_models/image_classification/result.py b/paddlex/inference/new_models/image_classification/result.py new file mode 100644 index 000000000..81851d4e5 --- /dev/null +++ b/paddlex/inference/new_models/image_classification/result.py @@ -0,0 +1,82 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import PIL +from PIL import Image, ImageDraw, ImageFont +import numpy as np + +from ....utils.fonts import PINGFANG_FONT_FILE_PATH +from ...utils.color_map import get_colormap +from ..base import CVResult + + +class TopkResult(CVResult): + INPUT_KEYS = ["input_img", "input_path", "class_ids", "scores", "label_names"] + + def _to_img(self): + """Draw label on image""" + labels = self.get("label_names", self["class_ids"]) + label_str = f"{labels[0]} {self['scores'][0]:.2f}" + + image = Image.fromarray(self._input_img) + image_size = image.size + draw = ImageDraw.Draw(image) + min_font_size = int(image_size[0] * 0.02) + max_font_size = int(image_size[0] * 0.05) + for font_size in range(max_font_size, min_font_size - 1, -1): + font = ImageFont.truetype( + PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8" + ) + if tuple(map(int, PIL.__version__.split("."))) <= (10, 0, 0): + text_width_tmp, text_height_tmp = draw.textsize(label_str, font) + else: + left, top, right, bottom = draw.textbbox((0, 0), label_str, font) + text_width_tmp, text_height_tmp = right - left, bottom - top + if text_width_tmp <= image_size[0]: + break + else: + font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, min_font_size) + color_list = get_colormap(rgb=True) + color = tuple(color_list[0]) + font_color = tuple(self._get_font_colormap(3)) + if tuple(map(int, PIL.__version__.split("."))) <= (10, 0, 0): + text_width, text_height = draw.textsize(label_str, font) + else: + left, top, right, bottom = draw.textbbox((0, 0), label_str, font) + text_width, text_height = right - left, bottom - top + + rect_left = 3 + rect_top = 3 + rect_right = rect_left + text_width + 3 + rect_bottom = rect_top + text_height + 6 + + draw.rectangle([(rect_left, rect_top), (rect_right, rect_bottom)], fill=color) + + text_x = rect_left + 3 + text_y = rect_top + draw.text((text_x, text_y), label_str, fill=font_color, font=font) + return image + + def _get_font_colormap(self, color_index): + """ + Get font colormap + """ + dark = np.array([0x14, 0x0E, 0x35]) + light = np.array([0xFF, 0xFF, 0xFF]) + light_indexs = [0, 3, 4, 8, 9, 13, 14, 18, 19] + if color_index in light_indexs: + return light.astype("int32") + else: + return dark.astype("int32") diff --git a/paddlex/inference/new_models/image_classification/transformers.py b/paddlex/inference/new_models/image_classification/transformers.py new file mode 100644 index 000000000..a8cff4417 --- /dev/null +++ b/paddlex/inference/new_models/image_classification/transformers.py @@ -0,0 +1,64 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from ....utils import logging +from ..base import BaseTransformer +from ..common.batchable import batchable + + +__all__ = ["Topk"] + + +def _parse_class_id_map(class_ids): + """parse class id to label map file""" + if class_ids is None: + return None + class_id_map = {id: str(lb) for id, lb in enumerate(class_ids)} + return class_id_map + + +class Topk(BaseTransformer): + """Topk Transform""" + + INPUT_KEYS = ["pred"] + OUTPUT_KEYS = ["class_ids", "scores", "label_names"] + + def __init__(self, topk, class_ids=None): + super().__init__() + assert isinstance(topk, (int,)) + self.topk = topk + self.class_id_map = _parse_class_id_map(class_ids) + + @batchable + def apply(self, pred): + """apply""" + cls_pred = pred[0] + index = cls_pred.argsort(axis=0)[-self.topk :][::-1].astype("int32") + clas_id_list = [] + score_list = [] + label_name_list = [] + for i in index: + clas_id_list.append(i.item()) + score_list.append(cls_pred[i].item()) + if self.class_id_map is not None: + label_name_list.append(self.class_id_map[i.item()]) + result = { + "class_ids": clas_id_list, + "scores": np.around(score_list, decimals=5), + } + if label_name_list is not None: + result["label_names"] = label_name_list + return result diff --git a/paddlex/inference/new_models/text_detection/__init__.py b/paddlex/inference/new_models/text_detection/__init__.py new file mode 100644 index 000000000..396b257e4 --- /dev/null +++ b/paddlex/inference/new_models/text_detection/__init__.py @@ -0,0 +1,15 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .predictor import TextDetPredictor diff --git a/paddlex/inference/new_models/text_detection/predictor.py b/paddlex/inference/new_models/text_detection/predictor.py new file mode 100644 index 000000000..3780e592c --- /dev/null +++ b/paddlex/inference/new_models/text_detection/predictor.py @@ -0,0 +1,121 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ....utils.func_register import FuncRegister +from ....modules.text_detection.model_list import MODELS +from ..base import BasicPredictor +from ..common.cv_components import * +from ..common.paddle_predictor import ImagePredictor +from .transformers import * +from .result import TextDetResult + + +class TextDetPredictor(BasicPredictor): + + entities = MODELS + + _FUNC_MAP = {} + register = FuncRegister(_FUNC_MAP) + + def _build_batch_sampler(self): + return ImageBatchSampler() + + def _get_result_class(self): + return TextDetResult + + def _build_transformers(self): + for cfg in self.config["PreProcess"]["transform_ops"]: + tf_key = list(cfg.keys())[0] + func = self._FUNC_MAP[tf_key] + args = cfg.get(tf_key, {}) + op = func(self, **args) if args else func(self) + if op: + self._add_transformer(op) + + predictor = ImagePredictor( + model_dir=self.model_dir, + model_prefix=self.MODEL_FILE_PREFIX, + option=self.pp_option, + ) + self._add_transformer(predictor) + + op = self.build_postprocess(**self.config["PostProcess"]) + self._add_transformer(op) + + def _set_dataflow(self): + self.ReadImage.inputs.img.fetch(self.batch_sampler.outputs.img) + self.DetResizeForTest.inputs.img.fetch(self.ReadImage.outputs.img) + self.NormalizeImage.inputs.img.fetch(self.DetResizeForTest.outputs.img) + self.ToCHWImage.inputs.img.fetch(self.NormalizeImage.outputs.img) + self.ImagePredictor.inputs.img.fetch(self.ToCHWImage.outputs.img) + self.DBPostProcess.inputs.pred.fetch(self.ImagePredictor.outputs.pred) + self.DBPostProcess.inputs.img_shape.fetch( + self.DetResizeForTest.outputs.img_shape + ) + self.result_packager.inputs.input_path.fetch(self.batch_sampler.outputs.img) + self.result_packager.inputs.polys.fetch(self.DBPostProcess.outputs.polys) + self.result_packager.inputs.scores.fetch(self.DBPostProcess.outputs.scores) + + @register("DecodeImage") + def build_readimg(self, channel_first, img_mode): + assert channel_first == False + return ReadImage(format=img_mode) + + @register("DetResizeForTest") + def build_resize(self, **kwargs): + # TODO: align to PaddleOCR + if self.model_name in ("PP-OCRv4_server_det", "PP-OCRv4_mobile_det"): + resize_long = kwargs.get("resize_long", 960) + return DetResizeForTest(limit_side_len=resize_long, limit_type="max") + return DetResizeForTest(**kwargs) + + @register("NormalizeImage") + def build_normalize( + self, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + scale=1 / 255, + order="", + channel_num=3, + ): + return NormalizeImage( + mean=mean, std=std, scale=scale, order=order, channel_num=channel_num + ) + + @register("ToCHWImage") + def build_to_chw(self): + return ToCHWImage() + + def build_postprocess(self, **kwargs): + if kwargs.get("name") == "DBPostProcess": + return DBPostProcess( + thresh=kwargs.get("thresh", 0.3), + box_thresh=kwargs.get("box_thresh", 0.7), + max_candidates=kwargs.get("max_candidates", 1000), + unclip_ratio=kwargs.get("unclip_ratio", 2.0), + use_dilation=kwargs.get("use_dilation", False), + score_mode=kwargs.get("score_mode", "fast"), + box_type=kwargs.get("box_type", "quad"), + ) + + else: + raise Exception() + + @register("DetLabelEncode") + def foo(self, *args, **kwargs): + return None + + @register("KeepKeys") + def foo(self, *args, **kwargs): + return None diff --git a/paddlex/inference/new_models/text_detection/result.py b/paddlex/inference/new_models/text_detection/result.py new file mode 100644 index 000000000..189604717 --- /dev/null +++ b/paddlex/inference/new_models/text_detection/result.py @@ -0,0 +1,35 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import cv2 + +from ..base import CVResult + + +class TextDetResult(CVResult): + INPUT_KEYS = ["input_path", "polys", "scores"] + + def __init__(self, data): + super().__init__(data) + self._img_reader.set_backend("opencv") + + def _to_img(self): + """draw rectangle""" + boxes = self["polys"] + image = self._img_reader.read(self["input_path"]) + for box in boxes: + box = np.reshape(np.array(box).astype(int), [-1, 1, 2]).astype(np.int64) + cv2.polylines(image, [box], True, (0, 0, 255), 2) + return image diff --git a/paddlex/inference/new_models/text_detection/transformers.py b/paddlex/inference/new_models/text_detection/transformers.py new file mode 100644 index 000000000..9e5561f93 --- /dev/null +++ b/paddlex/inference/new_models/text_detection/transformers.py @@ -0,0 +1,416 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import sys +import cv2 +import copy +import math +import pyclipper +import numpy as np +from numpy.linalg import norm +from PIL import Image +from shapely.geometry import Polygon + +from ...utils.io import ImageReader +from ....utils import logging +from ..base import BaseTransformer +from ..common.batchable import batchable + +__all__ = ["DetResizeForTest", "NormalizeImage", "DBPostProcess"] + + +class DetResizeForTest(BaseTransformer): + """DetResizeForTest""" + + INPUT_KEYS = ["img"] + OUTPUT_KEYS = ["img", "img_shape"] + + def __init__(self, **kwargs): + super().__init__() + self.resize_type = 0 + self.keep_ratio = False + if "image_shape" in kwargs: + self.image_shape = kwargs["image_shape"] + self.resize_type = 1 + if "keep_ratio" in kwargs: + self.keep_ratio = kwargs["keep_ratio"] + elif "limit_side_len" in kwargs: + self.limit_side_len = kwargs["limit_side_len"] + self.limit_type = kwargs.get("limit_type", "min") + elif "resize_long" in kwargs: + self.resize_type = 2 + self.resize_long = kwargs.get("resize_long", 960) + else: + self.limit_side_len = 736 + self.limit_type = "min" + + @batchable + def apply(self, img): + """apply""" + src_h, src_w, _ = img.shape + if sum([src_h, src_w]) < 64: + img = self.image_padding(img) + + if self.resize_type == 0: + # img, shape = self.resize_image_type0(img) + img, [ratio_h, ratio_w] = self.resize_image_type0(img) + elif self.resize_type == 2: + img, [ratio_h, ratio_w] = self.resize_image_type2(img) + else: + # img, shape = self.resize_image_type1(img) + img, [ratio_h, ratio_w] = self.resize_image_type1(img) + return {"img": img, "img_shape": np.array([src_h, src_w, ratio_h, ratio_w])} + + def image_padding(self, im, value=0): + """padding image""" + h, w, c = im.shape + im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value + im_pad[:h, :w, :] = im + return im_pad + + def resize_image_type1(self, img): + """resize the image""" + resize_h, resize_w = self.image_shape + ori_h, ori_w = img.shape[:2] # (h, w, c) + if self.keep_ratio is True: + resize_w = ori_w * resize_h / ori_h + N = math.ceil(resize_w / 32) + resize_w = N * 32 + ratio_h = float(resize_h) / ori_h + ratio_w = float(resize_w) / ori_w + img = cv2.resize(img, (int(resize_w), int(resize_h))) + # return img, np.array([ori_h, ori_w]) + return img, [ratio_h, ratio_w] + + def resize_image_type0(self, img): + """ + resize image to a size multiple of 32 which is required by the network + args: + img(array): array with shape [h, w, c] + return(tuple): + img, (ratio_h, ratio_w) + """ + limit_side_len = self.limit_side_len + h, w, c = img.shape + + # limit the max side + if self.limit_type == "max": + if max(h, w) > limit_side_len: + if h > w: + ratio = float(limit_side_len) / h + else: + ratio = float(limit_side_len) / w + else: + ratio = 1.0 + elif self.limit_type == "min": + if min(h, w) < limit_side_len: + if h < w: + ratio = float(limit_side_len) / h + else: + ratio = float(limit_side_len) / w + else: + ratio = 1.0 + elif self.limit_type == "resize_long": + ratio = float(limit_side_len) / max(h, w) + else: + raise Exception("not support limit type, image ") + resize_h = int(h * ratio) + resize_w = int(w * ratio) + + resize_h = max(int(round(resize_h / 32) * 32), 32) + resize_w = max(int(round(resize_w / 32) * 32), 32) + + try: + if int(resize_w) <= 0 or int(resize_h) <= 0: + return None, (None, None) + img = cv2.resize(img, (int(resize_w), int(resize_h))) + except: + logging.info(img.shape, resize_w, resize_h) + sys.exit(0) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + return img, [ratio_h, ratio_w] + + def resize_image_type2(self, img): + """resize image size""" + h, w, _ = img.shape + + resize_w = w + resize_h = h + + if resize_h > resize_w: + ratio = float(self.resize_long) / resize_h + else: + ratio = float(self.resize_long) / resize_w + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + img = cv2.resize(img, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + + return img, [ratio_h, ratio_w] + + +class NormalizeImage(BaseTransformer): + """normalize image such as substract mean, divide std""" + + INPUT_KEYS = ["img"] + OUTPUT_KEYS = ["img"] + + def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs): + super().__init__() + if isinstance(scale, str): + scale = eval(scale) + self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) + mean = mean if mean is not None else [0.485, 0.456, 0.406] + std = std if std is not None else [0.229, 0.224, 0.225] + + shape = (3, 1, 1) if order == "chw" else (1, 1, 3) + self.mean = np.array(mean).reshape(shape).astype("float32") + self.std = np.array(std).reshape(shape).astype("float32") + + @batchable + def apply(self, img): + """apply""" + if isinstance(img, Image.Image): + img = np.array(img) + assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage" + img = (img.astype("float32") * self.scale - self.mean) / self.std + return {"img": img} + + +class DBPostProcess(BaseTransformer): + """ + The post process for Differentiable Binarization (DB). + """ + + INPUT_KEYS = ["pred", "img_shape"] + OUTPUT_KEYS = ["polys", "scores"] + + def __init__( + self, + thresh=0.3, + box_thresh=0.7, + max_candidates=1000, + unclip_ratio=2.0, + use_dilation=False, + score_mode="fast", + box_type="quad", + **kwargs + ): + super().__init__() + self.thresh = thresh + self.box_thresh = box_thresh + self.max_candidates = max_candidates + self.unclip_ratio = unclip_ratio + self.min_size = 3 + self.score_mode = score_mode + self.box_type = box_type + assert score_mode in [ + "slow", + "fast", + ], "Score mode must be in [slow, fast] but got: {}".format(score_mode) + + self.dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]]) + + def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height): + """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}""" + + bitmap = _bitmap + height, width = bitmap.shape + + boxes = [] + scores = [] + + contours, _ = cv2.findContours( + (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE + ) + + for contour in contours[: self.max_candidates]: + epsilon = 0.002 * cv2.arcLength(contour, True) + approx = cv2.approxPolyDP(contour, epsilon, True) + points = approx.reshape((-1, 2)) + if points.shape[0] < 4: + continue + + score = self.box_score_fast(pred, points.reshape(-1, 2)) + if self.box_thresh > score: + continue + + if points.shape[0] > 2: + box = self.unclip(points, self.unclip_ratio) + if len(box) > 1: + continue + else: + continue + box = box.reshape(-1, 2) + + if len(box) > 0: + _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2))) + if sside < self.min_size + 2: + continue + else: + continue + + box = np.array(box) + box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip( + np.round(box[:, 1] / height * dest_height), 0, dest_height + ) + boxes.append(box) + scores.append(score) + return boxes, scores + + def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): + """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}""" + + bitmap = _bitmap + height, width = bitmap.shape + + outs = cv2.findContours( + (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE + ) + if len(outs) == 3: + img, contours, _ = outs[0], outs[1], outs[2] + elif len(outs) == 2: + contours, _ = outs[0], outs[1] + + num_contours = min(len(contours), self.max_candidates) + + boxes = [] + scores = [] + for index in range(num_contours): + contour = contours[index] + points, sside = self.get_mini_boxes(contour) + if sside < self.min_size: + continue + points = np.array(points) + if self.score_mode == "fast": + score = self.box_score_fast(pred, points.reshape(-1, 2)) + else: + score = self.box_score_slow(pred, contour) + if self.box_thresh > score: + continue + + box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2) + box, sside = self.get_mini_boxes(box) + if sside < self.min_size + 2: + continue + box = np.array(box) + + box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip( + np.round(box[:, 1] / height * dest_height), 0, dest_height + ) + boxes.append(box.astype(np.int16)) + scores.append(score) + return np.array(boxes, dtype=np.int16), scores + + def unclip(self, box, unclip_ratio): + """unclip""" + poly = Polygon(box) + distance = poly.area * unclip_ratio / poly.length + offset = pyclipper.PyclipperOffset() + offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + try: + expanded = np.array(offset.Execute(distance)) + except ValueError: + expanded = np.array(offset.Execute(distance)[0]) + return expanded + + def get_mini_boxes(self, contour): + """get mini boxes""" + bounding_box = cv2.minAreaRect(contour) + points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) + + index_1, index_2, index_3, index_4 = 0, 1, 2, 3 + if points[1][1] > points[0][1]: + index_1 = 0 + index_4 = 1 + else: + index_1 = 1 + index_4 = 0 + if points[3][1] > points[2][1]: + index_2 = 2 + index_3 = 3 + else: + index_2 = 3 + index_3 = 2 + + box = [points[index_1], points[index_2], points[index_3], points[index_4]] + return box, min(bounding_box[1]) + + def box_score_fast(self, bitmap, _box): + """box_score_fast: use bbox mean score as the mean score""" + h, w = bitmap.shape[:2] + box = _box.copy() + xmin = np.clip(np.floor(box[:, 0].min()).astype("int"), 0, w - 1) + xmax = np.clip(np.ceil(box[:, 0].max()).astype("int"), 0, w - 1) + ymin = np.clip(np.floor(box[:, 1].min()).astype("int"), 0, h - 1) + ymax = np.clip(np.ceil(box[:, 1].max()).astype("int"), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + box[:, 0] = box[:, 0] - xmin + box[:, 1] = box[:, 1] - ymin + cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) + return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0] + + def box_score_slow(self, bitmap, contour): + """box_score_slow: use polyon mean score as the mean score""" + h, w = bitmap.shape[:2] + contour = contour.copy() + contour = np.reshape(contour, (-1, 2)) + + xmin = np.clip(np.min(contour[:, 0]), 0, w - 1) + xmax = np.clip(np.max(contour[:, 0]), 0, w - 1) + ymin = np.clip(np.min(contour[:, 1]), 0, h - 1) + ymax = np.clip(np.max(contour[:, 1]), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + + contour[:, 0] = contour[:, 0] - xmin + contour[:, 1] = contour[:, 1] - ymin + + cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1) + return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0] + + @batchable + def apply(self, pred, img_shape): + """apply""" + pred = pred[0][0, :, :] + segmentation = pred > self.thresh + + src_h, src_w, ratio_h, ratio_w = img_shape + if self.dilation_kernel is not None: + mask = cv2.dilate( + np.array(segmentation).astype(np.uint8), + self.dilation_kernel, + ) + else: + mask = segmentation + if self.box_type == "poly": + boxes, scores = self.polygons_from_bitmap(pred, mask, src_w, src_h) + elif self.box_type == "quad": + boxes, scores = self.boxes_from_bitmap(pred, mask, src_w, src_h) + else: + raise ValueError("box_type can only be one of ['quad', 'poly']") + + return {"polys": boxes, "scores": scores} diff --git a/paddlex/inference/new_models/text_recognition/__init__.py b/paddlex/inference/new_models/text_recognition/__init__.py new file mode 100644 index 000000000..53f64b6c9 --- /dev/null +++ b/paddlex/inference/new_models/text_recognition/__init__.py @@ -0,0 +1,15 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .predictor import TextRecPredictor diff --git a/paddlex/inference/new_models/text_recognition/predictor.py b/paddlex/inference/new_models/text_recognition/predictor.py new file mode 100644 index 000000000..2af2f5288 --- /dev/null +++ b/paddlex/inference/new_models/text_recognition/predictor.py @@ -0,0 +1,90 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ....utils.func_register import FuncRegister +from ....modules.text_recognition.model_list import MODELS +from ..base import BasicPredictor +from ..common.cv_components import * +from ..common.paddle_predictor import ImagePredictor +from .transformers import * +from .result import TextRecResult + + +class TextRecPredictor(BasicPredictor): + + entities = MODELS + + _FUNC_MAP = {} + register = FuncRegister(_FUNC_MAP) + + def _build_batch_sampler(self): + return ImageBatchSampler() + + def _get_result_class(self): + return TextRecResult + + def _build_transformers(self): + for cfg in self.config["PreProcess"]["transform_ops"]: + tf_key = list(cfg.keys())[0] + assert tf_key in self._FUNC_MAP + func = self._FUNC_MAP[tf_key] + args = cfg.get(tf_key, {}) + op = func(self, **args) if args else func(self) + if op: + self._add_transformer(op) + + predictor = ImagePredictor( + model_dir=self.model_dir, + model_prefix=self.MODEL_FILE_PREFIX, + option=self.pp_option, + ) + self._add_transformer(predictor) + + op = self.build_postprocess(**self.config["PostProcess"]) + self._add_transformer(op) + + def _set_dataflow(self): + self.ReadImage.inputs.img.fetch(self.batch_sampler.outputs.img) + self.OCRReisizeNormImg.inputs.img.fetch(self.ReadImage.outputs.img) + self.OCRReisizeNormImg.inputs.img_size.fetch(self.ReadImage.outputs.img_size) + self.ImagePredictor.inputs.img.fetch(self.OCRReisizeNormImg.outputs.img) + self.CTCLabelDecode.inputs.pred.fetch(self.ImagePredictor.outputs.pred) + self.result_packager.inputs.input_path.fetch(self.batch_sampler.outputs.img) + self.result_packager.inputs.text.fetch(self.CTCLabelDecode.outputs.text) + self.result_packager.inputs.score.fetch(self.CTCLabelDecode.outputs.score) + + @register("DecodeImage") + def build_readimg(self, channel_first, img_mode): + assert channel_first == False + return ReadImage(format=img_mode) + + @register("RecResizeImg") + def build_resize(self, image_shape): + return OCRReisizeNormImg(rec_image_shape=image_shape) + + def build_postprocess(self, **kwargs): + if kwargs.get("name") == "CTCLabelDecode": + return CTCLabelDecode( + character_list=kwargs.get("character_dict"), + ) + else: + raise Exception() + + @register("MultiLabelEncode") + def foo(self, *args, **kwargs): + return None + + @register("KeepKeys") + def foo(self, *args, **kwargs): + return None diff --git a/paddlex/inference/new_models/text_recognition/result.py b/paddlex/inference/new_models/text_recognition/result.py new file mode 100644 index 000000000..d39894ffb --- /dev/null +++ b/paddlex/inference/new_models/text_recognition/result.py @@ -0,0 +1,65 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import PIL +from PIL import Image, ImageDraw, ImageFont + +from ....utils.fonts import PINGFANG_FONT_FILE_PATH +from ..base import CVResult + + +class TextRecResult(CVResult): + INPUT_KEYS = ["input_path", "text", "score"] + + def _to_img(self): + """Draw label on image""" + image = self._img_reader.read(self["input_path"]) + rec_text = self["text"] + rec_score = self["score"] + image = image.convert("RGB") + image_width, image_height = image.size + text = f"{rec_text} ({rec_score})" + font = self.adjust_font_size(image_width, text, PINGFANG_FONT_FILE_PATH) + row_height = font.getbbox(text)[3] + new_image_height = image_height + int(row_height * 1.2) + new_image = Image.new("RGB", (image_width, new_image_height), (255, 255, 255)) + new_image.paste(image, (0, 0)) + + draw = ImageDraw.Draw(new_image) + draw.text( + (0, image_height), + text, + fill=(0, 0, 0), + font=font, + ) + return new_image + + def adjust_font_size(self, image_width, text, font_path): + font_size = int(image_width * 0.06) + font = ImageFont.truetype(font_path, font_size) + + if int(PIL.__version__.split(".")[0]) < 10: + text_width, _ = font.getsize(text) + else: + text_width, _ = font.getbbox(text)[2:] + + while text_width > image_width: + font_size -= 1 + font = ImageFont.truetype(font_path, font_size) + if int(PIL.__version__.split(".")[0]) < 10: + text_width, _ = font.getsize(text) + else: + text_width, _ = font.getbbox(text)[2:] + + return font diff --git a/paddlex/inference/new_models/text_recognition/transformers.py b/paddlex/inference/new_models/text_recognition/transformers.py new file mode 100644 index 000000000..51aec0e96 --- /dev/null +++ b/paddlex/inference/new_models/text_recognition/transformers.py @@ -0,0 +1,196 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import os.path as osp + +import re +import numpy as np +from PIL import Image +import cv2 +import math +import json +import tempfile +from tokenizers import Tokenizer as TokenizerFast + +from ....utils import logging +from ..base import BaseTransformer +from ..common.batchable import batchable + + +__all__ = [ + "OCRReisizeNormImg", + "CTCLabelDecode", +] + + +class OCRReisizeNormImg(BaseTransformer): + """for ocr image resize and normalization""" + + INPUT_KEYS = ["img", "img_size"] + OUTPUT_KEYS = ["img"] + + def __init__(self, rec_image_shape=[3, 48, 320]): + super().__init__() + self.rec_image_shape = rec_image_shape + + def resize_norm_img(self, img, max_wh_ratio): + """resize and normalize the img""" + imgC, imgH, imgW = self.rec_image_shape + assert imgC == img.shape[2] + imgW = int((imgH * max_wh_ratio)) + + h, w = img.shape[:2] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype("float32") + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + return padding_im + + @batchable + def apply(self, img, img_size): + """apply""" + imgC, imgH, imgW = self.rec_image_shape + max_wh_ratio = imgW / imgH + w, h = img_size[:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + img = self.resize_norm_img(img, max_wh_ratio) + return {"img": img} + + +class BaseRecLabelDecode(BaseTransformer): + """Convert between text-label and text-index""" + + INPUT_KEYS = ["pred"] + OUTPUT_KEYS = ["text", "score"] + + def __init__(self, character_str=None, use_space_char=True): + super().__init__() + self.reverse = False + character_list = ( + list(character_str) + if character_str is not None + else list("0123456789abcdefghijklmnopqrstuvwxyz") + ) + if use_space_char: + character_list.append(" ") + + character_list = self.add_special_char(character_list) + self.dict = {} + for i, char in enumerate(character_list): + self.dict[char] = i + self.character = character_list + + def pred_reverse(self, pred): + """pred_reverse""" + pred_re = [] + c_current = "" + for c in pred: + if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)): + if c_current != "": + pred_re.append(c_current) + pred_re.append(c) + c_current = "" + else: + c_current += c + if c_current != "": + pred_re.append(c_current) + + return "".join(pred_re[::-1]) + + def add_special_char(self, character_list): + """add_special_char""" + return character_list + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """convert text-index into text-label.""" + result_list = [] + ignored_tokens = self.get_ignored_tokens() + batch_size = len(text_index) + for batch_idx in range(batch_size): + selection = np.ones(len(text_index[batch_idx]), dtype=bool) + if is_remove_duplicate: + selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1] + for ignored_token in ignored_tokens: + selection &= text_index[batch_idx] != ignored_token + + char_list = [ + self.character[text_id] for text_id in text_index[batch_idx][selection] + ] + if text_prob is not None: + conf_list = text_prob[batch_idx][selection] + else: + conf_list = [1] * len(selection) + if len(conf_list) == 0: + conf_list = [0] + + text = "".join(char_list) + + if self.reverse: # for arabic rec + text = self.pred_reverse(text) + + result_list.append((text, np.mean(conf_list).tolist())) + return result_list + + def get_ignored_tokens(self): + """get_ignored_tokens""" + return [0] # for ctc blank + + def apply(self, pred): + """apply""" + preds = np.array(pred) + if isinstance(preds, tuple) or isinstance(preds, list): + preds = preds[-1] + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) + output = {"text": [], "score": []} + for t in text: + output["text"].append(t[0]) + output["score"].append(t[1]) + return output + + +class CTCLabelDecode(BaseRecLabelDecode): + """Convert between text-label and text-index""" + + def __init__(self, character_list=None, use_space_char=True): + super().__init__(character_list, use_space_char=use_space_char) + + def apply(self, pred): + """apply""" + preds = np.array(pred[0]) + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) + output = {"text": [], "score": []} + for t in text: + output["text"].append(t[0]) + output["score"].append(t[1]) + return output + + def add_special_char(self, character_list): + """add_special_char""" + character_list = ["blank"] + character_list + return character_list diff --git a/paddlex/inference/utils/benchmark.py b/paddlex/inference/utils/benchmark.py index b6942ff56..405c2dcca 100644 --- a/paddlex/inference/utils/benchmark.py +++ b/paddlex/inference/utils/benchmark.py @@ -57,7 +57,12 @@ def iterate_cmp(self, cmps): def gather(self, e2e_num): # lazy import for avoiding circular import - from ..components.paddle_predictor import BasePaddlePredictor + from ...utils.flags import NEW_PREDICTOR + + if NEW_PREDICTOR: + from ..new_models.common.paddle_predictor import BasePaddlePredictor + else: + from ..models.common_components.paddle_predictor import BasePaddlePredictor detail = [] summary = {"preprocess": 0, "inference": 0, "postprocess": 0} @@ -158,13 +163,11 @@ def collect(self, e2e_num): save_dir = Path(INFER_BENCHMARK_OUTPUT) save_dir.mkdir(parents=True, exist_ok=True) csv_data = [detail_head, *detail] - # csv_data.extend(detail) with open(Path(save_dir) / "detail.csv", "w", newline="") as file: writer = csv.writer(file) writer.writerows(csv_data) csv_data = [summary_head, *summary] - # csv_data.extend(summary) with open(Path(save_dir) / "summary.csv", "w", newline="") as file: writer = csv.writer(file) writer.writerows(csv_data) diff --git a/paddlex/utils/flags.py b/paddlex/utils/flags.py index 9adab97d7..37e35fb54 100644 --- a/paddlex/utils/flags.py +++ b/paddlex/utils/flags.py @@ -20,6 +20,7 @@ "DRY_RUN", "CHECK_OPTS", "EAGER_INITIALIZATION", + "NEW_PREDICTOR", "INFER_BENCHMARK", "INFER_BENCHMARK_ITER", "INFER_BENCHMARK_WARMUP", @@ -47,6 +48,8 @@ def get_flag_from_env_var(name, default, format_func=str): EAGER_INITIALIZATION = get_flag_from_env_var("PADDLE_PDX_EAGER_INIT", True) FLAGS_json_format_model = get_flag_from_env_var("FLAGS_json_format_model", None) +NEW_PREDICTOR = get_flag_from_env_var("PADDLE_PDX_NEW_PREDICTOR", False) + # Inference Benchmark INFER_BENCHMARK = get_flag_from_env_var("PADDLE_PDX_INFER_BENCHMARK", None) INFER_BENCHMARK_WARMUP = get_flag_from_env_var( diff --git a/paddlex/utils/logging.py b/paddlex/utils/logging.py index f66f20ca2..20ee70dc2 100644 --- a/paddlex/utils/logging.py +++ b/paddlex/utils/logging.py @@ -54,8 +54,7 @@ def debug(msg, *args, **kwargs): else: caller_info = f"{caller_func_name}" msg = f"【{caller_info}】{msg}" - - _logger.debug(msg, *args, **kwargs) + _logger.debug(msg, *args, **kwargs) def info(msg, *args, **kwargs):