Skip to content

Commit

Permalink
Add support for exception list in EXIRATenDialectVerifierBase (pytorc…
Browse files Browse the repository at this point in the history
…h#3481)

Summary:
Pull Request resolved: pytorch#3481

Adding support for an exception list in EXIRATenDialectVerifierBase to support the implementation of `to_edge_transform_and_lower`. We'll pass in the list of ops that have been registered to not be decomposed into this verifier so that it skips these ops.

Reviewed By: larryliu0820

Differential Revision: D56560549

fbshipit-source-id: 88e7f52d8ba97b9caf11aac76fecfed4a0602217
  • Loading branch information
tarun292 authored and facebook-github-bot committed May 28, 2024
1 parent 9d4727d commit c3d1680
Showing 1 changed file with 44 additions and 22 deletions.
66 changes: 44 additions & 22 deletions exir/verification/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def _check_valid_dim_order_ops(op, use_dim_order) -> None:
class EXIRATenDialectVerifierBase(Verifier):
dialect = "OLD_EXIR_ATEN_DISABLED"

def __init__(
self, exception_list: Optional[List[torch._ops.OpOverload]] = None
) -> None:
super().__init__()
self._exception_list = exception_list if exception_list else []

def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
return (
torch.fx.GraphModule,
Expand All @@ -74,23 +80,33 @@ def __call__(self, *args, **kwargs):
class EXIRATenDialectVerifier(EXIRATenDialectVerifierBase):
dialect = "OLD_EXIR_ATEN"

def _get_exception_list(self) -> List[torch._ops.OpOverload]:
exception_list = [
torch.ops.aten.mkldnn_rnn_layer.default,
torch.ops.aten._upsample_bilinear2d_aa.default,
torch.ops.aten.quantize_per_tensor.default,
torch.ops.aten.dequantize.self,
torch.ops.aten.max.default, # TODO(T188268054)
torch.ops.aten.min.default, # TODO(T188268054)
torch.ops.aten.full_like.default, # TODO(T183507359)
]
exception_list += self._exception_list

return exception_list

def check_valid_op(self, op):
if isinstance(op, OpOverload):
# TODO These special ops should be removable easily.
if op.namespace in (
"quantized_decomposed",
"boltnn_nimble",
"nimble",
"quantized",
"dim_order_ops",
) or op in (
torch.ops.aten.mkldnn_rnn_layer.default,
torch.ops.aten._upsample_bilinear2d_aa.default,
torch.ops.aten.quantize_per_tensor.default,
torch.ops.aten.dequantize.self,
torch.ops.aten.max.default, # TODO(T188268054)
torch.ops.aten.min.default, # TODO(T188268054)
torch.ops.aten.full_like.default, # TODO(T183507359)
if (
op.namespace
in [
"quantized_decomposed",
"boltnn_nimble",
"nimble",
"quantized",
"dim_order_ops",
]
or op in self._get_exception_list()
):
return
if torch.Tag.core not in op.tags and torch.Tag.view_copy not in op.tags:
Expand Down Expand Up @@ -150,6 +166,7 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
def EXIREdgeDialectVerifier( # noqa: C901
edge_compile_config: Optional[EdgeCompileConfig] = None,
class_only: bool = False,
exception_list: Optional[List[torch._ops.OpOverload]] = None,
):
class _EXIREdgeDialectVerifier(Verifier):
dialect = "EDGE"
Expand All @@ -161,13 +178,14 @@ def __init__(self) -> None:
self.check_edge_ops = _edge_compile_config._use_edge_ops
self.use_dim_order = not _edge_compile_config._skip_dim_order

self.aten_op_verifier = EXIRATenDialectVerifier()
self.aten_op_verifier = EXIRATenDialectVerifier(exception_list)
self.check_valid_aten_op = self.aten_op_verifier.check_valid_op

if self.check_edge_ops:
self.check_valid_op = self.check_valid_edge_op
else:
self.check_valid_op = self.check_valid_aten_op
self._exception_list = exception_list if exception_list else []

def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
return (
Expand All @@ -183,13 +201,17 @@ def allowed_op_types(self):
def check_valid_edge_op(self, op):
if not self.enable:
return
if op in [
operator.getitem,
torch.ops.aten.sym_size.int,
torch.ops.aten.scalar_tensor.default,
torch.ops.aten._assert_async.msg,
torch.ops.aten._assert_scalar.default,
]:
if (
op
in [
operator.getitem,
torch.ops.aten.sym_size.int,
torch.ops.aten.scalar_tensor.default,
torch.ops.aten._assert_async.msg,
torch.ops.aten._assert_scalar.default,
]
+ self._exception_list
):
return

if isinstance(op, OpOverload) and not isinstance(op, EdgeOpOverload):
Expand Down

0 comments on commit c3d1680

Please sign in to comment.