Skip to content

Commit

Permalink
maintain same behavior as an mlp when no batch dim
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 30, 2024
1 parent 1ccccaa commit 9917a88
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.42.0',
version = '1.42.2',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
8 changes: 8 additions & 0 deletions x_transformers/neo_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ def forward(
x,
return_embeds = False
):
no_batch = x.ndim == 1

if no_batch:
x = rearrange(x, '... -> 1 ...')

batch = x.shape[0]

fouriered_input = self.random_fourier(x)
Expand Down Expand Up @@ -121,6 +126,9 @@ def forward(
output = einsum(output_embed, self.to_output_weights, 'b n d, n d -> b n')
output = output + self.to_output_bias

if no_batch:
output = rearrange(output, '1 ... -> ...')

if not return_embeds:
return output

Expand Down

0 comments on commit 9917a88

Please sign in to comment.