Skip to content

Commit

Permalink
flatten splash attention axes
Browse files Browse the repository at this point in the history
  • Loading branch information
blahBlahhhJ committed May 12, 2024
1 parent d71e9ca commit 54f407e
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions src/levanter/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,10 +707,21 @@ def _tpu_splash_attention(
raise ValueError(f"Embedding axes must be the same for q, k, and v: {q_class['D']} != {k_class['D']}")

def _physical_axis_for_binning(d):
b_out = tuple(ax for ax in pspec_for_axis(d["B"]) if ax is not None) or None
h_out = tuple(ax for ax in pspec_for_axis(d["H"]) if ax is not None) or None
s_out = tuple(ax for ax in pspec_for_axis(d["S"]) if ax is not None) or None
d_out = tuple(ax for ax in pspec_for_axis(d["D"]) if ax is not None) or None
def flatten(axes):
if axes is None:
return axes
result = []
for ax in axes:
if isinstance(ax, tuple):
result += list(ax)
else:
result.append(ax)
return tuple(result)

b_out = flatten(tuple(ax for ax in pspec_for_axis(d["B"]) if ax is not None) or None)
h_out = flatten(tuple(ax for ax in pspec_for_axis(d["H"]) if ax is not None) or None)
s_out = flatten(tuple(ax for ax in pspec_for_axis(d["S"]) if ax is not None) or None)
d_out = flatten(tuple(ax for ax in pspec_for_axis(d["D"]) if ax is not None) or None)

return PartitionSpec(b_out, h_out, s_out, d_out)

Expand Down

0 comments on commit 54f407e

Please sign in to comment.