Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 554633003
  • Loading branch information
The swirl_dynamics Authors committed Aug 7, 2023
1 parent 8e9d341 commit 8a5f5f2
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 0 deletions.
120 changes: 120 additions & 0 deletions swirl_dynamics/lib/networks/convnets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright 2023 The swirl_dynamics Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Convolution-based modules for training dynamical systems.
1) PeriodicConvNetModel
"""
from collections.abc import Callable
from typing import Any

import flax.linen as nn
import jax
import jax.numpy as jnp

Array = jax.Array


class DilatedBlock(nn.Module):
"""Implements a Dilated ConvNet block."""

num_channels: int
num_levels: int
kernel_size: tuple[int, ...]
act_fun: Callable[[jax.Array], jax.Array] | Any = nn.relu
dtype: Any = jnp.float32
padding: str = "CIRCULAR"

@nn.compact
def __call__(self, inputs: Array) -> Array:

x = inputs.astype(self.dtype)

# Ascending and descending order
dilation_order = list(range(self.num_levels))
dilation_order += list(range(self.num_levels-1))[::-1]
for i in dilation_order:
x = nn.Conv(features=self.num_channels,
kernel_size=self.kernel_size,
kernel_dilation=2**i,
padding=self.padding,
dtype=self.dtype)(x)
x = self.act_fun(x)
return x


class PeriodicConvNetModel(nn.Module):
"""Periodic ConvNet model.
Simple convolutional model with skip connections, and dilated blocks. Based on
the paper: Learned Coarse Models for Efficient Turbulence Simulation.
Attributes:
latent_dim: Dimension of the latent space in the processor.
num_levels: Number of dilated convolutions, the larger the number of levels,
the bigger receptives field the network will have.
num_processors: Number of dilated blocks.
encoder_kernel_size: Size of the kernel in the conv layer for the encoder.
decoder_kernel_size: Size of the kernel in the conv layer for the decoder.
processor_kernel_size: Size of the kernel in the conv layers inside the
dilated convolutional blocks.
act_fun: Activation function to be after each dilated block.
norm_layer: Normalization layer to be applied after each dilated block.
dtype: Type of input/layers.
padding: Type of padding added to the convolutional layers depending on the
geometry of underlying problem.
is_input_residual: Boolean to use a global skip connection, so the
architecture is similar to a Forward Euler integration rule.
"""
latent_dim: int = 48
num_levels: int = 4
num_processors: int = 4
encoder_kernel_size: tuple[int, ...] = (5,)
decoder_kernel_size: tuple[int, ...] = (5,)
processor_kernel_size: tuple[int, ...] = (5,)
act_fun: Callable[[jax.Array], jax.Array] | Any = nn.relu
norm_layer: Callable[[jax.Array], jax.Array] = lambda x: x # default is ID
dtype: Any = jnp.float32
padding: str = "CIRCULAR"
is_input_residual: bool = True

@nn.compact
def __call__(self, inputs: Array) -> Array:

x = inputs.astype(self.dtype)

# Encoder to latent dimension (larger than regular dimension).
latent_x = nn.Conv(features=self.latent_dim,
kernel_size=self.encoder_kernel_size,
padding=self.padding,
dtype=self.dtype)(x)

for _ in range(self.num_processors):
y = DilatedBlock(num_channels=self.latent_dim,
kernel_size=self.processor_kernel_size,
num_levels=self.num_levels,
act_fun=self.act_fun)(latent_x)
y = self.norm_layer(y)
latent_x += y

x = nn.Conv(features=inputs.shape[-1],
kernel_size=self.decoder_kernel_size,
padding=self.padding,
dtype=self.dtype)(latent_x)

# Last skip connection.
if self.is_input_residual:
x += inputs

return x
88 changes: 88 additions & 0 deletions swirl_dynamics/lib/networks/convnets_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2023 The swirl_dynamics Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for the Convolution-based Modules."""


from absl.testing import absltest
from absl.testing import parameterized
import jax
from swirl_dynamics.lib.networks import convnets


class DilatedBlockTest(parameterized.TestCase):
"""Testing the DilatedBlock building block of the PeriodicConvNetModel."""

@parameterized.named_parameters(
(f':input_dim={i}', i)
for i in ((512,), (64, 64))
)
def test_output_shapes(
self,
input_dim=(512,),
kernel_size=5,
num_levels=4,
num_channels=48,
batch_size=2,
input_channels=1,
):
d_block = convnets.DilatedBlock(
num_channels=num_channels,
kernel_size=(kernel_size,),
num_levels=num_levels
)
rng = jax.random.PRNGKey(0)
x = jax.random.normal(rng, ((batch_size,) + input_dim + (input_channels,)))
params = d_block.init(rng, x)['params']
y = d_block.apply({'params': params}, x)
self.assertEqual(y.shape, ((batch_size,) + input_dim + (num_channels,)))


class PeriodicConvNetModelTest(parameterized.TestCase):
"""Testing the PeriodicConvNetModel."""

@parameterized.named_parameters(
(f':input_dim={i}', i)
for i in ((512,), (64, 64))
)
def test_output_shapes(
self,
input_dim=(512,),
batch_size=2,
input_channels=3,
latent_dim=48,
num_levels=4,
num_processors=4,
encoder_kernel_size=(5,),
decoder_kernel_size=(5,),
processor_kernel_size=(5,)

):
test_model = convnets.PeriodicConvNetModel(
latent_dim=latent_dim,
num_levels=num_levels,
num_processors=num_processors,
encoder_kernel_size=encoder_kernel_size,
decoder_kernel_size=decoder_kernel_size,
processor_kernel_size=processor_kernel_size,
)
rng = jax.random.PRNGKey(0)
x = jax.random.normal(rng, ((batch_size,) + input_dim + (input_channels,)))
params = test_model.init(rng, x)['params']
y = test_model.apply({'params': params}, x)
self.assertEqual(y.shape, ((batch_size,) + input_dim + (input_channels,)))


if __name__ == '__main__':
absltest.main()

0 comments on commit 8a5f5f2

Please sign in to comment.