Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00-pl committed Dec 6, 2024
1 parent 40d63c8 commit e905043
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 11 deletions.
1 change: 0 additions & 1 deletion tests/test_torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,4 @@ def test_torch_compile_forward():
def test_torch_compile_backward():
model = SimpleNN()
compiled_model = torch.compile(model, backend="aot_eager")

check_torch_compile_backward(model, compiled_model)
7 changes: 4 additions & 3 deletions tests/test_torch_compile_aot_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from tests.module_pool.simple_nn import SimpleNN


@make_boxed_compiler
def custom_compiler(gm: torch.fx.GraphModule, example_inputs):
print("Using custom compiler!")
gm.graph.print_tabular()
Expand All @@ -15,11 +14,13 @@ def custom_compiler(gm: torch.fx.GraphModule, example_inputs):
return gm.forward


def test_torch_dump_compile_backward():
def test_torch_dump_compile():
# 初始化模型、损失函数和优化器
model = SimpleNN()
aot_backend = aot_autograd(fw_compiler=custom_compiler, bw_compiler=custom_compiler) # 在backward时也使用自定义编译器
boxed_compiler = make_boxed_compiler(custom_compiler) # 使用boxed_compiler包装自定义编译器, 解决aot_autograd的内存释放问题.
aot_backend = aot_autograd(fw_compiler=boxed_compiler, bw_compiler=boxed_compiler) # 在backward时也使用自定义编译器
compiled_model = torch.compile(model, backend=aot_backend)

criterion = nn.MSELoss()
optimizer = optim.SGD(compiled_model.parameters(), lr=0.01)

Expand Down
17 changes: 10 additions & 7 deletions tests/test_torch_compile_mutiple_output.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import torch
from torch._dynamo.backends.common import aot_autograd
from torch._functorch._aot_autograd.utils import make_boxed_compiler
from torch._functorch.aot_autograd import make_boxed_compiler

from plai.pl_torch_compiler import dummy_compiler, dump_compiler, plnn_compiler
from tests.module_pool.simple_nn import SimpleNN, check_torch_compile_forward, check_torch_compile_backward
from plai.pl_torch_compiler import plnn_compiler


def test_torch_plnn_compile_mutiple_output():
model = lambda x: torch.max(x, dim=0)[0]
def test_torch_plnn_compile_multiple_output():
model = (lambda x: torch.max(x, dim=0))
custom_compiler = plnn_compiler.CustomCompiler()
aot_backend = aot_autograd(fw_compiler=make_boxed_compiler(custom_compiler), bw_compiler=None)
aot_backend = aot_autograd(fw_compiler=make_boxed_compiler(custom_compiler))
compiled_model = torch.compile(model, backend=aot_backend)
check_torch_compile_forward(model, compiled_model)
input_data = torch.randn(1, 10)
expected_output = model(input_data)
actual_output = compiled_model(input_data)
assert torch.allclose(expected_output.values, actual_output.values)
assert torch.allclose(expected_output.indices, actual_output.indices)
print('dump compile forward:')
print(custom_compiler.graph)

0 comments on commit e905043

Please sign in to comment.