Skip to content

Commit

Permalink
feat: updates to metadata app to pull status from redshift db
Browse files Browse the repository at this point in the history
  • Loading branch information
dbirman committed Nov 7, 2024
1 parent 5b64887 commit 2bc1543
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 39 deletions.
13 changes: 5 additions & 8 deletions src/aind_metadata_viz/app.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import panel as pn
import altair as alt
import pandas as pd
from aind_metadata_viz import docdb
from aind_metadata_viz.docdb import _get_all
from aind_metadata_viz import database
from aind_data_schema import __version__ as ads_version

_get_all(test_mode=True)

pn.extension(design="material")
pn.extension("vega")
alt.themes.enable("ggplot2")
Expand Down Expand Up @@ -64,14 +61,14 @@
)
color_list = list(colors.values())

db = docdb.Database()
db = database.Database()

modality_selector = pn.widgets.Select(
name="Filter by modality:", options=["all"] + docdb.MODALITIES
name="Filter by modality:", options=["all"] + database.MODALITIES
)

top_selector = pn.widgets.Select(
name="Filter by core file:", options=docdb.ALL_FILES
name="Filter by core file:", options=database.ALL_FILES
)

field_selector = pn.widgets.Select(name="Filter download by field:", options=[])
Expand Down Expand Up @@ -139,7 +136,7 @@ def modality_present_chart():
"""Build a chart of presence split by modality"""

df = pd.DataFrame()
for modality in docdb.MODALITIES:
for modality in database.MODALITIES:
sum_longform_df = db.get_modality_presence(modality=modality)
df = pd.concat([df, sum_longform_df])

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from aind_data_access_api.document_db import MetadataDbClient
from aind_data_access_api.rds_tables import RDSCredentials
from aind_data_access_api.rds_tables import Client
import panel as pn
import pandas as pd
import param
Expand All @@ -24,6 +26,18 @@
DATABASE = os.getenv("DATABASE", "metadata_index")
COLLECTION = os.getenv("COLLECTION", "data_assets")

DEV_OR_PROD = "dev" if "test" in API_GATEWAY_HOST else "prod"
REDSHIFT_SECRETS = f"/aind/{DEV_OR_PROD}/redshift/credentials/readwrite"
RDS_TABLE_NAME = f"metadata_status_{DEV_OR_PROD}"

CHUNK_SIZE = 1000

rds_client = Client(
credentials=RDSCredentials(
aws_secrets_name=REDSHIFT_SECRETS
),
)

docdb_api_client = MetadataDbClient(
host=API_GATEWAY_HOST,
database=DATABASE,
Expand Down Expand Up @@ -68,7 +82,10 @@ def __init__(
"""Initialize"""
# get data
self._file_data = _get_file_presence(test_mode=test_mode)
self._field_data = _get_field_presence(test_mode=test_mode)
self._status_data = _get_status()

# inner join only keeps records that are in both dataframes
self.data = pd.merge(self._file_data, self._status_data, on="_id", how="inner")

# setup
(expected_files, _) = self.get_expected_files()
Expand All @@ -86,7 +103,7 @@ def data_filtered(self):
"""
mod_filter = not (self.modality_filter == "all")

filtered_df = self._file_data.copy()
filtered_df = self.data.copy()

# Filter by modality
if mod_filter:
Expand All @@ -109,12 +126,13 @@ def data_modality_filtered(self, modality: str):
modality : str
Modality.ONE_OF
"""
filtered_df = self._file_data.copy()
filtered_df = self.data.copy()

# Apply derived filter
if not (self.derived_filter == "All assets"):
filtered_df = filtered_df[filtered_df["derived"] == (self.derived_filter == "Derived")]

filtered_df = filtered_df[ALL_FILES + EXTRA_FIELDS]
filtered_df = filtered_df[filtered_df['modalities'].apply(lambda x: modality in x.split(','))]

return filtered_df
Expand Down Expand Up @@ -150,6 +168,7 @@ def get_file_presence(self):
# Melt to long form
df = self.data_filtered.copy()
df.drop(EXTRA_FIELDS, axis=1, inplace=True)
df = df[ALL_FILES]

df_melted = df.melt(var_name="file", value_name="state")
# Get sum
Expand Down Expand Up @@ -207,7 +226,7 @@ def set_field(self, field: str):
def get_file_field_presence(self):
"""Get the presence of fields in a specific file
"""
field_df = self._field_data[self.file]
field_df = self.data.filter(regex=rf'^{self.file}\.').rename(columns=lambda col: col.replace(f"{self.file}.", ""))

# we need to filter by the derived/modality filters here but they are in the other dataframe
if not (self.derived_filter == "All assets"):
Expand Down Expand Up @@ -259,6 +278,24 @@ def get_csv(self, vp_state: str = "Not Valid/Present"):
sio = StringIO()
df.to_csv(sio, index=False)
return sio.getvalue()


@pn.cache(ttl=CACHE_RESET_DAY)
def _get_metadata(test_mode=False) -> pd.DataFrame:
"""Get the metadata fields, modality, derived, name, location, created
Parameters
----------
test_mode : bool, optional
_description_, by default False
"""


def _get_status() -> pd.DataFrame:
"""Get the status of the metadata
"""
response = rds_client.read_table(RDS_TABLE_NAME)
return response


@pn.cache(ttl=CACHE_RESET_DAY)
Expand All @@ -271,8 +308,8 @@ def _get_file_presence(test_mode=False) -> pd.DataFrame:
_description_, by default False
"""
record_list = _get_all(test_mode=test_mode)
processed = process_record_list(record_list, ALL_FILES)

records = []
# Now add some information about the records, i.e. modality, derived state, etc.
for i, record in enumerate(record_list):
if (
Expand All @@ -298,37 +335,14 @@ def _get_file_presence(test_mode=False) -> pd.DataFrame:
"created": record["created"],
}

processed[i] = {**processed[i], **info_data}
records.append(info_data)

return pd.DataFrame(
processed,
columns=ALL_FILES
+ ["modalities", "derived", "name", "_id", "location", "created"],
records,
columns=["modalities", "derived", "name", "_id", "location", "created"],
)


@pn.cache(ttl=CACHE_RESET_DAY)
def _get_field_presence(test_mode=False) -> dict:
"""Get all and convert to data frame format
returns a dictionary {file: field_df}
"""
record_list = _get_all(test_mode=test_mode)

file_dfs = {}
# filter by file
for file in ALL_FILES:
expected_fields = second_layer_field_mappings[file]
# get field presence
field_record_list = [record[file] if file in record else None for record in record_list]
processed = process_record_list(field_record_list, expected_fields, parent=file)

file_df = pd.DataFrame(processed, columns=expected_fields)

file_dfs[file] = file_df

return file_dfs

@pn.cache(ttl=CACHE_RESET_DAY)
def _get_all(test_mode=False):
filter = {}
Expand Down

0 comments on commit 2bc1543

Please sign in to comment.