Skip to content

Commit

Permalink
Refactor CUI for using pypolymlp
Browse files Browse the repository at this point in the history
  • Loading branch information
atztogo committed Jan 15, 2025
1 parent 4ac2d78 commit 2f91fea
Show file tree
Hide file tree
Showing 3 changed files with 280 additions and 276 deletions.
62 changes: 32 additions & 30 deletions phono3py/cui/create_force_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ def run_pypolymlp_to_compute_forces(
displacement_distance: Optional[float] = None,
number_of_snapshots: Optional[int] = None,
random_seed: Optional[int] = None,
prepare_dataset: bool = False,
mlp_filename: str = "phono3py.pmlp",
log_level: int = 0,
):
Expand Down Expand Up @@ -546,43 +547,44 @@ def run_pypolymlp_to_compute_forces(
if log_level:
print("-" * 30 + " pypolymlp end " + "-" * 31, flush=True)

if displacement_distance is None:
_displacement_distance = 0.001
else:
_displacement_distance = displacement_distance
if prepare_dataset:
if displacement_distance is None:
_displacement_distance = 0.001
else:
_displacement_distance = displacement_distance

if log_level:
if number_of_snapshots:
print("Generate random displacements")
if log_level:
if number_of_snapshots:
print("Generate random displacements")
print(
" Twice of number of snapshots will be generated "
"for plus-minus displacements."
)
else:
print("Generate displacements")
print(
" Twice of number of snapshots will be generated "
"for plus-minus displacements."
)
else:
print("Generate displacements")
print(
f" Displacement distance: {_displacement_distance:.5f}".rstrip("0").rstrip(
"."
f" Displacement distance: {_displacement_distance:.5f}".rstrip(
"0"
).rstrip(".")
)
ph3py.generate_displacements(
distance=_displacement_distance,
is_plusminus=True,
number_of_snapshots=number_of_snapshots,
random_seed=random_seed,
)
ph3py.generate_displacements(
distance=_displacement_distance,
is_plusminus=True,
number_of_snapshots=number_of_snapshots,
random_seed=random_seed,
)

if log_level:
print(
f"Evaluate forces in {ph3py.displacements.shape[0]} supercells "
"by pypolymlp",
flush=True,
)
if log_level:
print(
f"Evaluate forces in {ph3py.displacements.shape[0]} supercells "
"by pypolymlp",
flush=True,
)

if ph3py.supercells_with_displacements is None:
raise RuntimeError("Displacements are not set. Run generate_displacements.")
if ph3py.supercells_with_displacements is None:
raise RuntimeError("Displacements are not set. Run generate_displacements.")

ph3py.evaluate_mlp()
ph3py.evaluate_mlp()


def run_pypolymlp_to_compute_phonon_forces(
Expand Down
191 changes: 87 additions & 104 deletions phono3py/cui/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,29 +340,34 @@ def load(
log_level=log_level,
)

read_fc = set_dataset_and_force_constants(
load_fc2_and_fc3(
ph3py, fc3_filename=fc3_filename, fc2_filename=fc2_filename, log_level=log_level
)
load_dataset_and_phonon_dataset(
ph3py,
ph3py_yaml,
fc3_filename=fc3_filename,
fc2_filename=fc2_filename,
forces_fc3_filename=forces_fc3_filename,
forces_fc2_filename=forces_fc2_filename,
phono3py_yaml_filename=phono3py_yaml,
calculator=_calculator,
use_pypolymlp=use_pypolymlp,
log_level=log_level,
)
if use_pypolymlp and ph3py.fc3 is None and forces_in_dataset(ph3py.dataset):
ph3py.mlp_dataset = ph3py.dataset
ph3py.dataset = None

if produce_fc:
if ph3py.fc3 is None and use_pypolymlp:
run_pypolymlp_to_compute_forces(
ph3py, mlp_params=mlp_params, log_level=log_level
)

compute_force_constants_from_datasets(
ph3py,
read_fc,
fc_calculator=fc_calculator,
fc_calculator_options=fc_calculator_options,
symmetrize_fc=symmetrize_fc,
is_compact_fc=is_compact_fc,
use_pypolymlp=use_pypolymlp,
mlp_params=mlp_params,
log_level=log_level,
)

Expand All @@ -376,76 +381,62 @@ def load(
return ph3py


def set_dataset_and_force_constants(
def load_fc2_and_fc3(
ph3py: Phono3py,
ph3py_yaml: Optional[Phono3pyYaml] = None,
fc3_filename: Optional[os.PathLike] = None,
fc2_filename: Optional[os.PathLike] = None,
log_level: int = 0,
):
"""Set force constants."""
if fc3_filename is not None or pathlib.Path("fc3.hdf5").exists():
fc3 = _load_fc3(ph3py, fc3_filename=fc3_filename, log_level=log_level)
ph3py.fc3 = fc3

if fc2_filename is not None or pathlib.Path("fc2.hdf5").exists():
fc2 = _load_fc2(ph3py, fc2_filename=fc2_filename, log_level=log_level)
ph3py.fc2 = fc2


def load_dataset_and_phonon_dataset(
ph3py: Phono3py,
ph3py_yaml: Optional[Phono3pyYaml] = None,
forces_fc3_filename: Optional[Union[os.PathLike, Sequence]] = None,
forces_fc2_filename: Optional[Union[os.PathLike, Sequence]] = None,
phono3py_yaml_filename: Optional[os.PathLike] = None,
cutoff_pair_distance: Optional[float] = None,
calculator: Optional[str] = None,
use_pypolymlp: bool = False,
log_level: int = 0,
) -> dict:
"""Set displacements, forces, and create force constants.
Most of properties are stored in ph3py.
Returns
-------
dict
This contains flags indicating whether fc2 and fc3 were read from
file(s) or not. This information can be different from ph3py.fc3 is
(not) None and ph3py.fc2 is (not) None. Items are as follows:
fc3 : bool
fc2 : bool
"""
read_fc = {"fc2": False, "fc3": False}
read_fc["fc3"], dataset = _get_dataset_or_fc3(
):
"""Set displacements, forces, and create force constants."""
dataset = _select_and_load_dataset(
ph3py,
ph3py_yaml=ph3py_yaml,
fc3_filename=fc3_filename,
forces_fc3_filename=forces_fc3_filename,
phono3py_yaml_filename=phono3py_yaml_filename,
cutoff_pair_distance=cutoff_pair_distance,
calculator=calculator,
log_level=log_level,
)
if not read_fc["fc3"]:
if use_pypolymlp:
if forces_in_dataset(dataset):
ph3py.mlp_dataset = dataset
else:
ph3py.dataset = dataset
read_fc["fc2"], phonon_dataset = _get_dataset_phonon_dataset_or_fc2(
if dataset is not None:
ph3py.dataset = dataset

phonon_dataset = _select_and_load_phonon_dataset(
ph3py,
ph3py_yaml=ph3py_yaml,
fc2_filename=fc2_filename,
forces_fc2_filename=forces_fc2_filename,
calculator=calculator,
log_level=log_level,
)
if not read_fc["fc2"]:
if phonon_dataset is not None:
ph3py.phonon_dataset = phonon_dataset

return read_fc


def compute_force_constants_from_datasets(
ph3py: Phono3py,
read_fc: dict,
fc_calculator: Optional[str] = None,
fc_calculator_options: Optional[Union[dict, str]] = None,
symmetrize_fc: bool = True,
is_compact_fc: bool = True,
use_pypolymlp: bool = False,
mlp_params: Optional[Union[dict, str]] = None,
displacement_distance: Optional[float] = None,
number_of_snapshots: Optional[int] = None,
random_seed: Optional[int] = None,
log_level: int = 0,
):
"""Compute force constants from datasets.
Expand All @@ -462,30 +453,19 @@ def compute_force_constants_from_datasets(
"""
fc3_calculator = extract_fc2_fc3_calculators(fc_calculator, 3)
fc2_calculator = extract_fc2_fc3_calculators(fc_calculator, 2)
if not read_fc["fc3"]:
if use_pypolymlp:
run_pypolymlp_to_compute_forces(
ph3py,
mlp_params=mlp_params,
displacement_distance=displacement_distance,
number_of_snapshots=number_of_snapshots,
random_seed=random_seed,
log_level=log_level,
)
if forces_in_dataset(ph3py.dataset):
ph3py.produce_fc3(
symmetrize_fc3r=symmetrize_fc,
is_compact_fc=is_compact_fc,
fc_calculator=fc3_calculator,
fc_calculator_options=extract_fc2_fc3_calculators(
fc_calculator_options, 3
),
)
exist_fc2 = ph3py.fc2 is not None
if ph3py.fc3 is None and forces_in_dataset(ph3py.dataset):
ph3py.produce_fc3(
symmetrize_fc3r=symmetrize_fc,
is_compact_fc=is_compact_fc,
fc_calculator=fc3_calculator,
fc_calculator_options=extract_fc2_fc3_calculators(fc_calculator_options, 3),
)

if log_level and symmetrize_fc and fc_calculator is None:
print("fc3 was symmetrized.")
if log_level and symmetrize_fc and fc_calculator is None:
print("fc3 was symmetrized.")

if not read_fc["fc2"]:
if not exist_fc2:
if (
ph3py.phonon_supercell_matrix is None and forces_in_dataset(ph3py.dataset)
) or (
Expand All @@ -504,32 +484,34 @@ def compute_force_constants_from_datasets(
print("fc2 was symmetrized.")


def _get_dataset_or_fc3(
def _load_fc3(
ph3py: Phono3py,
ph3py_yaml: Optional[Phono3pyYaml] = None,
fc3_filename: Optional[os.PathLike] = None,
log_level: int = 0,
) -> np.ndarray:
p2s_map = ph3py.primitive.p2s_map
if fc3_filename is None:
_fc3_filename = "fc3.hdf5"
else:
_fc3_filename = fc3_filename
fc3 = read_fc3_from_hdf5(filename=_fc3_filename, p2s_map=p2s_map)
_check_fc3_shape(ph3py, fc3, filename=_fc3_filename)
if log_level:
print(f'fc3 was read from "{_fc3_filename}".')
return fc3


def _select_and_load_dataset(
ph3py: Phono3py,
ph3py_yaml: Optional[Phono3pyYaml] = None,
forces_fc3_filename: Optional[Union[os.PathLike, Sequence]] = None,
phono3py_yaml_filename: Optional[os.PathLike] = None,
cutoff_pair_distance: Optional[float] = None,
calculator: Optional[str] = None,
log_level: int = 0,
) -> tuple[bool, dict]:
p2s_map = ph3py.primitive.p2s_map
read_fc3 = False
) -> Optional[dict]:
dataset = None

if fc3_filename is not None or pathlib.Path("fc3.hdf5").exists():
if fc3_filename is None:
_fc3_filename = "fc3.hdf5"
else:
_fc3_filename = fc3_filename
fc3 = read_fc3_from_hdf5(filename=_fc3_filename, p2s_map=p2s_map)
_check_fc3_shape(ph3py, fc3, filename=_fc3_filename)
ph3py.fc3 = fc3
read_fc3 = True
if log_level:
print(f'fc3 was read from "{_fc3_filename}".')
elif (
if (
ph3py_yaml is not None
and ph3py_yaml.dataset is not None
and forces_in_dataset(ph3py_yaml.dataset)
Expand Down Expand Up @@ -570,32 +552,33 @@ def _get_dataset_or_fc3(
log_level,
)

return read_fc3, dataset
return dataset


def _load_fc2(
ph3py: Phono3py, fc2_filename: Optional[os.PathLike] = None, log_level: int = 0
) -> np.ndarray:
phonon_p2s_map = ph3py.phonon_primitive.p2s_map
if fc2_filename is None:
_fc2_filename = "fc2.hdf5"
else:
_fc2_filename = fc2_filename
fc2 = read_fc2_from_hdf5(filename=_fc2_filename, p2s_map=phonon_p2s_map)
_check_fc2_shape(ph3py, fc2, filename=_fc2_filename)
if log_level:
print(f'fc2 was read from "{_fc2_filename}".')
return fc2

def _get_dataset_phonon_dataset_or_fc2(

def _select_and_load_phonon_dataset(
ph3py: Phono3py,
ph3py_yaml: Optional[Phono3pyYaml] = None,
fc2_filename: Optional[os.PathLike] = None,
forces_fc2_filename: Optional[Union[os.PathLike, Sequence]] = None,
calculator: Optional[str] = None,
log_level: int = 0,
) -> tuple[bool, dict, dict]:
phonon_p2s_map = ph3py.phonon_primitive.p2s_map
read_fc2 = False
) -> Optional[dict]:
phonon_dataset = None
if fc2_filename is not None or pathlib.Path("fc2.hdf5").exists():
if fc2_filename is None:
_fc2_filename = "fc2.hdf5"
else:
_fc2_filename = fc2_filename
fc2 = read_fc2_from_hdf5(filename=_fc2_filename, p2s_map=phonon_p2s_map)
_check_fc2_shape(ph3py, fc2, filename=_fc2_filename)
ph3py.fc2 = fc2
read_fc2 = True
if log_level:
print(f'fc2 was read from "{_fc2_filename}".')
elif (
if (
ph3py_yaml is not None
and ph3py_yaml.phonon_dataset is not None
and forces_in_dataset(ph3py_yaml.phonon_dataset)
Expand Down Expand Up @@ -635,7 +618,7 @@ def _get_dataset_phonon_dataset_or_fc2(
log_level,
)

return read_fc2, phonon_dataset
return phonon_dataset


def _get_dataset_for_fc3(
Expand Down
Loading

0 comments on commit 2f91fea

Please sign in to comment.