From 96ac64638d14e8c0b3211728c82093bb1536cbbc Mon Sep 17 00:00:00 2001 From: "Zhou, Huijuan" Date: Fri, 13 Dec 2024 02:06:42 -0800 Subject: [PATCH] fix bs and guidance_batches bug for pipeline_flux_img2img --- .../diffusers/pipelines/flux/pipeline_flux_img2img.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/optimum/habana/diffusers/pipelines/flux/pipeline_flux_img2img.py b/optimum/habana/diffusers/pipelines/flux/pipeline_flux_img2img.py index fda7e5ef7e..4cf3baea90 100644 --- a/optimum/habana/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/optimum/habana/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -382,10 +382,11 @@ def _split_inputs_into_batches(cls, batch_size, latents, prompt_embeds, pooled_p # Pad guidance if necessary if guidance is not None: + guidance_batches[-1]=guidance_batches[-1].unsqueeze(1) sequence_to_stack = (guidance_batches[-1],) + tuple( torch.zeros_like(guidance_batches[-1][0][None, :]) for _ in range(num_dummy_samples) ) - guidance_batches[-1] = torch.vstack(sequence_to_stack) + guidance_batches[-1] = torch.vstack(sequence_to_stack).squeeze(1) # Stack batches in the same tensor latents_batches = torch.stack(latents_batches) @@ -623,14 +624,14 @@ def __call__( f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." ) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + latent_timestep = timesteps[:1].repeat(num_prompts * num_images_per_prompt) # 6. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 latents, latent_image_ids = self.prepare_latents( init_image, latent_timestep, - batch_size * num_images_per_prompt, + num_prompts * num_images_per_prompt, num_channels_latents, height, width,