Skip to content

Commit

Permalink
Update get_displacements_and_forces_fc3
Browse files Browse the repository at this point in the history
  • Loading branch information
atztogo committed Dec 28, 2024
1 parent 3dda1e3 commit 47a2fa8
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 66 deletions.
5 changes: 2 additions & 3 deletions phono3py/interface/wien2k.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

from phono3py.phonon3.dataset import get_displacements_fc3
from phono3py.phonon3.dataset import get_displacements_and_forces_fc3


def get_fc3_calc_dataset_wien2k(
Expand All @@ -48,8 +48,7 @@ def get_fc3_calc_dataset_wien2k(
"""Read Wien2k output files and parse force sets."""
from phonopy.interface.wien2k import parse_set_of_forces

# disps, _ = get_displacements_and_forces_fc3(disp_dataset)
disps = get_displacements_fc3(disp_dataset)
disps, _ = get_displacements_and_forces_fc3(disp_dataset)
force_sets = parse_set_of_forces(
disps,
force_filenames,
Expand Down
88 changes: 25 additions & 63 deletions phono3py/phonon3/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@
import numpy as np


def get_displacements_and_forces_fc3(disp_dataset):
def get_displacements_and_forces_fc3(
disp_dataset: dict,
) -> tuple[np.ndarray, Optional[np.ndarray]]:
"""Return displacements and forces from disp_dataset.
Note
Expand Down Expand Up @@ -71,10 +73,13 @@ def get_displacements_and_forces_fc3(disp_dataset):
forces = np.zeros_like(displacements)
indices = []
count = 0
forces_count = 0
for disp1 in disp_dataset["first_atoms"]:
indices.append(count)
displacements[count, disp1["number"]] = disp1["displacement"]
forces[count] = disp1["forces"]
if "forces" in disp1:
forces_count += 1
forces[count] = disp1["forces"]
count += 1

for disp1 in disp_dataset["first_atoms"]:
Expand All @@ -86,69 +91,26 @@ def get_displacements_and_forces_fc3(disp_dataset):
indices.append(count)
displacements[count, disp1["number"]] = disp1["displacement"]
displacements[count, disp2["number"]] = disp2["displacement"]
forces[count] = disp2["forces"]
if "forces" in disp2:
forces_count += 1
forces[count] = disp2["forces"]
count += 1
return (
np.array(displacements[indices], dtype="double", order="C"),
np.array(forces[indices], dtype="double", order="C"),
)
elif "forces" in disp_dataset and "displacements" in disp_dataset:
return disp_dataset["displacements"], disp_dataset["forces"]
else:
raise RuntimeError("disp_dataset doesn't contain correct information.")


def get_displacements_fc3(disp_dataset):
"""Return displacements and forces from disp_dataset.
Note
----
Dipslacements and forces of all atoms in supercells are returned.
Parameters
----------
disp_dataset : dict
Displacement dataset.
Returns
-------
displacements : ndarray
Displacements of all atoms in all supercells.
shape=(snapshots, supercell atoms, 3), dtype='double', order='C'
forces : ndarray or None
Forces of all atoms in all supercells.
shape=(snapshots, supercell atoms, 3), dtype='double', order='C'
None is returned when forces don't exist.
"""
if "first_atoms" in disp_dataset:
natom = disp_dataset["natom"]
ndisp = len(disp_dataset["first_atoms"])
for disp1 in disp_dataset["first_atoms"]:
ndisp += len(disp1["second_atoms"])
displacements = np.zeros((ndisp, natom, 3), dtype="double", order="C")
indices = []
count = 0
for disp1 in disp_dataset["first_atoms"]:
indices.append(count)
displacements[count, disp1["number"]] = disp1["displacement"]
count += 1

for disp1 in disp_dataset["first_atoms"]:
for disp2 in disp1["second_atoms"]:
if "included" in disp2:
if disp2["included"]:
indices.append(count)
else:
indices.append(count)
displacements[count, disp1["number"]] = disp1["displacement"]
displacements[count, disp2["number"]] = disp2["displacement"]
count += 1

return np.array(displacements[indices], dtype="double", order="C")

elif "forces" in disp_dataset and "displacements" in disp_dataset:
return disp_dataset["displacements"]
if forces_count == 0:
forces = None
else:
forces = np.array(forces[indices], dtype="double", order="C")
assert forces_count == count

displacements = np.array(displacements[indices], dtype="double", order="C")
return displacements, forces
elif "displacements" in disp_dataset:
displacements = disp_dataset["displacements"]
if "forces" in disp_dataset:
forces = disp_dataset["forces"]
else:
forces = None
return displacements, forces
else:
raise RuntimeError("disp_dataset doesn't contain correct information.")

Expand Down

0 comments on commit 47a2fa8

Please sign in to comment.