Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add the implementation of Topodiff to modulus/models/diffusion #584

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added gradient clipping to StaticCapture utilities.
- Bistride Multiscale MeshGraphNet example.
- FIGConvUNet model and example.
- TopoDiff model and example.

### Changed

Expand Down
Binary file added docs/img/topodiff_doc/grid_topology.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/img/topodiff_doc/topodiff.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/img/topodiff_doc/topology_generated.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
97 changes: 97 additions & 0 deletions examples/generative/topodiff/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# TopoDiff
We propose Topodiff, a conditional diffusion-model-based architecture to perform performance-aware and manufacturability-aware topology optimization that overcomes the issues of Generative Adversarial Networks (GANs) such as difficult to train, limited generalizability and neglecting manufacturability. Topodiff introduces a surrogate model-based guidance strategy that actively favors structures with low compliance and good manufacturability.
- Link to the paper: [Link](https://arxiv.org/abs/2208.09591)
- Link to the project: [Link](https://decode.mit.edu/projects/topodiff/)

<p align="center">
<img src="../../../docs/img/topodiff_doc/topodiff.png" width="840" />
</p>

## Dataset
- Link to the complete dataset: [here](https://www.dropbox.com/home/decode_lab/Datasets/Public%20Documents/Topodiff_dataset)
- Link to the dataset for diffusion model training: [here](https://www.dropbox.com/scl/fi/9jy96a0lyf39wdwc27se7/dataset_1_diff.zip?rlkey=zz0ijw8e5h0hf0fb7qpnbwrgu&st=t7nuh1w7&dl=0)
- Link to the dataset for regression model training: [here](https://www.dropbox.com/scl/fi/486kmqbghzxuxqewjm9b4/dataset_2_reg.zip?rlkey=gw8yu1lv40tqk192py7wsl9o9&st=ao9yc1rw&dl=0)
- Link to the dataset for classifier model training: [here](https://www.dropbox.com/scl/fi/486kmqbghzxuxqewjm9b4/dataset_2_reg.zip?rlkey=gw8yu1lv40tqk192py7wsl9o9&st=ao9yc1rw&dl=0)


Download the dataset and set the **path_data** in [the config yaml file](conf/config.yaml)

## Instructions
2D topology structures could be generated by Topodiff conditioned on the boundary and loading conditions. A few examples are shown below:
<p align="center">
<img src="../../../docs/img/topodiff_doc/topology_generated.png" width="840" />
</p>

### Model training
Before training the model, config.yaml should be:
```
hydra:
job:
chdir: True
run:
dir: ./outputs/


path_data: path to the Topodiff dataset downloaded

epochs: 100000
batch_size: 128
lr: 1e-4

classifier_iterations: 20000
regressor_iterations: 20000
diffusion_steps: 1000

generation_path: ./

model_path_diffusion: path to the pt file of the diffusion model
model_path_classifier: path to the pt file of the classifier for floating material


path_training_data_diffusion: path to the /dataset_1_diff/training_data/

path_data_regressor_training: path to the /dataset_2_reg/training_data/
path_data_regressor_validation: path to the /dataset_2_reg/validation_data/

path_data_classifier_training: path to the /dataset_3_class/training_data/
path_data_classifier_validation: path to the /dataset_3_class/validation_data/

```

Run the following command to train the diffusion model, classifier for floating material and regressor for compliance:
```Bash
python examples/generative/topodiff/train.py
python examples/generative/topodiff/train_classifier.py
python examples/generative/topodiff/train_regressor.py
```
### Generation
By default, the generated topologies are conditioned on the boundary and loading conditions that have not been seen in the training process.
Run the following command to generate topologies:
```Bash
python examples/generative/topodiff/inference.py
```




## Citations
To cite this work, please use the following reference:

```bibtex
@inproceedings{maze2023diffusion,
title={Diffusion models beat gans on topology optimization},
author={Maz{\'e}, Fran{\c{c}}ois and Ahmed, Faez},
booktitle={Proceedings of the AAAI conference on artificial intelligence},
volume={37},
number={8},
pages={9108--9116},
year={2023}
}
```

## References
- [Diffusion Models Beat GANs on Topology Optimization](https://decode.mit.edu/assets/papers/2022_maze_topodiff.pdf)
- [Topodiff Project Page](https://decode.mit.edu/projects/topodiff/)
- [Topodiff Dataset](https://www.dropbox.com/home/decode_lab/Datasets/Public%20Documents/Topodiff_dataset)
- [Github](https://github.com/francoismaze/topodiff)

49 changes: 49 additions & 0 deletions examples/generative/topodiff/conf/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

hydra:
job:
chdir: True
run:
dir: ./outputs/


epochs: 100000
batch_size: 64
lr: 1e-4

classifier_iterations: 30000
regressor_iterations: 30000
diffusion_steps: 1000
model_path: ./
generation_path: ./
model_path_diffusion: /home/turbo/Qian/modulus/modulus/outputs/model_100000.pt
model_path_classifier: /home/turbo/Qian/modulus/modulus/outputs/classifier.pt


path_training_data_diffusion: /home/turbo/Qian/dataset_1_diff/training_data/

path_data_regressor_training: /home/turbo/Qian/dataset_2_reg/training_data/
path_data_regressor_validation: /home/turbo/Qian/dataset_2_reg/validation_data/

path_data_classifier_training: /home/turbo/Qian/dataset_3_class/training_data/
path_data_classifier_validation: /home/turbo/Qian/dataset_3_class/validation_data/



prefix_topology_file: gt_topo_
prefix_pf_file: cons_pf_array_
prefix_load_file: cons_load_array_
100 changes: 100 additions & 0 deletions examples/generative/topodiff/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import torch
from torch.optim import AdamW
import torch.nn.functional as F
from tqdm import trange
import numpy as np
import matplotlib.pyplot as plt


import hydra
from hydra.utils import to_absolute_path
from omegaconf import DictConfig

from modulus.models.topodiff import TopoDiff, Diffusion
from modulus.models.topodiff import UNetEncoder
from modulus.launch.logging import (
PythonLogger,
initialize_wandb
)
from utils import load_data_topodiff, load_data

@hydra.main(version_base="1.3", config_path="conf", config_name="config")
def main(cfg: DictConfig) -> None:

logger = PythonLogger("main") # General Python Logger
logger.log("Job start!")


topologies = np.random.randn(1800, 64,64)
vfs_stress_strain = load_data('/home/turbo/Qian/dataset_1_diff/test_data_level_1/',cfg.prefix_pf_file, '.npy', 200,2000)
load_imgs = load_data('/home/turbo/Qian/dataset_1_diff/test_data_level_1/', cfg.prefix_load_file, '.npy', 200,2000)

device = torch.device('cuda:1')
model = TopoDiff(64, 6, 1, model_channels=128, attn_resolutions=[16,8])
model.load_state_dict(torch.load(cfg.model_path_diffusion))
model.to(device)

classifier = UNetEncoder(in_channels = 1, out_channels=2)
classifier.load_state_dict(torch.load(cfg.model_path_classifier))
classifier.to(device)

diffusion = Diffusion(n_steps=1000,device=device)
batch_size = cfg.batch_size
data = load_data_topodiff(
topologies, vfs_stress_strain, load_imgs, batch_size= batch_size,deterministic=False
)

_, cons = next(data)

cons = cons.float().to(device)

n_steps = 1000

xt = torch.randn(batch_size, 1, 64, 64).to(device)
floating_labels = torch.tensor([1]*batch_size).long().to(device)

for i in reversed(trange(n_steps)):
with torch.no_grad():
t = torch.tensor([i] * batch_size, device = device)
noisy = diffusion.p_sample(model,xt, t, cons)

with torch.enable_grad():
xt.requires_grad_(True)
logits = classifier(xt,time_steps=t)
loss = F.cross_entropy(logits,floating_labels)

grad = torch.autograd.grad(loss, xt)[0]

xt = 1 / diffusion.alphas[i].sqrt() * (xt - noisy * (1 - diffusion.alphas[i])/(1 - diffusion.alpha_bars[i]).sqrt())

if i > 0:
z = torch.zeros_like(xt).to(device)
xt = xt + diffusion.betas[i].sqrt() * (z * 0.8 + 0.2 * grad.float())

result = (xt.cpu().detach().numpy() + 1) * 2

np.save(cfg.generation_path + 'results_topology.npy', result)

# plot images for the generated samples
fig, axes = plt.subplots(8,8, figsize=(12,6),dpi=300)

for i in range(8):
for j in range(8):
img = result[i*4 + j ][0]
axes[i,j].imshow(img, cmap='gray')
axes[i,j].set_xticks([])
axes[i,j].set_yticks([])


plt.xticks([]) # Remove x-axis ticks
plt.yticks([]) # Remove y-axis ticks
plt.gca().xaxis.set_visible(False) # Optionally hide x-axis
plt.gca().yaxis.set_visible(False) # Optionally hide y-axis

plt.savefig(cfg.generation_path + 'grid_topology.png', bbox_inches='tight', pad_inches=0)



if __name__ == "__main__":
main()

72 changes: 72 additions & 0 deletions examples/generative/topodiff/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import torch
from torch.optim import AdamW
from tqdm import trange
import numpy as np
import time, os


import hydra
from hydra.utils import to_absolute_path
from omegaconf import DictConfig

from modulus.models.topodiff import TopoDiff, Diffusion
from modulus.models.topodiff import UNetEncoder
from modulus.launch.logging import (
PythonLogger,
initialize_wandb
)
from utils import load_data_topodiff, load_data

@hydra.main(version_base="1.3", config_path="conf", config_name="config")
def main(cfg: DictConfig) -> None:

logger = PythonLogger("main") # General Python Logger
logger.log("Job start!")

device = torch.device('cuda:0')
model = TopoDiff(64, 6, 1, model_channels=128, attn_resolutions=[16,8]).to(device)
diffusion = Diffusion(n_steps=1000,device=device)


topologies = load_data(cfg.path_training_data_diffusion, cfg.prefix_topology_file, '.png', 0,30000)
vfs_stress_strain = load_data(cfg.path_training_data_diffusion,cfg.prefix_pf_file, '.npy', 0,30000)
load_imgs = load_data(cfg.path_training_data_diffusion, cfg.prefix_load_file, '.npy', 0,30000)

batch_size = cfg.batch_size
data = load_data_topodiff(
topologies, vfs_stress_strain, load_imgs, batch_size= batch_size,deterministic=False
)

lr = cfg.lr
optimizer = AdamW(model.parameters(), lr=lr)
logger.log("Start training!")

prog = trange(cfg.epochs)

for step in prog:

tops, cons = next(data)

tops = tops.float().to(device)
cons = cons.float().to(device)


losses = diffusion.train_loss(model, tops, cons)

optimizer.zero_grad()
losses.backward()
optimizer.step()

if step % 100 == 0:
logger.info("epoch: %d, loss: %.5f" % (step, losses.item()))

torch.save(model.state_dict(), cfg.model_path + "topodiff_model.pt")
logger.info("Training completed!")


# ----------------------------------------------------------------------------

if __name__ == "__main__":
main()

# ----------------------------------------------------------------------------
Loading