Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] upgrade model inference #2456

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion paddlex/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .models import create_predictor
from ..utils.flags import NEW_PREDICTOR

if NEW_PREDICTOR:
from .new_models import create_predictor
else:
from .models import create_predictor
from .pipelines import create_pipeline
from .utils.pp_option import PaddlePredictorOption
108 changes: 108 additions & 0 deletions paddlex/inference/new_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from pathlib import Path
from typing import Any, Dict, Optional

from ...utils import errors
from ..utils.official_models import official_models
from .base import BasePredictor, BasicPredictor

from .image_classification import ClasPredictor
from .text_detection import TextDetPredictor
from .text_recognition import TextRecPredictor

# from .table_recognition import TablePredictor
# from .object_detection import DetPredictor
# from .instance_segmentation import InstanceSegPredictor
# from .semantic_segmentation import SegPredictor
# from .general_recognition import ShiTuRecPredictor
# from .ts_fc import TSFcPredictor
# from .ts_ad import TSAdPredictor
# from .ts_cls import TSClsPredictor
# from .image_unwarping import WarpPredictor
# from .multilabel_classification import MLClasPredictor
# from .anomaly_detection import UadPredictor
# from .formula_recognition import LaTeXOCRPredictor
# from .face_recognition import FaceRecPredictor


def _create_hp_predictor(
model_name, model_dir, device, config, hpi_params, *args, **kwargs
):
try:
from paddlex_hpi.models import HPPredictor
except ModuleNotFoundError:
raise RuntimeError(
"The PaddleX HPI plugin is not properly installed, and the high-performance model inference features are not available."
) from None
try:
predictor = HPPredictor.get(model_name)(
model_dir=model_dir,
config=config,
device=device,
*args,
hpi_params=hpi_params,
**kwargs,
)
except errors.others.ClassNotFoundException:
raise ValueError(
f"{model_name} is not supported by the PaddleX HPI plugin."
) from None
return predictor


def create_predictor(
model: str,
device=None,
pp_option=None,
use_hpip: bool = False,
hpi_params: Optional[Dict[str, Any]] = None,
*args,
**kwargs,
) -> BasePredictor:
model_dir = check_model(model)
config = BasePredictor.load_config(model_dir)
model_name = config["Global"]["model_name"]
if use_hpip:
return _create_hp_predictor(
model_name=model_name,
model_dir=model_dir,
config=config,
hpi_params=hpi_params,
device=device,
*args,
**kwargs,
)
else:
return BasicPredictor.get(model_name)(
model_dir=model_dir,
config=config,
device=device,
pp_option=pp_option,
*args,
**kwargs,
)


def check_model(model):
if Path(model).exists():
return Path(model)
elif model in official_models:
return official_models[model]
else:
raise Exception(
f"The model ({model}) is no exists! Please using directory of local model files or model name supported by PaddleX!"
)
19 changes: 19 additions & 0 deletions paddlex/inference/new_models/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .component import BaseComponent
from .transformer import BaseTransformer
from .predictor import BasePredictor, BasicPredictor
from .result import BaseResult, CVResult
from .batch_sampler import BaseBatchSampler, BatchData
16 changes: 16 additions & 0 deletions paddlex/inference/new_models/base/batch_sampler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .batch_sampler import BaseBatchSampler
from .batch_data import BatchData
38 changes: 38 additions & 0 deletions paddlex/inference/new_models/base/batch_sampler/batch_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


class BatchData(object):

def __init__(self, data, num):
self._data = data
self._num = num

@property
def num(self):
return self._num

def get_by_key(self, key):
assert key in self._data, f"{key}, {list(self._data.keys())}"
return self._data[key]

def get_by_idx(self, idx):
assert idx <= self.num
return {k: v[idx] for k, v in self._data.items()}

def update_by_key(self, output):
assert isinstance(output, dict)
for k, v in output.items():
assert isinstance(v, list) and len(v) == self.num
self._data[k] = v
48 changes: 48 additions & 0 deletions paddlex/inference/new_models/base/batch_sampler/batch_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod

