diff --git a/optimum/habana/diffusers/pipelines/flux/pipeline_flux.py b/optimum/habana/diffusers/pipelines/flux/pipeline_flux.py index f8afeaa559..dbf6a2be8e 100644 --- a/optimum/habana/diffusers/pipelines/flux/pipeline_flux.py +++ b/optimum/habana/diffusers/pipelines/flux/pipeline_flux.py @@ -367,10 +367,11 @@ def _split_inputs_into_batches(cls, batch_size, latents, prompt_embeds, pooled_p # Pad guidance if necessary if guidance is not None: + guidance_batches[-1]=guidance_batches[-1].unsqueeze(1) sequence_to_stack = (guidance_batches[-1],) + tuple( torch.zeros_like(guidance_batches[-1][0][None, :]) for _ in range(num_dummy_samples) ) - guidance_batches[-1] = torch.vstack(sequence_to_stack) + guidance_batches[-1] = torch.vstack(sequence_to_stack).squeeze(1) # Stack batches in the same tensor latents_batches = torch.stack(latents_batches)