diff --git a/phono3py/interface/wien2k.py b/phono3py/interface/wien2k.py index c8dbf955..a823fb3d 100644 --- a/phono3py/interface/wien2k.py +++ b/phono3py/interface/wien2k.py @@ -34,7 +34,7 @@ # ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. -from phono3py.phonon3.dataset import get_displacements_fc3 +from phono3py.phonon3.dataset import get_displacements_and_forces_fc3 def get_fc3_calc_dataset_wien2k( @@ -48,8 +48,7 @@ def get_fc3_calc_dataset_wien2k( """Read Wien2k output files and parse force sets.""" from phonopy.interface.wien2k import parse_set_of_forces - # disps, _ = get_displacements_and_forces_fc3(disp_dataset) - disps = get_displacements_fc3(disp_dataset) + disps, _ = get_displacements_and_forces_fc3(disp_dataset) force_sets = parse_set_of_forces( disps, force_filenames, diff --git a/phono3py/phonon3/dataset.py b/phono3py/phonon3/dataset.py index ad2a4122..d65aea92 100644 --- a/phono3py/phonon3/dataset.py +++ b/phono3py/phonon3/dataset.py @@ -39,7 +39,9 @@ import numpy as np -def get_displacements_and_forces_fc3(disp_dataset): +def get_displacements_and_forces_fc3( + disp_dataset: dict, +) -> tuple[np.ndarray, Optional[np.ndarray]]: """Return displacements and forces from disp_dataset. Note @@ -71,10 +73,13 @@ def get_displacements_and_forces_fc3(disp_dataset): forces = np.zeros_like(displacements) indices = [] count = 0 + forces_count = 0 for disp1 in disp_dataset["first_atoms"]: indices.append(count) displacements[count, disp1["number"]] = disp1["displacement"] - forces[count] = disp1["forces"] + if "forces" in disp1: + forces_count += 1 + forces[count] = disp1["forces"] count += 1 for disp1 in disp_dataset["first_atoms"]: @@ -86,69 +91,26 @@ def get_displacements_and_forces_fc3(disp_dataset): indices.append(count) displacements[count, disp1["number"]] = disp1["displacement"] displacements[count, disp2["number"]] = disp2["displacement"] - forces[count] = disp2["forces"] + if "forces" in disp2: + forces_count += 1 + forces[count] = disp2["forces"] count += 1 - return ( - np.array(displacements[indices], dtype="double", order="C"), - np.array(forces[indices], dtype="double", order="C"), - ) - elif "forces" in disp_dataset and "displacements" in disp_dataset: - return disp_dataset["displacements"], disp_dataset["forces"] - else: - raise RuntimeError("disp_dataset doesn't contain correct information.") - - -def get_displacements_fc3(disp_dataset): - """Return displacements and forces from disp_dataset. - - Note - ---- - Dipslacements and forces of all atoms in supercells are returned. - - Parameters - ---------- - disp_dataset : dict - Displacement dataset. - - Returns - ------- - displacements : ndarray - Displacements of all atoms in all supercells. - shape=(snapshots, supercell atoms, 3), dtype='double', order='C' - forces : ndarray or None - Forces of all atoms in all supercells. - shape=(snapshots, supercell atoms, 3), dtype='double', order='C' - None is returned when forces don't exist. - - """ - if "first_atoms" in disp_dataset: - natom = disp_dataset["natom"] - ndisp = len(disp_dataset["first_atoms"]) - for disp1 in disp_dataset["first_atoms"]: - ndisp += len(disp1["second_atoms"]) - displacements = np.zeros((ndisp, natom, 3), dtype="double", order="C") - indices = [] - count = 0 - for disp1 in disp_dataset["first_atoms"]: - indices.append(count) - displacements[count, disp1["number"]] = disp1["displacement"] - count += 1 - - for disp1 in disp_dataset["first_atoms"]: - for disp2 in disp1["second_atoms"]: - if "included" in disp2: - if disp2["included"]: - indices.append(count) - else: - indices.append(count) - displacements[count, disp1["number"]] = disp1["displacement"] - displacements[count, disp2["number"]] = disp2["displacement"] - count += 1 - - return np.array(displacements[indices], dtype="double", order="C") - elif "forces" in disp_dataset and "displacements" in disp_dataset: - return disp_dataset["displacements"] + if forces_count == 0: + forces = None + else: + forces = np.array(forces[indices], dtype="double", order="C") + assert forces_count == count + + displacements = np.array(displacements[indices], dtype="double", order="C") + return displacements, forces + elif "displacements" in disp_dataset: + displacements = disp_dataset["displacements"] + if "forces" in disp_dataset: + forces = disp_dataset["forces"] + else: + forces = None + return displacements, forces else: raise RuntimeError("disp_dataset doesn't contain correct information.")