-
Notifications
You must be signed in to change notification settings - Fork 968
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c8767a7
commit fb8c28c
Showing
41 changed files
with
3,249 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from pathlib import Path | ||
from typing import Any, Dict, Optional | ||
|
||
from ...utils import errors | ||
from ..utils.official_models import official_models | ||
from .base import BasePredictor, BasicPredictor | ||
|
||
from .image_classification import ClasPredictor | ||
from .text_detection import TextDetPredictor | ||
from .text_recognition import TextRecPredictor | ||
|
||
# from .table_recognition import TablePredictor | ||
# from .object_detection import DetPredictor | ||
# from .instance_segmentation import InstanceSegPredictor | ||
# from .semantic_segmentation import SegPredictor | ||
# from .general_recognition import ShiTuRecPredictor | ||
# from .ts_fc import TSFcPredictor | ||
# from .ts_ad import TSAdPredictor | ||
# from .ts_cls import TSClsPredictor | ||
# from .image_unwarping import WarpPredictor | ||
# from .multilabel_classification import MLClasPredictor | ||
# from .anomaly_detection import UadPredictor | ||
# from .formula_recognition import LaTeXOCRPredictor | ||
# from .face_recognition import FaceRecPredictor | ||
|
||
|
||
def _create_hp_predictor( | ||
model_name, model_dir, device, config, hpi_params, *args, **kwargs | ||
): | ||
try: | ||
from paddlex_hpi.models import HPPredictor | ||
except ModuleNotFoundError: | ||
raise RuntimeError( | ||
"The PaddleX HPI plugin is not properly installed, and the high-performance model inference features are not available." | ||
) from None | ||
try: | ||
predictor = HPPredictor.get(model_name)( | ||
model_dir=model_dir, | ||
config=config, | ||
device=device, | ||
*args, | ||
hpi_params=hpi_params, | ||
**kwargs, | ||
) | ||
except errors.others.ClassNotFoundException: | ||
raise ValueError( | ||
f"{model_name} is not supported by the PaddleX HPI plugin." | ||
) from None | ||
return predictor | ||
|
||
|
||
def create_predictor( | ||
model: str, | ||
device=None, | ||
pp_option=None, | ||
use_hpip: bool = False, | ||
hpi_params: Optional[Dict[str, Any]] = None, | ||
*args, | ||
**kwargs, | ||
) -> BasePredictor: | ||
model_dir = check_model(model) | ||
config = BasePredictor.load_config(model_dir) | ||
model_name = config["Global"]["model_name"] | ||
if use_hpip: | ||
return _create_hp_predictor( | ||
model_name=model_name, | ||
model_dir=model_dir, | ||
config=config, | ||
hpi_params=hpi_params, | ||
device=device, | ||
*args, | ||
**kwargs, | ||
) | ||
else: | ||
return BasicPredictor.get(model_name)( | ||
model_dir=model_dir, | ||
config=config, | ||
device=device, | ||
pp_option=pp_option, | ||
*args, | ||
**kwargs, | ||
) | ||
|
||
|
||
def check_model(model): | ||
if Path(model).exists(): | ||
return Path(model) | ||
elif model in official_models: | ||
return official_models[model] | ||
else: | ||
raise Exception( | ||
f"The model ({model}) is no exists! Please using directory of local model files or model name supported by PaddleX!" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .component import BaseComponent | ||
from .transformer import BaseTransformer | ||
from .predictor import BasePredictor, BasicPredictor | ||
from .result import BaseResult, CVResult | ||
from .batch_sampler import BaseBatchSampler, BatchData |
16 changes: 16 additions & 0 deletions
16
paddlex/inference/new_models/base/batch_sampler/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .batch_sampler import BaseBatchSampler | ||
from .batch_data import BatchData |
38 changes: 38 additions & 0 deletions
38
paddlex/inference/new_models/base/batch_sampler/batch_data.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
class BatchData(object): | ||
|
||
def __init__(self, data, num): | ||
self._data = data | ||
self._num = num | ||
|
||
@property | ||
def num(self): | ||
return self._num | ||
|
||
def get_by_key(self, key): | ||
assert key in self._data, f"{key}, {list(self._data.keys())}" | ||
return self._data[key] | ||
|
||
def get_by_idx(self, idx): | ||
assert idx <= self.num | ||
return {k: v[idx] for k, v in self._data.items()} | ||
|
||
def update_by_key(self, output): | ||
assert isinstance(output, dict) | ||
for k, v in output.items(): | ||
assert isinstance(v, list) and len(v) == self.num | ||
self._data[k] = v |
48 changes: 48 additions & 0 deletions
48
paddlex/inference/new_models/base/batch_sampler/batch_sampler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
from ..component import BaseComponent | ||
from .batch_data import BatchData | ||
|
||
|
||
class BaseBatchSampler(BaseComponent): | ||
|
||
def __init__(self, batch_size=1): | ||
self._batch_size = batch_size | ||
super().__init__() | ||
|
||
@property | ||
def batch_size(self): | ||
return self._batch_size | ||
|
||
@batch_size.setter | ||
def batch_size(self, bs): | ||
assert bs > 0 | ||
self._batch_size = bs | ||
|
||
def __call__(self, *args, **kwargs): | ||
for batch, num in self.apply(*args, **kwargs): | ||
yield BatchData( | ||
{f"{self.name}.{k}": batch[0] for k in self.OUTPUT_KEYS}, num | ||
) | ||
|
||
@abstractmethod | ||
def apply(self, *args, **kwargs): | ||
raise NotImplementedError | ||
|
||
# def set_outputs(self, outputs): | ||
# assert isinstance(outputs, dict) | ||
# self.outputs = outputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import inspect | ||
from abc import ABC, abstractmethod | ||
from types import GeneratorType | ||
|
||
from ....utils.flags import INFER_BENCHMARK | ||
from ....utils import logging | ||
from ...utils.benchmark import Timer | ||
|
||
|
||
class Param: | ||
def __init__(self, name, cmpt): | ||
self.name = name | ||
self.cmpt = cmpt | ||
|
||
def __repr__(self): | ||
return f"{self.cmpt.name}.{self.name}" | ||
|
||
|
||
class OutParam(Param): | ||
pass | ||
|
||
|
||
class InParam(Param): | ||
def fetch(self, param: OutParam): | ||
self.cmpt.set_dep(self, param) | ||
|
||
|
||
class InOuts: | ||
def __getattr__(self, key): | ||
if key in self._keys: | ||
return self._keys.get(key) | ||
raise AttributeError( | ||
f"'{self._cmpt.name}.{self.__class__.__name__}' object has no attribute '{key}'" | ||
) | ||
|
||
def __repr__(self): | ||
_str = "" | ||
for key in self._keys: | ||
param = self._keys[key] | ||
_str += f"{param.cmpt.name}.{param.name}" | ||
return _str | ||
|
||
|
||
class Outputs(InOuts): | ||
def __init__(self, cmpt): | ||
self._cmpt = cmpt | ||
self._keys = {} | ||
if cmpt.OUTPUT_KEYS: | ||
for key in cmpt.OUTPUT_KEYS: | ||
self._keys[key] = OutParam(key, cmpt) | ||
|
||
|
||
class Inputs(InOuts): | ||
def __init__(self, cmpt): | ||
self._cmpt = cmpt | ||
self._keys = {} | ||
if cmpt.INPUT_KEYS: | ||
for key in cmpt.INPUT_KEYS: | ||
self._keys[key] = InParam(key, cmpt) | ||
|
||
def __iter__(self): | ||
for in_param, out_param in self._cmpt.dependencies: | ||
out_param_str = f"{out_param.cmpt.name}.{out_param.name}" | ||
yield f"{in_param.name}", f"{out_param_str}" | ||
|
||
# def __repr__(self): | ||
# _str = "" | ||
# for in_param, out_param in self._dependencies: | ||
# out_param_str = f"{out_param.cmpt.name}.{out_param.name}" | ||
# _str += f"{in_param.cmpt.name}.{in_param.name}: {out_param_str}\t" | ||
# return _str | ||
|
||
|
||
# class Dependencies: | ||
# def __init__(self): | ||
# pass | ||
|
||
# def add(self, ) | ||
|
||
|
||
class BaseComponent(ABC): | ||
|
||
INPUT_KEYS = None | ||
OUTPUT_KEYS = None | ||
|
||
def __init__(self): | ||
self.name = getattr(self, "NAME", self.__class__.__name__) | ||
self.inputs = Inputs(self) | ||
self.outputs = Outputs(self) | ||
self.dependencies = [] | ||
|
||
if INFER_BENCHMARK: | ||
self.timer = Timer() | ||
self.apply = self.timer.watch_func(self.apply) | ||
|
||
def set_dep(self, in_param, out_param): | ||
self.dependencies.append((in_param, out_param)) | ||
|
||
@classmethod | ||
def get_input_keys(cls) -> list: | ||
return cls.INPUT_KEYS | ||
|
||
@classmethod | ||
def get_output_keys(cls) -> list: | ||
return cls.OUTPUT_KEYS | ||
|
||
@abstractmethod | ||
def __call__(self, batch_data): | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def apply(self, input): | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .base_predictor import BasePredictor | ||
from .basic_predictor import BasicPredictor |
Oops, something went wrong.