Skip to content

Commit

Permalink
Refactoring of calculation using pypolymlp
Browse files Browse the repository at this point in the history
  • Loading branch information
atztogo committed Sep 4, 2024
1 parent dd4c43a commit 9fa805b
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 44 deletions.
41 changes: 31 additions & 10 deletions phono3py/api_phono3py.py
Original file line number Diff line number Diff line change
Expand Up @@ -2208,16 +2208,37 @@ def develop_mlp(
else:
_params = params

disps = self._mlp_dataset["displacements"]
forces = self._mlp_dataset["forces"]
energies = self._mlp_dataset["supercell_energies"]
n = int(len(disps) * (1 - test_size))
train_data = PypolymlpData(
displacements=disps[:n], forces=forces[:n], supercell_energies=energies[:n]
)
test_data = PypolymlpData(
displacements=disps[n:], forces=forces[n:], supercell_energies=energies[n:]
)
if _params.ntrain is not None and _params.ntest is not None:
ntrain = _params.ntrain
ntest = _params.ntest
disps = self._mlp_dataset["displacements"]
forces = self._mlp_dataset["forces"]
energies = self._mlp_dataset["supercell_energies"]
train_data = PypolymlpData(
displacements=disps[:ntrain],
forces=forces[:ntrain],
supercell_energies=energies[:ntrain],
)
test_data = PypolymlpData(
displacements=disps[-ntest:],
forces=forces[-ntest:],
supercell_energies=energies[-ntest:],
)
else:
disps = self._mlp_dataset["displacements"]
forces = self._mlp_dataset["forces"]
energies = self._mlp_dataset["supercell_energies"]
n = int(len(disps) * (1 - test_size))
train_data = PypolymlpData(
displacements=disps[:n],
forces=forces[:n],
supercell_energies=energies[:n],
)
test_data = PypolymlpData(
displacements=disps[n:],
forces=forces[n:],
supercell_energies=energies[n:],
)
self._mlp = develop_polymlp(
self._supercell,
train_data,
Expand Down
48 changes: 19 additions & 29 deletions phono3py/cui/create_force_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,19 +118,28 @@ def create_phono3py_force_constants(
if settings.read_fc3:
_read_phono3py_fc3(phono3py, symmetrize_fc3r, input_filename, log_level)
else: # fc3 from FORCES_FC3 or ph3py_yaml
_read_dataset_fc3(
dataset = _read_dataset_fc3(
phono3py,
ph3py_yaml,
phono3py_yaml_filename,
settings.cutoff_pair_distance,
calculator,
settings.use_pypolymlp,
settings.mlp_params,
settings.displacement_distance,
settings.random_displacements,
settings.random_seed,
log_level,
)

if settings.use_pypolymlp:
phono3py.mlp_dataset = dataset
run_pypolymlp_to_compute_forces(
phono3py,
settings.mlp_params,
displacement_distance=settings.displacement_distance,
number_of_snapshots=settings.random_displacements,
random_seed=settings.random_seed,
log_level=log_level,
)
else:
phono3py.dataset = dataset

phono3py.produce_fc3(
symmetrize_fc3r=symmetrize_fc3r,
is_compact_fc=settings.is_compact_fc,
Expand Down Expand Up @@ -214,7 +223,7 @@ def parse_forces(
fc_type: Literal["fc3", "phonon_fc2"] = "fc3",
calculator: Optional[str] = None,
log_level=0,
):
) -> dict:
"""Read displacements and forces.
Physical units of displacements and forces are converted following the
Expand Down Expand Up @@ -454,13 +463,8 @@ def _read_dataset_fc3(
phono3py_yaml_filename: Optional[str],
cutoff_pair_distance: Optional[float],
calculator: Optional[str],
use_pypolymlp: bool,
mlp_params: Union[str, dict, PypolymlpParams],
displacement_distance: Optional[float],
number_of_snapshots: Optional[int],
random_seed: Optional[int],
log_level: int,
):
) -> dict:
"""Read or calculate fc3.
Note
Expand Down Expand Up @@ -496,18 +500,7 @@ def _read_dataset_fc3(
# from _get_type2_dataset
file_exists(e.filename, log_level=log_level)

if use_pypolymlp:
phono3py.mlp_dataset = dataset
run_pypolymlp_to_compute_forces(
phono3py,
mlp_params,
displacement_distance=displacement_distance,
number_of_snapshots=number_of_snapshots,
random_seed=random_seed,
log_level=log_level,
)
else:
phono3py.dataset = dataset
return dataset


def run_pypolymlp_to_compute_forces(
Expand All @@ -529,8 +522,6 @@ def run_pypolymlp_to_compute_forces(
for k, v in asdict(parse_mlp_params(mlp_params)).items():
if v is not None:
print(f" {k}: {v}")
if log_level > 1:
print("")
if log_level:
print("Developing MLPs by pypolymlp...", flush=True)

Expand Down Expand Up @@ -579,6 +570,7 @@ def run_pypolymlp_to_compute_forces(
raise RuntimeError("Displacements are not set. Run generate_displacements.")

ph3py.evaluate_mlp()
ph3py.save("phono3py_mlp_eval_dataset.yaml")


def run_pypolymlp_to_compute_phonon_forces(
Expand All @@ -601,8 +593,6 @@ def run_pypolymlp_to_compute_phonon_forces(
for k, v in asdict(parse_mlp_params(mlp_params)).items():
if v is not None:
print(f" {k}: {v}")
if log_level > 1:
print("")
if log_level:
print("Developing MLPs by pypolymlp...", flush=True)

Expand Down
7 changes: 2 additions & 5 deletions phono3py/cui/phono3py_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,16 +165,13 @@ def finalize_phono3py(

_physical_units = get_default_physical_units(phono3py.calculator)

write_force_sets = phono3py.mlp is not None
_write_displacements = write_displacements or phono3py.mlp is not None

ph3py_yaml = Phono3pyYaml(
configuration=confs_dict,
calculator=phono3py.calculator,
physical_units=_physical_units,
settings={
"force_sets": write_force_sets,
"displacements": _write_displacements,
"force_sets": False,
"displacements": write_displacements,
},
)
ph3py_yaml.set_phonon_info(phono3py)
Expand Down

0 comments on commit 9fa805b

Please sign in to comment.