Skip to content

Commit

Permalink
fix googlenet test case for pyt2.5 (#214)
Browse files Browse the repository at this point in the history
  • Loading branch information
shivadbhavsar authored Nov 14, 2024
1 parent 8d95e63 commit 75fdc3e
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion py/torch_migraphx/fx/tracer/acc_tracer/acc_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def _rewrite(
# functions that are attrs of this moodule. Return the new, rewritten module
# hierarchy.
def rewrite_module(m: nn.Module):
if isinstance(m, jit.ScriptModule):
if isinstance(m, jit.ScriptModule) or m is None:
# ScriptModule cannot be rewritten, so bypass it. The issue is it
# requires explicitly calling its `__init__()`, calling
# `nn.Module.__init__()` in the derived `RewrittenModule` is not
Expand Down
3 changes: 2 additions & 1 deletion tests/fx/models/test_lowering_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
(models.alexnet(), DEFAULT_RTOL, DEFAULT_ATOL),
(models.densenet169(), DEFAULT_RTOL, DEFAULT_ATOL),
(models.efficientnet_b6(), DEFAULT_RTOL, DEFAULT_ATOL),
(models.googlenet(), DEFAULT_RTOL, DEFAULT_ATOL),
(models.googlenet(weights=models.GoogLeNet_Weights.IMAGENET1K_V1),
DEFAULT_RTOL, DEFAULT_ATOL),
(models.mnasnet1_0(), DEFAULT_RTOL, DEFAULT_ATOL),
(models.mobilenet_v2(), DEFAULT_RTOL, DEFAULT_ATOL),
(models.mobilenet_v3_large(), DEFAULT_RTOL, DEFAULT_ATOL),
Expand Down

0 comments on commit 75fdc3e

Please sign in to comment.