diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 69718bda82f2..c61680fe97f5 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -829,7 +829,8 @@ def __call__( guess_mode: bool = False, img2img_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, img2img_strength: float = 1.0, - controlnet_strength: float = 1.0, + controlnet_start: float = 1.0, + controlnet_end: float = 0.0, ): r""" Function invoked when calling the pipeline for generation. @@ -1082,7 +1083,9 @@ def __call__( mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # predict the noise residual - if t >= controlnet_strength * 1000: + # if t <= controlnet_start * 1000: + if i / len(timesteps) >= controlnet_start and i / len(timesteps) <= controlnet_end: + print("using controlnet", t) noise_pred = self.unet( latent_model_input, t, @@ -1092,6 +1095,7 @@ def __call__( mid_block_additional_residual=mid_block_res_sample, ).sample else: + print("not using controlnet", t) noise_pred = self.unet( latent_model_input, t, diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index a1f7575ef2af..f9f61761f2a2 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -770,7 +770,9 @@ def __call__( guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, - controlnet_strength: float = 1.0, + controlnet_end: float = 0.0, + controlnet_start: float = 1., + # controlnet_strength: float = None, #for legacy support with old sdxl original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, @@ -896,6 +898,8 @@ def __call__( control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ control_guidance_end ] + print(f"controlnet_end = {controlnet_end}") + print(f"controlnet_start = {controlnet_start}") # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -1112,7 +1116,9 @@ def __call__( # predict the noise residual - if t >= controlnet_strength * 1000: + # if t <= controlnet_start * 1000 and t >= controlnet_end * 1000: + if i / len(timesteps) >= controlnet_start and i / len(timesteps) <= controlnet_end: + print("using cnet at timetsep" , t) noise_pred = self.unet( latent_model_input, t,