Skip to content

Commit

Permalink
upgrade inference model
Browse files Browse the repository at this point in the history
  • Loading branch information
TingquanGao committed Nov 28, 2024
1 parent c8767a7 commit fb8c28c
Show file tree
Hide file tree
Showing 41 changed files with 3,249 additions and 6 deletions.
7 changes: 6 additions & 1 deletion paddlex/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .models import create_predictor
from ..utils.flags import NEW_PREDICTOR

if NEW_PREDICTOR:
from .new_models import create_predictor
else:
from .models import create_predictor
from .pipelines import create_pipeline
from .utils.pp_option import PaddlePredictorOption
108 changes: 108 additions & 0 deletions paddlex/inference/new_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from pathlib import Path
from typing import Any, Dict, Optional

from ...utils import errors
from ..utils.official_models import official_models
from .base import BasePredictor, BasicPredictor

from .image_classification import ClasPredictor
from .text_detection import TextDetPredictor
from .text_recognition import TextRecPredictor

# from .table_recognition import TablePredictor
# from .object_detection import DetPredictor
# from .instance_segmentation import InstanceSegPredictor
# from .semantic_segmentation import SegPredictor
# from .general_recognition import ShiTuRecPredictor
# from .ts_fc import TSFcPredictor
# from .ts_ad import TSAdPredictor
# from .ts_cls import TSClsPredictor
# from .image_unwarping import WarpPredictor
# from .multilabel_classification import MLClasPredictor
# from .anomaly_detection import UadPredictor
# from .formula_recognition import LaTeXOCRPredictor
# from .face_recognition import FaceRecPredictor


def _create_hp_predictor(
model_name, model_dir, device, config, hpi_params, *args, **kwargs
):
try:
from paddlex_hpi.models import HPPredictor
except ModuleNotFoundError:
raise RuntimeError(
"The PaddleX HPI plugin is not properly installed, and the high-performance model inference features are not available."
) from None
try:
predictor = HPPredictor.get(model_name)(
model_dir=model_dir,
config=config,
device=device,
*args,
hpi_params=hpi_params,
**kwargs,
)
except errors.others.ClassNotFoundException:
raise ValueError(
f"{model_name} is not supported by the PaddleX HPI plugin."
) from None
return predictor


def create_predictor(
model: str,
device=None,
pp_option=None,
use_hpip: bool = False,
hpi_params: Optional[Dict[str, Any]] = None,
*args,
**kwargs,
) -> BasePredictor:
model_dir = check_model(model)
config = BasePredictor.load_config(model_dir)
model_name = config["Global"]["model_name"]
if use_hpip:
return _create_hp_predictor(
model_name=model_name,
model_dir=model_dir,
config=config,
hpi_params=hpi_params,
device=device,
*args,
**kwargs,
)
else:
return BasicPredictor.get(model_name)(
model_dir=model_dir,
config=config,
device=device,
pp_option=pp_option,
*args,
**kwargs,
)


def check_model(model):
if Path(model).exists():
return Path(model)
elif model in official_models:
return official_models[model]
else:
raise Exception(
f"The model ({model}) is no exists! Please using directory of local model files or model name supported by PaddleX!"
)
19 changes: 19 additions & 0 deletions paddlex/inference/new_models/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .component import BaseComponent
from .transformer import BaseTransformer
from .predictor import BasePredictor, BasicPredictor
from .result import BaseResult, CVResult
from .batch_sampler import BaseBatchSampler, BatchData
16 changes: 16 additions & 0 deletions paddlex/inference/new_models/base/batch_sampler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .batch_sampler import BaseBatchSampler
from .batch_data import BatchData
38 changes: 38 additions & 0 deletions paddlex/inference/new_models/base/batch_sampler/batch_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


class BatchData(object):

def __init__(self, data, num):
self._data = data
self._num = num

@property
def num(self):
return self._num

def get_by_key(self, key):
assert key in self._data, f"{key}, {list(self._data.keys())}"
return self._data[key]

def get_by_idx(self, idx):
assert idx <= self.num
return {k: v[idx] for k, v in self._data.items()}

