From e90504395a6549d24e6a881c499f633c1071a49c Mon Sep 17 00:00:00 2001 From: 0x00-pl <0x00.pl@gmail.com> Date: Wed, 28 Aug 2024 00:33:49 +0800 Subject: [PATCH] update test --- tests/test_torch_compile.py | 1 - tests/test_torch_compile_aot_backward.py | 7 ++++--- tests/test_torch_compile_mutiple_output.py | 17 ++++++++++------- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/tests/test_torch_compile.py b/tests/test_torch_compile.py index 946b2e9..8f8a915 100644 --- a/tests/test_torch_compile.py +++ b/tests/test_torch_compile.py @@ -12,5 +12,4 @@ def test_torch_compile_forward(): def test_torch_compile_backward(): model = SimpleNN() compiled_model = torch.compile(model, backend="aot_eager") - check_torch_compile_backward(model, compiled_model) diff --git a/tests/test_torch_compile_aot_backward.py b/tests/test_torch_compile_aot_backward.py index afb1674..a182dd6 100644 --- a/tests/test_torch_compile_aot_backward.py +++ b/tests/test_torch_compile_aot_backward.py @@ -6,7 +6,6 @@ from tests.module_pool.simple_nn import SimpleNN -@make_boxed_compiler def custom_compiler(gm: torch.fx.GraphModule, example_inputs): print("Using custom compiler!") gm.graph.print_tabular() @@ -15,11 +14,13 @@ def custom_compiler(gm: torch.fx.GraphModule, example_inputs): return gm.forward -def test_torch_dump_compile_backward(): +def test_torch_dump_compile(): # 初始化模型、损失函数和优化器 model = SimpleNN() - aot_backend = aot_autograd(fw_compiler=custom_compiler, bw_compiler=custom_compiler) # 在backward时也使用自定义编译器 + boxed_compiler = make_boxed_compiler(custom_compiler) # 使用boxed_compiler包装自定义编译器, 解决aot_autograd的内存释放问题. + aot_backend = aot_autograd(fw_compiler=boxed_compiler, bw_compiler=boxed_compiler) # 在backward时也使用自定义编译器 compiled_model = torch.compile(model, backend=aot_backend) + criterion = nn.MSELoss() optimizer = optim.SGD(compiled_model.parameters(), lr=0.01) diff --git a/tests/test_torch_compile_mutiple_output.py b/tests/test_torch_compile_mutiple_output.py index bf78659..af6533b 100644 --- a/tests/test_torch_compile_mutiple_output.py +++ b/tests/test_torch_compile_mutiple_output.py @@ -1,16 +1,19 @@ import torch from torch._dynamo.backends.common import aot_autograd -from torch._functorch._aot_autograd.utils import make_boxed_compiler +from torch._functorch.aot_autograd import make_boxed_compiler -from plai.pl_torch_compiler import dummy_compiler, dump_compiler, plnn_compiler -from tests.module_pool.simple_nn import SimpleNN, check_torch_compile_forward, check_torch_compile_backward +from plai.pl_torch_compiler import plnn_compiler -def test_torch_plnn_compile_mutiple_output(): - model = lambda x: torch.max(x, dim=0)[0] +def test_torch_plnn_compile_multiple_output(): + model = (lambda x: torch.max(x, dim=0)) custom_compiler = plnn_compiler.CustomCompiler() - aot_backend = aot_autograd(fw_compiler=make_boxed_compiler(custom_compiler), bw_compiler=None) + aot_backend = aot_autograd(fw_compiler=make_boxed_compiler(custom_compiler)) compiled_model = torch.compile(model, backend=aot_backend) - check_torch_compile_forward(model, compiled_model) + input_data = torch.randn(1, 10) + expected_output = model(input_data) + actual_output = compiled_model(input_data) + assert torch.allclose(expected_output.values, actual_output.values) + assert torch.allclose(expected_output.indices, actual_output.indices) print('dump compile forward:') print(custom_compiler.graph)