diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 48bb4275..53324145 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -1314,6 +1314,40 @@ def get_wulff_shape(self, material_id: str): millers, energies = zip(*miller_energy_map.items()) return WulffShape(lattice, millers, energies) + def get_charge_density_from_task_id( + self, task_id: str, inc_task_doc: bool = False + ) -> Chgcar | tuple[Chgcar, TaskDoc | dict] | None: + """Get charge density data for a given task_id. + + Arguments: + task_id (str): A task id + inc_task_doc (bool): Whether to include the task document in the returned data. + + Returns: + (Chgcar, (Chgcar, TaskDoc | dict), None): Pymatgen Chgcar object, or tuple with object and TaskDoc + """ + decoder = MontyDecoder().decode if self.monty_decode else json.loads + chgcar = ( + self.materials.tasks._query_open_data( + bucket="materialsproject-parsed", + key=f"chgcars/{str(task_id)}.json.gz", + decoder=decoder, + fields=["data"], + )[0] + or {} + ) + + if not chgcar: + raise MPRestError(f"No charge density fetched for task_id {task_id}.") + + chgcar = chgcar[0]["data"] # type: ignore + + if inc_task_doc: + task_doc = self.materials.tasks.search(task_ids=task_id)[0] + return chgcar, task_doc + + return chgcar + def get_charge_density_from_material_id( self, material_id: str, inc_task_doc: bool = False ) -> Chgcar | tuple[Chgcar, TaskDoc | dict] | None: @@ -1331,7 +1365,7 @@ def get_charge_density_from_material_id( task_ids = self.get_task_ids_associated_with_material_id( material_id, calc_types=[CalcType.GGA_Static, CalcType.GGA_U_Static] ) - results: list[TaskDoc] = self.tasks.search( + results: list[TaskDoc] = self.materials.tasks.search( task_ids=task_ids, fields=["last_updated", "task_id"] ) # type: ignore @@ -1344,33 +1378,8 @@ def get_charge_density_from_material_id( if self.use_document_model else x["last_updated"], # type: ignore ) - - decoder = MontyDecoder().decode if self.monty_decode else json.loads - chgcar = ( - self.tasks._query_open_data( - bucket="materialsproject-parsed", - key=f"chgcars/{str(latest_doc.task_id)}.json.gz", - decoder=decoder, - fields=["data"], - )[0] - or {} - ) - - if not chgcar: - raise MPRestError(f"No charge density fetched for {material_id}.") - - chgcar = chgcar[0]["data"] # type: ignore - - if inc_task_doc: - task_doc = self.tasks.search( - task_ids=latest_doc.task_id - if self.use_document_model - else latest_doc["task_id"] - )[0] - - return chgcar, task_doc - - return chgcar + task_id = latest_doc.task_id if self.use_document_model else latest_doc["task_id"] + return self.get_charge_density_from_task_id(task_id, inc_task_doc) def get_download_info(self, material_ids, calc_types=None, file_patterns=None): """Get a list of URLs to retrieve raw VASP output files from the NoMaD repository diff --git a/tests/test_mprester.py b/tests/test_mprester.py index cbe3d7ec..f3d81a7d 100644 --- a/tests/test_mprester.py +++ b/tests/test_mprester.py @@ -292,8 +292,7 @@ def test_get_phonon_data_by_material_id(self, mpr): dos = mpr.get_phonon_dos_by_material_id("mp-2172") assert isinstance(dos, PhononDos) - # @pytest.mark.skip(reason="Test needs fixing with ENV variables") - def test_get_charge_density_data(self, mpr): + def test_get_charge_density_from_material_id(self, mpr): chgcar = mpr.get_charge_density_from_material_id("mp-149") assert isinstance(chgcar, Chgcar) @@ -303,6 +302,16 @@ def test_get_charge_density_data(self, mpr): assert isinstance(chgcar, Chgcar) assert isinstance(task_doc, TaskDoc) + def test_get_charge_density_from_task_id(self, mpr): + chgcar = mpr.get_charge_density_from_task_id("mp-2246557") + assert isinstance(chgcar, Chgcar) + + chgcar, task_doc = mpr.get_charge_density_from_task_id( + "mp-2246557", inc_task_doc=True + ) + assert isinstance(chgcar, Chgcar) + assert isinstance(task_doc, TaskDoc) + def test_get_wulff_shape(self, mpr): ws = mpr.get_wulff_shape("mp-126") assert isinstance(ws, WulffShape)