Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 669111124
  • Loading branch information
Forgotten authored and The swirl_dynamics Authors committed Aug 30, 2024
1 parent 13be32c commit 057c93c
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 0 deletions.
159 changes: 159 additions & 0 deletions swirl_dynamics/lib/diffusion/unets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import numpy as np
from swirl_dynamics.lib import layers


Array = jax.Array
Initializer = nn.initializers.Initializer
PrecisionLike = (
Expand Down Expand Up @@ -50,6 +51,7 @@ class AdaptiveScale(nn.Module):
see e.g. https://arxiv.org/abs/2105.05233, and for the
more general FiLM technique see https://arxiv.org/abs/1709.07871.
"""

act_fun: Callable[[Array], Array] = nn.swish
precision: PrecisionLike = None
dtype: jnp.dtype = jnp.float32
Expand Down Expand Up @@ -452,6 +454,139 @@ def __call__(self, x: Array, cond: dict[str, Array]):
return merge_channel_cond(x, proc_cond)


class MergeEmdCond(nn.Module):
"""Base class for merging conditional inputs as embeddings."""

def __call__(self, emb: Array, cond: dict[str, Array], is_training: bool):
pass


class EmbConvMerge(MergeEmdCond):
"""Compute conditional inputs through interpolation and convolutions.
We resize the conditional inputs to match the spatial shape of the main input
and then pass them through a nonlinearity and a ConvLayer. The output is then
mixied with the embedding from the Fourier embedding.
Attributes:
embed_dim: The output channel dimension.
latent_dim: The latent dimension of the embedding.
downsample_ratio: Ratio for the downsampling of the embedding.
interp_shape: The shape to which the conditional inputs are resized.
kernel_size: The convolutional kernel size.
resize_method: The interpolation method employed by `jax.image.resize`.
padding: The padding method of all convolutions.
num_heads: Number of heads in the attention block.
normalize_qk: Whether to normalize the query and key vectors in the
attention block.
"""

embed_dim: int
latent_dim: int
kernel_size: Sequence[int]
downsample_ratio: Sequence[int]
interp_shape: Sequence[int]
resize_method: str = "cubic"
padding: str = "CIRCULAR"
num_heads: int = 128
normalize_qk: bool = True
precision: PrecisionLike = None
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32

@nn.compact
def __call__(self, emb: Array, cond: dict[str, Array], is_training: bool):
"""Merges conditional inputs along the channel dimension.
Fields with spatial shape differing from the sample `x` are reshaped to
match it. Then, all conditional fields are passed through a nonlinearity,
a ConvLayer, and then concatenated with the main input along the last axis.
Args:
emb: Embedding coming from the kernel_dim = x.ndim - 2.
cond: A dictionary of conditional inputs. Those with keys that start with
"channel:" are processed here while all others are omitted.
is_training: Whether the model is in training mode.
Returns:
Embedding merged with channel conditions.
"""

if emb.shape[-1] != self.embed_dim:
raise ValueError(
f"Number of channels in the embedding ({emb.shape[-1]}) must "
"match the number of channels in the output "
f"{self.embed_dim})."
)

value_temp = []

# Extract fields, resize and concatenate.
for key, value in sorted(cond.items()):
# TODO: Change the prefix to "merge_embed:".
if key.startswith("channel:"):
# Enforcing prefix in the key.
value = layers.FilteredResize(
output_size=self.interp_shape,
kernel_size=self.kernel_size,
method=self.resize_method,
padding=self.padding,
precision=self.precision,
dtype=self.dtype,
param_dtype=self.param_dtype,
name=f"resize_embedding_{key}",
)(value)

value_temp.append(value)

value = jnp.concatenate(value_temp, axis=-1)

kernel_dim = value.ndim - 2
# Downsample the embedding.
num_levels = len(self.downsample_ratio)
for level in range(num_levels):
value = nn.swish(nn.LayerNorm()(value))
value = layers.DownsampleConv(
features=self.latent_dim,
ratios=(self.downsample_ratio[level],) * kernel_dim,
kernel_init=default_init(1.0),
precision=self.precision,
dtype=self.dtype,
param_dtype=self.param_dtype,
name=f"level_{level}.embedding_downsample_conv",
)(value)

# Add a self-attention block.
b, _, _, c = value.shape
value = AttentionBlock(
num_heads=self.num_heads,
precision=self.precision,
dtype=self.dtype,
normalize_qk=self.normalize_qk,
param_dtype=self.param_dtype,
name="cond_embedding.attention_block",
)(value.reshape(b, -1, c), is_training=is_training)

value = nn.Dense(
features=self.embed_dim,
precision=self.precision,
dtype=self.dtype,
param_dtype=self.param_dtype,
)(value.reshape(b, -1))
value = nn.swish(value)

# Concatenate the noise and conditional embedding.
emb = jnp.concatenate([emb, value], axis=-1)
emb = nn.Dense(
features=self.embed_dim,
precision=self.precision,
dtype=self.dtype,
param_dtype=self.param_dtype,
)(emb)

return emb


class DStack(nn.Module):
"""Downsampling stack.
Expand Down Expand Up @@ -682,6 +817,8 @@ class UNet(nn.Module):
cond_resize_method: str = "bilinear"
cond_embed_dim: int = 128
cond_merging_fn: type[MergeChannelCond] = InterpConvMerge
cond_embed_fn: type[nn.Module] | None = None
cond_embed_kwargs: dict[str, jax.typing.ArrayLike] | None = None
precision: PrecisionLike = None
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
Expand Down Expand Up @@ -743,6 +880,28 @@ def __call__(
)(x, cond)

emb = FourierEmbedding(dims=self.noise_embed_dim)(sigma)
# Incorporating the embedding from the conditional inputs.
if self.cond_embed_fn:
if self.cond_embed_kwargs is None:
# For backward compatibility.
# TODO: Remove this once the configs are updated.
cond_embed_kwargs = dict(latent_dim=32, num_heads=32)
else:
cond_embed_kwargs = self.cond_embed_kwargs

emb = self.cond_embed_fn(
embed_dim=self.noise_embed_dim,
latent_dim=cond_embed_kwargs["latent_dim"],
num_heads=cond_embed_kwargs["num_heads"],
kernel_size=(3,) * kernel_dim,
interp_shape=x.shape[:-1],
downsample_ratio=self.downsample_ratio,
padding=self.padding,
precision=self.precision,
dtype=self.dtype,
param_dtype=self.param_dtype,
)(emb, cond, is_training=is_training)

skips = DStack(
num_channels=self.num_channels,
num_res_blocks=len(self.num_channels) * (self.num_blocks,),
Expand Down
41 changes: 41 additions & 0 deletions swirl_dynamics/lib/diffusion/unets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,47 @@ def test_preconditioned_merging_functions(self):
)
self.assertEqual(out.shape, x.shape)

def test_preconditioned_embedding_functions(self):
x_dims = (1, 16, 8, 3)
c_dims = (1, 8, 4, 6)
x = jax.random.normal(jax.random.PRNGKey(42), x_dims)
cond = {
"channel:cond1": jax.random.normal(jax.random.PRNGKey(42), c_dims),
}
sigma = jnp.array(0.5)
model = diffusion.unets.PreconditionedDenoiser(
out_channels=x_dims[-1],
num_channels=(4, 8, 12),
downsample_ratio=(2, 2, 2),
num_blocks=2,
num_heads=4,
sigma_data=1.0,
use_position_encoding=False,
cond_embed_dim=32,
cond_resize_method="cubic",
cond_embed_fn=diffusion.unets.EmbConvMerge,
)
variables = model.init(
jax.random.PRNGKey(42), x=x, sigma=sigma, cond=cond, is_training=True
)
# Check shape dict so that err message is easier to read when things break.
shape_dict = jax.tree.map(jnp.shape, variables["params"])
self.assertIn("EmbConvMerge_0", shape_dict)
self.assertIn(
"level_0.embedding_downsample_conv", shape_dict["EmbConvMerge_0"]
)
self.assertIn(
"cond_embedding.attention_block", shape_dict["EmbConvMerge_0"]
)
self.assertIn(
"resize_embedding_channel:cond1", shape_dict["EmbConvMerge_0"]
)

out = jax.jit(functools.partial(model.apply, is_training=True))(
variables, x, sigma, cond
)
self.assertEqual(out.shape, x.shape)


if __name__ == "__main__":
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ml_collections import config_flags
import numpy as np
from swirl_dynamics.data import hdf5_utils
from swirl_dynamics.lib.diffusion import unets
from swirl_dynamics.lib.solvers import ode as ode_solvers
from swirl_dynamics.projects.debiasing.rectified_flow import data_utils
from swirl_dynamics.projects.debiasing.rectified_flow import evaluation_metrics as metrics
Expand Down Expand Up @@ -349,6 +350,13 @@ def read_normalized_stats(

def build_model(config):
"""Builds the model from config file."""

if "conditional_embedding" in config and config.conditional_embedding:
logging.info("Using conditional embedding")
cond_embed_fn = unets.EmbConvMerge
else:
cond_embed_fn = None

flow_model = models.RescaledUnet(
out_channels=config.out_channels,
num_channels=config.num_channels,
Expand All @@ -362,6 +370,7 @@ def build_model(config):
use_position_encoding=config.use_position_encoding,
num_heads=config.num_heads,
normalize_qk=config.normalize_qk,
cond_embed_fn=cond_embed_fn,
)

model = models.ConditionalReFlowModel(
Expand Down
10 changes: 10 additions & 0 deletions swirl_dynamics/projects/debiasing/rectified_flow/inference_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import ml_collections
from ml_collections import config_flags
import numpy as np
from swirl_dynamics.lib.diffusion import unets
from swirl_dynamics.lib.solvers import ode as ode_solvers
from swirl_dynamics.projects.debiasing.rectified_flow import data_utils
from swirl_dynamics.projects.debiasing.rectified_flow import models
Expand Down Expand Up @@ -251,6 +252,14 @@ def read_normalized_stats(

def build_model(config):
"""Builds the model from config file."""

# Adding the conditional embedding for the FILM layer.
if "conditional_embedding" in config and config.conditional_embedding:
logging.info("Using conditional embedding")
cond_embed_fn = unets.EmbConvMerge
else:
cond_embed_fn = None

flow_model = models.RescaledUnet(
out_channels=config.out_channels,
num_channels=config.num_channels,
Expand All @@ -263,6 +272,7 @@ def build_model(config):
resize_to_shape=config.resize_to_shape,
use_position_encoding=config.use_position_encoding,
num_heads=config.num_heads,
cond_embed_fn=cond_embed_fn,
normalize_qk=config.normalize_qk,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ml_collections import config_flags
import optax
from orbax import checkpoint
from swirl_dynamics.lib.diffusion import unets
from swirl_dynamics.projects.debiasing.rectified_flow import data_utils
from swirl_dynamics.projects.debiasing.rectified_flow import models
from swirl_dynamics.projects.debiasing.rectified_flow import trainers
Expand Down Expand Up @@ -239,6 +240,13 @@ def main(argv):
dtype = jax.numpy.float32
param_dtype = jax.numpy.float32

# Adding the conditional embedding for the FILM layer.
if "conditional_embedding" in config and config.conditional_embedding:
logging.info("Using conditional embedding")
cond_embed_fn = unets.EmbConvMerge
else:
cond_embed_fn = None

# Setting up the neural network for the flow model.
flow_model = models.RescaledUnet(
out_channels=config.out_channels,
Expand All @@ -253,6 +261,7 @@ def main(argv):
use_position_encoding=config.use_position_encoding,
num_heads=config.num_heads,
normalize_qk=config.normalize_qk,
cond_embed_fn=cond_embed_fn,
dtype=dtype,
param_dtype=param_dtype,
)
Expand Down

0 comments on commit 057c93c

Please sign in to comment.