diff --git a/CHANGELOG.md b/CHANGELOG.md index 964a981d86..4951dfde73 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added gradient clipping to StaticCapture utilities. - Bistride Multiscale MeshGraphNet example. - FIGConvUNet model and example. +- TopoDiff model and example. ### Changed diff --git a/docs/img/topodiff_doc/grid_topology.png b/docs/img/topodiff_doc/grid_topology.png new file mode 100644 index 0000000000..3fc9909d9d Binary files /dev/null and b/docs/img/topodiff_doc/grid_topology.png differ diff --git a/docs/img/topodiff_doc/topodiff.png b/docs/img/topodiff_doc/topodiff.png new file mode 100644 index 0000000000..9b8ef23240 Binary files /dev/null and b/docs/img/topodiff_doc/topodiff.png differ diff --git a/docs/img/topodiff_doc/topology_generated.png b/docs/img/topodiff_doc/topology_generated.png new file mode 100644 index 0000000000..88bc975f21 Binary files /dev/null and b/docs/img/topodiff_doc/topology_generated.png differ diff --git a/examples/generative/topodiff/README.md b/examples/generative/topodiff/README.md new file mode 100644 index 0000000000..13d5e705cc --- /dev/null +++ b/examples/generative/topodiff/README.md @@ -0,0 +1,97 @@ +# TopoDiff +We propose Topodiff, a conditional diffusion-model-based architecture to perform performance-aware and manufacturability-aware topology optimization that overcomes the issues of Generative Adversarial Networks (GANs) such as difficult to train, limited generalizability and neglecting manufacturability. Topodiff introduces a surrogate model-based guidance strategy that actively favors structures with low compliance and good manufacturability. +- Link to the paper: [Link](https://arxiv.org/abs/2208.09591) +- Link to the project: [Link](https://decode.mit.edu/projects/topodiff/) + +
+ +
+ +## Dataset +- Link to the complete dataset: [here](https://www.dropbox.com/home/decode_lab/Datasets/Public%20Documents/Topodiff_dataset) +- Link to the dataset for diffusion model training: [here](https://www.dropbox.com/scl/fi/9jy96a0lyf39wdwc27se7/dataset_1_diff.zip?rlkey=zz0ijw8e5h0hf0fb7qpnbwrgu&st=t7nuh1w7&dl=0) +- Link to the dataset for regression model training: [here](https://www.dropbox.com/scl/fi/486kmqbghzxuxqewjm9b4/dataset_2_reg.zip?rlkey=gw8yu1lv40tqk192py7wsl9o9&st=ao9yc1rw&dl=0) +- Link to the dataset for classifier model training: [here](https://www.dropbox.com/scl/fi/486kmqbghzxuxqewjm9b4/dataset_2_reg.zip?rlkey=gw8yu1lv40tqk192py7wsl9o9&st=ao9yc1rw&dl=0) + + +Download the dataset and set the **path_data** in [the config yaml file](conf/config.yaml) + +## Instructions +2D topology structures could be generated by Topodiff conditioned on the boundary and loading conditions. A few examples are shown below: ++ +
+ +### Model training +Before training the model, config.yaml should be: +``` +hydra: + job: + chdir: True + run: + dir: ./outputs/ + + +path_data: path to the Topodiff dataset downloaded + +epochs: 100000 +batch_size: 128 +lr: 1e-4 + +classifier_iterations: 20000 +regressor_iterations: 20000 +diffusion_steps: 1000 + +generation_path: ./ + +model_path_diffusion: path to the pt file of the diffusion model +model_path_classifier: path to the pt file of the classifier for floating material + + +path_training_data_diffusion: path to the /dataset_1_diff/training_data/ + +path_data_regressor_training: path to the /dataset_2_reg/training_data/ +path_data_regressor_validation: path to the /dataset_2_reg/validation_data/ + +path_data_classifier_training: path to the /dataset_3_class/training_data/ +path_data_classifier_validation: path to the /dataset_3_class/validation_data/ + +``` + +Run the following command to train the diffusion model, classifier for floating material and regressor for compliance: +```Bash +python examples/generative/topodiff/train.py +python examples/generative/topodiff/train_classifier.py +python examples/generative/topodiff/train_regressor.py +``` +### Generation +By default, the generated topologies are conditioned on the boundary and loading conditions that have not been seen in the training process. +Run the following command to generate topologies: +```Bash +python examples/generative/topodiff/inference.py +``` + + + + +## Citations +To cite this work, please use the following reference: + +```bibtex +@inproceedings{maze2023diffusion, + title={Diffusion models beat gans on topology optimization}, + author={Maz{\'e}, Fran{\c{c}}ois and Ahmed, Faez}, + booktitle={Proceedings of the AAAI conference on artificial intelligence}, + volume={37}, + number={8}, + pages={9108--9116}, + year={2023} +} +``` + +## References +- [Diffusion Models Beat GANs on Topology Optimization](https://decode.mit.edu/assets/papers/2022_maze_topodiff.pdf) +- [Topodiff Project Page](https://decode.mit.edu/projects/topodiff/) +- [Topodiff Dataset](https://www.dropbox.com/home/decode_lab/Datasets/Public%20Documents/Topodiff_dataset) +- [Github](https://github.com/francoismaze/topodiff) + diff --git a/examples/generative/topodiff/conf/config.yaml b/examples/generative/topodiff/conf/config.yaml new file mode 100644 index 0000000000..4865bd1d04 --- /dev/null +++ b/examples/generative/topodiff/conf/config.yaml @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: True + run: + dir: ./outputs/ + + +epochs: 100000 +batch_size: 64 +lr: 1e-4 + +classifier_iterations: 30000 +regressor_iterations: 30000 +diffusion_steps: 1000 +model_path: ./ +generation_path: ./ +model_path_diffusion: /home/turbo/Qian/modulus/modulus/outputs/model_100000.pt +model_path_classifier: /home/turbo/Qian/modulus/modulus/outputs/classifier.pt + + +path_training_data_diffusion: /home/turbo/Qian/dataset_1_diff/training_data/ + +path_data_regressor_training: /home/turbo/Qian/dataset_2_reg/training_data/ +path_data_regressor_validation: /home/turbo/Qian/dataset_2_reg/validation_data/ + +path_data_classifier_training: /home/turbo/Qian/dataset_3_class/training_data/ +path_data_classifier_validation: /home/turbo/Qian/dataset_3_class/validation_data/ + + + +prefix_topology_file: gt_topo_ +prefix_pf_file: cons_pf_array_ +prefix_load_file: cons_load_array_ \ No newline at end of file diff --git a/examples/generative/topodiff/inference.py b/examples/generative/topodiff/inference.py new file mode 100644 index 0000000000..baebb174cd --- /dev/null +++ b/examples/generative/topodiff/inference.py @@ -0,0 +1,100 @@ +import torch +from torch.optim import AdamW +import torch.nn.functional as F +from tqdm import trange +import numpy as np +import matplotlib.pyplot as plt + + +import hydra +from hydra.utils import to_absolute_path +from omegaconf import DictConfig + +from modulus.models.topodiff import TopoDiff, Diffusion +from modulus.models.topodiff import UNetEncoder +from modulus.launch.logging import ( + PythonLogger, + initialize_wandb +) +from utils import load_data_topodiff, load_data + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: + + logger = PythonLogger("main") # General Python Logger + logger.log("Job start!") + + + topologies = np.random.randn(1800, 64,64) + vfs_stress_strain = load_data('/home/turbo/Qian/dataset_1_diff/test_data_level_1/',cfg.prefix_pf_file, '.npy', 200,2000) + load_imgs = load_data('/home/turbo/Qian/dataset_1_diff/test_data_level_1/', cfg.prefix_load_file, '.npy', 200,2000) + + device = torch.device('cuda:1') + model = TopoDiff(64, 6, 1, model_channels=128, attn_resolutions=[16,8]) + model.load_state_dict(torch.load(cfg.model_path_diffusion)) + model.to(device) + + classifier = UNetEncoder(in_channels = 1, out_channels=2) + classifier.load_state_dict(torch.load(cfg.model_path_classifier)) + classifier.to(device) + + diffusion = Diffusion(n_steps=1000,device=device) + batch_size = cfg.batch_size + data = load_data_topodiff( + topologies, vfs_stress_strain, load_imgs, batch_size= batch_size,deterministic=False + ) + + _, cons = next(data) + + cons = cons.float().to(device) + + n_steps = 1000 + + xt = torch.randn(batch_size, 1, 64, 64).to(device) + floating_labels = torch.tensor([1]*batch_size).long().to(device) + + for i in reversed(trange(n_steps)): + with torch.no_grad(): + t = torch.tensor([i] * batch_size, device = device) + noisy = diffusion.p_sample(model,xt, t, cons) + + with torch.enable_grad(): + xt.requires_grad_(True) + logits = classifier(xt,time_steps=t) + loss = F.cross_entropy(logits,floating_labels) + + grad = torch.autograd.grad(loss, xt)[0] + + xt = 1 / diffusion.alphas[i].sqrt() * (xt - noisy * (1 - diffusion.alphas[i])/(1 - diffusion.alpha_bars[i]).sqrt()) + + if i > 0: + z = torch.zeros_like(xt).to(device) + xt = xt + diffusion.betas[i].sqrt() * (z * 0.8 + 0.2 * grad.float()) + + result = (xt.cpu().detach().numpy() + 1) * 2 + + np.save(cfg.generation_path + 'results_topology.npy', result) + + # plot images for the generated samples + fig, axes = plt.subplots(8,8, figsize=(12,6),dpi=300) + + for i in range(8): + for j in range(8): + img = result[i*4 + j ][0] + axes[i,j].imshow(img, cmap='gray') + axes[i,j].set_xticks([]) + axes[i,j].set_yticks([]) + + + plt.xticks([]) # Remove x-axis ticks + plt.yticks([]) # Remove y-axis ticks + plt.gca().xaxis.set_visible(False) # Optionally hide x-axis + plt.gca().yaxis.set_visible(False) # Optionally hide y-axis + + plt.savefig(cfg.generation_path + 'grid_topology.png', bbox_inches='tight', pad_inches=0) + + + +if __name__ == "__main__": + main() + \ No newline at end of file diff --git a/examples/generative/topodiff/train.py b/examples/generative/topodiff/train.py new file mode 100644 index 0000000000..100168907b --- /dev/null +++ b/examples/generative/topodiff/train.py @@ -0,0 +1,72 @@ +import torch +from torch.optim import AdamW +from tqdm import trange +import numpy as np +import time, os + + +import hydra +from hydra.utils import to_absolute_path +from omegaconf import DictConfig + +from modulus.models.topodiff import TopoDiff, Diffusion +from modulus.models.topodiff import UNetEncoder +from modulus.launch.logging import ( + PythonLogger, + initialize_wandb +) +from utils import load_data_topodiff, load_data + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: + + logger = PythonLogger("main") # General Python Logger + logger.log("Job start!") + + device = torch.device('cuda:0') + model = TopoDiff(64, 6, 1, model_channels=128, attn_resolutions=[16,8]).to(device) + diffusion = Diffusion(n_steps=1000,device=device) + + + topologies = load_data(cfg.path_training_data_diffusion, cfg.prefix_topology_file, '.png', 0,30000) + vfs_stress_strain = load_data(cfg.path_training_data_diffusion,cfg.prefix_pf_file, '.npy', 0,30000) + load_imgs = load_data(cfg.path_training_data_diffusion, cfg.prefix_load_file, '.npy', 0,30000) + + batch_size = cfg.batch_size + data = load_data_topodiff( + topologies, vfs_stress_strain, load_imgs, batch_size= batch_size,deterministic=False + ) + + lr = cfg.lr + optimizer = AdamW(model.parameters(), lr=lr) + logger.log("Start training!") + + prog = trange(cfg.epochs) + + for step in prog: + + tops, cons = next(data) + + tops = tops.float().to(device) + cons = cons.float().to(device) + + + losses = diffusion.train_loss(model, tops, cons) + + optimizer.zero_grad() + losses.backward() + optimizer.step() + + if step % 100 == 0: + logger.info("epoch: %d, loss: %.5f" % (step, losses.item())) + + torch.save(model.state_dict(), cfg.model_path + "topodiff_model.pt") + logger.info("Training completed!") + + +# ---------------------------------------------------------------------------- + +if __name__ == "__main__": + main() + +# ---------------------------------------------------------------------------- \ No newline at end of file diff --git a/examples/generative/topodiff/train_classifier.py b/examples/generative/topodiff/train_classifier.py new file mode 100644 index 0000000000..66eaf6b8b7 --- /dev/null +++ b/examples/generative/topodiff/train_classifier.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.optim import AdamW +from torch.optim.lr_scheduler import LinearLR +from tqdm import trange +import numpy as np +import time, os + + + +import hydra +from hydra.utils import to_absolute_path +from omegaconf import DictConfig + +from modulus.models.topodiff import Diffusion +from modulus.models.topodiff import UNetEncoder +from modulus.launch.logging import ( + PythonLogger, + initialize_wandb +) +from utils import load_data_topodiff, load_data_classifier + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: + + logger = PythonLogger("main") # General Python Logger + logger.log("Start running") + + train_img, train_labels = load_data_classifier(cfg.path_data_classifier_training) + valid_img, valid_labels = load_data_classifier(cfg.path_data_classifier_validation) + train_img = 2 * train_img - 1 + valid_img = 2 * valid_img - 1 + + device = torch.device('cuda:1') + + + classifier = UNetEncoder(in_channels = 1, out_channels=2).to(device) + + diffusion = Diffusion(n_steps=cfg.diffusion_steps,device=device) + + batch_size = cfg.batch_size + + + lr = cfg.lr + optimizer = AdamW(classifier.parameters(), lr=lr) + scheduler = LinearLR(optimizer, start_factor=1, end_factor=0.001, total_iters=cfg.classifier_iterations) + + for i in range(cfg.classifier_iterations+1): + # get random batch from training data + + idx = np.random.choice(len(train_img), batch_size, replace=False) + batch = torch.tensor(train_img[idx]).float().unsqueeze(1).to(device)*2-1 + batch_labels = torch.tensor(train_labels[idx]).long().to(device) + + t = torch.randint(0, cfg.diffusion_steps, (batch.shape[0], )).to(device) + batch = diffusion.q_sample(batch, t) + logits = classifier(batch,time_steps=t) + + loss = F.cross_entropy(logits,batch_labels) + optimizer.zero_grad() + loss.backward() + optimizer.step() + scheduler.step() + + if i % 100 == 0: + with torch.no_grad(): + idx = np.random.choice(len(valid_img), batch_size, replace=False) + batch = torch.tensor(valid_img[idx]).float().unsqueeze(1).to(device) * 2 - 1 + batch_labels = torch.tensor(valid_labels[idx]).long().to(device) + + # Sample diffusion steps and get noised images + t = torch.randint(0, cfg.diffusion_steps, (batch.shape[0], )).to(device) + batch = diffusion.q_sample(batch, t) + + # Forward pass + logits = classifier(batch, time_steps=t) + + # Compute accuracy + predicted_labels = torch.argmax(logits, dim=1) + correct_predictions = (predicted_labels == batch_labels).sum().item() + accuracy = correct_predictions / batch_size + + print("epoch: %d, loss: %.5f, validation accuracy: %.3f" % (i, loss.item(), accuracy)) + #if i % 10000 == 0: + # torch.save(classifier.state_dict(), cfg.model_path + "classifier_" +str(i) + ".pt") + torch.save(classifier.state_dict(), cfg.model_path + "classifier.pt") + + print("job done!") + + +# ---------------------------------------------------------------------------- + +if __name__ == "__main__": + main() + +# ---------------------------------------------------------------------------- \ No newline at end of file diff --git a/examples/generative/topodiff/train_regressor.py b/examples/generative/topodiff/train_regressor.py new file mode 100644 index 0000000000..0b7563c9ab --- /dev/null +++ b/examples/generative/topodiff/train_regressor.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn +from torch.optim import AdamW +from torch.optim.lr_scheduler import LinearLR +from tqdm import trange +import numpy as np +import time, os + + + +import hydra +from hydra.utils import to_absolute_path +from omegaconf import DictConfig + +from modulus.models.topodiff import Diffusion +from modulus.models.topodiff import UNetEncoder +from modulus.launch.logging import ( + PythonLogger, + initialize_wandb +) +from utils import load_data_topodiff, load_data_regressor + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: + + logger = PythonLogger("main") # General Python Logger + logger.log("Start running") + + topologies, load_imgs, vfs_stress_strain, labels = load_data_regressor(cfg.path_data_regressor_training) + topologies = topologies*2 - 1 # Normalize the range of image to be [-1, 1] + """ + topologies = np.load(cfg.path_data + "Compliance/Training/topologies.npy").astype(np.float64) + constraints = np.load(cfg.path_data + "Compliance/Training/constraints.npy", allow_pickle=True) + stress = np.load(cfg.path_data+ "Compliance/Training/vonmises.npy", allow_pickle=True) + strain = np.load(cfg.path_data + "Compliance/Training/strain_energy.npy", allow_pickle=True) + load_imgs = np.load(cfg.path_data + "Compliance/Training/load_imgs.npy") + bc_imgs = np.load(cfg.path_data + "Compliance/Training/bc_imgs.npy").astype(np.float64) + Compliance = np.load(cfg.path_data + "Compliance/Training/compliance.npy").astype(np.float64) + """ + + device = torch.device('cuda:0') + + in_channels = 6 + regressor = UNetEncoder(in_channels = in_channels, out_channels=1).to(device) + + diffusion = Diffusion(n_steps=cfg.diffusion_steps,device=device) + + + batch_size = cfg.batch_size + """ + data = load_data_topodiff( + topologies, vfs_stress_strain, load_imgs, batch_size= batch_size,deterministic=False + ) + """ + lr = cfg.lr + optimizer = AdamW(regressor.parameters(), lr=lr) + scheduler = LinearLR(optimizer, start_factor=1, end_factor=0.001, total_iters=cfg.regressor_iterations) + + loss_fn = nn.MSELoss() + for i in range(cfg.regressor_iterations+1): + + # get random batch from training data + idx = np.random.choice(len(topologies), batch_size, replace=False) + batch = torch.tensor(topologies[idx]).float().unsqueeze(1).to(device)*2-1 # 4 x 1 x 64 x 64 + batch_pf = torch.tensor(vfs_stress_strain[idx]).float().permute(0,3,1,2).to(device) + batch_load = torch.tensor(load_imgs[idx]).float().permute(0,3,1,2).to(device) + + batch_labels = torch.tensor(labels[idx]).float().to(device).unsqueeze(1) + + + + t = torch.randint(0, cfg.diffusion_steps, (batch.shape[0], )).to(device) + batch = diffusion.q_sample(batch, t) + + batch = torch.cat((batch,batch_pf,batch_load),dim=1) + + logits = regressor(batch,time_steps=t) + + loss = loss_fn(logits,batch_labels) + optimizer.zero_grad() + loss.backward() + optimizer.step() + scheduler.step() + + if i % 100 == 0: + print("epoch: %d, loss: %.5f" % (i, loss.item())) + + torch.save(regressor.state_dict(), cfg.model_path + "regressor.pt") + print("job done!") + + +# ---------------------------------------------------------------------------- + +if __name__ == "__main__": + main() + +# ---------------------------------------------------------------------------- \ No newline at end of file diff --git a/examples/generative/topodiff/utils.py b/examples/generative/topodiff/utils.py new file mode 100644 index 0000000000..59ce01f967 --- /dev/null +++ b/examples/generative/topodiff/utils.py @@ -0,0 +1,170 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from PIL import Image +from torch.utils.data import Dataset, DataLoader +import os + +""" +class DatasetTopoDiff(Dataset): + + def __init__(self, topologies, stress, strain, load_im, constraints): + + self.topologies = topologies + self.constraints = constraints + self.image_size = topologies.shape[1] + + self.stress = stress + self.strain = strain + self.load_im = load_im + + def __len__(self): + return self.topologies.shape[0] + + def __getitem__(self, idx): + + cons = self.constraints[idx] + + vol_frac = cons['VOL_FRAC'] + + cons = np.zeros((5, self.image_size, self.image_size)) + + cons[0] = self.stress[idx] + cons[1] = self.strain[idx] + cons[2] = self.load_im[idx][:,:,0] + cons[3] = self.load_im[idx][:,:,1] + cons[4] = np.ones((self.image_size,self.image_size)) * vol_frac + + return np.expand_dims(self.topologies[idx], 0) * 2 - 1, cons + +def load_data_topodiff(topologies, constraints, stress, strain, load_img, batch_size, deterministic=False): + dataset = DatasetTopoDiff(topologies, stress, strain, load_img, constraints) + if deterministic: + loader = DataLoader( + dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True + ) + else: + loader = DataLoader( + dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True + ) + + while True: + yield from loader +""" + +class DiffusionDataset_topodiff(Dataset): + def __init__(self, topologies, vfs_stress_strain, load_im): + + image_size = topologies.shape[1] + + self.topologies = topologies + self.vfs_stress_strain = vfs_stress_strain + self.image_size = image_size + self.load_im = load_im + + def __len__(self): + return self.topologies.shape[0] + + def __getitem__(self, idx): + + cons = np.zeros((5, self.image_size, self.image_size)) + + cons[0] = self.vfs_stress_strain[idx][:,:,0] + cons[1] = self.vfs_stress_strain[idx][:,:,1] + cons[2] = self.vfs_stress_strain[idx][:,:,2] + cons[3] = self.load_im[idx][:,:,0] + cons[4] = self.load_im[idx][:,:,1] + + return np.expand_dims(self.topologies[idx], 0) * 2 - 1, cons + +def load_data_topodiff( + topologies, vfs_stress_strain, load_im, batch_size, deterministic=False +): + + dataset = DiffusionDataset_topodiff( + topologies, vfs_stress_strain, load_im + ) + + if deterministic: + loader = DataLoader( + dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True + ) + else: + loader = DataLoader( + dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True + ) + while True: + yield from loader + +def load_data(root, prefix, file_format, num_file_start=0,num_file_end=30000): + """ + root: path to the folder of training data + prefix: file prefix to the ground truth topology, boundary condition and stress/strain + file_format: .npy for the conditions; .png for the ground truth topologies + """ + data_array = [] + + for i in range(num_file_start, num_file_end): + file= f'{root}{prefix}{i}{file_format}' + if file_format == '.npy': + data_array.append(np.load(file)) + elif file_format == '.png': + data_array.append(np.array(Image.open(file))/255) + else: + raise NotImplementedError + + return np.array(data_array).astype(np.float64) + +def load_data_regressor(root): + + file_list = os.listdir(root) + idx_list = [] + for file in file_list: + if file.startswith('gt_topo_'): + idx = int(file.split('.')[0][8:]) + idx_list.append(idx) + idx_list.sort() + + topology_array, load_array, pf_array = [], [], [] + for i in idx_list: + + topology_array.append(np.array(Image.open(root + "gt_topo_" + str(i) + '.png'))/255) + load_array.append(np.load(root + "cons_load_array_" + str(i) + '.npy')) + pf_array.append(np.load(root + "cons_pf_array_" + str(i) + '.npy')) + + labels = np.load(root + 'deflections_scaled_diff.npy') + return np.array(topology_array).astype(np.float64), np.array(load_array).astype(np.float64), np.array(pf_array).astype(np.float64), labels[idx_list] + + +def load_data_classifier(root): + """ + root: path to the folder of training data + prefix: file prefix to the ground truth topology, boundary condition and stress/strain + file_format: .npy for the conditions; .png for the ground truth topologies + """ + file_list= os.listdir(root) + labels = np.load(root + 'labels.npy') + image_list = [] + label_list = [] + for file in file_list: + if file.startswith('img_'): + idx = int(file.split('.')[0][4:]) + image = Image.open(root + file) + image_list.append(np.array(image)/255) + label_list.append(labels[idx]) + + return np.array(image_list).astype(np.float64), np.array(label_list).astype(np.float64) \ No newline at end of file diff --git a/modulus/models/topodiff/__init__.py b/modulus/models/topodiff/__init__.py new file mode 100644 index 0000000000..ae9f66a7f2 --- /dev/null +++ b/modulus/models/topodiff/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .topodiff import * +from .utils import * +from .diffusion import * diff --git a/modulus/models/topodiff/diffusion.py b/modulus/models/topodiff/diffusion.py new file mode 100644 index 0000000000..3c69e4a81c --- /dev/null +++ b/modulus/models/topodiff/diffusion.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Diffusion: + + def __init__(self, n_steps=1000, min_beta=10**-4, max_beta=0.02, device='cpu'): + + self.n_steps = n_steps + self.device = device + + self.betas = torch.linspace(min_beta, max_beta, self.n_steps).to(device) + + self.alphas = 1 - self.betas + + self.alpha_bars = torch.cumprod(self.alphas, 0).to(device) + + self.alpha_bars_prev = F.pad(self.alpha_bars[:-1], [1,0], 'constant', 0) + + self.posterior_variance = self.betas * (1. - self.alpha_bars_prev) / (1. - self.alpha_bars) + + self.loss = nn.MSELoss() + def q_sample(self, x0, t, noise=None): + + if noise is None: + noise = torch.rand_like(x0).to(self.device) + + alpha_bars = self.alpha_bars[t] + + x = alpha_bars.sqrt()[:,None, None, None] * x0 + (1 - alpha_bars).sqrt()[:, None, None, None] * noise + + return x + + def p_sample(self, model, xt, t, cons): + + return model(xt, cons, t) + + def train_loss(self, model, x0, cons): + + b, c, w, h = x0.shape + noise = torch.randn_like(x0).to(self.device) + + t = torch.randint(0, self.n_steps, (b,)).to(self.device) + + xt = self.q_sample(x0, t, noise) + + pred_noise = self.p_sample(model, xt, t, cons) + + return self.loss(pred_noise, noise) \ No newline at end of file diff --git a/modulus/models/topodiff/topodiff.py b/modulus/models/topodiff/topodiff.py new file mode 100644 index 0000000000..6f350ece48 --- /dev/null +++ b/modulus/models/topodiff/topodiff.py @@ -0,0 +1,333 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model architectures used in the paper Diffusion models beat gans on image synthesis". +""" + + +from dataclasses import dataclass +from typing import List + +import numpy as np +import torch +import torch.nn as nn +from torch.nn.functional import silu + +from ..diffusion import ( + Conv2d, + GroupNorm, + Linear, + PositionalEmbedding, + UNetBlock, +) +from ..meta import ModelMetaData +from ..module import Module + + +@dataclass +class MetaData(ModelMetaData): + name: str = "TopoDiff" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + # Inference + onnx: bool = True + # Physics informed + var_dim: int = 1 + func_torch: bool = False + auto_grad: bool = False + + +class TopoDiff(Module): + """ + Reimplementation of the ADM architecture, a U-Net variant, with optional + self-attention. + + This model supports conditional and unconditional setups, as well as several + options for various internal architectural choices such as encoder and decoder + type, embedding type, etc., making it flexible and adaptable to different tasks + and configurations. + + Parameters: + ----------- + img_resolution : int + The resolution of the input/output image. + in_channels : int + Number of channels in the input image. + out_channels : int + Number of channels in the output image. + label_dim : int, optional + Number of class labels; 0 indicates an unconditional model. By default 0. + augment_dim : int, optional + Dimensionality of augmentation labels; 0 means no augmentation. By default 0. + model_channels : int, optional + Base multiplier for the number of channels across the network, by default 192. + channel_mult : List[int], optional + Per-resolution multipliers for the number of channels. By default [1,2,3,4]. + channel_mult_emb : int, optional + Multiplier for the dimensionality of the embedding vector. By default 4. + num_blocks : int, optional + Number of residual blocks per resolution. By default 3. + attn_resolutions : List[int], optional + Resolutions at which self-attention layers are applied. By default [32, 16, 8]. + dropout : float, optional + Dropout probability applied to intermediate activations. By default 0.10. + label_dropout : float, optional + Dropout probability of class labels for classifier-free guidance. By default 0.0. + + Note: + ----- + Reference: Dhariwal, P. and Nichol, A., 2021. Diffusion models beat gans on image + synthesis. Advances in neural information processing systems, 34, pp.8780-8794. + + Note: + ----- + Equivalent to the original implementation by Dhariwal and Nichol, available at + https://github.com/openai/guided-diffusion + + Example: + -------- + >>> model = DhariwalUNet(img_resolution=16, in_channels=2, out_channels=2) + >>> noise_labels = torch.randn([1]) + >>> class_labels = torch.randint(0, 1, (1, 1)) + >>> input_image = torch.ones([1, 2, 16, 16]) + >>> output_image = model(input_image, noise_labels, class_labels) + """ + + def __init__( + self, + img_resolution: int, + in_channels: int, + out_channels: int, + label_dim: int = 0, + augment_dim: int = 0, + model_channels: int = 128, + channel_mult: List[int] = [1, 1, 1, 1], + channel_mult_emb: int = 4, + num_blocks: int = 2, + attn_resolutions: List[int] = [16, 8], + dropout: float = 0.10, + label_dropout: float = 0.0, + ): + super().__init__(meta=MetaData()) + self.label_dropout = label_dropout + emb_channels = model_channels * channel_mult_emb + init = dict( + init_mode="kaiming_uniform", + init_weight=np.sqrt(1 / 3), + init_bias=np.sqrt(1 / 3), + ) + init_zero = dict(init_mode="kaiming_uniform", init_weight=0, init_bias=0) + block_kwargs = dict( + emb_channels=emb_channels, + channels_per_head=64, + dropout=dropout, + init=init, + init_zero=init_zero, + ) + + # Mapping. + self.map_noise = PositionalEmbedding(num_channels=model_channels) + self.map_augment = ( + Linear( + in_features=augment_dim, + out_features=model_channels, + bias=False, + **init_zero, + ) + if augment_dim + else None + ) + self.map_layer0 = Linear( + in_features=model_channels, out_features=emb_channels, **init + ) + self.map_layer1 = Linear( + in_features=emb_channels, out_features=emb_channels, **init + ) + self.map_label = ( + Linear( + in_features=label_dim, + out_features=emb_channels, + bias=False, + init_mode="kaiming_normal", + init_weight=np.sqrt(label_dim), + ) + if label_dim + else None + ) + + # Encoder. + self.enc = torch.nn.ModuleDict() + cout = in_channels + for level, mult in enumerate(channel_mult): + res = img_resolution >> level + if level == 0: + cin = cout + cout = model_channels * mult + self.enc[f"{res}x{res}_conv"] = Conv2d( + in_channels=cin, out_channels=cout, kernel=3, **init + ) + else: + self.enc[f"{res}x{res}_down"] = UNetBlock( + in_channels=cout, out_channels=cout, down=True, **block_kwargs + ) + for idx in range(num_blocks): + cin = cout + cout = model_channels * mult + self.enc[f"{res}x{res}_block{idx}"] = UNetBlock( + in_channels=cin, + out_channels=cout, + attention=(res in attn_resolutions), + **block_kwargs, + ) + skips = [block.out_channels for block in self.enc.values()] + + # Decoder. + self.dec = torch.nn.ModuleDict() + for level, mult in reversed(list(enumerate(channel_mult))): + res = img_resolution >> level + if level == len(channel_mult) - 1: + self.dec[f"{res}x{res}_in0"] = UNetBlock( + in_channels=cout, out_channels=cout, attention=True, **block_kwargs + ) + self.dec[f"{res}x{res}_in1"] = UNetBlock( + in_channels=cout, out_channels=cout, **block_kwargs + ) + else: + self.dec[f"{res}x{res}_up"] = UNetBlock( + in_channels=cout, out_channels=cout, up=True, **block_kwargs + ) + for idx in range(num_blocks + 1): + cin = cout + skips.pop() + cout = model_channels * mult + self.dec[f"{res}x{res}_block{idx}"] = UNetBlock( + in_channels=cin, + out_channels=cout, + attention=(res in attn_resolutions), + **block_kwargs, + ) + self.out_norm = GroupNorm(num_channels=cout) + self.out_conv = Conv2d( + in_channels=cout, out_channels=out_channels, kernel=3, **init_zero + ) + + def forward(self, x, cons, timesteps): + # Mapping. + emb = self.map_noise(timesteps) + + emb = silu(self.map_layer0(emb)) + emb = self.map_layer1(emb) + emb = silu(emb) + + x = torch.cat([x, cons], dim=1) + # Encoder. + skips = [] + for block in self.enc.values(): + x = block(x, emb) if isinstance(block, UNetBlock) else block(x) + skips.append(x) + + # Decoder. + for block in self.dec.values(): + if x.shape[1] != block.in_channels: + x = torch.cat([x, skips.pop()], dim=1) + x = block(x, emb) + x = self.out_conv(silu(self.out_norm(x))) + return x + +class UNetEncoder(Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + model_channels: int = 128, + num_res_blocks: int = 4, + channel_mult: tuple = (1, 1, 1, 1), + channel_mult_emb: int = 4, + attention_resolutions: tuple = (16, 8), + dropout=0, + output_prob=False): + + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.dropout = dropout + self.map_noise = PositionalEmbedding(num_channels = model_channels) + self.output_prob = output_prob + + ch = int(model_channels*channel_mult[0]) + self.conv = Conv2d(in_channels=in_channels, out_channels=ch,kernel=3) + + emb_channels = model_channels * channel_mult_emb + self.time_embed = nn.Sequential( + Linear(in_features=model_channels, out_features=emb_channels), + nn.SiLU(), + Linear(in_features=emb_channels, out_features=emb_channels) + ) + + ds = 1 + self.encoder = nn.ModuleList() + for level, mult in enumerate(channel_mult): + attention = ds in attention_resolutions + for i in range(num_res_blocks): + + down = (i == num_res_blocks - 1 and level != len(channel_mult) - 1) + + layer = UNetBlock(in_channels=ch, + out_channels=int(mult * model_channels), + emb_channels=emb_channels, + down=down, + attention=attention) + + self.encoder.append(layer) + ch = int(mult * model_channels) + ds *= 2 + + self.middle = nn.ModuleList([ + UNetBlock(in_channels=ch, out_channels=ch, emb_channels=emb_channels,attention=True), + UNetBlock(in_channels=ch, out_channels=ch, emb_channels=emb_channels,attention=False)]) + + self.out = nn.Sequential( + Linear(in_features=ch, out_features=2048), + GroupNorm(num_channels=2048), + nn.ReLU(), + nn.Dropout(p=dropout), + nn.Linear(in_features=2048, out_features=self.out_channels), + ) + + if self.output_prob: + self.out.append(nn.Sigmoid()) + + def forward(self, x, time_steps): + """ + param x: an [N x C x H x W] Tensor of inputs + param time_steps: a 1-D batch of timesteps + return: an [N x K] tensor of products + """ + emb = self.time_embed(self.map_noise(time_steps)) + + h = self.conv(x) + + for m in self.encoder: + h = m(h, emb) + + for m in self.middle: + h = m(h, emb) + return self.out(h.mean(dim=(2,3))) \ No newline at end of file diff --git a/modulus/models/topodiff/utils.py b/modulus/models/topodiff/utils.py new file mode 100644 index 0000000000..266ecfc3ca --- /dev/null +++ b/modulus/models/topodiff/utils.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from torch.utils.data import Dataset, DataLoader + + +class DatasetTopoDiff(Dataset): + + def __init__(self, topologies, stress, strain, load_im, constraints): + + self.topologies = topologies + self.constraints = constraints + self.image_size = topologies.shape[1] + + self.stress = stress + self.strain = strain + self.load_im = load_im + + def __len__(self): + return self.topologies.shape[0] + + def __getitem__(self, idx): + + cons = self.constraints[idx] + + vol_frac = cons['VOL_FRAC'] + + cons = np.zeros((5, self.image_size, self.image_size)) + + cons[0] = self.stress[idx] + cons[1] = self.strain[idx] + cons[2] = self.load_im[idx][:,:,0] + cons[3] = self.load_im[idx][:,:,1] + cons[4] = np.ones((self.image_size,self.image_size)) * vol_frac + + return np.expand_dims(self.topologies[idx], 0) * 2 - 1, cons + +def load_data_topodiff(topologies, constraints, stress, strain, load_img, batch_size, deterministic=False): + dataset = DatasetTopoDiff(topologies, stress, strain, load_img, constraints) + + if deterministic: + loader = DataLoader( + dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True + ) + else: + loader = DataLoader( + dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True + ) + + while True: + yield from loader \ No newline at end of file diff --git a/test/models/data/topodiff_output.pth b/test/models/data/topodiff_output.pth new file mode 100644 index 0000000000..be585f52a5 Binary files /dev/null and b/test/models/data/topodiff_output.pth differ diff --git a/test/models/topodiff/test_topodiff.py b/test/models/topodiff/test_topodiff.py new file mode 100644 index 0000000000..9a38dc2a74 --- /dev/null +++ b/test/models/topodiff/test_topodiff.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ruff: noqa: E402 +import os +import random +import sys + +import numpy as np +import pytest +import torch + +script_path = os.path.abspath(__file__) +sys.path.append(os.path.join(os.path.dirname(script_path), "..")) + +import common +#from pytest_utils import import_or_fail + +#dgl = pytest.importorskip("dgl") + + +#@import_or_fail("dgl") +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_topodiff_forward(device): + """Test topodiff forward pass""" + + from modulus.models.topodiff import TopoDiff + + torch.manual_seed(0) + #dgl.seed(0) + np.random.seed(0) + # Construct Topodiff Model + model = TopoDiff(img_resolution=64, + in_channels=6, + out_channels=1).to(device) + + bsize = 4 + nsteps = 1000 # diffusion steps + tops = torch.randn(bsize, 1, 64, 64).to(device) + cons = torch.randn(bsize, 5, 64, 64).to(device) + timesteps = torch.randint(0, nsteps, (bsize,)).to(device) + out = model(tops, cons, timesteps) + + assert out.shape == (bsize, 1, 64, 64) + assert common.validate_forward_accuracy(model,(tops, cons, timesteps,)) + + +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_topodiff_constructor(device): + """Test topodiff forward pass""" + + from modulus.models.topodiff import TopoDiff + + torch.manual_seed(0) + #dgl.seed(0) + np.random.seed(0) + # Construct Topodiff Model + + args = { + "img_resolution": 64, + "in_channels": random.randint(1, 16), + "cond_channels": random.randint(1,16), + "out_channels": random.randint(1,16) + } + model = TopoDiff(img_resolution=args["img_resolution"], + in_channels=args["in_channels"] + args["cond_channels"], + out_channels=args["out_channels"]).to(device) + + bsize = 4 + nsteps = 1000 # diffusion steps + tops = torch.randn(bsize, args["in_channels"], 64, 64).to(device) + cons = torch.randn(bsize, args["cond_channels"], 64, 64).to(device) + timesteps = torch.randint(0, nsteps, (bsize,)).to(device) + out = model(tops, cons, timesteps) + + assert out.shape == (bsize, args["out_channels"], args["img_resolution"], args["img_resolution"]) +