diff --git a/py/torch_migraphx/fx/tracer/acc_tracer/acc_tracer.py b/py/torch_migraphx/fx/tracer/acc_tracer/acc_tracer.py index 2f8213cc..b7ad9618 100644 --- a/py/torch_migraphx/fx/tracer/acc_tracer/acc_tracer.py +++ b/py/torch_migraphx/fx/tracer/acc_tracer/acc_tracer.py @@ -386,7 +386,7 @@ def _rewrite( # functions that are attrs of this moodule. Return the new, rewritten module # hierarchy. def rewrite_module(m: nn.Module): - if isinstance(m, jit.ScriptModule): + if isinstance(m, jit.ScriptModule) or m is None: # ScriptModule cannot be rewritten, so bypass it. The issue is it # requires explicitly calling its `__init__()`, calling # `nn.Module.__init__()` in the derived `RewrittenModule` is not diff --git a/tests/fx/models/test_lowering_fx.py b/tests/fx/models/test_lowering_fx.py index 66e16b18..8883312e 100644 --- a/tests/fx/models/test_lowering_fx.py +++ b/tests/fx/models/test_lowering_fx.py @@ -15,7 +15,8 @@ (models.alexnet(), DEFAULT_RTOL, DEFAULT_ATOL), (models.densenet169(), DEFAULT_RTOL, DEFAULT_ATOL), (models.efficientnet_b6(), DEFAULT_RTOL, DEFAULT_ATOL), - (models.googlenet(), DEFAULT_RTOL, DEFAULT_ATOL), + (models.googlenet(weights=models.GoogLeNet_Weights.IMAGENET1K_V1), + DEFAULT_RTOL, DEFAULT_ATOL), (models.mnasnet1_0(), DEFAULT_RTOL, DEFAULT_ATOL), (models.mobilenet_v2(), DEFAULT_RTOL, DEFAULT_ATOL), (models.mobilenet_v3_large(), DEFAULT_RTOL, DEFAULT_ATOL),