Skip to content

Commit

Permalink
add MPRester.get_charge_density_from_task_id()
Browse files Browse the repository at this point in the history
  • Loading branch information
teddykoker authored and tschaume committed Aug 19, 2024
1 parent 56bf699 commit a9e1bc5
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 30 deletions.
65 changes: 37 additions & 28 deletions mp_api/client/mprester.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,40 @@ def get_wulff_shape(self, material_id: str):
millers, energies = zip(*miller_energy_map.items())
return WulffShape(lattice, millers, energies)

def get_charge_density_from_task_id(
self, task_id: str, inc_task_doc: bool = False
) -> Chgcar | tuple[Chgcar, TaskDoc | dict] | None:
"""Get charge density data for a given task_id.
Arguments:
task_id (str): A task id
inc_task_doc (bool): Whether to include the task document in the returned data.
Returns:
(Chgcar, (Chgcar, TaskDoc | dict), None): Pymatgen Chgcar object, or tuple with object and TaskDoc
"""
decoder = MontyDecoder().decode if self.monty_decode else json.loads
chgcar = (
self.materials.tasks._query_open_data(
bucket="materialsproject-parsed",
key=f"chgcars/{str(task_id)}.json.gz",
decoder=decoder,
fields=["data"],
)[0]
or {}
)

if not chgcar:
raise MPRestError(f"No charge density fetched for task_id {task_id}.")

chgcar = chgcar[0]["data"] # type: ignore

if inc_task_doc:
task_doc = self.materials.tasks.search(task_ids=task_id)[0]
return chgcar, task_doc

return chgcar

def get_charge_density_from_material_id(
self, material_id: str, inc_task_doc: bool = False
) -> Chgcar | tuple[Chgcar, TaskDoc | dict] | None:
Expand All @@ -1331,7 +1365,7 @@ def get_charge_density_from_material_id(
task_ids = self.get_task_ids_associated_with_material_id(
material_id, calc_types=[CalcType.GGA_Static, CalcType.GGA_U_Static]
)
results: list[TaskDoc] = self.tasks.search(
results: list[TaskDoc] = self.materials.tasks.search(
task_ids=task_ids, fields=["last_updated", "task_id"]
) # type: ignore

Expand All @@ -1344,33 +1378,8 @@ def get_charge_density_from_material_id(
if self.use_document_model
else x["last_updated"], # type: ignore
)

decoder = MontyDecoder().decode if self.monty_decode else json.loads
chgcar = (
self.tasks._query_open_data(
bucket="materialsproject-parsed",
key=f"chgcars/{str(latest_doc.task_id)}.json.gz",
decoder=decoder,
fields=["data"],
)[0]
or {}
)

if not chgcar:
raise MPRestError(f"No charge density fetched for {material_id}.")

chgcar = chgcar[0]["data"] # type: ignore

if inc_task_doc:
task_doc = self.tasks.search(
task_ids=latest_doc.task_id
if self.use_document_model
else latest_doc["task_id"]
)[0]

return chgcar, task_doc

return chgcar
task_id = latest_doc.task_id if self.use_document_model else latest_doc["task_id"]
return self.get_charge_density_from_task_id(task_id, inc_task_doc)

def get_download_info(self, material_ids, calc_types=None, file_patterns=None):
"""Get a list of URLs to retrieve raw VASP output files from the NoMaD repository
Expand Down
13 changes: 11 additions & 2 deletions tests/test_mprester.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,7 @@ def test_get_phonon_data_by_material_id(self, mpr):
dos = mpr.get_phonon_dos_by_material_id("mp-2172")
assert isinstance(dos, PhononDos)

# @pytest.mark.skip(reason="Test needs fixing with ENV variables")
def test_get_charge_density_data(self, mpr):
def test_get_charge_density_from_material_id(self, mpr):
chgcar = mpr.get_charge_density_from_material_id("mp-149")
assert isinstance(chgcar, Chgcar)

Expand All @@ -303,6 +302,16 @@ def test_get_charge_density_data(self, mpr):
assert isinstance(chgcar, Chgcar)
assert isinstance(task_doc, TaskDoc)

def test_get_charge_density_from_task_id(self, mpr):
chgcar = mpr.get_charge_density_from_task_id("mp-2246557")
assert isinstance(chgcar, Chgcar)

chgcar, task_doc = mpr.get_charge_density_from_task_id(
"mp-2246557", inc_task_doc=True
)
assert isinstance(chgcar, Chgcar)
assert isinstance(task_doc, TaskDoc)

def test_get_wulff_shape(self, mpr):
ws = mpr.get_wulff_shape("mp-126")
assert isinstance(ws, WulffShape)
Expand Down

0 comments on commit a9e1bc5

Please sign in to comment.