Skip to content

Commit

Permalink
Refactoring around API
Browse files Browse the repository at this point in the history
  • Loading branch information
atztogo committed Jul 31, 2024
1 parent b0f0898 commit 78df49b
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 144 deletions.
151 changes: 60 additions & 91 deletions src/phelel/api_phelel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,20 @@

from phelel.base.Dij_qij import DDijQij
from phelel.base.local_potential import DLocalPotential
from phelel.file_IO import write_dDijdu_hdf5, write_dVdu_hdf5, write_phelel_params_hdf5
from phelel.file_IO import write_phelel_params_hdf5
from phelel.version import __version__


@dataclass
class PhelelInput:
class PhelelDataset:
"""Data structure of input data to run derivatives."""

local_potentials: list
Dijs: list
qijs: list
local_potentials: list[np.ndarray]
Dijs: list[np.ndarray]
qijs: list[np.ndarray]
lm_channels: list[dict]
dataset: Optional[dict] = None
phonon_dataset: Optional[dict] = None
forces: Optional[np.ndarray] = None


Expand Down Expand Up @@ -79,6 +80,8 @@ def __init__(
symprec: float = 1e-5,
is_symmetry: bool = True,
calculator: Optional[str] = None,
nufft: Optional[str] = None,
finufft_eps: Optional[float] = None,
log_level: int = 0,
):
"""Init method.
Expand Down Expand Up @@ -115,18 +118,21 @@ def __init__(
Use crystal symmetry or not. Default is True.
calculator :
A dummy parameter.
nufft : str or None, optional
'finufft' only. Default is None, which corresponds to 'finufft'.
finufft_eps : float or None, optional
Accuracy of finufft interpolation. Default is None, which
corresponds to 1e-6.
log_level : int, optional
Log level. 0 is most quiet. Default is 0.
"""
self._unitcell = unitcell
if fft_mesh is None:
self._fft_mesh = None
else:
self.fft_mesh = fft_mesh
self._symprec = symprec
self._is_symmetry = is_symmetry
self._calculator = calculator
self._nufft = nufft
self._finufft_eps = finufft_eps
self._log_level = log_level

