Skip to content

Commit

Permalink
Add units (#378)
Browse files Browse the repository at this point in the history
* Save units for all calculations

* Fix ASE unit definition

* Test units

* Use Base units for MD stats

---------

Co-authored-by: Jacob Wilkins <[email protected]>
  • Loading branch information
ElliottKasoar and oerc0122 authored Jan 14, 2025
1 parent b727ff3 commit 43c8632
Show file tree
Hide file tree
Showing 10 changed files with 137 additions and 23 deletions.
32 changes: 32 additions & 0 deletions janus_core/calculations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,20 @@
from janus_core.helpers.struct_io import input_structs
from janus_core.helpers.utils import FileNameMixin, none_to_dict, set_log_tracker

UNITS = {
"energy": "eV",
"forces": "ev/Ang",
"stress": "ev/Ang^3",
"hessian": "ev/Ang^2",
"time": "fs",
"real_time": "s",
"temperature": "K",
"pressure": "GPa",
"momenta": "(eV*u)^0.5",
"density": "g/cm^3",
"volume": "Ang^3",
}


class BaseCalculation(FileNameMixin):
"""
Expand Down Expand Up @@ -209,3 +223,21 @@ def __init__(
self.tracker = config_tracker(
self.logger, self.track_carbon, **self.tracker_kwargs
)

def _set_info_units(
self, keys: Sequence[str] = ("energy", "forces", "stress")
) -> None:
"""
Save units to structure info.
Parameters
----------
keys : Sequence
Keys for which to add units to structure info. Default is
("energy", "forces", "stress").
"""
if isinstance(self.struct, Sequence):
for image in self.struct:
image.info["units"] = {key: UNITS[key] for key in keys}
else:
self.struct.info["units"] = {key: UNITS[key] for key in keys}
2 changes: 2 additions & 0 deletions janus_core/calculations/eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ def run(self) -> EoSResults:
Dictionary containing equation of state ASE object, and the fitted minimum
bulk modulus, volume, and energy.
"""
self._set_info_units()

if self.minimize:
if self.logger:
self.logger.info("Minimising initial structure")
Expand Down
2 changes: 2 additions & 0 deletions janus_core/calculations/geom_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,8 @@ def run(self) -> None:
if self.tracker:
self.tracker.start_task("Geometry optimization")

self._set_info_units()

converged = self.dyn.run(fmax=self.fmax, steps=self.steps)

# Calculate current maximum force
Expand Down
61 changes: 40 additions & 21 deletions janus_core/calculations/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Any
from warnings import warn

from ase import Atoms, units
from ase import Atoms
from ase.geometry.analysis import Analysis
from ase.io import read
from ase.md.langevin import Langevin
Expand All @@ -23,9 +23,11 @@
ZeroRotation,
)
from ase.md.verlet import VelocityVerlet
from ase.units import create_units
import numpy as np
import yaml

