Skip to content

Commit

Permalink
Merge pull request #904 from kbuma/bugfix/special-case-thermo-and-xas
Browse files Browse the repository at this point in the history
special casing for thermo, xas and synth_descriptions collections in OpenData
  • Loading branch information
Jason Munro authored Jan 5, 2024
2 parents 2f3d741 + 3d24140 commit d27555e
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 8 deletions.
20 changes: 12 additions & 8 deletions src/maggma/stores/open_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class S3IndexStore(MemoryStore):

def __init__(
self,
collection_name: str,
bucket: str,
prefix: str = "",
endpoint_url: Optional[str] = None,
Expand All @@ -33,13 +34,15 @@ def __init__(
"""Initializes an S3IndexStore
Args:
collection_name (str): name of the collection
bucket (str): Name of the bucket where the index is stored.
prefix (str, optional): The prefix to add to the name of the index, i.e. the manifest key.
Defaults to "".
endpoint_url (Optional[str], optional): S3-compatible endpoint URL.
Defaults to None, indicating to use the default configured AWS S3.
manifest_key (str, optional): The name of the index. Defaults to "manifest.json".
"""
self.collection_name = collection_name
self.bucket = bucket
self.prefix = prefix
self.endpoint_url = endpoint_url
Expand All @@ -48,6 +51,7 @@ def __init__(
self.s3_session_kwargs = {}
self.manifest_key = manifest_key

kwargs["collection_name"] = collection_name
super().__init__(**kwargs)

def _get_full_key_path(self) -> str:
Expand Down Expand Up @@ -189,16 +193,16 @@ def __init__(
super().__init__(**kwargs)

def _get_full_key_path(self, id: str) -> str:
if self.index.collection_name == "thermo" and self.key == "thermo_id":
material_id, thermo_type = id.split("_", 1)
return f"{self.sub_dir}{thermo_type}/{material_id}{self.object_file_extension}"
if self.index.collection_name == "xas" and self.key == "spectrum_id":
material_id, spectrum_type, absorbing_element, edge = id.rsplit("-", 3)
return f"{self.sub_dir}{edge}/{spectrum_type}/{absorbing_element}/{material_id}{self.object_file_extension}"
if self.index.collection_name == "synth_descriptions" and self.key == "doi":
return f"{self.sub_dir}{id.replace('/', '_')}{self.object_file_extension}"
return f"{self.sub_dir}{id}{self.object_file_extension}"

def _get_id_from_full_key_path(self, key: str) -> str:
prefix, suffix = self.sub_dir, self.object_file_extension
if prefix in key and suffix in key:
start_idx = key.index(prefix) + len(prefix)
end_idx = key.index(suffix, start_idx)
return key[start_idx:end_idx]
return ""

def _get_compression_function(self) -> Callable:
return gzip.compress

Expand Down
104 changes: 104 additions & 0 deletions tests/stores/test_open_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,3 +409,107 @@ def test_pickle(s3store_w_subdir):
dobj = pickle.loads(sobj)
assert hash(dobj) == hash(s3store_w_subdir)
assert dobj == s3store_w_subdir


@pytest.fixture()
def thermo_store():
with mock_s3():
conn = boto3.resource("s3", region_name="us-east-1")
conn.create_bucket(Bucket="bucket1")

index = S3IndexStore(collection_name="thermo", bucket="bucket1", key="thermo_id")
store = OpenDataStore(index=index, bucket="bucket1", key="thermo_id")
store.connect()

store.update(
[
{
"thermo_id": "mp-1_R2SCAN",
"data": "asd",
store.last_updated_field: datetime.utcnow(),
}
]
)

yield store


def test_thermo_collection_special_handling(thermo_store):
assert thermo_store.s3_bucket.Object(thermo_store._get_full_key_path("mp-1_R2SCAN")).key == "R2SCAN/mp-1.json.gz"
thermo_store.update([{"thermo_id": "mp-2_RSCAN", "data": "asd"}])
index_docs = thermo_store.rebuild_index_from_s3_data()
assert len(index_docs) == 2
for doc in index_docs:
for key in doc:
assert key == "thermo_id" or key == "last_updated"


@pytest.fixture()
def xas_store():
with mock_s3():
conn = boto3.resource("s3", region_name="us-east-1")
conn.create_bucket(Bucket="bucket1")

index = S3IndexStore(collection_name="xas", bucket="bucket1", key="spectrum_id")
store = OpenDataStore(index=index, bucket="bucket1", key="spectrum_id")
store.connect()

store.update(
[
{
"spectrum_id": "mp-1-XAFS-Cr-K",
"data": "asd",
store.last_updated_field: datetime.utcnow(),
}
]
)

yield store


def test_xas_collection_special_handling(xas_store):
assert xas_store.s3_bucket.Object(xas_store._get_full_key_path("mp-1-XAFS-Cr-K")).key == "K/XAFS/Cr/mp-1.json.gz"
xas_store.update([{"spectrum_id": "mp-2-XAFS-Li-K", "data": "asd"}])
index_docs = xas_store.rebuild_index_from_s3_data()
assert len(index_docs) == 2
for doc in index_docs:
for key in doc:
assert key == "spectrum_id" or key == "last_updated"


@pytest.fixture()
def synth_descriptions_store():
with mock_s3():
conn = boto3.resource("s3", region_name="us-east-1")
conn.create_bucket(Bucket="bucket1")

index = S3IndexStore(collection_name="synth_descriptions", bucket="bucket1", key="doi")
store = OpenDataStore(index=index, bucket="bucket1", key="doi")
store.connect()

store.update(
[
{
"doi": "10.1149/2.051201jes",
"data": "asd",
store.last_updated_field: datetime.utcnow(),
}
]
)

yield store


def test_synth_descriptions_collection_special_handling(synth_descriptions_store):
assert (
synth_descriptions_store.s3_bucket.Object(
synth_descriptions_store._get_full_key_path("10.1149/2.051201jes")
).key
== "10.1149_2.051201jes.json.gz"
)
synth_descriptions_store.update([{"doi": "10.1039/C5CP01095K", "data": "asd"}])
index_docs = synth_descriptions_store.rebuild_index_from_s3_data()
assert len(index_docs) == 2
for doc in index_docs:
for key in doc:
assert key == "doi" or key == "last_updated"

0 comments on commit d27555e

Please sign in to comment.