ph = self._get_phonopy(supercell_matrix, primitive_matrix)
Expand Down Expand Up @@ -156,7 +162,17 @@ def __init__(
self._dVdu = None
self._dDijdu = None

self._raw_data = None
if fft_mesh is None:
self._fft_mesh = None
else:
self.fft_mesh = fft_mesh

self._dDijdu = DDijQij(
self._supercell,
symmetry=self._symmetry,
atom_indices=self._atom_indices_in_derivatives,
verbose=self._log_level > 0,
)

@property
def version(self) -> str:
Expand Down Expand Up @@ -306,12 +322,22 @@ def phonon(self) -> Phonopy:

@property
def fft_mesh(self) -> np.ndarray:
"""Return FFT mesh numbers."""
"""Setter and getter of FFT mesh numbers."""
return self._fft_mesh

@fft_mesh.setter
def fft_mesh(self, fft_mesh):
def fft_mesh(self, fft_mesh: Union[Sequence, np.ndarray]):
self._fft_mesh = np.array(fft_mesh, dtype="int_")
self._dVdu = DLocalPotential(
self._fft_mesh,
self._p2s_matrix,
self._supercell,
symmetry=self._symmetry,
atom_indices=self._atom_indices_in_derivatives,
nufft=self._nufft,
finufft_eps=self._finufft_eps,
verbose=self._log_level > 0,
)

@property
def dVdu(self) -> Optional[DLocalPotential]:
Expand All @@ -323,7 +349,7 @@ def dVdu(self, dVdu: DLocalPotential):
self._dVdu = dVdu

@property
def dDijdu(self) -> Optional[DDijQij]:
def dDijdu(self) -> DDijQij:
"""Return DDijQij class instance."""
return self._dDijdu

Expand Down Expand Up @@ -368,7 +394,7 @@ def generate_displacements(
)
self._dataset = ph.dataset

def run_derivatives(self, phe_input: PhelelInput, nufft=None, finufft_eps=None):
def run_derivatives(self, phe_input: PhelelDataset):
"""Run displacement derivatives calculations from temporary raw data.
Note
Expand All @@ -383,20 +409,27 @@ def run_derivatives(self, phe_input: PhelelInput, nufft=None, finufft_eps=None):
)
raise RuntimeError(msg)

self.prepare_phonon(dataset=phe_input.dataset, forces=phe_input.forces)
self.run_dVdu(
phe_input.local_potentials,
dataset=phe_input.dataset,
nufft=nufft,
finufft_eps=finufft_eps,
)
self.run_dDijdu(
phe_input.Dijs,
phe_input.qijs,
if phe_input.dataset is not None:
self._dataset = phe_input.dataset
loc_pots = phe_input.local_potentials
Dijs = phe_input.Dijs
qijs = phe_input.qijs

if phe_input.phonon_dataset is not None:
self._prepare_phonon(
dataset=phe_input.phonon_dataset, forces=phe_input.forces
)
else:
self._prepare_phonon(dataset=self._dataset, forces=phe_input.forces)
self._dVdu.run(loc_pots[0], loc_pots[1:], self._dataset["first_atoms"])
self._dDijdu.run(
Dijs[0],
Dijs[1:],
qijs[0],
qijs[1:],
self._dataset["first_atoms"],
phe_input.lm_channels,
dataset=phe_input.dataset,
)
self._raw_data = None

def save_hdf5(
self, filename: Union[str, bytes, os.PathLike, io.IOBase] = "phelel_params.hdf5"
Expand All @@ -421,71 +454,7 @@ def save_hdf5(
filename=filename,
)

def run_dVdu(
self,
loc_pots,
dataset=None,
nufft=None,
finufft_eps=None,
write_hdf5=False,
):
"""Calculate dV/du.
Parameters
----------
nufft : str or None, optional
'finufft' only. Default is None, which corresponds to 'finufft'.
finufft_eps : float or None, optional
Accuracy of finufft interpolation. Default is None, which
corresponds to 1e-6.
"""
dVdu = DLocalPotential(
self._fft_mesh,
self._p2s_matrix,
self._supercell,
symmetry=self._symmetry,
atom_indices=self._atom_indices_in_derivatives,
nufft=nufft,
finufft_eps=finufft_eps,
verbose=True,
)
if dataset is not None:
self._dataset = dataset
displacements = self._dataset["first_atoms"]
dVdu.run(loc_pots[0], loc_pots[1:], displacements)

if write_hdf5:
write_dVdu_hdf5(
dVdu,
self._supercell_matrix,
self._primitive_matrix,
self._primitive,
self._unitcell,
self._supercell,
filename="dVdu.hdf5",
)
self._dVdu = dVdu

def run_dDijdu(self, Dijs, qijs, lm_channels, dataset=None, write_hdf5=False):
"""Calculate dDij/du."""
dDijdu = DDijQij(
self._supercell,
symmetry=self._symmetry,
atom_indices=self._atom_indices_in_derivatives,
verbose=True,
)
if dataset is not None:
self._dataset = dataset
displacements = self._dataset["first_atoms"]
dDijdu.run(Dijs[0], Dijs[1:], qijs[0], qijs[1:], displacements, lm_channels)

if write_hdf5:
write_dDijdu_hdf5(dDijdu)

self._dDijdu = dDijdu

def prepare_phonon(
def _prepare_phonon(
self,
dataset: Optional[dict] = None,
forces: Optional[
Expand Down
4 changes: 1 addition & 3 deletions src/phelel/cui/phelel_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def get_parser():
"--dim-phonon",
nargs="+",
dest="phonon_supercell_dimension",
metavar="INT",
default=None,
help=(
"Supercell dimensions for phonon with three integers or "
Expand All @@ -65,9 +64,8 @@ def get_parser():
)
parser.add_argument(
"--fft-mesh",
nargs=3,
nargs="+",
dest="fft_mesh_numbers",
metavar="INT",
default=None,
help="FFT mesh numbers used in primitive cell",
)
Expand Down
38 changes: 20 additions & 18 deletions src/phelel/cui/phelel_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,21 @@ def main(**argparse_control):
cell_info["phonon_supercell_matrix"] = ph_smat
phonon_supercell_matrix = cell_info["phonon_supercell_matrix"]

if settings.create_displacements:
phelel = create_phelel_supercells(
cell_info,
settings,
symprec,
log_level=log_level,
)
finalize_phelel(
phelel,
confs=phelel_conf.confs,
log_level=log_level,
displacements_mode=True,
filename="phelel_disp.yaml",
)

fft_mesh = settings.fft_mesh_numbers
phelel = Phelel(
unitcell,
Expand All @@ -143,6 +158,7 @@ def main(**argparse_control):
fft_mesh=fft_mesh,
symprec=symprec,
is_symmetry=settings.is_symmetry,
finufft_eps=settings.finufft_eps,
)

if log_level > 0:
Expand Down Expand Up @@ -179,20 +195,6 @@ def main(**argparse_control):
print_cell(phelel.phonon.supercell)
print("-" * 76)

if settings.create_displacements:
phelel = create_phelel_supercells(
cell_info,
settings,
symprec,
log_level=log_level,
)
finalize_phelel(
phelel,
confs=phelel_conf.confs,
log_level=log_level,
displacements_mode=True,
filename="phelel_disp.yaml",
)
##################################
# Create dV/du, dDij/du, dqij/du #
##################################
Expand Down Expand Up @@ -229,13 +231,13 @@ def main(**argparse_control):
create_derivatives(
phelel,
settings.create_derivatives,
finufft_eps=settings.finufft_eps,
subtract_rfs=settings.subtract_rfs,
log_level=log_level,
)
phelel.save_hdf5(filename="phelel_params.hdf5")
if log_level > 0:
print('"phelel_params.hdf5" has been created.')
if phelel.fft_mesh is not None:
phelel.save_hdf5(filename="phelel_params.hdf5")
if log_level > 0:
print('"phelel_params.hdf5" has been created.')
print_end()
sys.exit(0)

Expand Down
5 changes: 1 addition & 4 deletions src/phelel/cui/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,7 @@ def _read_options(self):
self._confs["create_derivatives"] = " ".join(dir_names)
if "fft_mesh_numbers" in self._args:
if self._args.fft_mesh_numbers:
if len(self._args.fft_mesh_numbers) == 1:
self._confs["fft_mesh"] = self._args.fft_mesh_numbers[0]
elif len(self._args.fft_mesh_numbers) == 3:
self._confs["fft_mesh"] = " ".join(self._args.fft_mesh_numbers)
self._confs["fft_mesh"] = " ".join(self._args.fft_mesh_numbers)
if "finufft_eps" in self._args:
if self._args.finufft_eps is not None:
self._confs["finufft_eps"] = self._args.finufft_eps
Expand Down
Loading

0 comments on commit 78df49b

Please sign in to comment.