Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
MythicalChu authored Oct 4, 2024
0 parents commit 73f7c8b
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 0 deletions.
3 changes: 3 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS

__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
109 changes: 109 additions & 0 deletions nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import torch

try:
from comfy.model_patcher import ModelPatcher

# #BACKEND = "ComfyUI"
except ImportError:
try:
from ldm_patched.modules.model_patcher import ModelPatcher

# #BACKEND = "reForge"
except ImportError:
from backend.patcher.base import ModelPatcher

# #BACKEND = "Forge"

class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0

def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average

def project( v0: torch.Tensor, v1: torch.Tensor,):
dtype = v0.dtype
castToCpu=["privateuseone:0"]
device = v0.device # .double() causes problems on DML, on the line v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3]) with "parameter error" probably because v1 gets corruped on the .double() attempt
if device in castToCpu:
v0 = v0.to("cpu")
v1 = v1.to("cpu")
v0, v1 = v0.double(), v1.double()
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel

#return v0_parallel.to(dtype), v0_orthogonal.to(dtype)
if device in castToCpu:
v0_parallel = v0_parallel.to(device)
v0_orthogonal = v0_orthogonal.to(device)
return v0_parallel.to(dtype), v0_orthogonal.to(dtype)

def normalized_guidance( pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, momentum_buffer: MomentumBuffer = None, eta: float = 1.0, norm_threshold: float = 0.0,):
diff = pred_cond - pred_uncond
if momentum_buffer is not None:
#print(" momentum: ", momentum_buffer.momentum, " running_average: ", momentum_buffer.running_average)
momentum_buffer.update(diff)
#print(" new running_average: ", momentum_buffer.running_average)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
diff_parallel, diff_orthogonal = project(diff, pred_cond)
normalized_update = diff_orthogonal + eta * diff_parallel
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update

return pred_guided

class APG_ImYourCFGNow:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"scale": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}),
"momentum": ("FLOAT", {"default": -0.5, "min": -1.5, "max": 0.5, "step": 0.1, "round": 0.01}),
"norm_threshold": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.1, "round": 0.01}),
"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.1, "round": 0.01}),
},
}

RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "model_patches/unet"

def patch(
self,
model: ModelPatcher,
scale: float = 5.0,
momentum: float = -0.5,
norm_threshold: float = 0.0,
eta: float = 1.0,
):

momentum_buffer = MomentumBuffer(momentum)

def apg_function(args):
cond = args["cond"]
uncond = args["uncond"]

return normalized_guidance(cond, uncond, scale, momentum_buffer, eta, norm_threshold)

m = model.clone()
m.set_model_sampler_cfg_function(apg_function, momentum_buffer)

return (m,)


NODE_CLASS_MAPPINGS = {
"APG_ImYourCFGNow": APG_ImYourCFGNow,
}

NODE_DISPLAY_NAME_MAPPINGS = {
"APG_ImYourCFGNow": "APG_ImYourCFGNow",
}

0 comments on commit 73f7c8b

Please sign in to comment.