From 93ce565c9c0740cf711c0314874a5ebbd7983627 Mon Sep 17 00:00:00 2001 From: allaffa Date: Fri, 20 Sep 2024 23:31:14 -0400 Subject: [PATCH] jarvis-dft example completed --- examples/jarvis_dft/download_dataset.sh | 12 + .../generate_dictionaries_pure_elements.py | 250 ++++++++ examples/jarvis_dft/jarvis_energy.json | 58 ++ examples/jarvis_dft/train.py | 571 ++++++++++++++++++ examples/jarvis_dft/uncompress_folder.py | 39 ++ 5 files changed, 930 insertions(+) create mode 100644 examples/jarvis_dft/download_dataset.sh create mode 100644 examples/jarvis_dft/generate_dictionaries_pure_elements.py create mode 100644 examples/jarvis_dft/jarvis_energy.json create mode 100644 examples/jarvis_dft/train.py create mode 100644 examples/jarvis_dft/uncompress_folder.py diff --git a/examples/jarvis_dft/download_dataset.sh b/examples/jarvis_dft/download_dataset.sh new file mode 100644 index 000000000..81b8aa686 --- /dev/null +++ b/examples/jarvis_dft/download_dataset.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +# URL to download the zip file from +URL="https://figshare.com/ndownloader/articles/6815699/versions/10" + +# Directory where the file will be saved +OUTPUT_DIR="dataset/JARVIS-DFT" + +mkdir -p "$OUTPUT_DIR" + +# Use curl to follow redirects and download the file +curl -L -o "$OUTPUT_DIR/6815699.zip" "$URL" diff --git a/examples/jarvis_dft/generate_dictionaries_pure_elements.py b/examples/jarvis_dft/generate_dictionaries_pure_elements.py new file mode 100644 index 000000000..09f131a5b --- /dev/null +++ b/examples/jarvis_dft/generate_dictionaries_pure_elements.py @@ -0,0 +1,250 @@ +def generate_dictionary_bulk_energies(): + + energy_bulk_metal = { + "H": 0.0, + "He": 0.0, + "Li": 0.0, + "Be": 0.0, + "B": 0.0, + "C": 0.0, + "N": 0.0, + "O": 0.0, + "F": 0.0, + "Ne": 0.0, + "Na": 0.0, + "Mg": 0.0, + "Al": 0.0, + "Si": 0.0, + "P": 0.0, + "S": 0.0, + "Cl": 0.0, + "Ar": 0.0, + "K": 0.0, + "Ca": 0.0, + "Sc": 0.0, + "Ti": 0.0, + "V": 0.0, + "Cr": 0.0, + "Mn": 0.0, + "Fe": 0.0, + "Co": 0.0, + "Ni": 0.0, + "Cu": 0.0, + "Zn": 0.0, + "Ga": 0.0, + "Ge": 0.0, + "As": 0.0, + "Se": 0.0, + "Br": 0.0, + "Kr": 0.0, + "Rb": 0.0, + "Sr": 0.0, + "Y": 0.0, + "Zr": 0.0, + "Nb": 0.0, + "Mo": 0.0, + "Tc": 0.0, + "Ru": 0.0, + "Rh": 0.0, + "Pd": 0.0, + "Ag": 0.0, + "Cd": 0.0, + "In": 0.0, + "Sn": 0.0, + "Sb": 0.0, + "Te": 0.0, + "I": 0.0, + "Xe": 0.0, + "Cs": 0.0, + "Ba": 0.0, + "La": 0.0, + "Ce": 0.0, + "Pr": 0.0, + "Nd": 0.0, + "Pm": 0.0, + "Sm": 0.0, + "Eu": 0.0, + "Gd": 0.0, + "Tb": 0.0, + "Dy": 0.0, + "Ho": 0.0, + "Er": 0.0, + "Tm": 0.0, + "Yb": 0.0, + "Lu": 0.0, + "Hf": 0.0, + "Ta": 0.0, + "W": 0.0, + "Re": 0.0, + "Os": 0.0, + "Ir": 0.0, + "Pt": 0.0, + "Au": 0.0, + "Hg": 0.0, + "Tl": 0.0, + "Pb": 0.0, + "Bi": 0.0, + "Po": 0.0, + "At": 0.0, + "Rn": 0.0, + "Fr": 0.0, + "Ra": 0.0, + "Ac": 0.0, + "Th": 0.0, + "Pa": 0.0, + "U": 0.0, + "Np": 0.0, + "Pu": 0.0, + "Am": 0.0, + "Cm": 0.0, + "Bk": 0.0, + "Cf": 0.0, + "Es": 0.0, + "Fm": 0.0, + "Md": 0.0, + "No": 0.0, + "Lr": 0.0, + "Rf": 0.0, + "Db": 0.0, + "Sg": 0.0, + "Bh": 0.0, + "Hs": 0.0, + "Mt": 0.0, + "Ds": 0.0, + "Rg": 0.0, + "Cn": 0.0, + "Nh": 0.0, + "Fl": 0.0, + "Mc": 0.0, + "Lv": 0.0, + "Ts": 0.0, + "Og": 0.0, + } + + return energy_bulk_metal + + +def generate_dictionary_elements(): + + periodic_table_atomic_numbers = { + 1: "H", + 2: "He", + 3: "Li", + 4: "Be", + 5: "B", + 6: "C", + 7: "N", + 8: "O", + 9: "F", + 10: "Ne", + 11: "Na", + 12: "Mg", + 13: "Al", + 14: "Si", + 15: "P", + 16: "S", + 17: "Cl", + 18: "Ar", + 19: "K", + 20: "Ca", + 21: "Sc", + 22: "Ti", + 23: "V", + 24: "Cr", + 25: "Mn", + 26: "Fe", + 27: "Co", + 28: "Ni", + 29: "Cu", + 30: "Zn", + 31: "Ga", + 32: "Ge", + 33: "As", + 34: "Se", + 35: "Br", + 36: "Kr", + 37: "Rb", + 38: "Sr", + 39: "Y", + 40: "Zr", + 41: "Nb", + 42: "Mo", + 43: "Tc", + 44: "Ru", + 45: "Rh", + 46: "Pd", + 47: "Ag", + 48: "Cd", + 49: "In", + 50: "Sn", + 51: "Sb", + 52: "Te", + 53: "I", + 54: "Xe", + 55: "Cs", + 56: "Ba", + 57: "La", + 58: "Ce", + 59: "Pr", + 60: "Nd", + 61: "Pm", + 62: "Sm", + 63: "Eu", + 64: "Gd", + 65: "Tb", + 66: "Dy", + 67: "Ho", + 68: "Er", + 69: "Tm", + 70: "Yb", + 71: "Lu", + 72: "Hf", + 73: "Ta", + 74: "W", + 75: "Re", + 76: "Os", + 77: "Ir", + 78: "Pt", + 79: "Au", + 80: "Hg", + 81: "Tl", + 82: "Pb", + 83: "Bi", + 84: "Po", + 85: "At", + 86: "Rn", + 87: "Fr", + 88: "Ra", + 89: "Ac", + 90: "Th", + 91: "Pa", + 92: "U", + 93: "Np", + 94: "Pu", + 95: "Am", + 96: "Cm", + 97: "Bk", + 98: "Cf", + 99: "Es", + 100: "Fm", + 101: "Md", + 102: "No", + 103: "Lr", + 104: "Rf", + 105: "Db", + 106: "Sg", + 107: "Bh", + 108: "Hs", + 109: "Mt", + 110: "Ds", + 111: "Rg", + 112: "Cn", + 113: "Nh", + 114: "Fl", + 115: "Mc", + 116: "Lv", + 117: "Ts", + 118: "Og", + } + + return periodic_table_atomic_numbers diff --git a/examples/jarvis_dft/jarvis_energy.json b/examples/jarvis_dft/jarvis_energy.json new file mode 100644 index 000000000..1357a5e65 --- /dev/null +++ b/examples/jarvis_dft/jarvis_energy.json @@ -0,0 +1,58 @@ +{ + "Verbosity": { + "level": 2 + }, + "NeuralNetwork": { + "Architecture": { + "model_type": "EGNN", + "equivariance": true, + "radius": 5.0, + "max_neighbours": 100000, + "num_gaussians": 50, + "envelope_exponent": 5, + "int_emb_size": 64, + "basis_emb_size": 8, + "out_emb_size": 128, + "num_after_skip": 2, + "num_before_skip": 1, + "num_radial": 6, + "num_spherical": 7, + "num_filters": 126, + "edge_features": ["length"], + "hidden_dim": 50, + "num_conv_layers": 3, + "output_heads": { + "graph":{ + "num_sharedlayers": 2, + "dim_sharedlayers": 50, + "num_headlayers": 2, + "dim_headlayers": [50,25] + } + }, + "task_weights": [1.0] + }, + "Variables_of_interest": { + "input_node_features": [0, 1, 2, 3], + "output_names": ["energy"], + "output_index": [0], + "output_dim": [1], + "type": ["graph"] + }, + "Training": { + "num_epoch": 50, + "perc_train": 0.8, + "loss_function_type": "mae", + "batch_size": 32, + "continue": 0, + "Optimizer": { + "type": "AdamW", + "learning_rate": 1e-3 + } + } + }, + "Visualization": { + "plot_init_solution": true, + "plot_hist_solution": false, + "create_plots": true + } +} \ No newline at end of file diff --git a/examples/jarvis_dft/train.py b/examples/jarvis_dft/train.py new file mode 100644 index 000000000..ac07eec6f --- /dev/null +++ b/examples/jarvis_dft/train.py @@ -0,0 +1,571 @@ +import bz2 + +import os, json +import logging +import sys +from mpi4py import MPI +import argparse + +import random +import numpy as np + +import torch +from torch_geometric.data import Data + +from torch_geometric.transforms import Distance, Spherical, LocalCartesian + +import hydragnn +from hydragnn.utils.time_utils import Timer +from hydragnn.utils.model import print_model +from hydragnn.utils.abstractbasedataset import AbstractBaseDataset +from hydragnn.utils.distdataset import DistDataset +from hydragnn.utils.pickledataset import SimplePickleWriter, SimplePickleDataset +from hydragnn.preprocess.utils import gather_deg +from hydragnn.preprocess.utils import RadiusGraph, RadiusGraphPBC +from hydragnn.preprocess.load_data import split_dataset + +import hydragnn.utils.tracer as tr +from hydragnn.utils.print_utils import iterate_tqdm, log + +from generate_dictionaries_pure_elements import ( + generate_dictionary_bulk_energies, + generate_dictionary_elements, +) + +try: + from hydragnn.utils.adiosdataset import AdiosWriter, AdiosDataset +except ImportError: + pass + +import subprocess +from hydragnn.utils import nsplit + + +def info(*args, logtype="info", sep=" "): + getattr(logging, logtype)(sep.join(map(str, args))) + + +periodic_table = generate_dictionary_elements() + +# Reversing the dictionary so the elements become keys and the atomic numbers become values +reversed_dict_periodic_table = {value: key for key, value in periodic_table.items()} + +# transform_coordinates = Spherical(norm=False, cat=False) +# transform_coordinates = LocalCartesian(norm=False, cat=False) +transform_coordinates = Distance(norm=False, cat=False) + + +class JARVIS_DFT(AbstractBaseDataset): + def __init__(self, dirpath, var_config, energy_per_atom=True, dist=False): + super().__init__() + + self.dist = dist + if self.dist: + assert torch.distributed.is_initialized() + self.world_size = torch.distributed.get_world_size() + self.rank = torch.distributed.get_rank() + + self.energy_per_atom = energy_per_atom + + self.radius_graph = RadiusGraph(5.0, loop=False, max_num_neighbors=50) + + subdirpath = os.path.join(dirpath, "JARVIS-DFT") + + # List all files in the directory + all_files = os.listdir(subdirpath) + + # Filter out the files that end with ".json" + json_files = [ + file + for file in all_files + if (file.startswith("jdft") and file.endswith(".json")) + ] + + for index in json_files: + self.process_file_content(os.path.join(subdirpath, index)) + + def get_data_dict(self, computed_entry_dict): + """ + Processes the ComputedStructureEntry dictionary to extract the structure, forces and magnetic moments + and other target properties. + """ + + data_object = None + + def get_forces_array_from_structure(structure): + forces = [site["properties"]["forces"] for site in structure["sites"]] + return np.array(forces) + + entry_id = computed_entry_dict["jid"] + + pos = None + cell = None + natoms = None + atomic_numbers = None + formation_energy_tensor = None + + if "atoms" in computed_entry_dict: + structure = computed_entry_dict["atoms"] + + try: + pos = torch.tensor(computed_entry_dict["atoms"]["coords"]).to( + torch.float32 + ) + assert ( + pos.shape[1] == 3 + ), "pos tensor does not have 3 coordinates per atom" + assert pos.shape[0] > 0, "pos tensor does not have any atoms" + except: + print(f"Structure {entry_id} does not have positional sites") + return data_object + natoms = torch.IntTensor([pos.shape[0]]) + + try: + cell = torch.tensor(structure["lattice_mat"]).to(torch.float32) + except: + print(f"Structure {entry_id} does not have cell") + return data_object + + try: + atomic_numbers = ( + torch.tensor( + [ + [ + reversed_dict_periodic_table[item] + for item in structure["elements"] + ] + ] + ) + .to(torch.float32) + .t() + ) + assert ( + pos.shape[0] == atomic_numbers.shape[0] + ), f"pos.shape[0]:{pos.shape[0]} does not match with atomic_numbers.shape[0]:{atomic_numbers.shape[0]}" + except: + print(f"Structure {entry_id} does not have positional atomic numbers") + return data_object + + try: + formation_energy_tensor = torch.tensor( + computed_entry_dict["formation_energy_peratom"] * natoms + ).unsqueeze(1) + except: + print(f"Structure {entry_id} does not have formation energy per atom") + return data_object + + formation_energy_per_atom_tensor = ( + torch.tensor(computed_entry_dict["formation_energy_peratom"]) + .unsqueeze(0) + .unsqueeze(1) + ) + + elif "final_str" in computed_entry_dict: + try: + pos = torch.tensor( + [item["xyz"] for item in computed_entry_dict["final_str"]["sites"]] + ) + assert ( + pos.shape[1] == 3 + ), "pos tensor does not have 3 coordinates per atom" + assert pos.shape[0] > 0, "pos tensor does not have any atoms" + except: + print(f"Structure {entry_id} does not have positional sites") + return data_object + + try: + atomic_numbers = torch.tensor( + [ + reversed_dict_periodic_table[item["species"][0]["element"]] + for item in computed_entry_dict["final_str"]["sites"] + ] + ).unsqueeze(1) + assert ( + pos.shape[0] == atomic_numbers.shape[0] + ), f"pos.shape[0]:{pos.shape[0]} does not match with atomic_numbers.shape[0]:{atomic_numbers.shape[0]}" + except: + return data_object + + try: + formation_energy_tensor = torch.tensor( + computed_entry_dict["final_str"]["formation_energy_peratom"] + * natoms + ).unsqueeze(1) + except: + print(f"Structure {entry_id} does not have formation energy per atom") + return data_object + + formation_energy_per_atom_tensor = ( + torch.tensor(computed_entry_dict["formation_energy_peratom"]) + .unsqueeze(0) + .unsqueeze(1) + ) + + # total_energy_tensor = None + # try: + # total_energy_tensor = torch.tensor(computed_entry_dict["optb88vdw_total_energy"]).unsqueeze(1) + # except: + # print(f"Structure {entry_id} does not have total energy per atom") + # return data_object + # total_energy_per_atom_tensor = torch.tensor(computed_entry_dict["optb88vdw_total_energy"]).unsqueeze(1)/natoms + + # band_gap = None + # try: + # band_gap=computed_entry_dict["optb88vdw_bandgap"] + # except: + # print(f"Structure {entry_id} does not have band_gap_ind") + # return data_object + # band_gap_tensor = torch.tensor(band_gap_ind).unsqueeze(1) + + # energy_above_hull = None + # try: + # energy_above_hull=computed_entry_dict["ehull"] + # except: + # print(f"Structure {entry_id} does not have e_above_hull") + # return data_object + # energy_above_hull_tensor = torch.tensor(energy_above_hull).unsqueeze(1) + + data_object = Data( + pos=pos, + cell=cell, + atomic_numbers=atomic_numbers, + # forces=forces, + # entry_id=entry_id, + natoms=natoms, + # total_energy=total_energy_tensor, + # total_energy_per_atom=total_energy_per_atom_tensor, + formation_energy=formation_energy_tensor.float(), + formation_energy_per_atom=formation_energy_per_atom_tensor.float(), + # energy_above_hull=energy_above_hull, + # magmoms=torch.tensor(magmoms_numpy).float(), + # total_mag=total_mag, + # dos_ef=dos_ef, + # band_gap_ind=band_gap_ind, + ) + + if self.energy_per_atom: + data_object.y = data_object.formation_energy_per_atom + else: + data_object.y = data_object.formation_energy + + data_object.x = torch.cat([data_object.atomic_numbers, data_object.pos], dim=1) + + data_object = self.radius_graph(data_object) + data_object = transform_coordinates(data_object) + + return data_object + + def process_file_content(self, filepath): + """ + Download a file from a dataset of the Alexandria database with the respective index + and write it to the LMDB file with the respective index. + + Parameters + ---------- + filepath : int + path of JSON file + """ + try: + with open(filepath, "r") as f: + data = json.load( + f + ) # Reads the JSON content from the file and converts it to a Python object + local_data = list(nsplit(data, self.world_size))[self.rank] + + computed_entry_dict = [ + self.get_data_dict(entry) + for entry in iterate_tqdm( + local_data, + desc=f"Processing file {filepath}", + verbosity_level=2, + ) + ] + + # remove None elements + filtered_computed_entry_dict = [ + x for x in computed_entry_dict if x is not None + ] + + random.shuffle(filtered_computed_entry_dict) + self.dataset.extend(filtered_computed_entry_dict) + except json.decoder.JSONDecodeError as e: + print(f"Error decoding JSON file: {filepath}, error: {str(e)}") + return None # Return None if the file is not a valid JSON + + def len(self): + return len(self.dataset) + + def get(self, idx): + return self.dataset[idx] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--sampling", type=float, help="sampling ratio", default=None) + parser.add_argument( + "--preonly", + action="store_true", + help="preprocess only (no training)", + ) + parser.add_argument( + "--inputfile", help="input file", type=str, default="jarvis_energy.json" + ) + parser.add_argument( + "--energy_per_atom", + help="option to normalize energy by number of atoms", + type=bool, + default=True, + ) + parser.add_argument("--ddstore", action="store_true", help="ddstore dataset") + parser.add_argument("--ddstore_width", type=int, help="ddstore width", default=None) + parser.add_argument("--shmem", action="store_true", help="shmem") + parser.add_argument("--log", help="log name") + parser.add_argument("--batch_size", type=int, help="batch_size", default=None) + parser.add_argument("--everyone", action="store_true", help="gptimer") + parser.add_argument("--modelname", help="model name") + group = parser.add_mutually_exclusive_group() + group.add_argument( + "--adios", + help="Adios dataset", + action="store_const", + dest="format", + const="adios", + ) + group.add_argument( + "--pickle", + help="Pickle dataset", + action="store_const", + dest="format", + const="pickle", + ) + parser.set_defaults(format="adios") + args = parser.parse_args() + + graph_feature_names = ["energy"] + graph_feature_dims = [1] + node_feature_names = ["atomic_number", "cartesian_coordinates", "forces"] + node_feature_dims = [1, 3, 3] + dirpwd = os.path.dirname(os.path.abspath(__file__)) + datadir = os.path.join(dirpwd, "dataset") + ################################################################################################################## + input_filename = os.path.join(dirpwd, args.inputfile) + ################################################################################################################## + # Configurable run choices (JSON file that accompanies this example script). + with open(input_filename, "r") as f: + config = json.load(f) + verbosity = config["Verbosity"]["level"] + var_config = config["NeuralNetwork"]["Variables_of_interest"] + var_config["graph_feature_names"] = graph_feature_names + var_config["graph_feature_dims"] = graph_feature_dims + var_config["node_feature_names"] = node_feature_names + var_config["node_feature_dims"] = node_feature_dims + + if args.batch_size is not None: + config["NeuralNetwork"]["Training"]["batch_size"] = args.batch_size + + ################################################################################################################## + # Always initialize for multi-rank training. + comm_size, rank = hydragnn.utils.setup_ddp() + ################################################################################################################## + + comm = MPI.COMM_WORLD + + ## Set up logging + logging.basicConfig( + level=logging.INFO, + format="%%(levelname)s (rank %d): %%(message)s" % (rank), + datefmt="%H:%M:%S", + ) + + log_name = "JARVIS_DFT" if args.log is None else args.log + hydragnn.utils.setup_log(log_name) + writer = hydragnn.utils.get_summary_writer(log_name) + + log("Command: {0}\n".format(" ".join([x for x in sys.argv])), rank=0) + + modelname = "JARVIS_DFT" if args.modelname is None else args.modelname + if args.preonly: + ## local data + total = JARVIS_DFT( + os.path.join(datadir), + var_config, + energy_per_atom=args.energy_per_atom, + dist=True, + ) + ## This is a local split + trainset, valset, testset = split_dataset( + dataset=total, + perc_train=0.9, + stratify_splitting=False, + ) + print(rank, "Local splitting: ", len(trainset), len(valset), len(testset)) + + deg = gather_deg(trainset) + config["pna_deg"] = deg + + setnames = ["trainset", "valset", "testset"] + + ## adios + if args.format == "adios": + fname = os.path.join( + os.path.dirname(__file__), "./dataset/%s.bp" % modelname + ) + adwriter = AdiosWriter(fname, comm) + adwriter.add("trainset", trainset) + adwriter.add("valset", valset) + adwriter.add("testset", testset) + # adwriter.add_global("minmax_node_feature", total.minmax_node_feature) + # adwriter.add_global("minmax_graph_feature", total.minmax_graph_feature) + adwriter.add_global("pna_deg", deg) + adwriter.save() + + ## pickle + elif args.format == "pickle": + basedir = os.path.join( + os.path.dirname(__file__), "dataset", "%s.pickle" % modelname + ) + attrs = dict() + attrs["pna_deg"] = deg + SimplePickleWriter( + trainset, + basedir, + "trainset", + # minmax_node_feature=total.minmax_node_feature, + # minmax_graph_feature=total.minmax_graph_feature, + use_subdir=True, + attrs=attrs, + ) + SimplePickleWriter( + valset, + basedir, + "valset", + # minmax_node_feature=total.minmax_node_feature, + # minmax_graph_feature=total.minmax_graph_feature, + use_subdir=True, + ) + SimplePickleWriter( + testset, + basedir, + "testset", + # minmax_node_feature=total.minmax_node_feature, + # minmax_graph_feature=total.minmax_graph_feature, + use_subdir=True, + ) + sys.exit(0) + + tr.initialize() + tr.disable() + timer = Timer("load_data") + timer.start() + + if args.format == "adios": + info("Adios load") + assert not (args.shmem and args.ddstore), "Cannot use both ddstore and shmem" + opt = { + "preload": False, + "shmem": args.shmem, + "ddstore": args.ddstore, + "ddstore_width": args.ddstore_width, + } + fname = os.path.join(os.path.dirname(__file__), "./dataset/%s.bp" % modelname) + trainset = AdiosDataset(fname, "trainset", comm, **opt, var_config=var_config) + valset = AdiosDataset(fname, "valset", comm, **opt, var_config=var_config) + testset = AdiosDataset(fname, "testset", comm, **opt, var_config=var_config) + elif args.format == "pickle": + info("Pickle load") + basedir = os.path.join( + os.path.dirname(__file__), "dataset", "%s.pickle" % modelname + ) + trainset = SimplePickleDataset( + basedir=basedir, label="trainset", var_config=var_config + ) + valset = SimplePickleDataset( + basedir=basedir, label="valset", var_config=var_config + ) + testset = SimplePickleDataset( + basedir=basedir, label="testset", var_config=var_config + ) + # minmax_node_feature = trainset.minmax_node_feature + # minmax_graph_feature = trainset.minmax_graph_feature + pna_deg = trainset.pna_deg + if args.ddstore: + opt = {"ddstore_width": args.ddstore_width} + trainset = DistDataset(trainset, "trainset", comm, **opt) + valset = DistDataset(valset, "valset", comm, **opt) + testset = DistDataset(testset, "testset", comm, **opt) + # trainset.minmax_node_feature = minmax_node_feature + # trainset.minmax_graph_feature = minmax_graph_feature + trainset.pna_deg = pna_deg + else: + raise NotImplementedError("No supported format: %s" % (args.format)) + + info( + "trainset,valset,testset size: %d %d %d" + % (len(trainset), len(valset), len(testset)) + ) + + if args.ddstore: + os.environ["HYDRAGNN_AGGR_BACKEND"] = "mpi" + os.environ["HYDRAGNN_USE_ddstore"] = "1" + + (train_loader, val_loader, test_loader,) = hydragnn.preprocess.create_dataloaders( + trainset, valset, testset, config["NeuralNetwork"]["Training"]["batch_size"] + ) + + config = hydragnn.utils.update_config(config, train_loader, val_loader, test_loader) + ## Good to sync with everyone right after DDStore setup + comm.Barrier() + + hydragnn.utils.save_config(config, log_name) + + timer.stop() + + model = hydragnn.models.create_model_config( + config=config["NeuralNetwork"], + verbosity=verbosity, + ) + model = hydragnn.utils.get_distributed_model(model, verbosity) + + # Print details of neural network architecture + print_model(model) + + learning_rate = config["NeuralNetwork"]["Training"]["Optimizer"]["learning_rate"] + optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=5, min_lr=0.00001 + ) + + hydragnn.utils.load_existing_model_config( + model, config["NeuralNetwork"]["Training"], optimizer=optimizer + ) + + ################################################################################################################## + + hydragnn.train.train_validate_test( + model, + optimizer, + train_loader, + val_loader, + test_loader, + writer, + scheduler, + config["NeuralNetwork"], + log_name, + verbosity, + create_plots=False, + ) + + hydragnn.utils.save_model(model, optimizer, log_name) + hydragnn.utils.print_timers(verbosity) + + if tr.has("GPTLTracer"): + import gptl4py as gp + + eligible = rank if args.everyone else 0 + if rank == eligible: + gp.pr_file(os.path.join("logs", log_name, "gp_timing.p%d" % rank)) + gp.pr_summary_file(os.path.join("logs", log_name, "gp_timing.summary")) + gp.finalize() + sys.exit(0) diff --git a/examples/jarvis_dft/uncompress_folder.py b/examples/jarvis_dft/uncompress_folder.py new file mode 100644 index 000000000..7e0019f35 --- /dev/null +++ b/examples/jarvis_dft/uncompress_folder.py @@ -0,0 +1,39 @@ +import zipfile +import os + +# Define the output directory and zip file path +output_dir = "dataset/JARVIS-DFT" +zip_file_path = os.path.join(output_dir, "6815699.zip") + +# Function to unzip a file and maintain the directory structure +def unzip_file(zip_file, extract_to): + # Check if the zip file exists + if os.path.exists(zip_file): + print(f"Extracting {zip_file} to {extract_to} ...") + with zipfile.ZipFile(zip_file, "r") as zip_ref: + zip_ref.extractall(extract_to) # Extract the current zip file + print(f"Extraction complete for {zip_file}.") + else: + print(f"Zip file {zip_file} not found!") + + +# Function to recursively unzip nested zip files while maintaining the original structure +def unzip_nested_zips(directory): + for root, dirs, files in os.walk(directory): + for file in files: + if file.endswith(".zip"): + file_path = os.path.join(root, file) + extract_dir = ( + root # Extract the zip file in the same directory it resides + ) + print(f"Found nested zip: {file_path}, extracting to {extract_dir}...") + unzip_file( + file_path, extract_dir + ) # Unzip the nested zip in the same folder + + +# Call the function to unzip the main zip file +unzip_file(zip_file_path, output_dir) + +# Call the function to recursively unzip nested zip files +unzip_nested_zips(output_dir)