diff --git a/README.md b/README.md index 4a2cb4e..0c0383b 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ Neutral prompt is an a1111 webui extension that adds alternative composable diff ## Features +- Now compatible wih [stable-diffusion-webui-forge](https://github.com/lllyasviel/stable-diffusion-webui-forge)! - [Perp-Neg](https://perp-neg.github.io/) orthogonal prompts, invoked using the `AND_PERP` keyword - saliency-aware noise blending, invoked using the `AND_SALT` keyword (credits to [Magic Fusion](https://magicfusion.github.io/) for the algorithm used to determine SNB maps from epsilons) - semantic guidance top-k filtering, invoked using the `AND_TOPK` keyword (reference: https://arxiv.org/abs/2301.12247) diff --git a/lib_neutral_prompt/cfg_denoiser_hijack.py b/lib_neutral_prompt/cfg_denoiser_hijack.py index a8605ab..0de579d 100644 --- a/lib_neutral_prompt/cfg_denoiser_hijack.py +++ b/lib_neutral_prompt/cfg_denoiser_hijack.py @@ -7,6 +7,8 @@ import sys import textwrap +from modules_forge.forge_sampler import cond_from_a1111_to_patched_ldm + def combine_denoised_hijack( x_out: torch.Tensor, @@ -49,8 +51,7 @@ def get_webui_denoised( sliced_batch_cond_indices.append(sliced_cond_indices) sliced_batch_x_out.extend(sliced_x_out) - sliced_batch_x_out += list(uncond) - sliced_batch_x_out = torch.stack(sliced_batch_x_out, dim=0) + sliced_batch_x_out = torch.stack(sliced_batch_x_out + list(uncond), dim=0) return original_function(sliced_batch_x_out, sliced_batch_cond_indices, text_uncond, cond_scale) @@ -207,32 +208,169 @@ def get_salience(vector: torch.Tensor) -> torch.Tensor: def filter_abs_top_k(vector: torch.Tensor, k_ratio: float) -> torch.Tensor: - k = int(torch.numel(vector) * (1 - k_ratio)) + k = int(vector.numel() * (1 - k_ratio)) top_k, _ = torch.kthvalue(torch.abs(torch.flatten(vector)), k) - return vector * (torch.abs(vector) >= top_k).to(vector.dtype) + return vector * (vector.abs() >= top_k).to(vector.dtype) -sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get( - module=sd_samplers, - hijacker_attribute='__neutral_prompt_hijacker', - on_uninstall=script_callbacks.on_script_unloaded, -) +try: + from ldm_patched.modules import samplers + from modules_forge import forge_sampler + forge = True +except ImportError: + forge = False -@sd_samplers_hijacker.hijack('create_sampler') -def create_sampler_hijack(name: str, model, original_function): - sampler = original_function(name, model) - if not hasattr(sampler, 'model_wrap_cfg') or not hasattr(sampler.model_wrap_cfg, 'combine_denoised'): - if global_state.is_enabled: - warn_unsupported_sampler() +if forge: + forge_sampler_hijacker = hijacker.ModuleHijacker.install_or_get( + module=forge_sampler, + hijacker_attribute='__forge_sampler_hijacker', + on_uninstall=script_callbacks.on_script_unloaded, + ) + samplers_hijacker = hijacker.ModuleHijacker.install_or_get( + module=samplers, + hijacker_attribute='__samplers_hijacker', + on_uninstall=script_callbacks.on_script_unloaded, + ) + + + @forge_sampler_hijacker.hijack('forge_sample') + def forge_sample(self, denoiser_params, cond_scale, cond_composition, original_function): + if not global_state.is_enabled: + return original_function(self, denoiser_params, cond_scale, cond_composition) + + self.inner_model.inner_model.forge_objects.unet.model_options['cond_composition'] = cond_composition + self.inner_model.inner_model.forge_objects.unet.model_options['uncond'] = cond_from_a1111_to_patched_ldm(denoiser_params.text_uncond) + + return original_function(self, denoiser_params, cond_scale, cond_composition) - return sampler - sampler.model_wrap_cfg.combine_denoised = functools.partial( - combine_denoised_hijack, - original_function=sampler.model_wrap_cfg.combine_denoised + def sampling_function_hijack(model, x, timestep, uncond, cond, cond_scale, model_options, seed, original_function): + if not global_state.is_enabled or not global_state.prompt_exprs: + return original_function(model, x, timestep, uncond, cond, cond_scale, model_options, seed) + + prompt = global_state.prompt_exprs[0] + original_strengths, new_strengths = prompt.accept(ForgeStrengthOverride(), cond, 0, False) + + model_options['neutral_prompt_override'] = True + model_options['original_strengths'] = original_strengths + return original_function(model, x, timestep, uncond, cond, cond_scale, model_options, seed) + + + class ForgeStrengthOverride: + def visit_leaf_prompt( + self, + that: neutral_prompt_parser.LeafPrompt, + cond: List[dict], + index: int, + is_parent_aux: bool, + ) -> tuple: + original_strength = cond[index].get('strength', 1.0) + new_strength = original_strength * float(not is_parent_aux and that.conciliation is None) + cond[index]['strength'] = new_strength + return [original_strength], [new_strength] + + def visit_composite_prompt( + self, + that: neutral_prompt_parser.CompositePrompt, + cond: List[dict], + index: int, + is_parent_aux: bool, + ) -> tuple: + original_strengths = [] + new_strengths = [] + + for child in that.children: + child_original_strengths, child_new_strengths = child.accept(ForgeStrengthOverride(), cond, index, is_parent_aux or that.conciliation is not None) + original_strengths.extend(child_original_strengths) + new_strengths.extend(child_new_strengths) + + index += child.accept(neutral_prompt_parser.FlatSizeVisitor()) + + return original_strengths, new_strengths + + + samplers_hijacker.hijack('sampling_function')(sampling_function_hijack) + forge_sampler_hijacker.hijack('sampling_function')(sampling_function_hijack) + + + @samplers_hijacker.hijack('calc_cond_uncond_batch') + def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options, original_function): + if not global_state.is_enabled or not 'neutral_prompt_override' in model_options.keys(): + return original_function(model, cond, uncond, x_in, timestep, model_options) + + cond_composition = model_options['cond_composition'] + original_strengths = model_options['original_strengths'] + if uncond is None: + uncond = model_options['uncond'] + + for i in range(len(cond)): + cond[i]['strength'] = original_strengths[i] + + cond = cond.copy() + cond.extend(uncond) + discard_last = (len(cond) % 2) == 1 + if discard_last: + cond.append(cond[-1]) + + for i in range(len(cond)): + cond[i] = cond[i].copy() + cond[i]['strength'] = 1.0 + + denoised_latents = [ + denoised + for first_cond, second_cond in zip(cond[::2], cond[1::2]) + for denoised in original_function(model, [first_cond], [second_cond], x_in, timestep, model_options) + ] + + if discard_last: + denoised_latents = denoised_latents[:-1] + + # B, C, H, W + denoised_uncond = denoised_latents[-1] + + # N, B, C, H, W + denoised_in_conds = torch.stack(denoised_latents[:-1], dim=0) + denoised_in_conds = denoised_in_conds.transpose(0, 1).reshape(-1, *denoised_in_conds.shape[2:]) + + # N, 1, 1, 1, 1 + denoised_cond = denoised_uncond.clone() + for batch_i in range(denoised_uncond.shape[0]): + prompt = global_state.prompt_exprs[batch_i] + args = CombineDenoiseArgs(denoised_in_conds, denoised_uncond[batch_i], cond_composition[batch_i]) + cond_delta = prompt.accept(CondDeltaVisitor(), args, 0) + aux_cond_delta = prompt.accept(AuxCondDeltaVisitor(), args, cond_delta, 0) + denoised_cond[batch_i] += cond_delta + aux_cond_delta + + # consume 'neutral_prompt_override' before returning, in case another extension calls the method + # outside of CFG sampling; for example: extensions-builtin/sd_forge_sag + del model_options['neutral_prompt_override'] + + return denoised_cond, denoised_uncond +else: + sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get( + module=sd_samplers, + hijacker_attribute='__neutral_prompt_hijacker', + on_uninstall=script_callbacks.on_script_unloaded, ) - return sampler + + + @sd_samplers_hijacker.hijack('create_sampler') + def create_sampler_hijack(name: str, model, original_function): + sampler = original_function(name, model) + + if not hasattr(sampler, 'model_wrap_cfg') or not hasattr(sampler.model_wrap_cfg, 'combine_denoised'): + if global_state.is_enabled: + warn_unsupported_sampler() + + return sampler + + sampler.model_wrap_cfg.combine_denoised = functools.partial( + combine_denoised_hijack, + original_function=sampler.model_wrap_cfg.combine_denoised + ) + + return sampler def warn_unsupported_sampler(): diff --git a/lib_neutral_prompt/ui.py b/lib_neutral_prompt/ui.py index e3a452b..c10a85b 100644 --- a/lib_neutral_prompt/ui.py +++ b/lib_neutral_prompt/ui.py @@ -4,6 +4,14 @@ import gradio as gr import dataclasses +try: + # import a forge-specific module + from modules_forge import forge_sampler + del forge_sampler + forge = True +except ImportError: + forge = False + txt2img_prompt_textbox = None img2img_prompt_textbox = None @@ -29,7 +37,7 @@ class AccordionInterface: def __post_init__(self): self.is_rendered = False - self.cfg_rescale = gr.Slider(label='CFG rescale', minimum=0, maximum=1, value=0) + self.cfg_rescale = gr.Slider(label='CFG rescale', minimum=0, maximum=1, value=0, visible=not forge, interactive=not forge) self.neutral_prompt = gr.Textbox(label='Neutral prompt', show_label=False, lines=3, placeholder='Neutral prompt (click on apply below to append this to the positive prompt textbox)') self.neutral_cond_scale = gr.Slider(label='Prompt weight', minimum=-3, maximum=3, value=1) self.aux_prompt_type = gr.Dropdown(label='Prompt type', choices=list(prompt_types.keys()), value=next(iter(prompt_types.keys())), tooltip=prompt_types_tooltip, elem_id=self.get_elem_id('formatter_prompt_type'))