Skip to content

Commit

Permalink
Merge pull request #272 from phonopy/pypolymlp
Browse files Browse the repository at this point in the history
Separate detail of Phono3py.develop_mlp() into a function in phonopy
  • Loading branch information
atztogo authored Sep 5, 2024
2 parents 853e3ce + 9b23427 commit b2444ae
Showing 1 changed file with 6 additions and 45 deletions.
51 changes: 6 additions & 45 deletions phono3py/api_phono3py.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from phonopy.interface.pypolymlp import (
PypolymlpData,
PypolymlpParams,
develop_mlp_by_pypolymlp,
develop_polymlp,
evalulate_polymlp,
load_polymlp,
Expand Down Expand Up @@ -2204,52 +2205,12 @@ def develop_mlp(
if self._mlp_dataset is None:
raise RuntimeError("MLP dataset is not set.")

if params is not None:
_params = parse_mlp_params(params)
else:
_params = params

if (
_params is not None
and _params.ntrain is not None
and _params.ntest is not None
):
ntrain = _params.ntrain
ntest = _params.ntest
disps = self._mlp_dataset["displacements"]
forces = self._mlp_dataset["forces"]
energies = self._mlp_dataset["supercell_energies"]
train_data = PypolymlpData(
displacements=disps[:ntrain],
forces=forces[:ntrain],
supercell_energies=energies[:ntrain],
)
test_data = PypolymlpData(
displacements=disps[-ntest:],
forces=forces[-ntest:],
supercell_energies=energies[-ntest:],
)
else:
disps = self._mlp_dataset["displacements"]
forces = self._mlp_dataset["forces"]
energies = self._mlp_dataset["supercell_energies"]
n = int(len(disps) * (1 - test_size))
train_data = PypolymlpData(
displacements=disps[:n],
forces=forces[:n],
supercell_energies=energies[:n],
)
test_data = PypolymlpData(
displacements=disps[n:],
forces=forces[n:],
supercell_energies=energies[n:],
)
self._mlp = develop_polymlp(
self._mlp = develop_mlp_by_pypolymlp(
self._mlp_dataset,
self._supercell,
train_data,
test_data,
params=_params,
verbose=self._log_level - 1 > 0,
params=params,
test_size=test_size,
log_level=self._log_level,
)

def load_mlp(self, filename: str = "phono3py.pmlp"):
Expand Down

0 comments on commit b2444ae

Please sign in to comment.