Skip to content

Commit

Permalink
fix FlUX.1_dev guidance_batches bug for pad case in _split_inputs_int…
Browse files Browse the repository at this point in the history
…o_batches
  • Loading branch information
huijuanzh committed Dec 13, 2024
1 parent f8b496a commit 2ea5b2f
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion optimum/habana/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,11 @@ def _split_inputs_into_batches(cls, batch_size, latents, prompt_embeds, pooled_p

# Pad guidance if necessary
if guidance is not None:
guidance_batches[-1]=guidance_batches[-1].unsqueeze(1)
sequence_to_stack = (guidance_batches[-1],) + tuple(
torch.zeros_like(guidance_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
)
guidance_batches[-1] = torch.vstack(sequence_to_stack)
guidance_batches[-1] = torch.vstack(sequence_to_stack).squeeze(1)

# Stack batches in the same tensor
latents_batches = torch.stack(latents_batches)
Expand Down

0 comments on commit 2ea5b2f

Please sign in to comment.