Skip to content

Commit

Permalink
Merge pull request #9 from mitkotak/pt2
Browse files Browse the repository at this point in the history
Adding `torch.compile` compatiblity
  • Loading branch information
laserkelvin authored Oct 24, 2024
2 parents 728b107 + 24bcaf2 commit ccfe80b
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
8 changes: 6 additions & 2 deletions scripts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
import numpy as np
import pandas as pd
import e3nn
from e3nn.o3._spherical_harmonics import _spherical_harmonics

from equitriton.sph_harm.bindings import *
Expand Down Expand Up @@ -82,11 +83,14 @@ def e3nn_benchmark(tensor_shape: list[int], device: str | torch.device, l_max: i
joint_tensor[..., 1].contiguous(),
joint_tensor[..., 2].contiguous(),
)
output = _spherical_harmonics(l_max, x, y, z)
e3nn.set_optimization_defaults(jit_script_fx=False)
output = torch.compile(_spherical_harmonics, fullgraph=True, mode="max-autotune")(l_max, x, y, z)
output.backward(gradient=torch.ones_like(output))
# delete references to ensure memory gets cleared
del output
del joint_tensor
e3nn.set_optimization_defaults(jit_script_fx=True) # Turn it back on to avoid any issues



@benchmark(num_steps=args.num_steps, warmup_fraction=args.warmup_fraction)
Expand Down Expand Up @@ -131,4 +135,4 @@ def triton_benchmark(tensor_shape: list[int], device: str | torch.device, l_max:
all_data.append(joint_results)

df = pd.DataFrame(all_data)
df.to_csv(f"{args.device}_lmax{args.l_max}_results.csv", index=False)
df.to_csv(f"{args.device}_lmax{args.l_max}_results.csv", index=False)
5 changes: 4 additions & 1 deletion scripts/measure_numerical_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
import numpy as np
import e3nn
from e3nn.o3._spherical_harmonics import _spherical_harmonics

from equitriton.sph_harm.bindings import *
Expand Down Expand Up @@ -73,7 +74,8 @@ def compare_e3nn_triton(
joint_tensor[..., 1].contiguous(),
joint_tensor[..., 2].contiguous(),
)
e3nn_output = _spherical_harmonics(l_max, x, y, z)
e3nn.set_optimization_defaults(jit_script_fx=False)
e3nn_output = torch.compile(_spherical_harmonics, fullgraph=True, mode="max-autotune")(l_max, x, y, z)
e3nn_output.backward(gradient=torch.ones_like(e3nn_output))
e3nn_grad = joint_tensor.grad.detach().clone()
joint_tensor.grad = None
Expand All @@ -95,6 +97,7 @@ def compare_e3nn_triton(
# delete intermediate tensors to make sure we don't leak
del e3nn_output
del triton_output
e3nn.set_optimization_defaults(jit_script_fx=True) # Turn it back on to avoid any issues
return (signed_fwd_error, signed_bwd_error)


Expand Down
5 changes: 4 additions & 1 deletion scripts/profile_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
from torch.profiler import record_function
import e3nn
from e3nn.o3._spherical_harmonics import _spherical_harmonics

from equitriton.sph_harm.bindings import *
Expand Down Expand Up @@ -74,13 +75,15 @@ def e3nn_benchmark(tensor_shape: list[int], device: str | torch.device, l_max: i
joint_tensor[..., 1].contiguous(),
joint_tensor[..., 2].contiguous(),
)
e3nn.set_optimization_defaults(jit_script_fx=False)
with record_function("forward"):
output = _spherical_harmonics(l_max, x, y, z)
output = torch.compile(_spherical_harmonics, fullgraph=True, mode="max-autotune")(l_max, x, y, z)
with record_function("backward"):
output.backward(gradient=torch.ones_like(output))
# delete references to ensure memory gets cleared
del output
del joint_tensor
e3nn.set_optimization_defaults(jit_script_fx=True) # Turn it back on to avoid any issues


@profile(
Expand Down

0 comments on commit ccfe80b

Please sign in to comment.