diff --git a/setup.py b/setup.py index 00bea0c..b43756c 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'voicebox-pytorch', packages = find_packages(exclude=[]), - version = '0.2.2', + version = '0.2.3', license='MIT', description = 'Voicebox - Pytorch', author = 'Phil Wang', diff --git a/voicebox_pytorch/voicebox_pytorch.py b/voicebox_pytorch/voicebox_pytorch.py index 07f3d8c..d944eda 100644 --- a/voicebox_pytorch/voicebox_pytorch.py +++ b/voicebox_pytorch/voicebox_pytorch.py @@ -383,10 +383,12 @@ def forward( # rotary embeddings - main_positions = torch.arange(seq_len, device = self.device, dtype = torch.long) - register_positions = torch.arange(self.num_register_tokens, device = self.device, dtype = torch.long) - register_positions -= 10000 - positions = torch.cat((register_positions, main_positions)) + positions = seq_len + + if self.has_register_tokens: + main_positions = torch.arange(seq_len, device = self.device, dtype = torch.long) + register_positions = torch.full((self.num_register_tokens,), -10000, device = self.device, dtype = torch.long) + positions = torch.cat((register_positions, main_positions)) rotary_emb = self.rotary_emb(positions)