-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PiperOrigin-RevId: 554633003
- Loading branch information
The swirl_dynamics Authors
committed
Aug 7, 2023
1 parent
8e9d341
commit 8a5f5f2
Showing
2 changed files
with
208 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# Copyright 2023 The swirl_dynamics Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Convolution-based modules for training dynamical systems. | ||
1) PeriodicConvNetModel | ||
""" | ||
from collections.abc import Callable | ||
from typing import Any | ||
|
||
import flax.linen as nn | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
Array = jax.Array | ||
|
||
|
||
class DilatedBlock(nn.Module): | ||
"""Implements a Dilated ConvNet block.""" | ||
|
||
num_channels: int | ||
num_levels: int | ||
kernel_size: tuple[int, ...] | ||
act_fun: Callable[[jax.Array], jax.Array] | Any = nn.relu | ||
dtype: Any = jnp.float32 | ||
padding: str = "CIRCULAR" | ||
|
||
@nn.compact | ||
def __call__(self, inputs: Array) -> Array: | ||
|
||
x = inputs.astype(self.dtype) | ||
|
||
# Ascending and descending order | ||
dilation_order = list(range(self.num_levels)) | ||
dilation_order += list(range(self.num_levels-1))[::-1] | ||
for i in dilation_order: | ||
x = nn.Conv(features=self.num_channels, | ||
kernel_size=self.kernel_size, | ||
kernel_dilation=2**i, | ||
padding=self.padding, | ||
dtype=self.dtype)(x) | ||
x = self.act_fun(x) | ||
return x | ||
|
||
|
||
class PeriodicConvNetModel(nn.Module): | ||
"""Periodic ConvNet model. | ||
Simple convolutional model with skip connections, and dilated blocks. Based on | ||
the paper: Learned Coarse Models for Efficient Turbulence Simulation. | ||
Attributes: | ||
latent_dim: Dimension of the latent space in the processor. | ||
num_levels: Number of dilated convolutions, the larger the number of levels, | ||
the bigger receptives field the network will have. | ||
num_processors: Number of dilated blocks. | ||
encoder_kernel_size: Size of the kernel in the conv layer for the encoder. | ||
decoder_kernel_size: Size of the kernel in the conv layer for the decoder. | ||
processor_kernel_size: Size of the kernel in the conv layers inside the | ||
dilated convolutional blocks. | ||
act_fun: Activation function to be after each dilated block. | ||
norm_layer: Normalization layer to be applied after each dilated block. | ||
dtype: Type of input/layers. | ||
padding: Type of padding added to the convolutional layers depending on the | ||
geometry of underlying problem. | ||
is_input_residual: Boolean to use a global skip connection, so the | ||
architecture is similar to a Forward Euler integration rule. | ||
""" | ||
latent_dim: int = 48 | ||
num_levels: int = 4 | ||
num_processors: int = 4 | ||
encoder_kernel_size: tuple[int, ...] = (5,) | ||
decoder_kernel_size: tuple[int, ...] = (5,) | ||
processor_kernel_size: tuple[int, ...] = (5,) | ||
act_fun: Callable[[jax.Array], jax.Array] | Any = nn.relu | ||
norm_layer: Callable[[jax.Array], jax.Array] = lambda x: x # default is ID | ||
dtype: Any = jnp.float32 | ||
padding: str = "CIRCULAR" | ||
is_input_residual: bool = True | ||
|
||
@nn.compact | ||
def __call__(self, inputs: Array) -> Array: | ||
|
||
x = inputs.astype(self.dtype) | ||
|
||
# Encoder to latent dimension (larger than regular dimension). | ||
latent_x = nn.Conv(features=self.latent_dim, | ||
kernel_size=self.encoder_kernel_size, | ||
padding=self.padding, | ||
dtype=self.dtype)(x) | ||
|
||
for _ in range(self.num_processors): | ||
y = DilatedBlock(num_channels=self.latent_dim, | ||
kernel_size=self.processor_kernel_size, | ||
num_levels=self.num_levels, | ||
act_fun=self.act_fun)(latent_x) | ||
y = self.norm_layer(y) | ||
latent_x += y | ||
|
||
x = nn.Conv(features=inputs.shape[-1], | ||
kernel_size=self.decoder_kernel_size, | ||
padding=self.padding, | ||
dtype=self.dtype)(latent_x) | ||
|
||
# Last skip connection. | ||
if self.is_input_residual: | ||
x += inputs | ||
|
||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Copyright 2023 The swirl_dynamics Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Tests for the Convolution-based Modules.""" | ||
|
||
|
||
from absl.testing import absltest | ||
from absl.testing import parameterized | ||
import jax | ||
from swirl_dynamics.lib.networks import convnets | ||
|
||
|
||
class DilatedBlockTest(parameterized.TestCase): | ||
"""Testing the DilatedBlock building block of the PeriodicConvNetModel.""" | ||
|
||
@parameterized.named_parameters( | ||
(f':input_dim={i}', i) | ||
for i in ((512,), (64, 64)) | ||
) | ||
def test_output_shapes( | ||
self, | ||
input_dim=(512,), | ||
kernel_size=5, | ||
num_levels=4, | ||
num_channels=48, | ||
batch_size=2, | ||
input_channels=1, | ||
): | ||
d_block = convnets.DilatedBlock( | ||
num_channels=num_channels, | ||
kernel_size=(kernel_size,), | ||
num_levels=num_levels | ||
) | ||
rng = jax.random.PRNGKey(0) | ||
x = jax.random.normal(rng, ((batch_size,) + input_dim + (input_channels,))) | ||
params = d_block.init(rng, x)['params'] | ||
y = d_block.apply({'params': params}, x) | ||
self.assertEqual(y.shape, ((batch_size,) + input_dim + (num_channels,))) | ||
|
||
|
||
class PeriodicConvNetModelTest(parameterized.TestCase): | ||
"""Testing the PeriodicConvNetModel.""" | ||
|
||
@parameterized.named_parameters( | ||
(f':input_dim={i}', i) | ||
for i in ((512,), (64, 64)) | ||
) | ||
def test_output_shapes( | ||
self, | ||
input_dim=(512,), | ||
batch_size=2, | ||
input_channels=3, | ||
latent_dim=48, | ||
num_levels=4, | ||
num_processors=4, | ||
encoder_kernel_size=(5,), | ||
decoder_kernel_size=(5,), | ||
processor_kernel_size=(5,) | ||
|
||
): | ||
test_model = convnets.PeriodicConvNetModel( | ||
latent_dim=latent_dim, | ||
num_levels=num_levels, | ||
num_processors=num_processors, | ||
encoder_kernel_size=encoder_kernel_size, | ||
decoder_kernel_size=decoder_kernel_size, | ||
processor_kernel_size=processor_kernel_size, | ||
) | ||
rng = jax.random.PRNGKey(0) | ||
x = jax.random.normal(rng, ((batch_size,) + input_dim + (input_channels,))) | ||
params = test_model.init(rng, x)['params'] | ||
y = test_model.apply({'params': params}, x) | ||
self.assertEqual(y.shape, ((batch_size,) + input_dim + (input_channels,))) | ||
|
||
|
||
if __name__ == '__main__': | ||
absltest.main() |