Skip to content

Commit

Permalink
fix bs and guidance_batches bug for pipeline_flux_img2img
Browse files Browse the repository at this point in the history
  • Loading branch information
huijuanzh committed Dec 13, 2024
1 parent 2ea5b2f commit 96ac646
Showing 1 changed file with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,11 @@ def _split_inputs_into_batches(cls, batch_size, latents, prompt_embeds, pooled_p

# Pad guidance if necessary
if guidance is not None:
guidance_batches[-1]=guidance_batches[-1].unsqueeze(1)
sequence_to_stack = (guidance_batches[-1],) + tuple(
torch.zeros_like(guidance_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
)
guidance_batches[-1] = torch.vstack(sequence_to_stack)
guidance_batches[-1] = torch.vstack(sequence_to_stack).squeeze(1)

# Stack batches in the same tensor
latents_batches = torch.stack(latents_batches)
Expand Down Expand Up @@ -623,14 +624,14 @@ def __call__(
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
latent_timestep = timesteps[:1].repeat(num_prompts * num_images_per_prompt)

# 6. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_image_ids = self.prepare_latents(
init_image,
latent_timestep,
batch_size * num_images_per_prompt,
num_prompts * num_images_per_prompt,
num_channels_latents,
height,
width,
Expand Down

0 comments on commit 96ac646

Please sign in to comment.