Skip to content

Commit

Permalink
add cnet start / end, make them a percentage of num_inference_steps a…
Browse files Browse the repository at this point in the history
…s opposed to using value of (potentially nonlinear) timesteps
  • Loading branch information
ErwannMillon committed Aug 11, 2023
1 parent 2f8cd5f commit 27c29b3
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
8 changes: 6 additions & 2 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,8 @@ def __call__(
guess_mode: bool = False,
img2img_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
img2img_strength: float = 1.0,
controlnet_strength: float = 1.0,
controlnet_start: float = 1.0,
controlnet_end: float = 0.0,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -1082,7 +1083,9 @@ def __call__(
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])

# predict the noise residual
if t >= controlnet_strength * 1000:
# if t <= controlnet_start * 1000:
if i / len(timesteps) >= controlnet_start and i / len(timesteps) <= controlnet_end:
print("using controlnet", t)
noise_pred = self.unet(
latent_model_input,
t,
Expand All @@ -1092,6 +1095,7 @@ def __call__(
mid_block_additional_residual=mid_block_res_sample,
).sample
else:
print("not using controlnet", t)
noise_pred = self.unet(
latent_model_input,
t,
Expand Down
10 changes: 8 additions & 2 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,9 @@ def __call__(
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
controlnet_strength: float = 1.0,
controlnet_end: float = 0.0,
controlnet_start: float = 1.,
# controlnet_strength: float = None, #for legacy support with old sdxl
original_size: Tuple[int, int] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Tuple[int, int] = None,
Expand Down Expand Up @@ -896,6 +898,8 @@ def __call__(
control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
control_guidance_end
]
print(f"controlnet_end = {controlnet_end}")
print(f"controlnet_start = {controlnet_start}")

# 1. Check inputs. Raise error if not correct
self.check_inputs(
Expand Down Expand Up @@ -1112,7 +1116,9 @@ def __call__(

# predict the noise residual

if t >= controlnet_strength * 1000:
# if t <= controlnet_start * 1000 and t >= controlnet_end * 1000:
if i / len(timesteps) >= controlnet_start and i / len(timesteps) <= controlnet_end:
print("using cnet at timetsep" , t)
noise_pred = self.unet(
latent_model_input,
t,
Expand Down

0 comments on commit 27c29b3

Please sign in to comment.