Skip to content

Commit

Permalink
[associative_scan] Fixing shape checks (pytorch#141698)
Browse files Browse the repository at this point in the history
This PR fixes the shape checks that are done in the associative_scan operation.
Before all shapes of the input leaves were required to be the same. With this PR only the shapes of the output of the combine_fn and the input leaves need to be the same, but not among the input leaves.

Pull Request resolved: pytorch#141698
Approved by: https://github.com/ydwu4
  • Loading branch information
bohnstingl authored and pytorchmergebot committed Dec 3, 2024
1 parent 3ab4a28 commit 871b93b
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 8 deletions.
89 changes: 88 additions & 1 deletion test/functorch/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,20 @@ def mul(x: torch.Tensor, y: torch.Tensor):
def div(x: torch.Tensor, y: torch.Tensor):
return x / y

def s5_operator(x, y):
def s5_operator(x: torch.Tensor, y: torch.Tensor):
A_i, Bu_i = x
A_j, Bu_j = y
return A_j * A_i, A_j * Bu_i + Bu_j

def different_input_size_operator(x: torch.Tensor, y: torch.Tensor):
x_o, dA_o, dB_o, C_o, y_o = x
x_n, dA_n, dB_n, C_n, y_n = y

x_new = x_n + x_o
y_new = torch.einsum("bdn,bn->bd", x_new, C_n)

return x_new, dA_n + 0.0, dB_n + 0.0, C_n + 0.0, y_new

def tuple_fct(x, y):
return (x[0] + y[0], x[1] * y[1])

Expand All @@ -144,6 +153,8 @@ def non_pointwise(x: torch.Tensor, y: torch.Tensor):
fct = div
elif name == "s5_operator":
fct = s5_operator
elif name == "different_input_size_operator":
fct = different_input_size_operator
elif name == "tuple_fct":
fct = tuple_fct
elif name == "complex_pointwise":
Expand Down Expand Up @@ -3253,6 +3264,82 @@ def test_associative_scan_binary_operator(
inputs=elements,
)

@skipIfRocm(msg="Unsupported on ROCM yet")
@unittest.skipIf(not SM70OrLater, "triton")
@requires_cuda
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
@parametrize("reverse", [False, True])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
def test_associative_scan_different_input_size(self, compile_mode, reverse, device):
batch = 5
hidden_dim = 3
length = 10
dstate = 7

deltaA = torch.randn(
(batch, hidden_dim, length, dstate), requires_grad=True, device=device
)
deltaB_u = torch.randn(
(batch, hidden_dim, length, dstate), requires_grad=True, device=device
)
C = torch.randn((batch, dstate, length), requires_grad=True, device=device)
x = torch.randn(
(batch, hidden_dim, length, dstate), requires_grad=True, device=device
)
y = torch.randn((batch, hidden_dim, length), requires_grad=True, device=device)
elements = (x, deltaA, deltaB_u, C, y)

kwargs = {
"dim": 2,
"reverse": reverse,
"compile_mode": compile_mode,
"combine_fn": get_scan_combine_fn("different_input_size_operator", True),
"combine_mode": "generic",
}
kwargs_fake = self._prepare_fake_kwargs(kwargs)
self._run_test(
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
inputs=elements,
)

@unittest.skipIf(not SM70OrLater, "triton")
@requires_cuda
def test_associative_scan_different_input_size_wrong_dim(self):
batch = 5
hidden_dim = 3
length = 10
dstate = 7

deltaA = torch.randn(
(batch, hidden_dim, length, dstate), device=torch.device("cuda")
)
deltaB_u = torch.randn(
(batch, hidden_dim, length, dstate), device=torch.device("cuda")
)
C = torch.randn((batch, dstate, length), device=torch.device("cuda"))
x = torch.randn(
(batch, hidden_dim, length, dstate), device=torch.device("cuda")
)
y = torch.randn(
(batch, hidden_dim, length, dstate), device=torch.device("cuda")
)
elements = (x, deltaA, deltaB_u, C, y)

with self.assertRaisesRegex(
# Should be
# ValueError,
# "All xs leaves must at least have 'dim' number of dimensions and scan dimension > 0"
torch._dynamo.exc.Unsupported,
"Observed exception.*",
):
out = associative_scan(
get_scan_combine_fn("different_input_size_operator", True),
elements,
3,
combine_mode="pointwise",
)

@unittest.skipIf(not SM70OrLater, "triton")
@requires_cuda
def test_associative_scan_sparse_tensor(self):
Expand Down
9 changes: 2 additions & 7 deletions torch/_higher_order_ops/associative_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def add(x: torch.Tensor, y: torch.Tensor):
raise ValueError("xs leaves must be a Tensor")
if any(x.is_sparse for x in leaves):
raise ValueError("xs leaves must dense Tensors, consider using `to_dense()`")
if any(x.ndim < dim for x in leaves):
if any(x.ndim <= dim for x in leaves):
raise ValueError(
"All xs leaves must at least have 'dim' number of dimensions and scan dimension > 0"
)
Expand All @@ -166,15 +166,10 @@ def add(x: torch.Tensor, y: torch.Tensor):

ndim = leaves[0].ndim
dim = utils.canonicalize_dim(ndim, dim)
shape = leaves[0].shape

for x in leaves[1:]:
assert x.shape == shape, "All xs tensors must have the same shape"

# Call the combine_fn with only a slice along the scan dim
# and check whether the output leaves have the same slice dimensions
sliced_leaves = [first_slice_copy(leaf, dim) for leaf in leaves]
sliced_shape = sliced_leaves[0].shape

out = combine_fn(
pytree.tree_unflatten(sliced_leaves, spec),
Expand All @@ -186,7 +181,7 @@ def add(x: torch.Tensor, y: torch.Tensor):
"The number of leaves of the pytree of the output of the operator needs to match the length of the pytree of the input"
)
if any(
x.shape != sliced_shape
x.shape != x_sliced.shape
or x.dtype != x_sliced.dtype
or x.device != x_sliced.device
or x.stride() != x_sliced.stride()
Expand Down

0 comments on commit 871b93b

Please sign in to comment.