diff --git a/README.md b/README.md index 9a6d8b0d..c7e91f5e 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ - [MACE-OFF: Transferable Organic Force Fields](#mace-off-transferable-organic-force-fields) - [Example usage in ASE](#example-usage-in-ase-1) - [Finetuning foundation models](#finetuning-foundation-models) + - [Latest recommended foundation models](#latest-recommended-foundation-models) - [Caching](#caching) - [Development](#development) - [References](#references) @@ -273,6 +274,14 @@ atoms.calc = calc print(atoms.get_potential_energy()) ``` +### Latest Recommended Foundation Models + +| Model Name | Elements Covered | Training Dataset | Level of Theory | Target System | Model Size | GitHub Release | Notes | License | +|-------------------|------------------|------------------|-----------------------|----------------------|---------------------|----------------|-------------------------------------------------------|---------| +| MACE-MP-0 | 89 | MPTrj | DFT (PBE+U) | Materials | [small](https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model), [medium](https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model), [large](https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2024-01-07-mace-128-L2_epoch-199.model)| >=v0.3.6 | Initial release of foundation model. | MIT | +| MACE-MPA-0 | 89 | MPTrj + sAlex | DFT (PBE+U) | Materials | [medium](https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model) | >=v0.3.10 | Improved accuracy for materials, improved high pressure stability. | MIT | +| MACE-OFF23 | 10 | SPICE v1 | DFT (wB97M+D3) | Organic Chemistry | [small](https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model), [medium](https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_medium.model), [large](https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_large.model)| >=v0.3.6 | Initial release covering neutral organic chemistry. | ASL | + ### Finetuning foundation models To finetune one of the mace-mp-0 foundation model, you can use the `mace_run_train` script with the extra argument `--foundation_model=model_type`. For example to finetune the small model on a new dataset, you can use: diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 1c0898b7..b2a9c7bc 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -380,6 +380,8 @@ def run(args: argparse.Namespace) -> None: for head_config in head_configs: all_atomic_numbers.update(head_config.atomic_numbers) z_table = AtomicNumberTable(sorted(list(all_atomic_numbers))) + if args.foundation_model_elements and model_foundation: + z_table = AtomicNumberTable(sorted(model_foundation.atomic_numbers.tolist())) logging.info(f"Atomic Numbers used: {z_table.zs}") # Atomic energies @@ -434,6 +436,16 @@ def run(args: argparse.Namespace) -> None: for z in z_table.zs } + # Padding atomic energies if keeping all elements of the foundation model + if args.foundation_model_elements and model_foundation: + atomic_energies_dict_padded = {} + for head_name, head_energies in atomic_energies_dict.items(): + energy_head_padded = {} + for z in z_table.zs: + energy_head_padded[z] = head_energies.get(z, 0.0) + atomic_energies_dict_padded[head_name] = energy_head_padded + atomic_energies_dict = atomic_energies_dict_padded + if args.model == "AtomicDipolesMACE": atomic_energies = None dipole_only = True @@ -634,6 +646,7 @@ def run(args: argparse.Namespace) -> None: distributed_model = DDP(model, device_ids=[local_rank]) else: distributed_model = None + print("MODEL", model) tools.train( model=model, loss_fn=loss_fn, diff --git a/mace/cli/select_head.py b/mace/cli/select_head.py index de0e6935..661bf466 100644 --- a/mace/cli/select_head.py +++ b/mace/cli/select_head.py @@ -34,6 +34,7 @@ def main(): args = parser.parse_args() model = torch.load(args.model_file) + torch.set_default_dtype(next(model.parameters()).dtype) if args.list_heads: print("Available heads:") diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index e4e90a10..7f67f38a 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -403,6 +403,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser: type=str, default=None, ) + parser.add_argument( + "--foundation_model_elements", + help="Keep all elements of the foundation model during fine-tuning", + type=str2bool, + default=False, + ) parser.add_argument( "--keep_isolated_atoms", help="Keep isolated atoms in the dataset, useful for transfer learning", diff --git a/mace/tools/model_script_utils.py b/mace/tools/model_script_utils.py index d937446c..92b61146 100644 --- a/mace/tools/model_script_utils.py +++ b/mace/tools/model_script_utils.py @@ -7,6 +7,7 @@ from mace import modules from mace.tools.finetuning_utils import load_foundations_elements from mace.tools.scripts_utils import extract_config_mace_model +from mace.tools.utils import AtomicNumberTable def configure_model( @@ -43,8 +44,22 @@ def configure_model( logging.info("Loading FOUNDATION model") model_config_foundation = extract_config_mace_model(model_foundation) model_config_foundation["atomic_energies"] = atomic_energies - model_config_foundation["atomic_numbers"] = z_table.zs - model_config_foundation["num_elements"] = len(z_table) + + if args.foundation_model_elements: + foundation_z_table = AtomicNumberTable( + [int(z) for z in model_foundation.atomic_numbers] + ) + model_config_foundation["atomic_numbers"] = foundation_z_table.zs + model_config_foundation["num_elements"] = len(foundation_z_table) + z_table = foundation_z_table + logging.info( + f"Using all elements from foundation model: {foundation_z_table.zs}" + ) + else: + model_config_foundation["atomic_numbers"] = z_table.zs + model_config_foundation["num_elements"] = len(z_table) + logging.info(f"Using filtered elements: {z_table.zs}") + args.max_L = model_config_foundation["hidden_irreps"].lmax if args.model == "MACE" and model_foundation.__class__.__name__ == "MACE": diff --git a/tests/test_run_train.py b/tests/test_run_train.py index 2b56c10b..db33ddf4 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -10,7 +10,7 @@ import torch from ase.atoms import Atoms -from mace.calculators.mace import MACECalculator +from mace.calculators import MACECalculator, mace_mp try: import cuequivariance as cue # pylint: disable=unused-import @@ -1051,3 +1051,248 @@ def test_run_train_foundation_multihead_json_cueq(tmp_path, fitting_configs): 0.5574042201042175, ] assert np.allclose(Es, ref_Es, atol=1e-1) + + +def test_run_train_foundation_elements(tmp_path, fitting_configs): + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + base_params = { + "name": "MACE", + "checkpoints_dir": str(tmp_path), + "model_dir": str(tmp_path), + "train_file": tmp_path / "fit.xyz", + "loss": "weighted", + "foundation_model": "small", + "hidden_irreps": "128x0e", + "r_max": 6.0, + "default_dtype": "float64", + "max_num_epochs": 5, + "num_radial_basis": 10, + "interaction_first": "RealAgnosticResidualInteractionBlock", + "multiheads_finetuning": False, + } + + # Run environment setup + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + + # First run: without foundation_model_elements (default behavior) + mace_params = base_params.copy() + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + # Load model and check elements + model_filtered = torch.load(tmp_path / "MACE.model", map_location="cpu") + filtered_elements = set(int(z) for z in model_filtered.atomic_numbers) + assert filtered_elements == {1, 8} # Only H and O should be present + + # Second run: with foundation_model_elements + mace_params = base_params.copy() + mace_params["name"] = "MACE_all_elements" + mace_params["foundation_model_elements"] = True # Flag-only argument + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + # Load model and check elements + model_all = torch.load(tmp_path / "MACE_all_elements.model", map_location="cpu") + all_elements = set(int(z) for z in model_all.atomic_numbers) + + # Get elements from foundation model for comparison + calc = mace_mp(model="small", device="cpu") + foundation_elements = set(int(z) for z in calc.models[0].atomic_numbers) + + # Check that all foundation model elements are preserved + assert all_elements == foundation_elements + assert len(all_elements) > len(filtered_elements) + + # Check that both models can make predictions + at = fitting_configs[2].copy() + + # Test filtered model + calc_filtered = MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" + ) + at.calc = calc_filtered + e1 = at.get_potential_energy() + + # Test all-elements model + calc_all = MACECalculator( + model_paths=tmp_path / "MACE_all_elements.model", + device="cpu", + default_dtype="float64", + ) + at.calc = calc_all + e2 = at.get_potential_energy() + + # Energies should be different since the models are trained differently, + # but both should give reasonable results + assert np.isfinite(e1) + assert np.isfinite(e2) + + +def test_run_train_foundation_elements_multihead(tmp_path, fitting_configs): + + fitting_configs_dft = [] + fitting_configs_mp2 = [] + for i, c in enumerate(fitting_configs): + if i in (0, 1): + c_dft = c.copy() + c_dft.info["head"] = "DFT" + fitting_configs_dft.append(c_dft) + c_mp2 = c.copy() + c_mp2.info["head"] = "MP2" + fitting_configs_mp2.append(c_mp2) + if i % 2 == 0: + c_copy = c.copy() + c_copy.info["head"] = "DFT" + fitting_configs_dft.append(c_copy) + else: + c_copy = c.copy() + c_copy.info["head"] = "MP2" + fitting_configs_mp2.append(c_copy) + + ase.io.write(tmp_path / "fit_dft.xyz", fitting_configs_dft) + ase.io.write(tmp_path / "fit_mp2.xyz", fitting_configs_mp2) + + # Create multihead configuration + heads = { + "DFT": {"train_file": f"{str(tmp_path)}/fit_dft.xyz"}, + "MP2": {"train_file": f"{str(tmp_path)}/fit_mp2.xyz"}, + } + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + config_file = tmp_path / "config.yaml" + with open(config_file, "w", encoding="utf-8") as file: + file.write(yaml_str) + + base_params = { + "name": "MACE", + "checkpoints_dir": str(tmp_path), + "model_dir": str(tmp_path), + "config": str(config_file), + "loss": "weighted", + "foundation_model": "small", + "hidden_irreps": "128x0e", + "r_max": 6.0, + "default_dtype": "float64", + "max_num_epochs": 5, + "num_radial_basis": 10, + "interaction_first": "RealAgnosticResidualInteractionBlock", + "force_mh_ft_lr": True, + "batch_size": 1, + "num_samples_pt": 50, + "subselect_pt": "random", + "valid_fraction": 0.1, + "valid_batch_size": 1, + } + + # Run environment setup + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + + # First run: without foundation_model_elements (default behavior) + mace_params = base_params.copy() + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + # Load model and check elements + model_filtered = torch.load(tmp_path / "MACE.model", map_location="cpu") + filtered_elements = set(int(z) for z in model_filtered.atomic_numbers) + assert filtered_elements == {1, 8} # Only H and O should be present + assert len(model_filtered.heads) == 3 # pt_head + DFT + MP2 + + # Second run: with foundation_model_elements + mace_params = base_params.copy() + mace_params["name"] = "MACE_all_elements" + mace_params["foundation_model_elements"] = True + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + # Load model and check elements + model_all = torch.load(tmp_path / "MACE_all_elements.model", map_location="cpu") + all_elements = set(int(z) for z in model_all.atomic_numbers) + + # Get elements from foundation model for comparison + calc = mace_mp(model="small", device="cpu") + foundation_elements = set(int(z) for z in calc.models[0].atomic_numbers) + + # Check that all foundation model elements are preserved + assert all_elements == foundation_elements + assert len(all_elements) > len(filtered_elements) + assert len(model_all.heads) == 3 # pt_head + DFT + MP2 + + # Check that both models can make predictions + at = fitting_configs_dft[2].copy() + + # Test filtered model + calc_filtered = MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" + ) + at.calc = calc_filtered + e1 = at.get_potential_energy() + + # Test all-elements model + calc_all = MACECalculator( + model_paths=tmp_path / "MACE_all_elements.model", + device="cpu", + default_dtype="float64", + ) + at.calc = calc_all + e2 = at.get_potential_energy() + + assert np.isfinite(e1) + assert np.isfinite(e2)