from janus_core.calculations.base import UNITS as JANUS_UNITS
from janus_core.calculations.base import BaseCalculation
from janus_core.calculations.geom_opt import GeomOpt
from janus_core.helpers.janus_types import (
Expand All @@ -43,6 +45,7 @@
from janus_core.processing.correlator import Correlation
from janus_core.processing.post_process import compute_rdf, compute_vaf

units = create_units("2014")
DENS_FACT = (units.m / 1.0e2) ** 3 / units.mol


Expand Down Expand Up @@ -525,7 +528,7 @@ def _set_info(self) -> None:
"""Set time in fs, current dynamics step, and density to info."""
time = (self.offset * self.timestep + self.dyn.get_time()) / units.fs
step = self.offset + self.dyn.nsteps
self.dyn.atoms.info["time_fs"] = time
self.dyn.atoms.info["time"] = time
self.dyn.atoms.info["step"] = step
try:
density = (
Expand Down Expand Up @@ -772,7 +775,7 @@ def get_stats(self) -> dict[str, float]:
return {
"Step": self.dyn.atoms.info["step"],
"Real_Time": real_time.total_seconds(),
"Time": self.dyn.atoms.info["time_fs"],
"Time": self.dyn.atoms.info["time"],
"Epot/N": e_pot,
"EKin/N": e_kin,
"T": current_temp,
Expand Down Expand Up @@ -800,21 +803,21 @@ def unit_info(self) -> dict[str, str]:
"""
return {
"Step": None,
"Real_Time": "s",
"Time": "fs",
"Epot/N": "eV",
"EKin/N": "eV",
"T": "K",
"ETot/N": "eV",
"Density": "g/cm^3",
"Volume": "A^3",
"P": "GPa",
"Pxx": "GPa",
"Pyy": "GPa",
"Pzz": "GPa",
"Pyz": "GPa",
"Pxz": "GPa",
"Pxy": "GPa",
"Real_Time": JANUS_UNITS["real_time"],
"Time": JANUS_UNITS["time"],
"Epot/N": JANUS_UNITS["energy"],
"EKin/N": JANUS_UNITS["energy"],
"T": JANUS_UNITS["temperature"],
"ETot/N": JANUS_UNITS["energy"],
"Density": JANUS_UNITS["density"],
"Volume": JANUS_UNITS["volume"],
"P": JANUS_UNITS["pressure"],
"Pxx": JANUS_UNITS["pressure"],
"Pyy": JANUS_UNITS["pressure"],
"Pzz": JANUS_UNITS["pressure"],
"Pyz": JANUS_UNITS["pressure"],
"Pxz": JANUS_UNITS["pressure"],
"Pxy": JANUS_UNITS["pressure"],
}

@property
Expand Down Expand Up @@ -1024,6 +1027,19 @@ def _write_restart(self) -> None:

def run(self) -> None:
"""Run molecular dynamics simulation and/or temperature ramp."""
unit_keys = (
"energy",
"forces",
"stress",
"time",
"real_time",
"temperature",
"pressure",
"density",
"momenta",
)
self._set_info_units(unit_keys)

if not self.restart:
if self.minimize:
self._optimize_structure()
Expand Down Expand Up @@ -1265,7 +1281,10 @@ def unit_info(self) -> dict[str, str]:
dict[str, str]
Units attached to statistical properties.
"""
return super().unit_info | {"Target_P": "GPa", "Target_T": "K"}
return super().unit_info | {
"Target_P": JANUS_UNITS["pressure"],
"Target_T": JANUS_UNITS["temperature"],
}

@property
def default_formats(self) -> dict[str, str]:
Expand Down Expand Up @@ -1362,7 +1381,7 @@ def unit_info(self) -> dict[str, str]:
dict[str, str]
Units attached to statistical properties.
"""
return super().unit_info | {"Target_T": "K"}
return super().unit_info | {"Target_T": JANUS_UNITS["temperature"]}

@property
def default_formats(self) -> dict[str, str]:
Expand Down Expand Up @@ -1505,7 +1524,7 @@ def unit_info(self) -> dict[str, str]:
dict[str, str]
Units attached to statistical properties.
"""
return super().unit_info | {"Target_T": "K"}
return super().unit_info | {"Target_T": JANUS_UNITS["temperature"]}

@property
def default_formats(self) -> dict[str, str]:
Expand Down
2 changes: 2 additions & 0 deletions janus_core/calculations/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,8 @@ def calc_force_constants(
if self.tracker:
self.tracker.start_task("Phonon calculation")

self._set_info_units()

cell = self._ASE_to_PhonopyAtoms(self.struct)

if len(self.supercell) == 3:
Expand Down
2 changes: 2 additions & 0 deletions janus_core/calculations/single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,8 @@ def run(self) -> CalcResults:
if self.tracker:
self.tracker.start_task("Single point")

self._set_info_units(self.properties)

if "energy" in self.properties:
self.results["energy"] = self._get_potential_energy()
if "forces" in self.properties:
Expand Down
5 changes: 4 additions & 1 deletion janus_core/processing/observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

from ase import Atoms, units
from ase import Atoms
from ase.units import create_units

if TYPE_CHECKING:
from janus_core.helpers.janus_types import SliceLike

from janus_core.helpers.utils import slicelike_to_startstopstep

units = create_units("2014")


# pylint: disable=too-few-public-methods
class Observable(ABC):
Expand Down
29 changes: 29 additions & 0 deletions tests/test_geomopt_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,3 +740,32 @@ def test_no_carbon(tmp_path):
with open(summary_path, encoding="utf8") as file:
geomopt_summary = yaml.safe_load(file)
assert "emissions" not in geomopt_summary


def test_units(tmp_path):
"""Test correct units are saved."""
results_path = tmp_path / "NaCl-opt.extxyz"
log_path = tmp_path / "test.log"
summary_path = tmp_path / "summary.yml"

result = runner.invoke(
app,
[
"geomopt",
"--struct",
DATA_PATH / "NaCl.cif",
"--out",
results_path,
"--log",
log_path,
"--summary",
summary_path,
],
)
assert result.exit_code == 0

atoms = read(results_path)
expected_units = {"energy": "eV", "forces": "ev/Ang", "stress": "ev/Ang^3"}
assert "units" in atoms.info
for prop, units in expected_units.items():
assert atoms.info["units"][prop] == units
19 changes: 18 additions & 1 deletion tests/test_md_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,23 @@ def test_md(ensemble):
assert "momenta" in atoms.arrays
assert "masses" in atoms.arrays

expected_units = {
"time": "fs",
"real_time": "s",
"energy": "eV",
"forces": "ev/Ang",
"stress": "ev/Ang^3",
"temperature": "K",
"density": "g/cm^3",
"momenta": "(eV*u)^0.5",
}
if ensemble in ("nvt", "nvt-nh"):
expected_units["pressure"] = "GPa"

assert "units" in atoms.info
for prop, units in expected_units.items():
assert atoms.info["units"][prop] == units

finally:
final_path.unlink(missing_ok=True)
restart_path.unlink(missing_ok=True)
Expand Down Expand Up @@ -165,7 +182,7 @@ def test_log(tmp_path):
assert len(lines) == 22

# Test constant volume
assert lines[0].split(" | ")[8] == "Volume [A^3]"
assert lines[0].split(" | ")[8] == "Volume [Ang^3]"
init_volume = float(lines[1].split()[8])
final_volume = float(lines[-1].split()[8])
assert init_volume == 179.406144
Expand Down
6 changes: 6 additions & 0 deletions tests/test_singlepoint_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ def test_singlepoint():
assert "system_name" in atoms.info
assert atoms.info["system_name"] == "NaCl"

expected_units = {"energy": "eV", "forces": "ev/Ang", "stress": "ev/Ang^3"}
assert "units" in atoms.info
for prop, units in expected_units.items():
assert atoms.info["units"][prop] == units

clear_log_handlers()


Expand Down Expand Up @@ -399,6 +404,7 @@ def test_hessian(tmp_path):
assert "mace_mp_hessian" in atoms.info
assert "mace_stress" not in atoms.info
assert atoms.info["mace_mp_hessian"].shape == (24, 8, 3)
assert atoms.info["units"]["hessian"] == "ev/Ang^2"


def test_no_carbon(tmp_path):
Expand Down

0 comments on commit 43c8632

Please sign in to comment.