Skip to content

Commit

Permalink
Merge pull request #781 from ACEsuit/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
ilyes319 authored Jan 14, 2025
2 parents 84cf7d0 + c7c0229 commit 6dce504
Show file tree
Hide file tree
Showing 6 changed files with 292 additions and 3 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
- [MACE-OFF: Transferable Organic Force Fields](#mace-off-transferable-organic-force-fields)
- [Example usage in ASE](#example-usage-in-ase-1)
- [Finetuning foundation models](#finetuning-foundation-models)
- [Latest recommended foundation models](#latest-recommended-foundation-models)
- [Caching](#caching)
- [Development](#development)
- [References](#references)
Expand Down Expand Up @@ -273,6 +274,14 @@ atoms.calc = calc
print(atoms.get_potential_energy())
```

### Latest Recommended Foundation Models

| Model Name | Elements Covered | Training Dataset | Level of Theory | Target System | Model Size | GitHub Release | Notes | License |
|-------------------|------------------|------------------|-----------------------|----------------------|---------------------|----------------|-------------------------------------------------------|---------|
| MACE-MP-0 | 89 | MPTrj | DFT (PBE+U) | Materials | [small](https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model), [medium](https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model), [large](https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2024-01-07-mace-128-L2_epoch-199.model)| >=v0.3.6 | Initial release of foundation model. | MIT |
| MACE-MPA-0 | 89 | MPTrj + sAlex | DFT (PBE+U) | Materials | [medium](https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model) | >=v0.3.10 | Improved accuracy for materials, improved high pressure stability. | MIT |
| MACE-OFF23 | 10 | SPICE v1 | DFT (wB97M+D3) | Organic Chemistry | [small](https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model), [medium](https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_medium.model), [large](https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_large.model)| >=v0.3.6 | Initial release covering neutral organic chemistry. | ASL |

### Finetuning foundation models

To finetune one of the mace-mp-0 foundation model, you can use the `mace_run_train` script with the extra argument `--foundation_model=model_type`. For example to finetune the small model on a new dataset, you can use:
Expand Down
13 changes: 13 additions & 0 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,8 @@ def run(args: argparse.Namespace) -> None:
for head_config in head_configs:
all_atomic_numbers.update(head_config.atomic_numbers)
z_table = AtomicNumberTable(sorted(list(all_atomic_numbers)))
if args.foundation_model_elements and model_foundation:
z_table = AtomicNumberTable(sorted(model_foundation.atomic_numbers.tolist()))
logging.info(f"Atomic Numbers used: {z_table.zs}")

# Atomic energies
Expand Down Expand Up @@ -434,6 +436,16 @@ def run(args: argparse.Namespace) -> None:
for z in z_table.zs
}

# Padding atomic energies if keeping all elements of the foundation model
if args.foundation_model_elements and model_foundation:
atomic_energies_dict_padded = {}
for head_name, head_energies in atomic_energies_dict.items():
energy_head_padded = {}
for z in z_table.zs:
energy_head_padded[z] = head_energies.get(z, 0.0)
atomic_energies_dict_padded[head_name] = energy_head_padded
atomic_energies_dict = atomic_energies_dict_padded

if args.model == "AtomicDipolesMACE":
atomic_energies = None
dipole_only = True
Expand Down Expand Up @@ -634,6 +646,7 @@ def run(args: argparse.Namespace) -> None:
distributed_model = DDP(model, device_ids=[local_rank])
else:
distributed_model = None
print("MODEL", model)
tools.train(
model=model,
loss_fn=loss_fn,
Expand Down
1 change: 1 addition & 0 deletions mace/cli/select_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def main():
args = parser.parse_args()

model = torch.load(args.model_file)
torch.set_default_dtype(next(model.parameters()).dtype)

if args.list_heads:
print("Available heads:")
Expand Down
6 changes: 6 additions & 0 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
type=str,
default=None,
)
parser.add_argument(
"--foundation_model_elements",
help="Keep all elements of the foundation model during fine-tuning",
type=str2bool,
default=False,
)
parser.add_argument(
"--keep_isolated_atoms",
help="Keep isolated atoms in the dataset, useful for transfer learning",
Expand Down
19 changes: 17 additions & 2 deletions mace/tools/model_script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from mace import modules
from mace.tools.finetuning_utils import load_foundations_elements
from mace.tools.scripts_utils import extract_config_mace_model
from mace.tools.utils import AtomicNumberTable


def configure_model(
Expand Down Expand Up @@ -43,8 +44,22 @@ def configure_model(
logging.info("Loading FOUNDATION model")
model_config_foundation = extract_config_mace_model(model_foundation)
model_config_foundation["atomic_energies"] = atomic_energies
model_config_foundation["atomic_numbers"] = z_table.zs
model_config_foundation["num_elements"] = len(z_table)

if args.foundation_model_elements:
foundation_z_table = AtomicNumberTable(
[int(z) for z in model_foundation.atomic_numbers]
)
model_config_foundation["atomic_numbers"] = foundation_z_table.zs
model_config_foundation["num_elements"] = len(foundation_z_table)
z_table = foundation_z_table
logging.info(
f"Using all elements from foundation model: {foundation_z_table.zs}"
)
else:
model_config_foundation["atomic_numbers"] = z_table.zs
model_config_foundation["num_elements"] = len(z_table)
logging.info(f"Using filtered elements: {z_table.zs}")

args.max_L = model_config_foundation["hidden_irreps"].lmax

if args.model == "MACE" and model_foundation.__class__.__name__ == "MACE":
Expand Down
247 changes: 246 additions & 1 deletion tests/test_run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from ase.atoms import Atoms

from mace.calculators.mace import MACECalculator
from mace.calculators import MACECalculator, mace_mp

try:
import cuequivariance as cue # pylint: disable=unused-import
Expand Down Expand Up @@ -1051,3 +1051,248 @@ def test_run_train_foundation_multihead_json_cueq(tmp_path, fitting_configs):
0.5574042201042175,
]
assert np.allclose(Es, ref_Es, atol=1e-1)


def test_run_train_foundation_elements(tmp_path, fitting_configs):

ase.io.write(tmp_path / "fit.xyz", fitting_configs)

base_params = {
"name": "MACE",
"checkpoints_dir": str(tmp_path),
"model_dir": str(tmp_path),
"train_file": tmp_path / "fit.xyz",
"loss": "weighted",
"foundation_model": "small",
"hidden_irreps": "128x0e",
"r_max": 6.0,
"default_dtype": "float64",
"max_num_epochs": 5,
"num_radial_basis": 10,
"interaction_first": "RealAgnosticResidualInteractionBlock",
"multiheads_finetuning": False,
}

# Run environment setup
run_env = os.environ.copy()
sys.path.insert(0, str(Path(__file__).parent.parent))
run_env["PYTHONPATH"] = ":".join(sys.path)

# First run: without foundation_model_elements (default behavior)
mace_params = base_params.copy()
cmd = (
sys.executable
+ " "
+ str(run_train)
+ " "
+ " ".join(
[
(f"--{k}={v}" if v is not None else f"--{k}")
for k, v in mace_params.items()
]
)
)
p = subprocess.run(cmd.split(), env=run_env, check=True)
assert p.returncode == 0

# Load model and check elements
model_filtered = torch.load(tmp_path / "MACE.model", map_location="cpu")
filtered_elements = set(int(z) for z in model_filtered.atomic_numbers)
assert filtered_elements == {1, 8} # Only H and O should be present

# Second run: with foundation_model_elements
mace_params = base_params.copy()
mace_params["name"] = "MACE_all_elements"
mace_params["foundation_model_elements"] = True # Flag-only argument
cmd = (
sys.executable
+ " "
+ str(run_train)
+ " "
+ " ".join(
[
(f"--{k}={v}" if v is not None else f"--{k}")
for k, v in mace_params.items()
]
)
)
p = subprocess.run(cmd.split(), env=run_env, check=True)
assert p.returncode == 0

# Load model and check elements
model_all = torch.load(tmp_path / "MACE_all_elements.model", map_location="cpu")
all_elements = set(int(z) for z in model_all.atomic_numbers)

# Get elements from foundation model for comparison
calc = mace_mp(model="small", device="cpu")
foundation_elements = set(int(z) for z in calc.models[0].atomic_numbers)

# Check that all foundation model elements are preserved
assert all_elements == foundation_elements
assert len(all_elements) > len(filtered_elements)

# Check that both models can make predictions
at = fitting_configs[2].copy()

# Test filtered model
calc_filtered = MACECalculator(
model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64"
)
at.calc = calc_filtered
e1 = at.get_potential_energy()

# Test all-elements model
calc_all = MACECalculator(
model_paths=tmp_path / "MACE_all_elements.model",
device="cpu",
default_dtype="float64",
)
at.calc = calc_all
e2 = at.get_potential_energy()

# Energies should be different since the models are trained differently,
# but both should give reasonable results
assert np.isfinite(e1)
assert np.isfinite(e2)


def test_run_train_foundation_elements_multihead(tmp_path, fitting_configs):

fitting_configs_dft = []
fitting_configs_mp2 = []
for i, c in enumerate(fitting_configs):
if i in (0, 1):
c_dft = c.copy()
c_dft.info["head"] = "DFT"
fitting_configs_dft.append(c_dft)
c_mp2 = c.copy()
c_mp2.info["head"] = "MP2"
fitting_configs_mp2.append(c_mp2)
if i % 2 == 0:
c_copy = c.copy()
c_copy.info["head"] = "DFT"
fitting_configs_dft.append(c_copy)
else:
c_copy = c.copy()
c_copy.info["head"] = "MP2"
fitting_configs_mp2.append(c_copy)

ase.io.write(tmp_path / "fit_dft.xyz", fitting_configs_dft)
ase.io.write(tmp_path / "fit_mp2.xyz", fitting_configs_mp2)

# Create multihead configuration
heads = {
"DFT": {"train_file": f"{str(tmp_path)}/fit_dft.xyz"},
"MP2": {"train_file": f"{str(tmp_path)}/fit_mp2.xyz"},
}
yaml_str = "heads:\n"
for key, value in heads.items():
yaml_str += f" {key}:\n"
for sub_key, sub_value in value.items():
yaml_str += f" {sub_key}: {sub_value}\n"
config_file = tmp_path / "config.yaml"
with open(config_file, "w", encoding="utf-8") as file:
file.write(yaml_str)

base_params = {
"name": "MACE",
"checkpoints_dir": str(tmp_path),
"model_dir": str(tmp_path),
"config": str(config_file),
"loss": "weighted",
"foundation_model": "small",
"hidden_irreps": "128x0e",
"r_max": 6.0,
"default_dtype": "float64",
"max_num_epochs": 5,
"num_radial_basis": 10,
"interaction_first": "RealAgnosticResidualInteractionBlock",
"force_mh_ft_lr": True,
"batch_size": 1,
"num_samples_pt": 50,
"subselect_pt": "random",
"valid_fraction": 0.1,
"valid_batch_size": 1,
}

# Run environment setup
run_env = os.environ.copy()
sys.path.insert(0, str(Path(__file__).parent.parent))
run_env["PYTHONPATH"] = ":".join(sys.path)

# First run: without foundation_model_elements (default behavior)
mace_params = base_params.copy()
cmd = (
sys.executable
+ " "
+ str(run_train)
+ " "
+ " ".join(
[
(f"--{k}={v}" if v is not None else f"--{k}")
for k, v in mace_params.items()
]
)
)
p = subprocess.run(cmd.split(), env=run_env, check=True)
assert p.returncode == 0

# Load model and check elements
model_filtered = torch.load(tmp_path / "MACE.model", map_location="cpu")
filtered_elements = set(int(z) for z in model_filtered.atomic_numbers)
assert filtered_elements == {1, 8} # Only H and O should be present
assert len(model_filtered.heads) == 3 # pt_head + DFT + MP2

# Second run: with foundation_model_elements
mace_params = base_params.copy()
mace_params["name"] = "MACE_all_elements"
mace_params["foundation_model_elements"] = True
cmd = (
sys.executable
+ " "
+ str(run_train)
+ " "
+ " ".join(
[
(f"--{k}={v}" if v is not None else f"--{k}")
for k, v in mace_params.items()
]
)
)
p = subprocess.run(cmd.split(), env=run_env, check=True)
assert p.returncode == 0

# Load model and check elements
model_all = torch.load(tmp_path / "MACE_all_elements.model", map_location="cpu")
all_elements = set(int(z) for z in model_all.atomic_numbers)

# Get elements from foundation model for comparison
calc = mace_mp(model="small", device="cpu")
foundation_elements = set(int(z) for z in calc.models[0].atomic_numbers)

# Check that all foundation model elements are preserved
assert all_elements == foundation_elements
assert len(all_elements) > len(filtered_elements)
assert len(model_all.heads) == 3 # pt_head + DFT + MP2

# Check that both models can make predictions
at = fitting_configs_dft[2].copy()

# Test filtered model
calc_filtered = MACECalculator(
model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64"
)
at.calc = calc_filtered
e1 = at.get_potential_energy()

# Test all-elements model
calc_all = MACECalculator(
model_paths=tmp_path / "MACE_all_elements.model",
device="cpu",
default_dtype="float64",
)
at.calc = calc_all
e2 = at.get_potential_energy()

assert np.isfinite(e1)
assert np.isfinite(e2)

0 comments on commit 6dce504

Please sign in to comment.