From c4d3c984bad13e9aae5e961e75e7331e512c526d Mon Sep 17 00:00:00 2001 From: Zhong Yi Wan Date: Fri, 8 Dec 2023 10:13:33 -0800 Subject: [PATCH] Code update PiperOrigin-RevId: 589169316 --- swirl_dynamics/lib/diffusion/guidance.py | 72 +++++++++++++++++-- swirl_dynamics/lib/diffusion/guidance_test.py | 24 +++++++ swirl_dynamics/lib/diffusion/samplers.py | 24 +++---- 3 files changed, 102 insertions(+), 18 deletions(-) diff --git a/swirl_dynamics/lib/diffusion/guidance.py b/swirl_dynamics/lib/diffusion/guidance.py index 631d9ee..84fdc5a 100644 --- a/swirl_dynamics/lib/diffusion/guidance.py +++ b/swirl_dynamics/lib/diffusion/guidance.py @@ -14,7 +14,7 @@ """Modules for guidance transforms for denoising functions.""" -from collections.abc import Callable, Mapping +from collections.abc import Callable, Mapping, Sequence from typing import Any, Protocol import flax @@ -23,8 +23,8 @@ Array = jax.Array PyTree = Any -Cond = Mapping[str, PyTree] | None -DenoiseFn = Callable[[Array, Array, Cond], Array] +ArrayMapping = Mapping[str, Array] +DenoiseFn = Callable[[Array, Array, ArrayMapping | None], Array] class Transform(Protocol): @@ -37,7 +37,7 @@ class Transform(Protocol): """ def __call__( - self, denoise_fn: DenoiseFn, guidance_inputs: Mapping[str, Array] + self, denoise_fn: DenoiseFn, guidance_inputs: ArrayMapping ) -> DenoiseFn: """Constructs a guided denoising function. @@ -84,11 +84,13 @@ class InfillFromSlices: guide_strength: float = 0.5 def __call__( - self, denoise_fn: DenoiseFn, guidance_inputs: Mapping[str, Array] + self, denoise_fn: DenoiseFn, guidance_inputs: ArrayMapping ) -> DenoiseFn: """Constructs denoise function guided by values on specified slices.""" - def _guided_denoise(x: Array, sigma: Array, cond: Cond = None) -> Array: + def _guided_denoise( + x: Array, sigma: Array, cond: ArrayMapping | None = None + ) -> Array: def constraint(xt: Array) -> tuple[Array, Array]: denoised = denoise_fn(xt, sigma, cond) error = jnp.sum( @@ -106,3 +108,61 @@ def constraint(xt: Array) -> tuple[Array, Array]: return denoised.at[self.slices].set(guidance_inputs["observed_slices"]) return _guided_denoise + + +@flax.struct.dataclass +class ClassifierFreeHybrid: + """Classifier-free guidance for a hybrid (cond/uncond) denoising model. + + This guidance technique, introduced by Ho and Salimans + (https://arxiv.org/abs/2207.12598), aims to improve the quality of denoised + images by combining conditional and unconditional denoising outputs. + The guided denoise function is given by: + + D̃(x, σ, c) = (1 + w) * D(x, σ, c) - w * D(x, σ, Ø), + + where + + - x: The noisy input. + - σ: The noise level. + - c: The conditioning information (e.g., class label). + - Ø: A special masking condition (typically zeros) indicating unconditional + denoising. + - w: The guidance strength, controlling the influence of each denoising + output. A value of 0 indicates non-guided denoising. + + Attributes: + guidance_strength: The strength of guidance (i.e. w). The original paper + reports optimal values of 0.1 and 0.3 for 64x64 and 128x128 ImageNet + respectively. + cond_mask_keys: A collection of keys in the conditions dictionary that will + be masked. If `None`, all conditions are masked. + cond_mask_value: The values that the conditions will be masked by. This + value must be consistent with the masking applied at training. + """ + + guidance_strength: float = 0.0 + cond_mask_keys: Sequence[str] | None = None + cond_mask_value: float = 0.0 + + def __call__( + self, denoise_fn: DenoiseFn, guidance_inputs: ArrayMapping + ) -> DenoiseFn: + """Constructs denoise function with classifier free guidance.""" + + def _guided_denoise(x: Array, sigma: Array, cond: ArrayMapping) -> Array: + masked_cond = { + k: ( + v # pylint: disable=g-long-ternary + if self.cond_mask_keys is not None + and k not in self.cond_mask_keys + else jnp.ones_like(v) * self.cond_mask_value + ) + for k, v in cond.items() + } + uncond_denoised = denoise_fn(x, sigma, masked_cond) + return (1 + self.guidance_strength) * denoise_fn( + x, sigma, cond + ) - self.guidance_strength * uncond_denoised + + return _guided_denoise diff --git a/swirl_dynamics/lib/diffusion/guidance_test.py b/swirl_dynamics/lib/diffusion/guidance_test.py index a20abb2..cefdf0a 100644 --- a/swirl_dynamics/lib/diffusion/guidance_test.py +++ b/swirl_dynamics/lib/diffusion/guidance_test.py @@ -50,6 +50,30 @@ def _dummy_denoiser(x, sigma, cond=None): expected[superresolve.slices] = 0.0 np.testing.assert_allclose(denoised, expected) + @parameterized.parameters( + {"mask_keys": None, "mask_value": 0, "expected": 13}, + {"mask_keys": None, "mask_value": 1, "expected": 12}, + {"mask_keys": ("0", "1", "2"), "mask_value": 0, "expected": 11.6}, + ) + def test_classifier_free_hybrid(self, mask_keys, mask_value, expected): + cf_hybrid = guidance.ClassifierFreeHybrid( + guidance_strength=0.2, + cond_mask_keys=mask_keys, + cond_mask_value=mask_value, + ) + + def _dummy_denoiser(x, sigma, cond): + del sigma + out = jnp.ones_like(x) + for v in cond.values(): + out += v + return out + + guided_denoiser = cf_hybrid(_dummy_denoiser, {}) + cond = {str(v): jnp.array(v) for v in range(5)} + denoised = guided_denoiser(jnp.array(0), jnp.array(0.1), cond) + self.assertAlmostEqual(denoised, expected, places=5) + if __name__ == "__main__": absltest.main() diff --git a/swirl_dynamics/lib/diffusion/samplers.py b/swirl_dynamics/lib/diffusion/samplers.py index ad5d906..0f089cd 100644 --- a/swirl_dynamics/lib/diffusion/samplers.py +++ b/swirl_dynamics/lib/diffusion/samplers.py @@ -26,9 +26,9 @@ from swirl_dynamics.lib.solvers import sde Array = jax.Array -PyTree = Any -Cond = Mapping[str, PyTree] | None -DenoiseFn = Callable[[Array, Array, Cond], Array] +ArrayMapping = Mapping[str, Array] +DenoiseFn = Callable[[Array, Array, ArrayMapping | None], Array] +Params = Mapping[str, Any] ScoreFn = DenoiseFn @@ -47,7 +47,7 @@ def denoiser2score( ) -> ScoreFn: """Converts a denoiser to the corresponding score function.""" - def _score(x: Array, sigma: Array, cond: Cond = None) -> Array: + def _score(x: Array, sigma: Array, cond: ArrayMapping | None = None) -> Array: # reference: eq. 74 in Karras et al. (https://arxiv.org/abs/2206.00364). scale = scheme.scale(scheme.sigma.inverse(sigma)) x_hat = jnp.divide(x, scale) @@ -145,7 +145,7 @@ def generate( def _apply_guidance_transforms( denoise_fn: DenoiseFn, transforms: Sequence[guidance.Transform], - guidance_inputs: Mapping[str, PyTree], + guidance_inputs: Mapping[str, Array], ) -> DenoiseFn: for transform in transforms: denoise_fn = transform(denoise_fn, guidance_inputs) @@ -181,8 +181,8 @@ def generate( num_samples: int, rng: Array, tspan: Array, - cond: Cond = None, - guidance_inputs: Mapping[str, Any] | None = None, + cond: ArrayMapping | None = None, + guidance_inputs: ArrayMapping | None = None, ) -> tuple[Array, dict[str, Array]]: """Generate samples by solving the sampling ODE. @@ -247,7 +247,7 @@ def dynamics(self) -> ode.OdeDynamics: where s(t), σ(t) are the scale and noise schedule of the diffusion scheme. """ - def _dynamics(x: Array, t: Array, params: PyTree) -> Array: + def _dynamics(x: Array, t: Array, params: Params) -> Array: assert not t.ndim, "`t` must be a scalar." denoise_fn = _apply_guidance_transforms( self.denoise_fn, @@ -280,8 +280,8 @@ def generate( num_samples: int, rng: Array, tspan: Array, - cond: Cond = None, - guidance_inputs: Mapping[str, Any] | None = None, + cond: ArrayMapping | None = None, + guidance_inputs: ArrayMapping | None = None, ) -> tuple[Array, dict[str, Array]]: """Generate samples by solving an SDE. @@ -355,7 +355,7 @@ def dynamics(self) -> sde.SdeDynamics: respectively. """ - def _drift(x: Array, t: Array, params: PyTree) -> Array: + def _drift(x: Array, t: Array, params: Params) -> Array: assert not t.ndim, "`t` must be a scalar." denoise_fn = _apply_guidance_transforms( self.denoise_fn, @@ -370,7 +370,7 @@ def _drift(x: Array, t: Array, params: PyTree) -> Array: drift -= 2 * dlog_sigma_dt * s * denoise_fn(x_hat, sigma, params["cond"]) return drift - def _diffusion(x: Array, t: Array, params: PyTree) -> Array: + def _diffusion(x: Array, t: Array, params: Params) -> Array: del x, params assert not t.ndim, "`t` must be a scalar." dsquare_sigma_dt = dsquare_dt(self.scheme.sigma)(t)