diff --git a/CHANGELOG.md b/CHANGELOG.md index a79f924e82..93824e625a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Bistride Multiscale MeshGraphNet example. - FIGConvUNet model and example. - The Transolver model. +- The XAeroNet model. - Incoporated CorrDiff-GEFS-HRRR model into CorrDiff, with lead-time aware SongUNet and cross entropy loss. diff --git a/docs/img/xaeronet_s_results.png b/docs/img/xaeronet_s_results.png new file mode 100644 index 0000000000..0af1fbb2d5 Binary files /dev/null and b/docs/img/xaeronet_s_results.png differ diff --git a/docs/img/xaeronet_v_results.png b/docs/img/xaeronet_v_results.png new file mode 100644 index 0000000000..d4c4ee38f2 Binary files /dev/null and b/docs/img/xaeronet_v_results.png differ diff --git a/examples/cfd/xaeronet/README.md b/examples/cfd/xaeronet/README.md new file mode 100644 index 0000000000..f42ec2f539 --- /dev/null +++ b/examples/cfd/xaeronet/README.md @@ -0,0 +1,164 @@ +# XAeroNet: Scalable Neural Models for External Aerodynamics + +XAeroNet is a collection of scalable models for large-scale external +aerodynamic evaluations. It consists of two models, XAeroNet-S and XAeroNet-V for +surface and volume predictions, respectively. + +## Problem overview + +External aerodynamics plays a crucial role in the design and optimization of vehicles, +aircraft, and other transportation systems. Accurate predictions of aerodynamic +properties such as drag, pressure distribution, and airflow characteristics are +essential for improving fuel efficiency, vehicle stability, and performance. +Traditional approaches, such as computational fluid dynamics (CFD) simulations, +are computationally expensive and time-consuming, especially when evaluating multiple +design iterations or large datasets. + +XAeroNet addresses these challenges by leveraging neural network-based surrogate +models to provide fast, scalable, and accurate predictions for both surface-level +and volume-level aerodynamic properties. By using the DrivAerML dataset, which +contains high-fidelity CFD data for a variety of vehicle geometries, XAeroNet aims +to significantly reduce the computational cost while maintaining high prediction +accuracy. The two models in XAeroNet—XAeroNet-S for surface predictions and XAeroNet-V +for volume predictions—enable rapid aerodynamic evaluations across different design +configurations, making it easier to incorporate aerodynamic considerations early in +the design process. + +## Model Overview and Architecture + +### XAeroNet-S + +XAeroNet-S is a scalable MeshGraphNet model that partitions large input graphs into +smaller subgraphs to reduce training memory overhead. Halo regions are added to these +subgraphs to prevent message-passing truncations at the boundaries. Gradient aggregation +is employed to accumulate gradients from each partition before updating the model parameters. +This approach ensures that training on partitions is equivalent to training on the entire +graph in terms of model updates and accuracy. Additionally, XAeroNet-S does not rely on +simulation meshes for training and inference, overcoming a significant limitation of +GNN models in simulation tasks. + +The input to the training pipeline is STL files, from which the model samples a point cloud +on the surface. It then constructs a connectivity graph by linking the N nearest neighbors. +This method also supports multi-mesh setups, where point clouds with different resolutions +are generated, their connectivity graphs are created, and all are superimposed. The Metis +library is used to partition the graph for efficient training. + +For the XAeroNet-S model, STL files are used to generate point clouds and establish graph +connectivity. Additionally, the .vtp files are used to interpolate the solution fields onto +the point clouds. + +### XAeroNet-V + +XAeroNet-V is a scalable 3D UNet model with attention gates, designed to partition large +voxel grids into smaller sub-grids to reduce memory overhead during training. Halo regions +are added to these partitions to avoid convolution truncations at the boundaries. +Gradient aggregation is used to accumulate gradients from each partition before updating +the model parameters, ensuring that training on partitions is equivalent to training on +the entire voxel grid in terms of model updates and accuracy. Additionally, XAeroNet-V +incorporates a continuity constraint as an additional loss term during training to +enhance model interpretability. + +For the XAeroNet-V model, the .vtu files are used to interpolate the volumetric +solution fields onto a voxel grid, while the .stl files are utilized to compute +the signed distance field (SDF) and its derivatives on the voxel grid. + +## Dataset + +We trained our models using the DrivAerML dataset from the [CAE ML Dataset collection](https://caemldatasets.org/drivaerml/). +This high-fidelity, open-source (CC-BY-SA) public dataset is specifically designed +for automotive aerodynamics research. It comprises 500 parametrically morphed variants +of the widely utilized DrivAer notchback generic vehicle. Mesh generation and scale-resolving +computational fluid dynamics (CFD) simulations were executed using consistent and validated +automatic workflows that represent the industrial state-of-the-art. Geometries and comprehensive +aerodynamic data are published in open-source formats. For more technical details about this +dataset, please refer to their [paper](https://arxiv.org/pdf/2408.11969). + +## Training the XAeroNet-S model + +To train the XAeroNet-S model, follow these steps: + +1. Download the DrivAer ML dataset using the provided `download_aws_dataset.sh` script. + +2. Navigate to the `surface` folder. + +3. Specify the configurations in `conf/config.yaml`. Make sure path to the dataset + is specified correctly. + +4. Run `combine_stl_solids.py`. The STL files in the DriveML dataset consist of multiple + solids. Those should be combined into a single solid to properly generate a surface point + cloud using the Modulus Tesselated geometry module. + +5. Run `preprocessing.py`. This will prepare and save the partitioned graphs. + +6. Create a `partitions_validation` folder, and move the samples you wish to use for + validation to that folder. + +7. Run `compute_stats.py` to compute the global mean and standard deviation from the + training samples. + +8. Run `train.py` to start the training. + +9. Download the validation results (saved in form of point clouds in `.vtp` format), + and visualize in Paraview. + +![XAeroNet-S Validation results for the sample #500.](../../../docs/img/xaeronet_s_results.png) + +## Training the XAeroNet-V model + +To train the XAeroNet-V model, follow these steps: + +1. Download the DrivAer ML dataset using the provided `download_aws_dataset.sh` script. + +2. Navigate to the `volume` folder. + +3. Specify the configurations in `conf/config.yaml`. Make sure path to the dataset + is specified correctly. + +4. Run `preprocessing.py`. This will prepare and save the voxel grids. + +5. Create a `drivaer_aws_h5_validation` folder, and move the samples you wish to + use for validation to that folder. + +6. Run `compute_stats.py` to compute the global mean and standard deviation from + the training samples. + +7. Run `train.py` to start the training. Partitioning is performed prior to training. + +8. Download the validation results (saved in form of voxel grids in `.vti` format), + and visualize in Paraview. + +![XAeroNet-V Validation results.](../../../docs/img/xaeronet_v_results.png) + +## Logging + +We mainly use TensorBoard for logging training and validation losses, as well as +the learning rate during training. You can also optionally use Weight & Biases to +log training metrics. To visualize TensorBoard running in a +Docker container on a remote server from your local desktop, follow these steps: + +1. **Expose the Port in Docker:** + Expose port 6006 in the Docker container by including + `-p 6006:6006` in your docker run command. + +2. **Launch TensorBoard:** + Start TensorBoard within the Docker container: + + ```bash + tensorboard --logdir=/path/to/logdir --port=6006 + ``` + +3. **Set Up SSH Tunneling:** + Create an SSH tunnel to forward port 6006 from the remote server to your local machine: + + ```bash + ssh -L 6006:localhost:6006 @ + ``` + + Replace `` with your SSH username and `` with the IP address + of your remote server. You can use a different port if necessary. + +4. **Access TensorBoard:** + Open your web browser and navigate to `http://localhost:6006` to view TensorBoard. + +**Note:** Ensure the remote server’s firewall allows connections on port `6006` +and that your local machine’s firewall allows outgoing connections. diff --git a/examples/cfd/xaeronet/cleanup_corrupted_downloads.sh b/examples/cfd/xaeronet/cleanup_corrupted_downloads.sh new file mode 100755 index 0000000000..052f5b0c00 --- /dev/null +++ b/examples/cfd/xaeronet/cleanup_corrupted_downloads.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +# This is a Bash script designed to identify and remove corrupted files after downloading the AWS DrivAer dataset. +# The script defines two functions: check_and_remove_corrupted_extension and check_all_runs. +# The check_and_remove_corrupted_extension function checks for files in a given directory that have extra characters after their extension. +# If such a file is found, it is considered corrupted, and the function removes it. +# The check_all_runs function iterates over all directories in a specified local directory (LOCAL_DIR), checking for corrupted files with the extensions ".vtu", ".stl", and ".vtp". +# The script begins the cleanup process by calling the check_all_runs function. The target directory for this operation is set as "./drivaer_data_full". + +# Set the local directory to check the files +LOCAL_DIR="./drivaer_data_full" # <--- This is the directory where the files are downloaded. + +# Function to check if a file has extra characters after the extension and remove it +check_and_remove_corrupted_extension() { + local dir=$1 + local base_filename=$2 + local extension=$3 + + # Find any files with extra characters after the extension + for file in "$dir/$base_filename"$extension*; do + if [[ -f "$file" && "$file" != "$dir/$base_filename$extension" ]]; then + echo "Corrupted file detected: $file (extra characters after extension), removing it." + rm "$file" + fi + done +} + +# Function to go over all the run directories and check files +check_all_runs() { + for RUN_DIR in "$LOCAL_DIR"/run_*; do + echo "Checking folder: $RUN_DIR" + + # Check for corrupted .vtu files + base_vtu="volume_${RUN_DIR##*_}" + check_and_remove_corrupted_extension "$RUN_DIR" "$base_vtu" ".vtu" + + # Check for corrupted .stl files + base_stl="drivaer_${RUN_DIR##*_}" + check_and_remove_corrupted_extension "$RUN_DIR" "$base_stl" ".stl" + + # Check for corrupted .vtp files + base_stl="drivaer_${RUN_DIR##*_}" + check_and_remove_corrupted_extension "$RUN_DIR" "$base_stl" ".vtp" + done +} + +# Start checking +check_all_runs diff --git a/examples/cfd/xaeronet/download_aws_dataset.sh b/examples/cfd/xaeronet/download_aws_dataset.sh new file mode 100755 index 0000000000..96f558fbff --- /dev/null +++ b/examples/cfd/xaeronet/download_aws_dataset.sh @@ -0,0 +1,64 @@ +#!/bin/bash + +# This Bash script downloads the AWS DrivAer files from the Amazon S3 bucket to a local directory. +# Only the volume files (.vtu), STL files (.stl), and VTP files (.vtp) are downloaded. +# It uses a function, download_run_files, to check for the existence of three specific files (".vtu", ".stl", ".vtp") in a run directory. +# If a file doesn't exist, it's downloaded from the S3 bucket. If it does exist, the download is skipped. +# The script runs multiple downloads in parallel, both within a single run and across multiple runs. +# It also includes checks to prevent overloading the system by limiting the number of parallel downloads. + +# Set the local directory to download the files +LOCAL_DIR="./drivaer_data_full" # <--- This is the directory where the files will be downloaded. + +# Set the S3 bucket and prefix +S3_BUCKET="caemldatasets" +S3_PREFIX="drivaer/dataset" + +# Create the local directory if it doesn't exist +mkdir -p "$LOCAL_DIR" + +# Function to download files for a specific run +download_run_files() { + local i=$1 + RUN_DIR="run_$i" + RUN_LOCAL_DIR="$LOCAL_DIR/$RUN_DIR" + + # Create the run directory if it doesn't exist + mkdir -p "$RUN_LOCAL_DIR" + + # Check if the .vtu file exists before downloading + if [ ! -f "$RUN_LOCAL_DIR/volume_$i.vtu" ]; then + aws s3 cp --no-sign-request "s3://$S3_BUCKET/$S3_PREFIX/$RUN_DIR/volume_$i.vtu" "$RUN_LOCAL_DIR/" & + else + echo "File volume_$i.vtu already exists, skipping download." + fi + + # Check if the .stl file exists before downloading + if [ ! -f "$RUN_LOCAL_DIR/drivaer_$i.stl" ]; then + aws s3 cp --no-sign-request "s3://$S3_BUCKET/$S3_PREFIX/$RUN_DIR/drivaer_$i.stl" "$RUN_LOCAL_DIR/" & + else + echo "File drivaer_$i.stl already exists, skipping download." + fi + + # Check if the .vtp file exists before downloading + if [ ! -f "$RUN_LOCAL_DIR/boundary_$i.vtp" ]; then + aws s3 cp --no-sign-request "s3://$S3_BUCKET/$S3_PREFIX/$RUN_DIR/boundary_$i.vtp" "$RUN_LOCAL_DIR/" & + else + echo "File boundary_$i.vtp already exists, skipping download." + fi + + wait # Ensure that both files for this run are downloaded before moving to the next run +} + +# Loop through the run folders and download the files +for i in $(seq 1 500); do + download_run_files "$i" & + + # Limit the number of parallel jobs to avoid overloading the system + if (( $(jobs -r | wc -l) >= 8 )); then + wait -n # Wait for the next background job to finish before starting a new one + fi +done + +# Wait for all remaining background jobs to finish +wait diff --git a/examples/cfd/xaeronet/requirements.txt b/examples/cfd/xaeronet/requirements.txt new file mode 100644 index 0000000000..0fa0219dea --- /dev/null +++ b/examples/cfd/xaeronet/requirements.txt @@ -0,0 +1 @@ +trimesh==4.5.0 diff --git a/examples/cfd/xaeronet/surface/combine_stl_solids.py b/examples/cfd/xaeronet/surface/combine_stl_solids.py new file mode 100644 index 0000000000..a8b4be276c --- /dev/null +++ b/examples/cfd/xaeronet/surface/combine_stl_solids.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module provides functionality to convert STL files with multiple solids +to another STL file with a single combined solid. It includes support for +processing multiple files in parallel with progress tracking. +""" + +import os +import trimesh +import hydra + +from multiprocessing import Pool +from tqdm import tqdm +from hydra.utils import to_absolute_path +from omegaconf import DictConfig + + +def process_stl_file(task): + stl_path = task + + # Load the STL file using trimesh + mesh = trimesh.load_mesh(stl_path) + + # If the STL file contains multiple solids (as a Scene object) + if isinstance(mesh, trimesh.Scene): + # Extract all geometries (solids) from the scene + meshes = list(mesh.geometry.values()) + + # Combine all the solids into a single mesh + combined_mesh = trimesh.util.concatenate(meshes) + else: + # If it's a single solid, no need to combine + combined_mesh = mesh + + # Prepare the output file path (next to the original file) + base_name, ext = os.path.splitext(stl_path) + output_file_path = to_absolute_path(f"{base_name}_single_solid{ext}") + + # Save the new combined mesh as an STL file + combined_mesh.export(output_file_path) + + return f"Processed: {stl_path} -> {output_file_path}" + + +def process_directory(data_path, num_workers=16): + """Process all STL files in the given directory using multiprocessing with progress tracking.""" + tasks = [] + for root, _, files in os.walk(data_path): + stl_files = [f for f in files if f.endswith(".stl")] + for stl_file in stl_files: + stl_path = os.path.join(root, stl_file) + + # Add the STL file to the tasks list (no need for output dir, saving next to the original) + tasks.append(stl_path) + + # Use multiprocessing to process the tasks with progress tracking + with Pool(num_workers) as pool: + for _ in tqdm( + pool.imap_unordered(process_stl_file, tasks), + total=len(tasks), + desc="Processing STL Files", + unit="file", + ): + pass + + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: + # Process the directory with multiple STL files + process_directory( + to_absolute_path(cfg.data_path), num_workers=cfg.num_preprocess_workers + ) + + +if __name__ == "__main__": + main() diff --git a/examples/cfd/xaeronet/surface/compute_stats.py b/examples/cfd/xaeronet/surface/compute_stats.py new file mode 100644 index 0000000000..1d7c34bfdf --- /dev/null +++ b/examples/cfd/xaeronet/surface/compute_stats.py @@ -0,0 +1,209 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code processes partitioned graph data stored in .bin files to compute global +mean and standard deviation,for various node and edge data fields. It identifies +all .bin files in a directory, processes each file to accumulate statistics for +specific fields (like coordinates and pressure), and then aggregates the results +across all files. The code supports parallel processing to handle multiple files +simultaneously, speeding up the computation. Finally, the global statistics are +saved to a JSON file. +""" + +import os +import json +import numpy as np +import dgl +import hydra + +from tqdm import tqdm +from concurrent.futures import ProcessPoolExecutor +from hydra.utils import to_absolute_path +from omegaconf import DictConfig + + +def find_bin_files(data_path): + """ + Finds all .bin files in the specified directory. + """ + return [ + os.path.join(data_path, f) for f in os.listdir(data_path) if f.endswith(".bin") + ] + + +def process_file(bin_file): + """ + Processes a single .bin file containing graph partitions to compute the mean, mean of squares, and count for each variable. + """ + graphs, _ = dgl.load_graphs(bin_file) + + # Initialize dictionaries to accumulate stats + node_fields = ["coordinates", "normals", "area", "pressure", "shear_stress"] + edge_fields = ["x"] + + field_means = {} + field_square_means = {} + counts = {} + + # Initialize stats accumulation for each partitioned graph + for field in node_fields + edge_fields: + field_means[field] = 0 + field_square_means[field] = 0 + counts[field] = 0 + + # Loop through each partition in the file + for graph in graphs: + # Process node data + for field in node_fields: + if field in graph.ndata: + data = graph.ndata[field].numpy() + + if data.ndim == 1: + data = np.expand_dims(data, axis=-1) + + # Compute mean, mean of squares, and count for each partition + field_mean = np.mean(data, axis=0) + field_square_mean = np.mean(data**2, axis=0) + count = data.shape[0] + + # Accumulate stats across partitions + field_means[field] += field_mean * count + field_square_means[field] += field_square_mean * count + counts[field] += count + else: + print(f"Warning: Node field '{field}' not found in {bin_file}") + + # Process edge data + for field in edge_fields: + if field in graph.edata: + data = graph.edata[field].numpy() + + field_mean = np.mean(data, axis=0) + field_square_mean = np.mean(data**2, axis=0) + count = data.shape[0] + + field_means[field] += field_mean * count + field_square_means[field] += field_square_mean * count + counts[field] += count + else: + print(f"Warning: Edge field '{field}' not found in {bin_file}") + + return field_means, field_square_means, counts + + +def aggregate_results(results): + """ + Aggregates the results from all files to compute global mean and standard deviation. + """ + total_mean = {} + total_square_mean = {} + total_count = {} + + # Initialize totals with zeros for each field + for field in results[0][0].keys(): + total_mean[field] = 0 + total_square_mean[field] = 0 + total_count[field] = 0 + + # Accumulate weighted sums and counts + for field_means, field_square_means, counts in results: + for field in field_means: + total_mean[field] += field_means[field] + total_square_mean[field] += field_square_means[field] + total_count[field] += counts[field] + + # Compute global mean and standard deviation + global_mean = {} + global_std = {} + + for field in total_mean: + global_mean[field] = total_mean[field] / total_count[field] + variance = (total_square_mean[field] / total_count[field]) - ( + global_mean[field] ** 2 + ) + global_std[field] = np.sqrt( + np.maximum(variance, 0) + ) # Ensure no negative variance due to rounding errors + + return global_mean, global_std + + +def compute_global_stats(bin_files, num_workers=4): + """ + Computes the global mean and standard deviation for each field across all .bin files + using parallel processing. + """ + with ProcessPoolExecutor(max_workers=num_workers) as executor: + results = list( + tqdm( + executor.map(process_file, bin_files), + total=len(bin_files), + desc="Processing BIN Files", + unit="file", + ) + ) + + # Aggregate the results from all files + global_mean, global_std = aggregate_results(results) + + return global_mean, global_std + + +def save_stats_to_json(mean, std_dev, output_file): + """ + Saves the global mean and standard deviation to a JSON file. + """ + stats = { + "mean": { + k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in mean.items() + }, + "std_dev": { + k: v.tolist() if isinstance(v, np.ndarray) else v + for k, v in std_dev.items() + }, + } + + with open(output_file, "w") as f: + json.dump(stats, f, indent=4) + + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: + + data_path = to_absolute_path( + cfg.partitions_path + ) # Directory containing the .bin graph files with partitions + output_file = to_absolute_path(cfg.stats_file) # File to save the global statistics + # Find all .bin files in the directory + bin_files = find_bin_files(data_path) + + # Compute global statistics with parallel processing + global_mean, global_std = compute_global_stats( + bin_files, num_workers=cfg.num_preprocess_workers + ) + + # Save statistics to a JSON file + save_stats_to_json(global_mean, global_std, output_file) + + # Print the results + print("Global Mean:", global_mean) + print("Global Standard Deviation:", global_std) + print(f"Statistics saved to {output_file}") + + +if __name__ == "__main__": + main() diff --git a/examples/cfd/xaeronet/surface/conf/config.yaml b/examples/cfd/xaeronet/surface/conf/config.yaml new file mode 100644 index 0000000000..0b62ab9015 --- /dev/null +++ b/examples/cfd/xaeronet/surface/conf/config.yaml @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: true + name: XAeroNetS + run: + dir: ./outputs/${hydra:job.name} + +# ┌───────────────────────────────────────────┐ +# │ Data Preprocessing │ +# └───────────────────────────────────────────┘ + +num_nodes: [100000, 200000, 400000] # Number of nodes in the graphs +node_degree: 6 # Degree of the nodes in the graphs +num_partitions: 3 # Number of partitions for each graph +data_path: /data/drivaer_aws/drivaer_data_full # Path to the raw data +num_preprocess_workers: 32 # Number of workers for data preprocessing +save_point_clouds: false # Save point clouds for the preprocessed data + +# ┌───────────────────────────────────────────┐ +# │ Model Configuration │ +# └───────────────────────────────────────────┘ + +num_message_passing_layers: 15 # Number of message passing layers +hidden_dim: 512 # Hidden dimension of the model +activation: silu # Activation function + +# ┌───────────────────────────────────────────┐ +# │ Training Configuration │ +# └───────────────────────────────────────────┘ + +partitions_path: partitions # Path to the partitions (.bin files) +validation_partitions_path: validation_partitions # Path to the validation partitions (.bin files) +stats_file: global_stats.json # Path to the global statistics (.json file) +checkpoint_filename: model_checkpoint.pth # Filename of the model checkpoint +num_epochs: 2000 # Number of epochs +start_lr: 0.001 # Initial learning rate (cos annealing schedule is used) +end_lr: 0.000001 # Final learning rate (cos annealing schedule is used) +save_checkpoint_freq: 5 # Frequency of saving the model checkpoint +validation_freq: 50 # Frequency of validation + +# ┌───────────────────────────────────────────┐ +# │ Performance Optimization │ +# └───────────────────────────────────────────┘ + +use_concat_trick: true # Use the concatenation trick +checkpoint_segments: 3 # Number of segments for the activation checkpointing +enable_cudnn_benchmark: true # Enable cudnn benchmark diff --git a/examples/cfd/xaeronet/surface/dataloader.py b/examples/cfd/xaeronet/surface/dataloader.py new file mode 100644 index 0000000000..dee82a02b5 --- /dev/null +++ b/examples/cfd/xaeronet/surface/dataloader.py @@ -0,0 +1,184 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This code defines a custom dataset class GraphDataset for loading and normalizing +graph partition data stored in .bin files. The dataset is initialized with a list +of file paths and global mean and standard deviation for node and edge attributes. +It normalizes node data (like coordinates, normals, pressure) and edge data based +on these statistics before returning the processed graph partitions and a corresponding +label (extracted from the file name). The code also provides a function create_dataloader +to create a data loader for efficient batch loading with configurable parameters such as +batch size, shuffle, and prefetching options. +""" + +import json +import torch +from torch.utils.data import Dataset +import os +import sys +import dgl +from dgl.dataloading import GraphDataLoader + +# Get the absolute path to the parent directory +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.append(parent_dir) + +from utils import find_bin_files + + +class GraphDataset(Dataset): + """ + Custom dataset class for loading + + Parameters: + ---------- + file_list (list of str): List of paths to .bin files containing partitions. + mean (np.ndarray): Global mean for normalization. + std (np.ndarray): Global standard deviation for normalization. + """ + + def __init__(self, file_list, mean, std): + self.file_list = file_list + self.mean = mean + self.std = std + + # Store normalization stats as tensors + self.coordinates_mean = torch.tensor(mean["coordinates"]) + self.coordinates_std = torch.tensor(std["coordinates"]) + self.normals_mean = torch.tensor(mean["normals"]) + self.normals_std = torch.tensor(std["normals"]) + self.area_mean = torch.tensor(mean["area"]) + self.area_std = torch.tensor(std["area"]) + self.pressure_mean = torch.tensor(mean["pressure"]) + self.pressure_std = torch.tensor(std["pressure"]) + self.shear_stress_mean = torch.tensor(mean["shear_stress"]) + self.shear_stress_std = torch.tensor(std["shear_stress"]) + self.edge_x_mean = torch.tensor(mean["x"]) + self.edge_x_std = torch.tensor(std["x"]) + + def __len__(self): + return len(self.file_list) + + def __getitem__(self, idx): + file_path = self.file_list[idx] + + # Extract the ID from the file name + file_name = os.path.basename(file_path) + # Assuming file format is "graph_partitions_.bin" + run_id = file_name.split("_")[-1].split(".")[0] # Extract the run ID + + # Load the partitioned graphs from the .bin file + graphs, _ = dgl.load_graphs(file_path) + + # Process each partition (graph) + normalized_partitions = [] + for graph in graphs: + # Normalize node data + graph.ndata["coordinates"] = ( + graph.ndata["coordinates"] - self.coordinates_mean + ) / self.coordinates_std + graph.ndata["normals"] = ( + graph.ndata["normals"] - self.normals_mean + ) / self.normals_std + graph.ndata["area"] = (graph.ndata["area"] - self.area_mean) / self.area_std + graph.ndata["pressure"] = ( + graph.ndata["pressure"] - self.pressure_mean + ) / self.pressure_std + graph.ndata["shear_stress"] = ( + graph.ndata["shear_stress"] - self.shear_stress_mean + ) / self.shear_stress_std + + # Normalize edge data + if "x" in graph.edata: + graph.edata["x"] = ( + graph.edata["x"] - self.edge_x_mean + ) / self.edge_x_std + + normalized_partitions.append(graph) + + return normalized_partitions, run_id + + +def create_dataloader( + file_list, + mean, + std, + batch_size=1, + shuffle=False, + use_ddp=True, + drop_last=True, + num_workers=4, + pin_memory=True, + prefetch_factor=2, +): + """ + Creates a DataLoader for the GraphDataset with prefetching. + + Args: + file_list (list of str): List of paths to .bin files. + mean (np.ndarray): Global mean for normalization. + std (np.ndarray): Global standard deviation for normalization. + batch_size (int): Number of samples per batch. + num_workers (int): Number of worker processes for data loading. + pin_memory (bool): If True, the data loader will copy tensors into CUDA pinned memory. + + Returns: + DataLoader: Configured DataLoader for the dataset. + """ + dataset = GraphDataset(file_list, mean, std) + dataloader = GraphDataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + use_ddp=use_ddp, + pin_memory=pin_memory, + prefetch_factor=prefetch_factor, + ) + return dataloader + + +if __name__ == "__main__": + data_path = "partitions" + stats_file = "global_stats.json" + + # Load global statistics + with open(stats_file, "r") as f: + stats = json.load(f) + mean = stats["mean"] + std = stats["std_dev"] + + # Find all .bin files in the directory + file_list = find_bin_files(data_path) + + # Create DataLoader + dataloader = create_dataloader( + file_list, + mean, + std, + batch_size=1, + prefetch_factor=None, + use_ddp=False, + num_workers=1, + ) + + # Example usage + for batch_partitions, label in dataloader: + for graph in batch_partitions: + print(graph) + print(label) diff --git a/examples/cfd/xaeronet/surface/preprocessor.py b/examples/cfd/xaeronet/surface/preprocessor.py new file mode 100644 index 0000000000..98667640b0 --- /dev/null +++ b/examples/cfd/xaeronet/surface/preprocessor.py @@ -0,0 +1,329 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code processes mesh data from .stl and .vtp files to create partitioned +graphs for large scale training. It first converts meshes to triangular format +and extracts surface triangles, vertices, and relevant attributes such as pressure +and shear stress. Using nearest neighbors, the code interpolates these attributes +for a sampled boundary of points, and constructs a graph based on these points, with +node features like coordinates, normals, pressure, and shear stress, as well as edge +features representing relative displacement. The graph is partitioned into subgraphs, +and the partitions are saved. The code supports parallel processing to handle multiple +samples simultaneously, improving efficiency. Additionally, it provides an option to +save the point cloud of each graph for visualization purposes. +""" + +import os +import vtk +import pyvista as pv +import numpy as np +import torch +import dgl +import hydra + +from tqdm import tqdm +from concurrent.futures import ProcessPoolExecutor +from sklearn.neighbors import NearestNeighbors +from dgl.data.utils import save_graphs +from hydra.utils import to_absolute_path +from omegaconf import DictConfig + +from modulus.datapipes.cae.readers import read_vtp +from modulus.sym.geometry.tessellation import Tessellation + + +def convert_to_triangular_mesh( + polydata, write=False, output_filename="surface_mesh_triangular.vtu" +): + """Converts a vtkPolyData object to a triangular mesh.""" + tet_filter = vtk.vtkDataSetTriangleFilter() + tet_filter.SetInputData(polydata) + tet_filter.Update() + + tet_mesh = pv.wrap(tet_filter.GetOutput()) + + if write: + tet_mesh.save(output_filename) + + return tet_mesh + + +def extract_surface_triangles(tet_mesh): + """Extracts the surface triangles from a triangular mesh.""" + surface_filter = vtk.vtkDataSetSurfaceFilter() + surface_filter.SetInputData(tet_mesh) + surface_filter.Update() + + surface_mesh = pv.wrap(surface_filter.GetOutput()) + triangle_indices = [] + faces = surface_mesh.faces.reshape((-1, 4)) + for face in faces: + if face[0] == 3: + triangle_indices.extend([face[1], face[2], face[3]]) + else: + raise ValueError("Face is not a triangle") + + return triangle_indices + + +def fetch_mesh_vertices(mesh): + """Fetches the vertices of a mesh.""" + points = mesh.GetPoints() + num_points = points.GetNumberOfPoints() + vertices = [points.GetPoint(i) for i in range(num_points)] + return vertices + + +def add_edge_features(graph): + """ + Add relative displacement and displacement norm as edge features to the graph. + The calculations are done using the 'pos' attribute in the + node data of each graph. The resulting edge features are stored in the 'x' attribute + in the edge data of each graph. + + This method will modify the graph in-place. + + Returns + ------- + dgl.DGLGraph + Graph with updated edge features. + """ + + pos = graph.ndata.get("coordinates") + if pos is None: + raise ValueError( + "'coordinates' does not exist in the node data of one or more graphs." + ) + + row, col = graph.edges() + row = row.long() + col = col.long() + + disp = pos[row] - pos[col] + disp_norm = torch.linalg.norm(disp, dim=-1, keepdim=True) + graph.edata["x"] = torch.cat((disp, disp_norm), dim=-1) + + return graph + + +# Define this function outside of any local scope so it can be pickled +def run_task(params): + """Wrapper function to unpack arguments for process_run.""" + return process_run(*params) + + +def process_partition(graph, num_partitions, halo_hops): + """ + Helper function to partition a single graph and include node and edge features. + """ + # Perform the partitioning + partitioned = dgl.metis_partition( + graph, k=num_partitions, extra_cached_hops=halo_hops, reshuffle=True + ) + + # For each partition, restore node and edge features + partition_list = [] + for _, subgraph in partitioned.items(): + subgraph.ndata["coordinates"] = graph.ndata["coordinates"][ + subgraph.ndata[dgl.NID] + ] + subgraph.ndata["normals"] = graph.ndata["normals"][subgraph.ndata[dgl.NID]] + subgraph.ndata["area"] = graph.ndata["area"][subgraph.ndata[dgl.NID]] + subgraph.ndata["pressure"] = graph.ndata["pressure"][subgraph.ndata[dgl.NID]] + subgraph.ndata["shear_stress"] = graph.ndata["shear_stress"][ + subgraph.ndata[dgl.NID] + ] + if "x" in graph.edata: + subgraph.edata["x"] = graph.edata["x"][subgraph.edata[dgl.EID]] + + partition_list.append(subgraph) + + return partition_list + + +def process_run( + run_path, point_list, node_degree, num_partitions, halo_hops, save_point_cloud=False +): + """Process a single run directory to generate a multi-level graph and apply partitioning.""" + run_id = os.path.basename(run_path).split("_")[-1] + + stl_file = os.path.join(run_path, f"drivaer_{run_id}_single_solid.stl") + vtp_file = os.path.join(run_path, f"boundary_{run_id}.vtp") + + # Path to save the list of partitions + partition_file_path = to_absolute_path(f"partitions/graph_partitions_{run_id}.bin") + + if os.path.exists(partition_file_path): + print(f"Partitions for run {run_id} already exist. Skipping...") + return + + if not os.path.exists(stl_file) or not os.path.exists(vtp_file): + print(f"Warning: Missing files for run {run_id}. Skipping...") + return + + try: + # Load the STL and VTP files + obj = Tessellation.from_stl(stl_file, airtight=False) + surface_mesh = read_vtp(vtp_file) + surface_mesh = convert_to_triangular_mesh(surface_mesh) + surface_vertices = fetch_mesh_vertices(surface_mesh) + surface_mesh = surface_mesh.cell_data_to_point_data() + node_attributes = surface_mesh.point_data + pressure_ref = node_attributes["pMeanTrim"] + shear_stress_ref = node_attributes["wallShearStressMeanTrim"] + + # Sort the list of points in ascending order + sorted_points = sorted(point_list) + + # Initialize arrays to store all points, normals, and areas + all_points = np.empty((0, 3)) + all_normals = np.empty((0, 3)) + all_areas = np.empty((0, 1)) + edge_sources = [] + edge_destinations = [] + + # Precompute the nearest neighbors for surface vertices + nbrs_surface = NearestNeighbors(n_neighbors=1, algorithm="ball_tree").fit( + surface_vertices + ) + + for num_points in sorted_points: + # Sample the boundary points for the current level + boundary = obj.sample_boundary(num_points) + points = np.concatenate( + [boundary["x"], boundary["y"], boundary["z"]], axis=1 + ) + normals = np.concatenate( + [boundary["normal_x"], boundary["normal_y"], boundary["normal_z"]], + axis=1, + ) + area = boundary["area"] + + # Concatenate new points with the previous ones + all_points = np.vstack([all_points, points]) + all_normals = np.vstack([all_normals, normals]) + all_areas = np.vstack([all_areas, area]) + + # Construct edges for the combined point cloud at this level + nbrs_points = NearestNeighbors( + n_neighbors=node_degree + 1, algorithm="ball_tree" + ).fit(all_points) + _, indices_within = nbrs_points.kneighbors(all_points) + src_within = [i for i in range(len(all_points)) for _ in range(node_degree)] + dst_within = indices_within[:, 1:].flatten() + + # Add the within-level edges + edge_sources.extend(src_within) + edge_destinations.extend(dst_within) + + # Now, compute pressure and shear stress for the final combined point cloud + _, indices = nbrs_surface.kneighbors(all_points) + indices = indices.flatten() + + pressure = pressure_ref[indices] + shear_stress = shear_stress_ref[indices] + + except Exception as e: + print(f"Error processing run {run_id}: {e}. Skipping this run...") + return + + try: + # Create the final graph with multi-level edges + graph = dgl.graph((edge_sources, edge_destinations)) + graph = dgl.remove_self_loop(graph) + graph = dgl.to_simple(graph) + graph = dgl.to_bidirected(graph, copy_ndata=True) + graph = dgl.add_self_loop(graph) + + graph.ndata["coordinates"] = torch.tensor(all_points, dtype=torch.float32) + graph.ndata["normals"] = torch.tensor(all_normals, dtype=torch.float32) + graph.ndata["area"] = torch.tensor(all_areas, dtype=torch.float32) + graph.ndata["pressure"] = torch.tensor(pressure, dtype=torch.float32).unsqueeze( + -1 + ) + graph.ndata["shear_stress"] = torch.tensor(shear_stress, dtype=torch.float32) + graph = add_edge_features(graph) + + # Partition the graph + partitioned_graphs = process_partition(graph, num_partitions, halo_hops) + + # Save the partitions + save_graphs(partition_file_path, partitioned_graphs) + + if save_point_cloud: + point_cloud = pv.PolyData(graph.ndata["coordinates"].numpy()) + point_cloud["coordinates"] = graph.ndata["coordinates"].numpy() + point_cloud["normals"] = graph.ndata["normals"].numpy() + point_cloud["area"] = graph.ndata["area"].numpy() + point_cloud["pressure"] = graph.ndata["pressure"].numpy() + point_cloud["shear_stress"] = graph.ndata["shear_stress"].numpy() + point_cloud.save(f"point_clouds/point_cloud_{run_id}.vtp") + + except Exception as e: + print( + f"Error while constructing graph or saving data for run {run_id}: {e}. Skipping this run..." + ) + return + + +def process_all_runs( + base_path, + num_points, + node_degree, + num_partitions, + halo_hops, + num_workers=16, + save_point_cloud=False, +): + """Process all runs in the base directory in parallel.""" + + run_dirs = [ + os.path.join(base_path, d) + for d in os.listdir(base_path) + if d.startswith("run_") and os.path.isdir(os.path.join(base_path, d)) + ] + + tasks = [ + (run_dir, num_points, node_degree, num_partitions, halo_hops, save_point_cloud) + for run_dir in run_dirs + ] + + with ProcessPoolExecutor(max_workers=num_workers) as pool: + for _ in tqdm( + pool.map(run_task, tasks), + total=len(tasks), + desc="Processing Runs", + unit="run", + ): + pass + + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: + process_all_runs( + base_path=to_absolute_path(cfg.data_path), + num_points=cfg.num_nodes, + node_degree=cfg.node_degree, + num_partitions=cfg.num_partitions, + halo_hops=cfg.num_message_passing_layers, + num_workers=cfg.num_preprocess_workers, + save_point_cloud=cfg.save_point_clouds, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/cfd/xaeronet/surface/train.py b/examples/cfd/xaeronet/surface/train.py new file mode 100644 index 0000000000..19ac0d0bb0 --- /dev/null +++ b/examples/cfd/xaeronet/surface/train.py @@ -0,0 +1,404 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code defines a distributed training pipeline for training MeshGraphNet at scale, +which operates on partitioned graph data for the AWS drivaer dataset. It includes +loading partitioned graphs from .bin files, normalizing node and edge features using +precomputed statistics, and training the model in parallel using DistributedDataParallel +across multiple GPUs. The training loop involves computing predictions for each graph +partition, calculating loss, and updating model parameters using mixed precision. +Periodic checkpointing is performed to save the model, optimizer state, and training +progress. Validation is also conducted every few epochs, where predictions are compared +against ground truth values, and results are saved as point clouds. The code logs training +and validation metrics to TensorBoard and optionally integrates with Weights and Biases for +experiment tracking. +""" + +import os +import sys +import json +import dgl +import pyvista as pv +import torch +import hydra +import numpy as np +from hydra.utils import to_absolute_path +from torch.nn.parallel import DistributedDataParallel +import torch.optim as optim +from torch.cuda.amp import GradScaler +from torch.utils.tensorboard import SummaryWriter +from omegaconf import DictConfig + +from modulus.distributed import DistributedManager +from modulus.launch.logging import initialize_wandb +from modulus.models.meshgraphnet import MeshGraphNet + +# Get the absolute path to the parent directory +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.append(parent_dir) + +from dataloader import create_dataloader +from utils import ( + find_bin_files, + save_checkpoint, + load_checkpoint, + count_trainable_params, +) + + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: + + # Enable cuDNN auto-tuner + torch.backends.cudnn.benchmark = cfg.enable_cudnn_benchmark + + # Instantiate the distributed manager + DistributedManager.initialize() + dist = DistributedManager() + device = dist.device + print(f"Rank {dist.rank} of {dist.world_size}") + + # Instantiate the writers + if dist.rank == 0: + writer = SummaryWriter(log_dir="tensorboard") + initialize_wandb( + project="aws_drivaer", + entity="Modulus", + name="aws_drivaer", + mode="disabled", + group="group", + save_code=True, + ) + + # AMP Configs + amp_dtype = torch.bfloat16 + amp_device = "cuda" + + # Find all .bin files in the directory + train_dataset = find_bin_files(to_absolute_path(cfg.partitions_path)) + valid_dataset = find_bin_files(to_absolute_path(cfg.validation_partitions_path)) + + # Prepare the stats + with open(to_absolute_path(cfg.stats_file), "r") as f: + stats = json.load(f) + mean = stats["mean"] + std = stats["std_dev"] + + # Create DataLoader + train_dataloader = create_dataloader( + train_dataset, + mean, + std, + batch_size=1, + prefetch_factor=None, + use_ddp=True, + num_workers=4, + ) + # graphs is a list of graphs, each graph is a list of partitions + graphs = [graph_partitions for graph_partitions, _ in train_dataloader] + + if dist.rank == 0: + validation_dataloader = create_dataloader( + valid_dataset, + mean, + std, + batch_size=1, + prefetch_factor=None, + use_ddp=False, + num_workers=4, + ) + validation_graphs = [ + graph_partitions for graph_partitions, _ in validation_dataloader + ] + validation_ids = [id[0] for _, id in validation_dataloader] + print(f"Training dataset size: {len(graphs)*dist.world_size}") + print(f"Validation dataset size: {len(validation_dataloader)}") + + ###################################### + # Training # + ###################################### + + # Initialize model + model = MeshGraphNet( + input_dim_nodes=24, + input_dim_edges=4, + output_dim=4, + processor_size=cfg.num_message_passing_layers, + aggregation="sum", + hidden_dim_node_encoder=cfg.hidden_dim, + hidden_dim_edge_encoder=cfg.hidden_dim, + hidden_dim_node_decoder=cfg.hidden_dim, + mlp_activation_fn=cfg.activation, + do_concat_trick=cfg.use_concat_trick, + num_processor_checkpoint_segments=cfg.checkpoint_segments, + ).to(device) + print(f"Number of trainable parameters: {count_trainable_params(model)}") + + # DistributedDataParallel wrapper + if dist.world_size > 1: + model = DistributedDataParallel( + model, + device_ids=[dist.local_rank], + output_device=dist.device, + broadcast_buffers=dist.broadcast_buffers, + find_unused_parameters=dist.find_unused_parameters, + gradient_as_bucket_view=True, + static_graph=True, + ) + + # Optimizer and scheduler + optimizer = optim.Adam(model.parameters(), lr=0.001) + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=2000, eta_min=1e-6 + ) + scaler = GradScaler() + print("Instantiated the model and optimizer") + + # Check if there's a checkpoint to resume from + start_epoch, _ = load_checkpoint( + model, optimizer, scaler, scheduler, cfg.checkpoint_filename + ) + + # Training loop + print("Training started") + for epoch in range(start_epoch, cfg.num_epochs): + model.train() + total_loss = 0 + for i in range(len(graphs)): + optimizer.zero_grad() + subgraphs = graphs[i] # Get the partitions of the graph + for j in range(cfg.num_partitions): + with torch.autocast(amp_device, enabled=True, dtype=amp_dtype): + part = subgraphs[j].to(device) + ndata = torch.cat( + ( + part.ndata["coordinates"], + part.ndata["normals"], + torch.sin(2 * np.pi * part.ndata["coordinates"]), + torch.cos(2 * np.pi * part.ndata["coordinates"]), + torch.sin(4 * np.pi * part.ndata["coordinates"]), + torch.cos(4 * np.pi * part.ndata["coordinates"]), + torch.sin(8 * np.pi * part.ndata["coordinates"]), + torch.cos(8 * np.pi * part.ndata["coordinates"]), + ), + dim=1, + ) + pred = model(ndata, part.edata["x"], part) + pred_filtered = pred[part.ndata["inner_node"].bool(), :] + target = torch.cat( + (part.ndata["pressure"], part.ndata["shear_stress"]), dim=1 + ) + target_filtered = target[part.ndata["inner_node"].bool()] + loss = ( + torch.mean((pred_filtered - target_filtered) ** 2) + / cfg.num_partitions + ) + total_loss += loss.item() + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), 32.0) + scaler.step(optimizer) + scaler.update() + scheduler.step() + + # Log the training loss + if dist.rank == 0: + current_lr = optimizer.param_groups[0]["lr"] + print( + f"Epoch {epoch+1}, Learning Rate: {current_lr}, Total Loss: {total_loss / len(graphs)}" + ) + writer.add_scalar("training_loss", total_loss / len(graphs), epoch) + writer.add_scalar("learning_rate", current_lr, epoch) + + # Save checkpoint periodically + if (epoch) % cfg.save_checkpoint_freq == 0: + if dist.world_size > 1: + torch.distributed.barrier() + if dist.rank == 0: + save_checkpoint( + model, + optimizer, + scaler, + scheduler, + epoch + 1, + loss.item(), + cfg.checkpoint_filename, + ) + + ###################################### + # Validation # + ###################################### + + if dist.rank == 0 and epoch % cfg.validation_freq == 0: + valid_loss = 0 + + for i in range(len(validation_graphs)): + # Placeholder to accumulate predictions and node features for the full graph's nodes + num_nodes = sum( + [subgraph.num_nodes() for subgraph in validation_graphs[i]] + ) + + # Initialize accumulators for predictions and node features + pressure_pred = torch.zeros( + (num_nodes, 1), dtype=torch.float32, device=device + ) + shear_stress_pred = torch.zeros( + (num_nodes, 3), dtype=torch.float32, device=device + ) + pressure_true = torch.zeros( + (num_nodes, 1), dtype=torch.float32, device=device + ) + shear_stress_true = torch.zeros( + (num_nodes, 3), dtype=torch.float32, device=device + ) + coordinates = torch.zeros( + (num_nodes, 3), dtype=torch.float32, device=device + ) + normals = torch.zeros( + (num_nodes, 3), dtype=torch.float32, device=device + ) + area = torch.zeros((num_nodes, 1), dtype=torch.float32, device=device) + + # Accumulate predictions and node features from all partitions + for j in range(cfg.num_partitions): + part = validation_graphs[i][j].to(device) + + # Get node features (coordinates and normals) + ndata = torch.cat( + ( + part.ndata["coordinates"], + part.ndata["normals"], + torch.sin(2 * np.pi * part.ndata["coordinates"]), + torch.cos(2 * np.pi * part.ndata["coordinates"]), + torch.sin(4 * np.pi * part.ndata["coordinates"]), + torch.cos(4 * np.pi * part.ndata["coordinates"]), + torch.sin(8 * np.pi * part.ndata["coordinates"]), + torch.cos(8 * np.pi * part.ndata["coordinates"]), + ), + dim=1, + ) + + with torch.no_grad(): + with torch.autocast(amp_device, enabled=True, dtype=amp_dtype): + pred = model(ndata, part.edata["x"], part) + pred_filtered = pred[part.ndata["inner_node"].bool()] + target = torch.cat( + (part.ndata["pressure"], part.ndata["shear_stress"]), + dim=1, + ) + target_filtered = target[part.ndata["inner_node"].bool()] + loss = ( + torch.mean((pred_filtered - target_filtered) ** 2) + / cfg.num_partitions + ) + valid_loss += loss.item() + + # Store the predictions based on the original node IDs (using `dgl.NID`) + original_nodes = part.ndata[dgl.NID] + inner_original_nodes = original_nodes[ + part.ndata["inner_node"].bool() + ] + + # Accumulate the predictions + pressure_pred[inner_original_nodes] = ( + pred_filtered[:, 0:1].clone().to(torch.float32) + ) + shear_stress_pred[inner_original_nodes] = ( + pred_filtered[:, 1:].clone().to(torch.float32) + ) + + # Accumulate the ground truth + pressure_true[inner_original_nodes] = ( + target_filtered[:, 0:1].clone().to(torch.float32) + ) + shear_stress_true[inner_original_nodes] = ( + target_filtered[:, 1:].clone().to(torch.float32) + ) + + # Accumulate the node features + coordinates[original_nodes] = ( + part.ndata["coordinates"].clone().to(torch.float32) + ) + normals[original_nodes] = ( + part.ndata["normals"].clone().to(torch.float32) + ) + area[original_nodes] = ( + part.ndata["area"].clone().to(torch.float32) + ) + + # Denormalize predictions and node features using the global stats + pressure_pred_denorm = ( + pressure_pred.cpu() * torch.tensor(std["pressure"]) + ) + torch.tensor(mean["pressure"]) + shear_stress_pred_denorm = ( + shear_stress_pred.cpu() * torch.tensor(std["shear_stress"]) + ) + torch.tensor(mean["shear_stress"]) + pressure_true_denorm = ( + pressure_true.cpu() * torch.tensor(std["pressure"]) + ) + torch.tensor(mean["pressure"]) + shear_stress_true_denorm = ( + shear_stress_true.cpu() * torch.tensor(std["shear_stress"]) + ) + torch.tensor(mean["shear_stress"]) + coordinates_denorm = ( + coordinates.cpu() * torch.tensor(std["coordinates"]) + ) + torch.tensor(mean["coordinates"]) + normals_denorm = ( + normals.cpu() * torch.tensor(std["normals"]) + ) + torch.tensor(mean["normals"]) + area_denorm = (area.cpu() * torch.tensor(std["area"])) + torch.tensor( + mean["area"] + ) + + # Save the full point cloud after accumulating all partition predictions + # Create a PyVista PolyData object for the point cloud + point_cloud = pv.PolyData(coordinates_denorm.numpy()) + point_cloud["coordinates"] = coordinates_denorm.numpy() + point_cloud["normals"] = normals_denorm.numpy() + point_cloud["area"] = area_denorm.numpy() + point_cloud["pressure_pred"] = pressure_pred_denorm.numpy() + point_cloud["shear_stress_pred"] = shear_stress_pred_denorm.numpy() + point_cloud["pressure_true"] = pressure_true_denorm.numpy() + point_cloud["shear_stress_true"] = shear_stress_true_denorm.numpy() + + # Save the point cloud + point_cloud.save(f"point_cloud_{validation_ids[i]}.vtp") + + print( + f"Epoch {epoch+1}, Validation Error: {valid_loss / len(validation_graphs)}" + ) + writer.add_scalar( + "validation_loss", valid_loss / len(validation_graphs), epoch + ) + + # Save final checkpoint + if dist.world_size > 1: + torch.distributed.barrier() + if dist.rank == 0: + save_checkpoint( + model, + optimizer, + scaler, + scheduler, + cfg.num_epochs, + loss.item(), + "final_model_checkpoint.pth", + ) + print("Training complete") + + +if __name__ == "__main__": + main() diff --git a/examples/cfd/xaeronet/utils.py b/examples/cfd/xaeronet/utils.py new file mode 100644 index 0000000000..e9c9064625 --- /dev/null +++ b/examples/cfd/xaeronet/utils.py @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch + +""" +This code provides utilities for file handling, checkpoint management, +model evaluation, and loss computation. It includes functions to find .bin +and .h5 files, save and load model checkpoints (including model state, +optimizer, scaler, and scheduler), and count the number of trainable parameters +in a model. Additionally, it offers a custom loss function to calculate the +continuity loss for a velocity field using central difference approximations +and a signed distance field (SDF) mask to restrict the computation to valid +regions. +""" + + +def find_bin_files(data_path): + """ + Finds all .bin files in the specified directory. + """ + return [ + os.path.join(data_path, f) for f in os.listdir(data_path) if f.endswith(".bin") + ] + + +def find_h5_files(directory): + """ + Recursively finds all .h5 files in the given directory. + """ + h5_files = [] + for root, _, files in os.walk(directory): + for file in files: + if file.endswith(".h5"): + h5_files.append(os.path.join(root, file)) + return h5_files + + +def save_checkpoint(model, optimizer, scaler, scheduler, epoch, loss, filename): + checkpoint = { + "epoch": epoch, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "scaler": scaler.state_dict(), + "scheduler_state_dict": scheduler.state_dict(), + "loss": loss, + } + torch.save(checkpoint, filename) + print(f"Checkpoint saved: {filename}") + + +def load_checkpoint(model, optimizer, scaler, scheduler, filename): + if os.path.isfile(filename): + checkpoint = torch.load(filename) + model.load_state_dict(checkpoint["model_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + scaler.load_state_dict(checkpoint["scaler"]) + scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) + epoch = checkpoint["epoch"] + loss = checkpoint["loss"] + print(f"Checkpoint loaded: {filename}") + return epoch, loss + else: + print(f"No checkpoint found at {filename}") + return 0, None + + +def count_trainable_params(model: torch.nn.Module) -> int: + """Count the number of trainable parameters in a model. + + Args: + model (torch.nn.Module): Model to count parameters of. + + Returns: + int: Number of trainable parameters. + """ + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def calculate_continuity_loss(u, sdf): + """Calculate the continuity residual of a velocity field on a uniform grid.""" + i, j, k = u.shape[2], u.shape[3], u.shape[4] + + # First-order central difference approximations (up to a constant) + u__x = u[:, 0, 2:i, 1:-1, 1:-1] - u[:, 0, 0 : i - 2, 1:-1, 1:-1] + v__y = u[:, 1, 1:-1, 2:j, 1:-1] - u[:, 1, 1:-1, 0 : j - 2, 1:-1] + w__z = u[:, 2, 1:-1, 1:-1, 2:k] - u[:, 2, 1:-1, 1:-1, 0 : k - 2] + + sdf = sdf[:, 1:-1, 1:-1, 1:-1] + mask = (sdf > 0).squeeze() + + residual = u__x + v__y + w__z + + return torch.mean(residual[:, mask] ** 2) diff --git a/examples/cfd/xaeronet/volume/compute_stats.py b/examples/cfd/xaeronet/volume/compute_stats.py new file mode 100644 index 0000000000..895764ce3b --- /dev/null +++ b/examples/cfd/xaeronet/volume/compute_stats.py @@ -0,0 +1,144 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code processes voxel data stored in .h5 files to compute global +mean and standard deviation, for various data fields. It identifies +all .h5 files in a directory, processes each file to accumulate statistics for +specific fields (like coordinates and pressure), and then aggregates the results +across all files. The code supports parallel processing to handle multiple files +simultaneously, speeding up the computation. Finally, the global statistics are +saved to a JSON file. +""" + +import os +import sys +import h5py +import numpy as np +import json +import hydra + +from tqdm import tqdm +from concurrent.futures import ProcessPoolExecutor +from hydra import to_absolute_path +from omegaconf import DictConfig + +# Get the absolute path to the parent directory +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.append(parent_dir) + +from utils import find_h5_files + + +def process_file(h5_file): + """ + Processes a single .h5 file to compute the sum, sum of squares, and count for each variable. + """ + print(h5_file) + with h5py.File(h5_file, "r") as hf: + data = hf["data"][:] + nan_mask = np.isnan(data) + sum_data = np.mean(data, axis=(1, 2, 3), where=~nan_mask) + sum_squares = np.mean(data**2, axis=(1, 2, 3), where=~nan_mask) + + return sum_data, sum_squares + + +def aggregate_results(results): + """ + Aggregates the results from all files to compute global mean and standard deviation. + """ + total_sum = None + total_sum_squares = None + total_count = 0 + + for sum_data, sum_squares in results: + if total_sum is None: + total_sum = np.zeros(sum_data.shape) + total_sum_squares = np.zeros(sum_squares.shape) + + total_sum += sum_data + total_sum_squares += sum_squares + total_count += 1 + + global_mean = total_sum / total_count + global_variance = (total_sum_squares / total_count) - (global_mean**2) + global_std = np.sqrt(global_variance) + + return global_mean, global_std + + +def compute_global_stats(h5_files, num_workers=4): + """ + Computes the global mean and standard deviation for each variable across all .h5 files + using parallel processing. + """ + with ProcessPoolExecutor(max_workers=num_workers) as executor: + results = list( + tqdm( + executor.map(process_file, h5_files), + total=len(h5_files), + desc="Processing H5 Files", + unit="file", + ) + ) + + # Aggregate the results from all files + global_mean, global_std = aggregate_results(results) + + return global_mean, global_std + + +def save_stats_to_json(mean, std_dev, output_file): + """ + Saves the global mean and standard deviation to a JSON file. + """ + stats = { + "mean": mean.tolist(), # Convert numpy arrays to lists + "std_dev": std_dev.tolist(), # Convert numpy arrays to lists + } + + with open(output_file, "w") as f: + json.dump(stats, f, indent=4) + + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: + + data_path = to_absolute_path( + cfg.partitions_path + ) # Directory containing the .bin graph files with partitions + output_file = to_absolute_path(cfg.stats_file) # File to save the global statistics + + # Find all .h5 files in the directory + h5_files = find_h5_files(data_path) + + # Compute global statistics with parallel processing + global_mean, global_std = compute_global_stats( + h5_files, num_workers=cfg.num_preprocess_workers + ) + + # Save statistics to a JSON file + save_stats_to_json(global_mean, global_std, output_file) + + # Print the results + print("Global Mean:", global_mean) + print("Global Standard Deviation:", global_std) + print(f"Statistics saved to {output_file}") + + +if __name__ == "__main__": + main() diff --git a/examples/cfd/xaeronet/volume/conf/config.yaml b/examples/cfd/xaeronet/volume/conf/config.yaml new file mode 100644 index 0000000000..4cd0cbbc40 --- /dev/null +++ b/examples/cfd/xaeronet/volume/conf/config.yaml @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: true + name: XAeroNetV + run: + dir: ./outputs/${hydra:job.name} + +# ┌───────────────────────────────────────────┐ +# │ Data Preprocessing │ +# └───────────────────────────────────────────┘ + +num_voxels_x: 700 # Number of voxels in x direction +num_voxels_y: 256 # Number of voxels in y direction +num_voxels_z: 128 # Number of voxels in z direction +spacing: 0.015 # Spacing between the voxels (unit is meters) +grid_origin_x: -3.08 # Origin of the grid in x direction +grid_origin_y: -1.92 # Origin of the grid in y direction +grid_origin_z: -0.32 # Origin of the grid in z direction +num_partitions: 7 # Number of partitions for each voxel grid +partition_width: 100 # Width of each partition (in x-direction only) +halo_width: 40 # Width of the halo region (in x-direction only) +data_path: /data/drivaer_aws/drivaer_data_full # Path to the raw data +num_preprocess_workers: 32 # Number of workers for data preprocessing +save_vti: false # Save a .vti file for the preprocessed voxel data + +# ┌───────────────────────────────────────────┐ +# │ Model Configuration │ +# └───────────────────────────────────────────┘ + +initial_hidden_dim: 64 # Hidden dimension in the first level +activation: gelu # Activation function +use_attn_gate: true # Use attention gate +attn_intermediate_channels: 256 # Intermediate channels in the attention gate + +# ┌───────────────────────────────────────────┐ +# │ Training Configuration │ +# └───────────────────────────────────────────┘ + +h5_path: drivaer_aws_h5 # Path to the h5 files containing the voxel grids for training +validation_h5_path: drivaer_aws_h5_validation # Path to the h5 files containing the voxel grids for validation +stats_file: global_stats.json # Path to the global statistics (.json file) +checkpoint_filename: model_checkpoint.pth # Filename of the model checkpoint +num_epochs: 2000 # Number of epochs +start_lr: 0.00015 # Initial learning rate (cos annealing schedule is used) +end_lr: 0.0000005 # Final learning rate (cos annealing schedule is used) +save_checkpoint_freq: 5 # Frequency of saving the model checkpoint +validation_freq: 50 # Frequency of validation +continuity_lambda: 0.05 # Continuity loss weight + +# ┌───────────────────────────────────────────┐ +# │ Performance Optimization │ +# └───────────────────────────────────────────┘ + +gradient_checkpointing: true # use activation checkpointing +enable_cudnn_benchmark: true # Enable cudnn benchmark \ No newline at end of file diff --git a/examples/cfd/xaeronet/volume/dataloader.py b/examples/cfd/xaeronet/volume/dataloader.py new file mode 100644 index 0000000000..f3e737a16d --- /dev/null +++ b/examples/cfd/xaeronet/volume/dataloader.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code defines a custom dataset class H5Dataset for loading and +normalizing data stored in .h5 files. The dataset is initialized with +a list of file paths and global statistics (mean and standard deviation) +for normalizing the data. The data is normalized using z-score normalization, +and NaN values can be replaced with zeros. The code also provides a function +create_dataloader to create a PyTorch DataLoader for efficient batch loading +with configurable parameters such as batch size, number of workers, and +prefetching. This setup is ideal for handling large datasets stored in .h5 +files while leveraging parallel data loading for efficiency. +""" + +import os +import sys +import json +import h5py +import torch +import numpy as np + +from torch.utils.data import Dataset, DataLoader + + +# Get the absolute path to the parent directory +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.append(parent_dir) + +from utils import find_h5_files + + +class H5Dataset(Dataset): # TODO: Use a Dali datapipe for better performance + + """ + Custom dataset class for loading + + Parameters: + ---------- + file_list (list of str): List of paths to .h5 files. + mean (np.ndarray): Global mean for normalization. + std (np.ndarray): Global standard deviation for normalization. + """ + + def __init__(self, file_list, mean, std, nan_to_0=True): + self.file_list = file_list + self.mean = mean + self.std = std + self.nan_to_0 = nan_to_0 + + def __len__(self): + return len(self.file_list) + + def __getitem__(self, idx): + file_path = self.file_list[idx] + with h5py.File(file_path, "r") as hf: + data = hf["data"][:] + + # Normalize data using z-score + data = (data - self.mean[:, None, None, None]) / self.std[:, None, None, None] + + # Convert to PyTorch tensor + data_tensor = torch.tensor(data, dtype=torch.float32) + + # Replace nan with zeros + if self.nan_to_0: + data_tensor = torch.nan_to_num(data_tensor, nan=0.0) + + return data_tensor + + +def create_dataloader( + file_list, + mean, + std, + sampler, + nan_to_0=True, + batch_size=1, + num_workers=4, + pin_memory=True, + prefetch_factor=2, +): + """ + Creates a DataLoader for the H5Dataset with prefetching. + + Args: + file_list (list of str): List of paths to .h5 files. + mean (np.ndarray): Global mean for normalization. + std (np.ndarray): Global standard deviation for normalization. + batch_size (int): Number of samples per batch. + num_workers (int): Number of worker processes for data loading. + pin_memory (bool): If True, the data loader will copy tensors into CUDA pinned memory. + prefetch_factor (int): Number of samples to prefetch. + + Returns: + DataLoader: Configured DataLoader for the dataset. + """ + dataset = H5Dataset(file_list, mean, std, nan_to_0) + dataloader = DataLoader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + prefetch_factor=prefetch_factor, + sampler=sampler, + ) + return dataloader + + +if __name__ == "__main__": + data_path = "drivaer_aws_h5" + stats_file = "global_stats.json" + + # Load global statistics + with open(stats_file, "r") as f: + stats = json.load(f) + mean = np.array(stats["mean"]) + std = np.array(stats["std_dev"]) + + # Find all .h5 files in the directory + file_list = find_h5_files(data_path) + + # Create DataLoader + dataloader = create_dataloader( + file_list, mean, std, nan_to_0=True, batch_size=2, num_workers=1 + ) + + # Example usage + for batch in dataloader: + print(batch.shape) # Print batch shape diff --git a/examples/cfd/xaeronet/volume/partition.py b/examples/cfd/xaeronet/volume/partition.py new file mode 100644 index 0000000000..cc627b6de4 --- /dev/null +++ b/examples/cfd/xaeronet/volume/partition.py @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code partitions large data batches into smaller sub-batches, while +managing boundary regions (halos) and applying filters for physics-based +computations. It provides a function to partition a single batch and handle +inner and halo regions, and uses parallel processing to efficiently partition +multiple batches from a data loader simultaneously. This setup is particularly +useful for distributed or large-scale computations where handling boundaries +between partitions is critical for accuracy and performance. +""" + +import numpy as np +import concurrent.futures + + +def partition_batch(batch, num_partitions, partition_width, halo_width): + # Preallocate data list and filter list + data = [None] * num_partitions + filter = np.zeros((num_partitions, partition_width + 2 * halo_width), dtype=bool) + phys_filter = np.zeros( + (num_partitions, partition_width + 2 * halo_width), dtype=bool + ) + + # Handle first partition + data[0] = batch[:, :, 0 : partition_width + 2 * halo_width, :, :] + + # Handle middle partitions + for i in range(1, num_partitions - 1): + start_idx = i * partition_width - halo_width + end_idx = (i + 1) * partition_width + halo_width + data[i] = batch[:, :, start_idx:end_idx, :, :] + + # Handle last partition + data[num_partitions - 1] = batch[ + :, :, (num_partitions - 1) * partition_width - 2 * halo_width :, :, : + ] + + # Create filter for inner nodes + filter[0, 0:partition_width] = True + filter[1 : num_partitions - 1, halo_width : partition_width + halo_width] = True + filter[num_partitions - 1, 2 * halo_width :] = True + + # Create padded filters for physics loss + phys_filter[0, 0 : partition_width + 1] = True + phys_filter[ + 1 : num_partitions - 1, halo_width - 1 : partition_width + halo_width + 1 + ] = True + phys_filter[num_partitions - 1, 2 * halo_width - 1 :] = True + + return data, filter, phys_filter + + +# Function to process each batch (partitioning and filtering) +def process_batch(batch, num_partitions, partition_width, halo_width): + data_i, filter_i, phys_filter_i = partition_batch( + batch, num_partitions, partition_width, halo_width + ) + return data_i, filter_i, phys_filter_i, batch + + +# Efficient processing of valid_dataloader batches using parallelism +def parallel_partitioning( + dataloader, num_partitions=7, partition_width=100, halo_width=40 +): + data, filter, phys_filter, batch = [], [], [], [] + + # Use ThreadPoolExecutor for CPU-bound tasks + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [] + + # Submit tasks in parallel + for batch_i in dataloader: + futures.append( + executor.submit( + process_batch, batch_i, num_partitions, partition_width, halo_width + ) + ) + + # Collect results as they complete + for future in concurrent.futures.as_completed(futures): + data_i, filter_i, phys_filter_i, batch_i = future.result() + data.append(data_i) + filter.append(filter_i) + phys_filter.append(phys_filter_i) + batch.append(batch_i) + + print("Partitioning completed") + return data, filter, phys_filter, batch diff --git a/examples/cfd/xaeronet/volume/preprocessor.py b/examples/cfd/xaeronet/volume/preprocessor.py new file mode 100644 index 0000000000..0b8c037b7e --- /dev/null +++ b/examples/cfd/xaeronet/volume/preprocessor.py @@ -0,0 +1,273 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This code processes mesh data from .vtu and .stl (or .vtp) files to create +voxel grids for large-scale simulations. The process involves converting +unstructured grids (from .vtu files) into voxel grids, extracting surface +triangles and vertices from the mesh files, and calculating the signed distance +field (SDF) and its derivatives (DSDF). The SDF is computed using the mesh surface +and the voxel grid. The resulting data, which includes voxel vertices, SDF, DSDF, +velocity (U), and pressure (p), is saved in an HDF5 format for training. The code +supports multiprocessing to process multiple files concurrently and can optionally +save the voxel grids as .vti files for debugging or visualization. +""" + +import vtk +import pyvista as pv +import numpy as np +import h5py +import os +import hydra + +from multiprocessing import Pool +from tqdm import tqdm +from pyvista.core import _vtk_core as _vtk +from vtk import vtkDataSetTriangleFilter +from modulus.datapipes.cae.readers import read_vtp, read_vtu, read_stl +from hydra.utils import to_absolute_path +from omegaconf import DictConfig + +from sdf import signed_distance_field + + +def unstructured2voxel( + unstructured_grid, grid_size, bounds, write=False, output_filename="image.vti" +): + """Converts an unstructured grid to a voxel grid (structured grid) using resampling.""" + resampler = vtk.vtkResampleToImage() + resampler.AddInputDataObject(unstructured_grid) + resampler.UseInputBoundsOff() + resampler.SetSamplingDimensions(*grid_size) + + if not bounds: + bounds = unstructured_grid.GetBounds() + resampler.SetSamplingBounds(bounds) + + resampler.Update() + voxel_grid = resampler.GetOutput() + + if write: + writer = vtk.vtkXMLImageDataWriter() + writer.SetFileName(output_filename) + writer.SetInputData(voxel_grid) + writer.Write() + + return voxel_grid + + +def convert_to_triangular_mesh( + polydata, write=False, output_filename="surface_mesh_triangular.vtu" +): + """Converts a vtkPolyData object to a triangular mesh.""" + tet_filter = vtkDataSetTriangleFilter() + tet_filter.SetInputData(polydata) + tet_filter.Update() + + tet_mesh = pv.wrap(tet_filter.GetOutput()) + + if write: + tet_mesh.save(output_filename) + + return tet_mesh + + +def extract_surface_triangles(tet_mesh): + """Extracts the surface triangles from a triangular mesh.""" + surface_filter = vtk.vtkDataSetSurfaceFilter() + surface_filter.SetInputData(tet_mesh) + surface_filter.Update() + + surface_mesh = pv.wrap(surface_filter.GetOutput()) + triangle_indices = [] + faces = surface_mesh.faces.reshape((-1, 4)) + for face in faces: + if face[0] == 3: + triangle_indices.extend([face[1], face[2], face[3]]) + else: + raise ValueError("Face is not a triangle") + + return triangle_indices + + +def fetch_mesh_vertices(mesh): + """Fetches the vertices of a mesh.""" + points = mesh.GetPoints() + num_points = points.GetNumberOfPoints() + vertices = [points.GetPoint(i) for i in range(num_points)] + return vertices + + +def get_cell_centers(voxel_grid): + """Extracts the cell centers from a voxel grid.""" + cell_centers_filter = vtk.vtkCellCenters() + cell_centers_filter.SetInputData(voxel_grid) + cell_centers_filter.Update() + + cell_centers = cell_centers_filter.GetOutput() + points = cell_centers.GetPoints() + centers = [points.GetPoint(i) for i in range(points.GetNumberOfPoints())] + return np.array(centers) + + +def process_file(task): + """Process a single pair of VTU and VTP/STL files and save the output.""" + ( + vtu_path, + surface_mesh_path, + grid_size, + bounds, + output_dir, + surface_mesh_file_format, + save_vti, + ) = task + + vtu_mesh = read_vtu(vtu_path) + + grid_size_expanded = tuple( + s + 1 for s in grid_size + ) # Add 1 to each dimension for the voxel grid + voxel_grid = unstructured2voxel(vtu_mesh, grid_size_expanded, bounds) + if surface_mesh_file_format == "vtp": + surface_mesh = read_vtp(surface_mesh_path) + surface_mesh = convert_to_triangular_mesh(surface_mesh) + else: + surface_mesh = read_stl(surface_mesh_path) + triangle_indices = extract_surface_triangles(surface_mesh) + surface_vertices = fetch_mesh_vertices(surface_mesh) + volume_vertices = get_cell_centers(voxel_grid) + + sdf, dsdf = signed_distance_field( + surface_vertices, triangle_indices, volume_vertices, include_hit_points=True + ) + + sdf = sdf.numpy() + dsdf = dsdf.numpy() + dsdf = -(dsdf - volume_vertices) + dsdf = dsdf / np.linalg.norm(dsdf, axis=1, keepdims=True) + + voxel_grid = pv.wrap(voxel_grid).point_data_to_cell_data() + data = voxel_grid.cell_data + U = _vtk.vtk_to_numpy(data["UMeanTrim"]) + p = _vtk.vtk_to_numpy(data["pMeanTrim"]) + + # Reshape the arrays according to the voxel grid dimensions + volume_vertices = np.transpose(volume_vertices) + sdf = np.expand_dims(sdf, axis=0) + U = np.transpose(U) + p = np.expand_dims(p, axis=0) + volume_vertices = volume_vertices.reshape(3, *grid_size, order="F") + sdf = sdf.reshape(1, *grid_size, order="F") + dsdf = np.transpose(dsdf) + dsdf = dsdf.reshape(3, *grid_size, order="F") + U = U.reshape(3, *grid_size, order="F") + p = p.reshape(1, *grid_size, order="F") + + # Create a merged array maintaining the voxel shape + merged_array = np.concatenate([volume_vertices, sdf, dsdf, U, p], axis=0) + os.makedirs(output_dir, exist_ok=True) + output_filename = os.path.join( + output_dir, os.path.basename(vtu_path).replace(".vtu", ".h5") + ) + + with h5py.File(output_filename, "w") as hf: + hf.create_dataset("data", data=merged_array) + + # Optionally save voxel grid as .vti for debugging + if save_vti: + voxel_grid.cell_data["SDF"] = sdf.flatten(order="F") + voxel_grid.cell_data["DSDFx"] = dsdf[0].flatten(order="F") + voxel_grid.cell_data["DSDFy"] = dsdf[1, :].flatten(order="F") + voxel_grid.cell_data["DSDFz"] = dsdf[2, :].flatten(order="F") + vti_filename = os.path.join( + output_dir, os.path.basename(vtu_path).replace(".vtu", ".vti") + ) + voxel_grid.save(vti_filename) + + +def process_directory( + data_path, + output_base_path, + grid_size, + bounds=None, + surface_mesh_file_format="stl", + num_workers=16, + save_vti=False, +): + """Process all VTU and VTP files in the given directory using multiprocessing with progress tracking.""" + tasks = [] + for root, _, files in os.walk(data_path): + vtu_files = [f for f in files if f.endswith(".vtu")] + for vtu_file in vtu_files: + vtu_path = os.path.join(root, vtu_file) + if surface_mesh_file_format == "vtp": + surface_mesh_path = vtu_path.replace(".vtu", ".vtp") + elif surface_mesh_file_format == "stl": + vtu_id = vtu_file[len("volume_") : -len(".vtu")] # Extract the ID part + surface_mesh_file = f"drivaer_{vtu_id}.stl" + surface_mesh_path = os.path.join(root, surface_mesh_file) + else: + raise ValueError( + f"Unsupported surface mesh file format: {surface_mesh_file_format}" + ) + + if os.path.exists(surface_mesh_path): + relative_path = os.path.relpath(root, data_path) + output_dir = os.path.join(output_base_path, relative_path) + tasks.append( + ( + vtu_path, + surface_mesh_path, + grid_size, + bounds, + output_dir, + surface_mesh_file_format, + save_vti, + ) + ) + else: + print( + f"Warning: Corresponding surface mesh file not found for {vtu_path}" + ) + + # Use multiprocessing to process the tasks with progress tracking + with Pool(num_workers) as pool: + # Use imap_unordered to process tasks as they complete + for _ in tqdm( + pool.imap_unordered(process_file, tasks), + total=len(tasks), + desc="Processing Files", + unit="file", + ): + pass + + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: + process_directory( + to_absolute_path(cfg.data_path), + to_absolute_path(cfg.h5_path), + (cfg.num_voxels_x, cfg.num_voxels_y, cfg.num_voxels_z), + (cfg.grid_origin_x, cfg.grid_origin_y, cfg.grid_origin_z), + surface_mesh_file_format="stl", + num_workers=cfg.num_preprocess_workers, + save_vti=cfg.save_vti, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/cfd/xaeronet/volume/sdf.py b/examples/cfd/xaeronet/volume/sdf.py new file mode 100644 index 0000000000..2ccc15d776 --- /dev/null +++ b/examples/cfd/xaeronet/volume/sdf.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: F401 + +import warp as wp +from numpy.typing import NDArray + + +@wp.kernel +def _bvh_query_distance( + mesh: wp.uint64, + points: wp.array(dtype=wp.vec3f), + max_dist: wp.float32, + sdf: wp.array(dtype=wp.float32), + sdf_hit_point: wp.array(dtype=wp.vec3f), + sdf_hit_point_id: wp.array(dtype=wp.int32), +): + + """ + Computes the signed distance from each point in the given array `points` + to the mesh represented by `mesh`,within the maximum distance `max_dist`, + and stores the result in the array `sdf`. + It is different from the `signed_distance_field` in `modulus.utils.sdf` + as it uses the winding number method to compute the signed distance. + + Parameters: + mesh (wp.uint64): The identifier of the mesh. + points (wp.array): An array of 3D points for which to compute the + signed distance. + max_dist (wp.float32): The maximum distance within which to search + for the closest point on the mesh. + sdf (wp.array): An array to store the computed signed distances. + sdf_hit_point (wp.array): An array to store the computed hit points. + sdf_hit_point_id (wp.array): An array to store the computed hit point ids. + + Returns: + None + """ + tid = wp.tid() + + res = wp.mesh_query_point_sign_winding_number(mesh, points[tid], max_dist) + + mesh_ = wp.mesh_get(mesh) + + p0 = mesh_.points[mesh_.indices[3 * res.face + 0]] + p1 = mesh_.points[mesh_.indices[3 * res.face + 1]] + p2 = mesh_.points[mesh_.indices[3 * res.face + 2]] + + p_closest = res.u * p0 + res.v * p1 + (1.0 - res.u - res.v) * p2 + + sdf[tid] = res.sign * wp.abs(wp.length(points[tid] - p_closest)) + sdf_hit_point[tid] = p_closest + sdf_hit_point_id[tid] = res.face + + +def signed_distance_field( + mesh_vertices: list[tuple[float, float, float]], + mesh_indices: NDArray[float], + input_points: list[tuple[float, float, float]], + max_dist: float = 1e8, + include_hit_points: bool = False, + include_hit_points_id: bool = False, +) -> wp.array: + """ + Computes the signed distance field (SDF) for a given mesh and input points. + + Parameters: + ---------- + mesh_vertices (list[tuple[float, float, float]]): List of vertices defining the mesh. + mesh_indices (list[tuple[int, int, int]]): List of indices defining the triangles of the mesh. + input_points (list[tuple[float, float, float]]): List of input points for which to compute the SDF. + max_dist (float, optional): Maximum distance within which to search for + the closest point on the mesh. Default is 1e8. + include_hit_points (bool, optional): Whether to include hit points in + the output. Default is False. + include_hit_points_id (bool, optional): Whether to include hit point + IDs in the output. Default is False. + + Returns: + ------- + wp.array: An array containing the computed signed distance field. + + Example: + ------- + >>> mesh_vertices = [(0, 0, 0), (1, 0, 0), (0, 1, 0)] + >>> mesh_indices = np.array((0, 1, 2)) + >>> input_points = [(0.5, 0.5, 0.5)] + >>> signed_distance_field(mesh_vertices, mesh_indices, input_points).numpy() + Module ... + array([0.5], dtype=float32) + """ + + wp.init() + mesh = wp.Mesh( + wp.array(mesh_vertices, dtype=wp.vec3), wp.array(mesh_indices, dtype=wp.int32) + ) + + sdf_points = wp.array(input_points, dtype=wp.vec3) + sdf = wp.zeros(shape=sdf_points.shape, dtype=wp.float32) + sdf_hit_point = wp.zeros(shape=sdf_points.shape, dtype=wp.vec3f) + sdf_hit_point_id = wp.zeros(shape=sdf_points.shape, dtype=wp.int32) + + wp.launch( + kernel=_bvh_query_distance, + dim=len(sdf_points), + inputs=[mesh.id, sdf_points, max_dist, sdf, sdf_hit_point, sdf_hit_point_id], + ) + + if include_hit_points and include_hit_points_id: + return (sdf, sdf_hit_point, sdf_hit_point_id) + elif include_hit_points: + return (sdf, sdf_hit_point) + elif include_hit_points_id: + return (sdf, sdf_hit_point_id) + else: + return sdf diff --git a/examples/cfd/xaeronet/volume/train.py b/examples/cfd/xaeronet/volume/train.py new file mode 100644 index 0000000000..86a05579f6 --- /dev/null +++ b/examples/cfd/xaeronet/volume/train.py @@ -0,0 +1,388 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code defines a distributed training pipeline for training UNet at scale, +which operates on partitioned voxel girds for the AWS drivaer dataset. It includes +loading voxels grids from h5 files, partitioning them, normalizing node and edge features using +precomputed statistics, and training the model in parallel using DistributedDataParallel +across multiple GPUs. The training loop involves computing predictions for each +partition, calculating loss, and updating model parameters using mixed precision. +Periodic checkpointing is performed to save the model, optimizer state, and training +progress. Validation is also conducted every few epochs, where predictions are compared +against ground truth values, and results are saved as point clouds. The code logs training +and validation metrics to TensorBoard and optionally integrates with Weights and Biases for +experiment tracking. +""" + +import os +import sys +import pyvista as pv +import torch +import numpy as np +import torch.optim as optim +import matplotlib.pyplot as plt +from modulus.launch.logging import initialize_wandb +import json +import wandb as wb +import hydra + +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel +from modulus.distributed import DistributedManager +from modulus.models.unet import UNet +from torch.cuda.amp import GradScaler +from torch.utils.tensorboard import SummaryWriter +from hydra.utils import to_absolute_path +from omegaconf import DictConfig + +from dataloader import create_dataloader +from partition import parallel_partitioning + +# Get the absolute path to the parent directory +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.append(parent_dir) + +from utils import ( + find_h5_files, + save_checkpoint, + load_checkpoint, + count_trainable_params, + calculate_continuity_loss, +) + + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: + + # Enable cuDNN auto-tuner + torch.backends.cudnn.benchmark = cfg.enable_cudnn_benchmark + + # Instantiate the distributed manager + DistributedManager.initialize() + dist = DistributedManager() + device = dist.device + print(f"Rank {dist.rank} of {dist.world_size}") + + # Instantiate the writers + if dist.rank == 0: + writer = SummaryWriter(log_dir="tensorboard") + initialize_wandb( + project="aws_drivaer", + entity="Modulus", + name="aws_drivaer", + mode="disabled", + group="group", + save_code=True, + ) + + # AMP Configs + amp_dtype = torch.float16 # UNet does not work with bfloat16 + amp_device = "cuda" + + # Find all .h5 files in the directory + train_dataset = find_h5_files(to_absolute_path(cfg.h5_path)) + valid_dataset = find_h5_files(to_absolute_path(cfg.validation_h5_path)) + + # Prepare the stats + with open(to_absolute_path(cfg.stats_file), "r") as f: + stats = json.load(f) + mean = np.array(stats["mean"]) + std = np.array(stats["std_dev"]) + mean_tensor = torch.from_numpy(mean).to(device) + std_tensor = torch.from_numpy(std).to(device) + + # Create DataLoader + sampler = DistributedSampler( + train_dataset, + num_replicas=dist.world_size, + rank=dist.rank, + shuffle=True, + drop_last=True, + ) + train_dataloader = create_dataloader( + train_dataset, mean, std, batch_size=1, num_workers=1, sampler=sampler + ) + valid_dataloader = create_dataloader( + valid_dataset, mean, std, batch_size=1, num_workers=1, sampler=None + ) + print(f"Training dataset size: {len(train_dataloader)*dist.world_size}") + print(f"Validation dataset size: {len(valid_dataloader)}") + + # Partitioning + print("Partitioning started") + data, filter, phys_filter, _ = parallel_partitioning( + train_dataloader, + num_partitions=cfg.num_partitions, + partition_width=cfg.partition_width, + halo_width=cfg.halo_width, + ) + vdata, vfilter, _, vbatch = parallel_partitioning( + valid_dataloader, + num_partitions=cfg.num_partitions, + partition_width=cfg.partition_width, + halo_width=cfg.halo_width, + ) + print("Partitioning completed") + + ###################################### + # Training # + ###################################### + + # Initialize model, loss function, and optimizer + h = cfg.initial_hidden_dim + model = UNet( + in_channels=25, + out_channels=4, + model_depth=3, + feature_map_channels=[h, h, 2 * h, 2 * h, 8 * h, 8 * h], + num_conv_blocks=2, + kernel_size=3, + stride=1, + conv_activation=cfg.activation, + padding=1, + padding_mode="replicate", + pooling_type="MaxPool3d", + pool_size=2, + normalization="layernorm", + use_attn_gate=cfg.use_attn_gate, + attn_decoder_feature_maps=[8 * h, 2 * h], + attn_feature_map_channels=[2 * h, h], + attn_intermediate_channels=cfg.attn_intermediate_channels, + gradient_checkpointing=cfg.gradient_checkpointing, + ).to(device) + print(f"Number of trainable parameters: {count_trainable_params(model)}") + + # DistributedDataParallel wrapper + if dist.world_size > 1: + model = DistributedDataParallel( + model, + device_ids=[dist.local_rank], + output_device=dist.device, + broadcast_buffers=dist.broadcast_buffers, + find_unused_parameters=dist.find_unused_parameters, + gradient_as_bucket_view=True, + static_graph=True, + ) + + # Optimizer and scheduler + optimizer = optim.Adam(model.parameters(), lr=cfg.start_lr) + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=cfg.num_epochs, eta_min=cfg.end_lr + ) + scaler = GradScaler() + print("Instantiated the model and optimizer") + + # Check if there's a checkpoint to resume from + start_epoch, _ = load_checkpoint( + model, optimizer, scaler, scheduler, cfg.checkpoint_filename + ) + + # Training loop + print("Training started") + for epoch in range(start_epoch, cfg.num_epochs): + model.train() + total_loss_data = 0 + total_loss_continuity = 0 + for i in range(len(data)): + optimizer.zero_grad() + + for idx, part in enumerate( + data[i] + ): # (x, y, z, sdf, dsdf_dx, dsdf_dy, dsdf_dz, u_x, u_y, u_z, p) + with torch.autocast(amp_device, enabled=True, dtype=amp_dtype): + part = part.to(device) + inp = torch.cat( + [ + part[:, 0:7], + torch.sin(np.pi * part[:, 0:3]), + torch.cos(np.pi * part[:, 0:3]), + torch.sin(2 * np.pi * part[:, 0:3]), + torch.cos(2 * np.pi * part[:, 0:3]), + torch.sin(4 * np.pi * part[:, 0:3]), + torch.cos(4 * np.pi * part[:, 0:3]), + ], + dim=1, + ) + pred = model(inp) + pred_filtered = pred[:, :, list(filter[i][idx])] + data_filtered = part[:, 7:, list(filter[i][idx])] + sdf_filtered = part[:, 3, list(filter[i][idx])] + sdf_filtered_denormalized = sdf_filtered * std[3] + mean[3] + mask = (sdf_filtered_denormalized > 0).squeeze() + pred_masked = pred_filtered[:, :, mask] + data_masked = data_filtered[:, :, mask] + loss_data = torch.mean((pred_masked - data_masked) ** 2) / len( + data[0] + ) + total_loss_data += loss_data.item() + pred_phys_filtered = pred[:, :, list(phys_filter[i][idx])] + pred_phys_filtered_denormalized = ( + pred_phys_filtered[:, 0:3] + * std_tensor[None, 7:10, None, None, None] + + mean_tensor[None, 7:10, None, None, None] + ) + sdf_phys_filtered = part[:, 3, list(phys_filter[i][idx])] + sdf_phys_filtered_denormalized = ( + sdf_phys_filtered * std[3] + mean[3] + ) + loss_continuity = calculate_continuity_loss( + pred_phys_filtered_denormalized, sdf_phys_filtered_denormalized + ) / len(data[0]) + total_loss_continuity += loss_continuity.item() + loss = loss_data * cfg.continuity_lambda * loss_continuity + scaler.scale(loss).backward() + + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), 32.0) + scaler.step(optimizer) + scaler.update() + # Update scheduler after each epoch + scheduler.step() + + if dist.rank == 0: + current_lr = optimizer.param_groups[0]["lr"] + print( + f"Epoch {epoch+1}, Learning Rate: {current_lr}, Data Loss: {total_loss_data/len(data)}, Continuity Loss: {cfg.continuity_lambda * total_loss_continuity/len(data)}, Total Loss: {(total_loss_data + cfg.continuity_lambda * total_loss_continuity) / len(data)}" + ) + writer.add_scalar("training_data_loss", total_loss_data / len(data), epoch) + writer.add_scalar( + "training_continuity_loss", + cfg.continuity_lambda * total_loss_continuity / len(data), + epoch, + ) + writer.add_scalar( + "total_loss", + (total_loss_data + cfg.continuity_lambda * total_loss_continuity) + / len(data), + epoch, + ) + writer.add_scalar("learning_rate", current_lr, epoch) + # wb.log({"training loss": total_loss / len(data), "learning_rate": current_lr}, step=epoch) + + # Save checkpoint periodically + if (epoch) % cfg.save_checkpoint_frequency == 0: + if dist.world_size > 1: + torch.distributed.barrier() + if dist.rank == 0: + save_checkpoint( + model, + optimizer, + scaler, + scheduler, + epoch + 1, + loss.item(), + cfg.checkpoint_filename, + ) + + ###################################### + # Validation # + ###################################### + + if dist.rank == 0 and (epoch) % cfg.validation_freq == 0: + with torch.no_grad(): + with torch.autocast(amp_device, enabled=True, dtype=amp_dtype): + valid_loss = 0 + for i in range(len(vdata)): + pred_list = [] + for idx, part in enumerate(vdata[i]): + part = part.to(device) + inp = torch.cat( + [ + part[:, 0:7], + torch.sin(np.pi * part[:, 0:3]), + torch.cos(np.pi * part[:, 0:3]), + torch.sin(2 * np.pi * part[:, 0:3]), + torch.cos(2 * np.pi * part[:, 0:3]), + torch.sin(4 * np.pi * part[:, 0:3]), + torch.cos(4 * np.pi * part[:, 0:3]), + ], + dim=1, + ) + pred = model(inp) + pred_filtered = pred[:, :, list(vfilter[i][idx])] + data_filtered = part[:, 7:, list(vfilter[i][idx])] + sdf_filtered = part[:, 3, list(vfilter[i][idx])] + sdf_filtered_denormalized = sdf_filtered * std[3] + mean[3] + mask = (sdf_filtered_denormalized > 0).squeeze() + pred_masked = pred_filtered[:, :, mask] + data_masked = data_filtered[:, :, mask] + pred_list.append(pred_filtered) + pred = torch.cat(pred_list, dim=2) + err = torch.mean((pred_masked - data_masked) ** 2) / len( + vdata[0] + ) + valid_loss += err + pred = pred.to(torch.float32).cpu().numpy() + print(f"Epoch {epoch+1}, Validation Error: {valid_loss/len(vdata)}") + writer.add_scalar("validation_loss", valid_loss / len(vdata), epoch) + wb.log({"Validation Error": valid_loss / len(vdata)}, step=epoch) + + # Define the dimensions and grid spacing + x_dim, y_dim, z_dim = cfg.num_voxels_x, cfg.num_voxels_y, cfg.num_voxels_z + dims = np.array( + [x_dim, y_dim, z_dim] + ) # The number of voxels in each direction + spacing = (cfg.spacing, cfg.spacing, cfg.spacing) # Grid spacing + + # Create a uniform grid + grid = pv.ImageData() + grid.dimensions = dims + grid.spacing = spacing # Spacing between grid points + cbatch = vbatch[-1].to(device) + cbatch = cbatch.clone().cpu().numpy() + grid.origin = (cfg.grid_origin_x, cfg.grid_origin_y, cfg.grid_origin_z) + + # Add the scalar data to the grid (flatten the array as point data) + # TODO denormalize the data + grid.point_data["p"] = pred[:, -1].squeeze().flatten(order="F") + grid.point_data["true_p"] = cbatch[:, -1].squeeze().flatten(order="F") + grid.point_data["u_z"] = pred[:, -2].squeeze().flatten(order="F") + grid.point_data["true_u_z"] = cbatch[:, -2].squeeze().flatten(order="F") + grid.point_data["u_y"] = pred[:, -3].squeeze().flatten(order="F") + grid.point_data["true_u_y"] = cbatch[:, -3].squeeze().flatten(order="F") + grid.point_data["u_x"] = pred[:, -4].squeeze().flatten(order="F") + grid.point_data["true_u_x"] = cbatch[:, -4].squeeze().flatten(order="F") + grid.point_data["dsdf_dz"] = cbatch[:, -5].squeeze().flatten(order="F") + grid.point_data["dsdf_dy"] = cbatch[:, -6].squeeze().flatten(order="F") + grid.point_data["dsdf_dx"] = cbatch[:, -7].squeeze().flatten(order="F") + grid.point_data["sdf"] = cbatch[:, -8].squeeze().flatten(order="F") + grid.point_data["z"] = cbatch[:, -9].squeeze().flatten(order="F") + grid.point_data["y"] = cbatch[:, -10].squeeze().flatten(order="F") + grid.point_data["x"] = cbatch[:, -11].squeeze().flatten(order="F") + + # Save the grid to a .vti file + grid.save("output.vti") + print("Saved the vti file") + + # Save final checkpoint + if dist.world_size > 1: + torch.distributed.barrier() + if dist.rank == 0: + save_checkpoint( + model, + optimizer, + scaler, + scheduler, + cfg.num_epochs, + loss.item(), + "final_model_checkpoint.pth", + ) + print("Training complete") + + +if __name__ == "__main__": + main() diff --git a/modulus/datapipes/cae/readers.py b/modulus/datapipes/cae/readers.py index 1c6abc07ec..b083f2e50d 100644 --- a/modulus/datapipes/cae/readers.py +++ b/modulus/datapipes/cae/readers.py @@ -132,6 +132,43 @@ def read_cgns(file_path: str) -> Any: return _extract_unstructured_grid(multi_block) +def read_stl(file_path: str) -> vtk.vtkPolyData: + """ + Read an STL file and return the polydata. + + Parameters + ---------- + file_path : str + Path to the STL file. + + Returns + ------- + vtkPolyData + The polydata read from the STL file. + """ + # Check if file exists + if not os.path.exists(file_path): + raise FileNotFoundError(f"{file_path} does not exist.") + + # Check if file has .stl extension + if not file_path.endswith(".stl"): + raise ValueError(f"Expected a .stl file, got {file_path}") + + # Create an STL reader + reader = vtk.vtkSTLReader() + reader.SetFileName(file_path) + reader.Update() + + # Get the polydata + polydata = reader.GetOutput() + + # Check if polydata is valid + if polydata is None: + raise ValueError(f"Failed to read polydata from {file_path}") + + return polydata + + def _extract_unstructured_grid( multi_block: vtk.vtkMultiBlockDataSet, ) -> vtk.vtkUnstructuredGrid: diff --git a/modulus/models/unet/unet.py b/modulus/models/unet/unet.py index dd1ea30e61..b437c31f00 100644 --- a/modulus/models/unet/unet.py +++ b/modulus/models/unet/unet.py @@ -20,9 +20,45 @@ import torch import torch.nn as nn import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from transformer_engine import pytorch as te -from ..meta import ModelMetaData -from ..module import Module +from modulus.models.meta import ModelMetaData +from modulus.models.module import Module + + +class ReshapedLayerNorm(te.LayerNorm): + + """ + A modified LayerNorm that reshapes and transposes the input tensor before + applying layer normalization, then restores the original shape after normalization. + + This is useful when layer normalization is required over multiple dimensions + while preserving the original spatial structure of the input. + + Parameters: + ---------- + normalized_shape (int or list/tuple of ints): Input shape from an expected input of size. If a single integer is used, + it is treated as a singleton list. + eps (float, optional): A value added to the denominator for numerical stability. Default is 1e-5. + elementwise_affine (bool, optional): Whether to learn affine parameters (scale and shift). Default is True. + + Returns: + ------- + torch.Tensor: The input tensor after applying reshaped layer normalization. + """ + + def __init__( + self, normalized_shape, eps: float = 1e-5, elementwise_affine: bool = True + ): + super().__init__(normalized_shape, eps, elementwise_affine) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shape = x.shape + x = x.view(shape[0], shape[1], -1).transpose(1, 2).contiguous() + x = super().forward(x) + x = x.transpose(1, 2).contiguous().view(shape) + return x class ConvBlock(nn.Module): @@ -98,6 +134,8 @@ def __init__( self.norm = nn.GroupNorm(**norm_args) elif normalization == "batchnorm": self.norm = nn.BatchNorm3d(out_channels) + elif normalization == "layernorm": + self.norm = ReshapedLayerNorm(out_channels) else: raise ValueError( f"Normalization type '{normalization}' is not supported." @@ -188,6 +226,8 @@ def __init__( self.norm = nn.GroupNorm(**norm_args) elif normalization == "batchnorm": self.norm = nn.BatchNorm3d(out_channels) + elif normalization == "layernorm": + self.norm = ReshapedLayerNorm(out_channels) else: raise ValueError( f"Normalization type '{normalization}' is not supported." @@ -261,6 +301,52 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.pooling(x) +class AttentionBlock(nn.Module): + """ + Attention block for the skip connections using LayerNorm instead of BatchNorm. + + Parameters: + ---------- + F_g (int): Number of channels in the decoder's features (query). + F_l (int): Number of channels in the encoder's features (key/value). + F_int (int): Number of intermediate channels (reduction in feature maps before attention computation). + + Returns: + ------- + torch.Tensor: The attended skip feature map. + """ + + def __init__(self, F_g, F_l, F_int): + super().__init__() + # The attention mechanism reduces the feature maps to F_int channels + self.W_g = nn.Sequential( + nn.Conv3d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), + ReshapedLayerNorm(F_int), + ) + + self.W_x = nn.Sequential( + nn.Conv3d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), + ReshapedLayerNorm(F_int), + ) + + self.psi = nn.Sequential( + nn.Conv3d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), + ReshapedLayerNorm(1), + nn.Sigmoid(), + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, g, x): + # g is the decoder's upsampled features (query) + # x is the encoder's skip connection features (key/value) + g1 = self.W_g(g) + x1 = self.W_x(x) + psi = self.relu(g1 + x1) + psi = self.psi(psi) + return x * psi # element-wise multiplication with attention mask + + class EncoderBlock(nn.Module): """ An encoder block that sequentially applies multiple convolutional blocks followed by a pooling operation, aggregating features at multiple scales. @@ -284,11 +370,17 @@ def __init__( self, in_channels: int, feature_map_channels: List[int], + kernel_size: Union[int, tuple] = 3, + stride: Union[int, tuple] = 1, model_depth: int = 4, num_conv_blocks: int = 2, activation: Optional[str] = "relu", + padding: int = 1, + padding_mode: str = "zeros", pooling_type: str = "AvgPool3d", pool_size: int = 2, + normalization: Optional[str] = "groupnorm", + normalization_args: Optional[dict] = None, ): super().__init__() @@ -306,7 +398,13 @@ def __init__( ConvBlock( in_channels=current_channels, out_channels=feature_map_channels[depth * num_conv_blocks + i], + kernel_size=kernel_size, + stride=stride, + padding=padding, + padding_mode=padding_mode, activation=activation, + normalization=normalization, + normalization_args=normalization_args, ) ) current_channels = feature_map_channels[depth * num_conv_blocks + i] @@ -346,10 +444,16 @@ def __init__( self, out_channels: int, feature_map_channels: List[int], + kernel_size: Union[int, tuple] = 3, + stride: Union[int, tuple] = 1, model_depth: int = 3, num_conv_blocks: int = 2, conv_activation: Optional[str] = "relu", conv_transpose_activation: Optional[str] = None, + padding: int = 1, + padding_mode: str = "zeros", + normalization: Optional[str] = "groupnorm", + normalization_args: Optional[dict] = None, ): super().__init__() @@ -380,7 +484,13 @@ def __init__( ConvBlock( in_channels=current_channels, out_channels=feature_map_channels[depth * num_conv_blocks + i], + kernel_size=kernel_size, + stride=stride, + padding=padding, + padding_mode=padding_mode, activation=conv_activation, + normalization=normalization, + normalization_args=normalization_args, ) ) current_channels = feature_map_channels[depth * num_conv_blocks + i] @@ -390,6 +500,10 @@ def __init__( ConvBlock( in_channels=current_channels, out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + padding_mode=padding_mode, activation=None, normalization=None, ) @@ -444,6 +558,8 @@ def __init__( self, in_channels: int, out_channels: int, + kernel_size: Union[int, tuple] = 3, + stride: Union[int, tuple] = 1, model_depth: int = 5, feature_map_channels: List[int] = [ 64, @@ -460,20 +576,37 @@ def __init__( num_conv_blocks: int = 2, conv_activation: Optional[str] = "relu", conv_transpose_activation: Optional[str] = None, + padding: int = 1, + padding_mode: str = "zeros", pooling_type: str = "MaxPool3d", pool_size: int = 2, + normalization: Optional[str] = "groupnorm", + normalization_args: Optional[dict] = None, + use_attn_gate: bool = False, + attn_decoder_feature_maps=None, + attn_feature_map_channels=None, + attn_intermediate_channels=None, + gradient_checkpointing: bool = True, ): super().__init__(meta=MetaData()) + self.use_attn_gate = use_attn_gate + self.gradient_checkpointing = gradient_checkpointing # Construct the encoder self.encoder = EncoderBlock( in_channels=in_channels, feature_map_channels=feature_map_channels, + kernel_size=kernel_size, + stride=stride, model_depth=model_depth, num_conv_blocks=num_conv_blocks, activation=conv_activation, + padding=padding, + padding_mode=padding_mode, pooling_type=pooling_type, pool_size=pool_size, + normalization=normalization, + normalization_args=normalization_args, ) # Construct the decoder @@ -483,30 +616,62 @@ def __init__( self.decoder = DecoderBlock( out_channels=out_channels, feature_map_channels=decoder_feature_maps, + kernel_size=kernel_size, + stride=stride, model_depth=model_depth - 1, num_conv_blocks=num_conv_blocks, conv_activation=conv_activation, conv_transpose_activation=conv_transpose_activation, + padding=padding, + padding_mode=padding_mode, + normalization=normalization, + normalization_args=normalization_args, ) + # Initialize attention blocks for each skip connection + if self.use_attn_gate: + self.attention_blocks = nn.ModuleList( + [ + AttentionBlock( + F_g=attn_decoder_feature_maps[i], + F_l=attn_feature_map_channels[i], + F_int=attn_intermediate_channels, + ) + for i in range(model_depth - 1) + ] + ) + + def checkpointed_forward(self, layer, x): + """Wrapper to apply gradient checkpointing if enabled.""" + if self.gradient_checkpointing: + return checkpoint.checkpoint(layer, x, use_reentrant=False) + return layer(x) + def forward(self, x: torch.Tensor) -> torch.Tensor: skip_features = [] # Encoding path for layer in self.encoder.layers: if isinstance(layer, Pool3d): skip_features.append(x) - x = layer(x) + # Apply checkpointing if enabled + x = self.checkpointed_forward(layer, x) # Decoding path - skip_features = skip_features[::-1] # Reverse - concats = 0 # keep track of the number of concats + skip_features = skip_features[::-1] # Reverse the skip features + concats = 0 # Track number of concats for layer in self.decoder.layers: if isinstance(layer, ConvTranspose): - x = layer(x) - x = torch.cat([x, skip_features[concats]], dim=1) + x = self.checkpointed_forward(layer, x) + if self.use_attn_gate: + # Apply attention to the skip connection + skip_att = self.attention_blocks[concats](x, skip_features[concats]) + x = torch.cat([x, skip_att], dim=1) + else: + x = torch.cat([x, skip_features[concats]], dim=1) concats += 1 else: - x = layer(x) + # Apply checkpointing for other layers + x = self.checkpointed_forward(layer, x) return x