From 916db4df68f534ddfbb9ac33486043f23f1f80f9 Mon Sep 17 00:00:00 2001 From: shivadbhavsar Date: Wed, 13 Nov 2024 20:48:10 -0600 Subject: [PATCH] fix googlenet test case for pyt2.5 --- py/torch_migraphx/fx/tracer/acc_tracer/acc_tracer.py | 2 +- tests/fx/models/test_lowering_fx.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/py/torch_migraphx/fx/tracer/acc_tracer/acc_tracer.py b/py/torch_migraphx/fx/tracer/acc_tracer/acc_tracer.py index 2f8213cc..b7ad9618 100644 --- a/py/torch_migraphx/fx/tracer/acc_tracer/acc_tracer.py +++ b/py/torch_migraphx/fx/tracer/acc_tracer/acc_tracer.py @@ -386,7 +386,7 @@ def _rewrite( # functions that are attrs of this moodule. Return the new, rewritten module # hierarchy. def rewrite_module(m: nn.Module): - if isinstance(m, jit.ScriptModule): + if isinstance(m, jit.ScriptModule) or m is None: # ScriptModule cannot be rewritten, so bypass it. The issue is it # requires explicitly calling its `__init__()`, calling # `nn.Module.__init__()` in the derived `RewrittenModule` is not diff --git a/tests/fx/models/test_lowering_fx.py b/tests/fx/models/test_lowering_fx.py index 66e16b18..8883312e 100644 --- a/tests/fx/models/test_lowering_fx.py +++ b/tests/fx/models/test_lowering_fx.py @@ -15,7 +15,8 @@ (models.alexnet(), DEFAULT_RTOL, DEFAULT_ATOL), (models.densenet169(), DEFAULT_RTOL, DEFAULT_ATOL), (models.efficientnet_b6(), DEFAULT_RTOL, DEFAULT_ATOL), - (models.googlenet(), DEFAULT_RTOL, DEFAULT_ATOL), + (models.googlenet(weights=models.GoogLeNet_Weights.IMAGENET1K_V1), + DEFAULT_RTOL, DEFAULT_ATOL), (models.mnasnet1_0(), DEFAULT_RTOL, DEFAULT_ATOL), (models.mobilenet_v2(), DEFAULT_RTOL, DEFAULT_ATOL), (models.mobilenet_v3_large(), DEFAULT_RTOL, DEFAULT_ATOL),