diff --git a/phono3py/api_phono3py.py b/phono3py/api_phono3py.py index b42d4a16..c1d8da07 100644 --- a/phono3py/api_phono3py.py +++ b/phono3py/api_phono3py.py @@ -58,6 +58,7 @@ from phonopy.interface.pypolymlp import ( PypolymlpData, PypolymlpParams, + develop_mlp_by_pypolymlp, develop_polymlp, evalulate_polymlp, load_polymlp, @@ -2204,52 +2205,12 @@ def develop_mlp( if self._mlp_dataset is None: raise RuntimeError("MLP dataset is not set.") - if params is not None: - _params = parse_mlp_params(params) - else: - _params = params - - if ( - _params is not None - and _params.ntrain is not None - and _params.ntest is not None - ): - ntrain = _params.ntrain - ntest = _params.ntest - disps = self._mlp_dataset["displacements"] - forces = self._mlp_dataset["forces"] - energies = self._mlp_dataset["supercell_energies"] - train_data = PypolymlpData( - displacements=disps[:ntrain], - forces=forces[:ntrain], - supercell_energies=energies[:ntrain], - ) - test_data = PypolymlpData( - displacements=disps[-ntest:], - forces=forces[-ntest:], - supercell_energies=energies[-ntest:], - ) - else: - disps = self._mlp_dataset["displacements"] - forces = self._mlp_dataset["forces"] - energies = self._mlp_dataset["supercell_energies"] - n = int(len(disps) * (1 - test_size)) - train_data = PypolymlpData( - displacements=disps[:n], - forces=forces[:n], - supercell_energies=energies[:n], - ) - test_data = PypolymlpData( - displacements=disps[n:], - forces=forces[n:], - supercell_energies=energies[n:], - ) - self._mlp = develop_polymlp( + self._mlp = develop_mlp_by_pypolymlp( + self._mlp_dataset, self._supercell, - train_data, - test_data, - params=_params, - verbose=self._log_level - 1 > 0, + params=params, + test_size=test_size, + log_level=self._log_level, ) def load_mlp(self, filename: str = "phono3py.pmlp"):