From 9917a887f50a31663ddaa6dc944a28810f727db7 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 30 Oct 2024 10:14:07 -0700 Subject: [PATCH] maintain same behavior as an mlp when no batch dim --- setup.py | 2 +- x_transformers/neo_mlp.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 5a989f5a..f33234e0 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.42.0', + version = '1.42.2', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/x_transformers/neo_mlp.py b/x_transformers/neo_mlp.py index 687a1804..5e882467 100644 --- a/x_transformers/neo_mlp.py +++ b/x_transformers/neo_mlp.py @@ -94,6 +94,11 @@ def forward( x, return_embeds = False ): + no_batch = x.ndim == 1 + + if no_batch: + x = rearrange(x, '... -> 1 ...') + batch = x.shape[0] fouriered_input = self.random_fourier(x) @@ -121,6 +126,9 @@ def forward( output = einsum(output_embed, self.to_output_weights, 'b n d, n d -> b n') output = output + self.to_output_bias + if no_batch: + output = rearrange(output, '1 ... -> ...') + if not return_embeds: return output