from ..component import BaseComponent
from .batch_data import BatchData


class BaseBatchSampler(BaseComponent):

def __init__(self, batch_size=1):
self._batch_size = batch_size
super().__init__()

@property
def batch_size(self):
return self._batch_size

@batch_size.setter
def batch_size(self, bs):
assert bs > 0
self._batch_size = bs

def __call__(self, *args, **kwargs):
for batch, num in self.apply(*args, **kwargs):
yield BatchData(
{f"{self.name}.{k}": batch[0] for k in self.OUTPUT_KEYS}, num
)

@abstractmethod
def apply(self, *args, **kwargs):
raise NotImplementedError

# def set_outputs(self, outputs):
# assert isinstance(outputs, dict)
# self.outputs = outputs
127 changes: 127 additions & 0 deletions paddlex/inference/new_models/base/component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from abc import ABC, abstractmethod
from types import GeneratorType

from ....utils.flags import INFER_BENCHMARK
from ....utils import logging
from ...utils.benchmark import Timer


class Param:
def __init__(self, name, cmpt):
self.name = name
self.cmpt = cmpt

def __repr__(self):
return f"{self.cmpt.name}.{self.name}"


class OutParam(Param):
pass


class InParam(Param):
def fetch(self, param: OutParam):
self.cmpt.set_dep(self, param)


class InOuts:
def __getattr__(self, key):
if key in self._keys:
return self._keys.get(key)
raise AttributeError(
f"'{self._cmpt.name}.{self.__class__.__name__}' object has no attribute '{key}'"
)

def __repr__(self):
_str = ""
for key in self._keys:
param = self._keys[key]
_str += f"{param.cmpt.name}.{param.name}"
return _str


class Outputs(InOuts):
def __init__(self, cmpt):
self._cmpt = cmpt
self._keys = {}
if cmpt.OUTPUT_KEYS:
for key in cmpt.OUTPUT_KEYS:
self._keys[key] = OutParam(key, cmpt)


class Inputs(InOuts):
def __init__(self, cmpt):
self._cmpt = cmpt
self._keys = {}
if cmpt.INPUT_KEYS:
for key in cmpt.INPUT_KEYS:
self._keys[key] = InParam(key, cmpt)

def __iter__(self):
for in_param, out_param in self._cmpt.dependencies:
out_param_str = f"{out_param.cmpt.name}.{out_param.name}"
yield f"{in_param.name}", f"{out_param_str}"

# def __repr__(self):
# _str = ""
# for in_param, out_param in self._dependencies:
# out_param_str = f"{out_param.cmpt.name}.{out_param.name}"
# _str += f"{in_param.cmpt.name}.{in_param.name}: {out_param_str}\t"
# return _str


# class Dependencies:
# def __init__(self):
# pass

# def add(self, )


class BaseComponent(ABC):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

base component和pipeline中的component是否可复用?


INPUT_KEYS = None
OUTPUT_KEYS = None

def __init__(self):
self.name = getattr(self, "NAME", self.__class__.__name__)
self.inputs = Inputs(self)
self.outputs = Outputs(self)
self.dependencies = []

if INFER_BENCHMARK:
self.timer = Timer()
self.apply = self.timer.watch_func(self.apply)

def set_dep(self, in_param, out_param):
self.dependencies.append((in_param, out_param))

@classmethod
def get_input_keys(cls) -> list:
return cls.INPUT_KEYS

@classmethod
def get_output_keys(cls) -> list:
return cls.OUTPUT_KEYS

@abstractmethod
def __call__(self, batch_data):
raise NotImplementedError

@abstractmethod
def apply(self, input):
raise NotImplementedError
16 changes: 16 additions & 0 deletions paddlex/inference/new_models/base/predictor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base_predictor import BasePredictor
from .basic_predictor import BasicPredictor
Loading