def update_by_key(self, output):
assert isinstance(output, dict)
for k, v in output.items():
assert isinstance(v, list) and len(v) == self.num
self._data[k] = v
48 changes: 48 additions & 0 deletions paddlex/inference/new_models/base/batch_sampler/batch_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod

from ..component import BaseComponent
from .batch_data import BatchData


class BaseBatchSampler(BaseComponent):

def __init__(self, batch_size=1):
self._batch_size = batch_size
super().__init__()

@property
def batch_size(self):
return self._batch_size

@batch_size.setter
def batch_size(self, bs):
assert bs > 0
self._batch_size = bs

def __call__(self, *args, **kwargs):
for batch, num in self.apply(*args, **kwargs):
yield BatchData(
{f"{self.name}.{k}": batch[0] for k in self.OUTPUT_KEYS}, num
)

@abstractmethod
def apply(self, *args, **kwargs):
raise NotImplementedError

# def set_outputs(self, outputs):
# assert isinstance(outputs, dict)
# self.outputs = outputs
127 changes: 127 additions & 0 deletions paddlex/inference/new_models/base/component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from abc import ABC, abstractmethod
from types import GeneratorType

from ....utils.flags import INFER_BENCHMARK
from ....utils import logging
from ...utils.benchmark import Timer


class Param:
def __init__(self, name, cmpt):
self.name = name
self.cmpt = cmpt

def __repr__(self):
return f"{self.cmpt.name}.{self.name}"


class OutParam(Param):
pass


class InParam(Param):
def fetch(self, param: OutParam):
self.cmpt.set_dep(self, param)


class InOuts:
def __getattr__(self, key):
if key in self._keys:
return self._keys.get(key)
raise AttributeError(
f"'{self._cmpt.name}.{self.__class__.__name__}' object has no attribute '{key}'"
)

def __repr__(self):
_str = ""
for key in self._keys:
param = self._keys[key]
_str += f"{param.cmpt.name}.{param.name}"
return _str


class Outputs(InOuts):
def __init__(self, cmpt):
self._cmpt = cmpt
self._keys = {}
if cmpt.OUTPUT_KEYS:
for key in cmpt.OUTPUT_KEYS:
self._keys[key] = OutParam(key, cmpt)


class Inputs(InOuts):
def __init__(self, cmpt):
self._cmpt = cmpt
self._keys = {}
if cmpt.INPUT_KEYS:
for key in cmpt.INPUT_KEYS:
self._keys[key] = InParam(key, cmpt)

def __iter__(self):
for in_param, out_param in self._cmpt.dependencies:
out_param_str = f"{out_param.cmpt.name}.{out_param.name}"
yield f"{in_param.name}", f"{out_param_str}"

# def __repr__(self):
# _str = ""
# for in_param, out_param in self._dependencies:
# out_param_str = f"{out_param.cmpt.name}.{out_param.name}"
# _str += f"{in_param.cmpt.name}.{in_param.name}: {out_param_str}\t"
# return _str


# class Dependencies:
# def __init__(self):
# pass

# def add(self, )


class BaseComponent(ABC):

INPUT_KEYS = None
OUTPUT_KEYS = None

def __init__(self):
self.name = getattr(self, "NAME", self.__class__.__name__)
self.inputs = Inputs(self)
self.outputs = Outputs(self)
self.dependencies = []

if INFER_BENCHMARK:
self.timer = Timer()
self.apply = self.timer.watch_func(self.apply)

def set_dep(self, in_param, out_param):
self.dependencies.append((in_param, out_param))

@classmethod
def get_input_keys(cls) -> list:
return cls.INPUT_KEYS

@classmethod
def get_output_keys(cls) -> list:
return cls.OUTPUT_KEYS

@abstractmethod
def __call__(self, batch_data):
raise NotImplementedError

@abstractmethod
def apply(self, input):
raise NotImplementedError
16 changes: 16 additions & 0 deletions paddlex/inference/new_models/base/predictor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base_predictor import BasePredictor
from .basic_predictor import BasicPredictor
Loading

0 comments on commit fb8c28c

Please sign in to comment.