Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor handler_helper #286

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 47 additions & 13 deletions onnx_tf/common/handler_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,32 @@
from . import op_name_to_lower


class DomainHandlerDict(dict):

def __init__(self, domain, unknown_message="", failed_message=""):
self.unknown = {}
self.failed = {}
self._domain = domain
self._unknown_message = unknown_message
self._failed_message = failed_message

def _warn(self, k):
if k in self.unknown:
warnings.warn(
self._unknown_message.format(self._domain, self.unknown.pop(k)))
if k in self.failed:
warnings.warn(
self._failed_message.format(self._domain, k, self.failed.pop(k)))

def __getitem__(self, k):
self._warn(k)
return super(DomainHandlerDict, self).__getitem__(k)

def get(self, k, d=None):
self._warn(k)
return super(DomainHandlerDict, self).get(k, d)


def get_all_frontend_handlers(opset_dict):
""" Get a dict of all frontend handler classes.
e.g. {'domain': {'Abs': Abs handler class}, ...}, }.
Expand All @@ -23,23 +49,27 @@ def get_all_frontend_handlers(opset_dict):
domain = handler.DOMAIN
version = opset_dict[domain]
handler.VERSION = version
domain_handler_dict = handlers.setdefault(
domain,
DomainHandlerDict(
domain or "ai.onnx",
unknown_message="Unknown op {1} in domain `{0}`. "
"Can't check specification by ONNX. "
"Please set should_check flag to False "
"when call make_node method in handler."))

since_version = 1
if handler.ONNX_OP and defs.has(handler.ONNX_OP, domain=handler.DOMAIN):
since_version = defs.get_schema(
handler.ONNX_OP, domain=handler.DOMAIN,
max_inclusive_version=version).since_version
else:
warnings.warn("Unknown op {} in domain `{}`. "
"Can't check specification by ONNX. "
"Please set should_check flag to False "
"when call make_node method in handler.".format(
handler.ONNX_OP or "Undefined", handler.DOMAIN or
"ai.onnx"))
for tf_op in handler.TF_OP:
domain_handler_dict.unknown[tf_op] = handler.ONNX_OP or tf_op
handler.SINCE_VERSION = since_version

for tf_op in handler.TF_OP:
handlers.setdefault(domain, {})[tf_op] = handler
domain_handler_dict[tf_op] = handler
return handlers


Expand All @@ -57,6 +87,13 @@ def get_all_backend_handlers(opset_dict):
domain = handler.DOMAIN
version = opset_dict[domain]
handler.VERSION = version
domain_handler_dict = handlers.setdefault(
domain,
DomainHandlerDict(
domain or "ai.onnx",
failed_message="Fail to get since_version of {1} in domain `{0}` "
"with max_inclusive_version={2}. Set to 1.",
unknown_message="Unknown op {1} in domain `{0}`."))

since_version = 1
if defs.has(handler.ONNX_OP, domain=handler.DOMAIN):
Expand All @@ -66,14 +103,11 @@ def get_all_backend_handlers(opset_dict):
domain=handler.DOMAIN,
max_inclusive_version=version).since_version
except RuntimeError:
warnings.warn("Fail to get since_version of {} in domain `{}` "
"with max_inclusive_version={}. Set to 1.".format(
handler.ONNX_OP, handler.DOMAIN, version))
domain_handler_dict.failed[handler.ONNX_OP] = version
else:
warnings.warn("Unknown op {} in domain `{}`.".format(
handler.ONNX_OP, handler.DOMAIN or "ai.onnx"))
domain_handler_dict.unknown[handler.ONNX_OP] = handler.ONNX_OP
handler.SINCE_VERSION = since_version
handlers.setdefault(domain, {})[handler.ONNX_OP] = handler
domain_handler_dict[handler.ONNX_OP] = handler
return handlers


Expand Down