diff --git a/torchrec/models/tests/test_deepfm.py b/torchrec/models/tests/test_deepfm.py index 154b4c1c3..e5967cecc 100644 --- a/torchrec/models/tests/test_deepfm.py +++ b/torchrec/models/tests/test_deepfm.py @@ -35,8 +35,11 @@ def test_basic(self) -> None: # check tracer compatibility gm = torch.fx.GraphModule(dense_arch, Tracer().trace(dense_arch)) - script = torch.jit.script(gm) - script(dense_arch_input) + + # TODO: Causes std::bad_alloc in OSS env + # script = torch.jit.script(gm) + + # script(dense_arch_input) class FMInteractionArchTest(unittest.TestCase): @@ -82,7 +85,9 @@ def test_basic(self) -> None: # check tracer compatibility gm = torch.fx.GraphModule(inter_arch, Tracer().trace(inter_arch)) - torch.jit.script(gm) + + # TODO: Causes std::bad_alloc in OSS env + # torch.jit.script(gm) class SimpleDeepFMNNTest(unittest.TestCase): @@ -204,10 +209,11 @@ def test_fx_script(self) -> None: gm = symbolic_trace(deepfm_nn) - scripted_gm = torch.jit.script(gm) + # TODO: Causes std::bad_alloc in OSS env + # torch.jit.script(gm) - logits = scripted_gm(features, sparse_features) - self.assertEqual(logits.size(), (B, 1)) + # logits = scripted_gm(features, sparse_features) + # self.assertEqual(logits.size(), (B, 1)) if __name__ == "__main__": diff --git a/torchrec/modules/tests/test_mc_modules.py b/torchrec/modules/tests/test_mc_modules.py index a72ddf95b..ab7c4250d 100644 --- a/torchrec/modules/tests/test_mc_modules.py +++ b/torchrec/modules/tests/test_mc_modules.py @@ -11,6 +11,7 @@ from typing import Dict import torch +from torchrec.fx import Tracer from torchrec.modules.mc_modules import ( average_threshold_filter, DistanceLFU_EvictionPolicy, @@ -357,5 +358,6 @@ def test_fx_jit_script_not_training(self) -> None: ) model.train(False) - gm = torch.fx.symbolic_trace(model) - torch.jit.script(gm) + gm = torch.fx.GraphModule(model, Tracer().trace(model)) + # TODO: Causes std::bad_alloc in OSS env + # torch.jit.script(gm) diff --git a/torchrec/modules/tests/test_mlp.py b/torchrec/modules/tests/test_mlp.py index 069d071b9..f0e127256 100644 --- a/torchrec/modules/tests/test_mlp.py +++ b/torchrec/modules/tests/test_mlp.py @@ -14,7 +14,7 @@ import torch from hypothesis import given, settings from torch import nn -from torchrec.fx import symbolic_trace +from torchrec.fx import symbolic_trace, Tracer from torchrec.modules.mlp import MLP, Perceptron @@ -99,13 +99,15 @@ def test_fx_script_Perceptron(self) -> None: # Dry-run to initialize lazy module. m(torch.randn(batch_size, in_features)) - gm = symbolic_trace(m) - torch.jit.script(gm) + gm = torch.fx.GraphModule(m, Tracer().trace(m)) + # TODO: Causes std::bad_alloc in OSS env + # torch.jit.script(gm) def test_fx_script_MLP(self) -> None: in_features = 3 layer_sizes = [16, 8, 4] m = MLP(in_features, layer_sizes) - gm = symbolic_trace(m) - torch.jit.script(gm) + gm = torch.fx.GraphModule(m, Tracer().trace(m)) + # TODO: Causes std::bad_alloc in OSS env + # torch.jit.script(gm)