Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 589169316
  • Loading branch information
zhong1wan authored and The swirl_dynamics Authors committed Dec 8, 2023
1 parent 26c2b2b commit c4d3c98
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 18 deletions.
72 changes: 66 additions & 6 deletions swirl_dynamics/lib/diffusion/guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Modules for guidance transforms for denoising functions."""

from collections.abc import Callable, Mapping
from collections.abc import Callable, Mapping, Sequence
from typing import Any, Protocol

import flax
Expand All @@ -23,8 +23,8 @@

Array = jax.Array
PyTree = Any
Cond = Mapping[str, PyTree] | None
DenoiseFn = Callable[[Array, Array, Cond], Array]
ArrayMapping = Mapping[str, Array]
DenoiseFn = Callable[[Array, Array, ArrayMapping | None], Array]


class Transform(Protocol):
Expand All @@ -37,7 +37,7 @@ class Transform(Protocol):
"""

def __call__(
self, denoise_fn: DenoiseFn, guidance_inputs: Mapping[str, Array]
self, denoise_fn: DenoiseFn, guidance_inputs: ArrayMapping
) -> DenoiseFn:
"""Constructs a guided denoising function.
Expand Down Expand Up @@ -84,11 +84,13 @@ class InfillFromSlices:
guide_strength: float = 0.5

def __call__(
self, denoise_fn: DenoiseFn, guidance_inputs: Mapping[str, Array]
self, denoise_fn: DenoiseFn, guidance_inputs: ArrayMapping
) -> DenoiseFn:
"""Constructs denoise function guided by values on specified slices."""

def _guided_denoise(x: Array, sigma: Array, cond: Cond = None) -> Array:
def _guided_denoise(
x: Array, sigma: Array, cond: ArrayMapping | None = None
) -> Array:
def constraint(xt: Array) -> tuple[Array, Array]:
denoised = denoise_fn(xt, sigma, cond)
error = jnp.sum(
Expand All @@ -106,3 +108,61 @@ def constraint(xt: Array) -> tuple[Array, Array]:
return denoised.at[self.slices].set(guidance_inputs["observed_slices"])

return _guided_denoise


@flax.struct.dataclass
class ClassifierFreeHybrid:
"""Classifier-free guidance for a hybrid (cond/uncond) denoising model.
This guidance technique, introduced by Ho and Salimans
(https://arxiv.org/abs/2207.12598), aims to improve the quality of denoised
images by combining conditional and unconditional denoising outputs.
The guided denoise function is given by:
D̃(x, σ, c) = (1 + w) * D(x, σ, c) - w * D(x, σ, Ø),
where
- x: The noisy input.
- σ: The noise level.
- c: The conditioning information (e.g., class label).
- Ø: A special masking condition (typically zeros) indicating unconditional
denoising.
- w: The guidance strength, controlling the influence of each denoising
output. A value of 0 indicates non-guided denoising.
Attributes:
guidance_strength: The strength of guidance (i.e. w). The original paper
reports optimal values of 0.1 and 0.3 for 64x64 and 128x128 ImageNet
respectively.
cond_mask_keys: A collection of keys in the conditions dictionary that will
be masked. If `None`, all conditions are masked.
cond_mask_value: The values that the conditions will be masked by. This
value must be consistent with the masking applied at training.
"""

guidance_strength: float = 0.0
cond_mask_keys: Sequence[str] | None = None
cond_mask_value: float = 0.0

def __call__(
self, denoise_fn: DenoiseFn, guidance_inputs: ArrayMapping
) -> DenoiseFn:
"""Constructs denoise function with classifier free guidance."""

def _guided_denoise(x: Array, sigma: Array, cond: ArrayMapping) -> Array:
masked_cond = {
k: (
v # pylint: disable=g-long-ternary
if self.cond_mask_keys is not None
and k not in self.cond_mask_keys
else jnp.ones_like(v) * self.cond_mask_value
)
for k, v in cond.items()
}
uncond_denoised = denoise_fn(x, sigma, masked_cond)
return (1 + self.guidance_strength) * denoise_fn(
x, sigma, cond
) - self.guidance_strength * uncond_denoised

return _guided_denoise
24 changes: 24 additions & 0 deletions swirl_dynamics/lib/diffusion/guidance_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,30 @@ def _dummy_denoiser(x, sigma, cond=None):
expected[superresolve.slices] = 0.0
np.testing.assert_allclose(denoised, expected)

@parameterized.parameters(
{"mask_keys": None, "mask_value": 0, "expected": 13},
{"mask_keys": None, "mask_value": 1, "expected": 12},
{"mask_keys": ("0", "1", "2"), "mask_value": 0, "expected": 11.6},
)
def test_classifier_free_hybrid(self, mask_keys, mask_value, expected):
cf_hybrid = guidance.ClassifierFreeHybrid(
guidance_strength=0.2,
cond_mask_keys=mask_keys,
cond_mask_value=mask_value,
)

def _dummy_denoiser(x, sigma, cond):
del sigma
out = jnp.ones_like(x)
for v in cond.values():
out += v
return out

guided_denoiser = cf_hybrid(_dummy_denoiser, {})
cond = {str(v): jnp.array(v) for v in range(5)}
denoised = guided_denoiser(jnp.array(0), jnp.array(0.1), cond)
self.assertAlmostEqual(denoised, expected, places=5)


if __name__ == "__main__":
absltest.main()
24 changes: 12 additions & 12 deletions swirl_dynamics/lib/diffusion/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from swirl_dynamics.lib.solvers import sde

Array = jax.Array
PyTree = Any
Cond = Mapping[str, PyTree] | None
DenoiseFn = Callable[[Array, Array, Cond], Array]
ArrayMapping = Mapping[str, Array]
DenoiseFn = Callable[[Array, Array, ArrayMapping | None], Array]
Params = Mapping[str, Any]
ScoreFn = DenoiseFn


Expand All @@ -47,7 +47,7 @@ def denoiser2score(
) -> ScoreFn:
"""Converts a denoiser to the corresponding score function."""

def _score(x: Array, sigma: Array, cond: Cond = None) -> Array:
def _score(x: Array, sigma: Array, cond: ArrayMapping | None = None) -> Array:
# reference: eq. 74 in Karras et al. (https://arxiv.org/abs/2206.00364).
scale = scheme.scale(scheme.sigma.inverse(sigma))
x_hat = jnp.divide(x, scale)
Expand Down Expand Up @@ -145,7 +145,7 @@ def generate(
def _apply_guidance_transforms(
denoise_fn: DenoiseFn,
transforms: Sequence[guidance.Transform],
guidance_inputs: Mapping[str, PyTree],
guidance_inputs: Mapping[str, Array],
) -> DenoiseFn:
for transform in transforms:
denoise_fn = transform(denoise_fn, guidance_inputs)
Expand Down Expand Up @@ -181,8 +181,8 @@ def generate(
num_samples: int,
rng: Array,
tspan: Array,
cond: Cond = None,
guidance_inputs: Mapping[str, Any] | None = None,
cond: ArrayMapping | None = None,
guidance_inputs: ArrayMapping | None = None,
) -> tuple[Array, dict[str, Array]]:
"""Generate samples by solving the sampling ODE.
Expand Down Expand Up @@ -247,7 +247,7 @@ def dynamics(self) -> ode.OdeDynamics:
where s(t), σ(t) are the scale and noise schedule of the diffusion scheme.
"""

def _dynamics(x: Array, t: Array, params: PyTree) -> Array:
def _dynamics(x: Array, t: Array, params: Params) -> Array:
assert not t.ndim, "`t` must be a scalar."
denoise_fn = _apply_guidance_transforms(
self.denoise_fn,
Expand Down Expand Up @@ -280,8 +280,8 @@ def generate(
num_samples: int,
rng: Array,
tspan: Array,
cond: Cond = None,
guidance_inputs: Mapping[str, Any] | None = None,
cond: ArrayMapping | None = None,
guidance_inputs: ArrayMapping | None = None,
) -> tuple[Array, dict[str, Array]]:
"""Generate samples by solving an SDE.
Expand Down Expand Up @@ -355,7 +355,7 @@ def dynamics(self) -> sde.SdeDynamics:
respectively.
"""

def _drift(x: Array, t: Array, params: PyTree) -> Array:
def _drift(x: Array, t: Array, params: Params) -> Array:
assert not t.ndim, "`t` must be a scalar."
denoise_fn = _apply_guidance_transforms(
self.denoise_fn,
Expand All @@ -370,7 +370,7 @@ def _drift(x: Array, t: Array, params: PyTree) -> Array:
drift -= 2 * dlog_sigma_dt * s * denoise_fn(x_hat, sigma, params["cond"])
return drift

def _diffusion(x: Array, t: Array, params: PyTree) -> Array:
def _diffusion(x: Array, t: Array, params: Params) -> Array:
del x, params
assert not t.ndim, "`t` must be a scalar."
dsquare_sigma_dt = dsquare_dt(self.scheme.sigma)(t)
Expand Down

0 comments on commit c4d3c98

Please sign in to comment.