Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
D.Sarpa committed Jan 31, 2024
1 parent 063b832 commit 739be11
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 12 deletions.
15 changes: 7 additions & 8 deletions src/quacc/calculators/espresso/espresso.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ def __init__(
"wfcdir": os.environ.get("ESPRESSO_TMPDIR", "."),
}

self.outfiles = {"fildos": "pwscf.dos",
"filpdos": "pwscf.pdos_tot"}
self.outfiles = {"fildos": "pwscf.dos", "filpdos": "pwscf.pdos_tot"}

self.test_run = test_run

Expand Down Expand Up @@ -209,17 +208,17 @@ def read_results(self, directory: Path | str) -> dict[str, Any]:
fildos = self.outfiles["fildos"]
with Path(fildos).open("r") as fd:
lines = fd.readlines()
fermi = float(re.search(r"-?\d+\.?\d*", lines[0])[0])
fermi = float(re.search(r"-?\d+\.?\d*", lines[0]).group(0))
dos = np.loadtxt(lines[1:])
results = {fildos.name: {"dos": dos, "fermi": fermi}}
elif self.binary == "projwfc":
filpdos = self.outfiles["filpdos"]
with Path(filpdos).open("r") as fd:
lines = fd.readlines()
energy = np.loadtxt(lines[0:])
dos = np.loadtxt(lines[1:])
pdos = np.loadtxt(lines[2:])
results = {filpdos.name: {"energy": energy,"dos": dos, "pdos": pdos}}
lines = np.loadtxt(fd.readlines())
energy = lines[1:, 0]
dos = lines[1:, 1]
pdos = lines[1:, 2]
results = {filpdos.name: {"energy": energy, "dos": dos, "pdos": pdos}}
else:
results = {}

Expand Down
4 changes: 3 additions & 1 deletion src/quacc/recipes/espresso/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class ProjwfcSchema(TypedDict):
non_scf_job: RunSchema
projwfc_job: RunSchema


@job
def dos_job(
prev_dir: str | Path,
Expand Down Expand Up @@ -78,6 +79,7 @@ def dos_job(
copy_files=prev_dir,
)


@job
def projwfc_job(
prev_dir: str | Path,
Expand Down Expand Up @@ -307,4 +309,4 @@ def projwfc_flow(
"static_job": static_results,
"non_scf_job": non_scf_results,
"projwfc_job": projwfc_results,
}
}
10 changes: 7 additions & 3 deletions tests/core/recipes/espresso_recipes/test_dos.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

from pathlib import Path
from shutil import which

import pytest
from ase.build import bulk
from numpy.testing import assert_allclose

from quacc.recipes.espresso.dos import dos_flow,projwfc_job, projwfc_flow
from quacc.utils.files import copy_decompress_files,copy_decompress_tree
from quacc.recipes.espresso.dos import dos_flow, projwfc_flow, projwfc_job
from quacc.utils.files import copy_decompress_files, copy_decompress_tree

pytestmark = pytest.mark.skipif(
which("pw.x") is None or which("dos.x") is None, reason="QE not installed"
Expand All @@ -29,9 +31,11 @@ def test_projwfc_job(tmp_path, monkeypatch):
copy_decompress_tree({DATA_DIR / "dos_test/": "pwscf.save/*.gz"}, tmp_path)
copy_decompress_files([DATA_DIR / "Si.upf.gz"], tmp_path)
output = projwfc_job(tmp_path)
print(output)
assert output["name"] == "projwfc.x Projects-wavefunctions"
assert output["parameters"]["input_data"]["projwfc"] == {}


def test_dos_flow(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)

Expand Down Expand Up @@ -135,4 +139,4 @@ def test_projwfc_flow(tmp_path, monkeypatch):
== "TF"
)
assert output["non_scf_job"]["results"]["nbands"] == 8
assert output["non_scf_job"]["results"]["nspins"] == 1
assert output["non_scf_job"]["results"]["nspins"] == 1

0 comments on commit 739be11

Please sign in to comment.