diff --git a/code/Home.py b/code/Home.py
index 31b8ecb..f4c7646 100644
--- a/code/Home.py
+++ b/code/Home.py
@@ -33,7 +33,7 @@
from util.streamlit import (_plot_population_x_y, add_auto_train_manager,
add_dot_property_mapper, add_session_filter,
add_xy_selector, add_xy_setting,
- aggrid_interactive_table_curriculum,
+ aggrid_interactive_table_basic,
aggrid_interactive_table_session, data_selector,
add_footnote)
from util.url_query_helper import (checkbox_wrapper_for_url_query,
@@ -280,7 +280,7 @@ def show_curriculums():
pass
# ------- Layout starts here -------- #
-def init(if_load_docDB_override=None):
+def init(if_load_bpod_data_override=None, if_load_docDB_override=None):
# Clear specific session state and all filters
for key in st.session_state:
@@ -296,12 +296,17 @@ def init(if_load_docDB_override=None):
# Because sync_URL_to_session_state() needs df to be loaded (for dynamic column filtering),
# 'if_load_bpod_sessions' has not been synced from URL to session state yet.
# So here we need to manually get it from URL or session state.
- if (st.query_params['if_load_bpod_sessions'].lower() == 'true'
+ _if_load_bpod = if_load_bpod_data_override if if_load_bpod_data_override is not None else (
+ st.query_params['if_load_bpod_sessions'].lower() == 'true'
if 'if_load_bpod_sessions' in st.query_params
else st.session_state.if_load_bpod_sessions
if 'if_load_bpod_sessions' in st.session_state
- else False):
+ else False)
+
+ st.session_state.bpod_loaded = False
+ if _if_load_bpod:
df_bpod = load_data(['sessions'], data_source='bpod')
+ st.session_state.bpod_loaded = True
# For historial reason, the suffix of df['sessions_bonsai'] just mean the data of the Home.py page
df['sessions_bonsai'] = pd.concat([df['sessions_bonsai'], df_bpod['sessions_bonsai']], axis=0)
@@ -377,6 +382,10 @@ def _get_data_source(rig):
_df.dropna(subset=['session'], inplace=True) # Remove rows with no session number (only leave the nwb file with the largest finished_trials for now)
_df.drop(_df.query('session < 1').index, inplace=True)
+ # Remove invalid subject_id
+ _df = _df[(999999 > _df["subject_id"].astype(int))
+ & (_df["subject_id"].astype(int) > 300000)]
+
# Remove abnormal values
_df.loc[_df['weight_after'] > 100,
['weight_after', 'weight_after_ratio', 'water_in_session_total', 'water_after_session', 'water_day_total']
@@ -745,7 +754,7 @@ def app():
pre_selected_rows = None
# Show df_curriculum
- aggrid_interactive_table_curriculum(df=df_curriculums,
+ aggrid_interactive_table_basic(df=df_curriculums,
pre_selected_rows=pre_selected_rows)
diff --git a/code/__init__.py b/code/__init__.py
index 005fd5f..fe3e3b8 100644
--- a/code/__init__.py
+++ b/code/__init__.py
@@ -1 +1 @@
-__ver__ = 'v2.5.6'
+__ver__ = 'v2.6.0'
diff --git a/code/pages/0_Data inventory.py b/code/pages/0_Data inventory.py
new file mode 100644
index 0000000..cd07753
--- /dev/null
+++ b/code/pages/0_Data inventory.py
@@ -0,0 +1,522 @@
+import logging
+import re
+import json
+
+from matplotlib_venn import venn2, venn3, venn2_circles, venn3_circles
+import matplotlib.pyplot as plt
+import streamlit as st
+from streamlit_dynamic_filters import DynamicFilters
+import pandas as pd
+import numpy as np
+import plotly.graph_objects as go
+from plotly.subplots import make_subplots
+from streamlit_plotly_events import plotly_events
+
+import time
+import streamlit_nested_layout
+
+from util.streamlit import aggrid_interactive_table_basic, download_df, add_footnote
+from util.fetch_data_docDB import (
+ fetch_queries_from_docDB,
+ fetch_queries_from_docDB_parallel,
+)
+from util.reformat import formatting_metadata_df
+from util.aws_s3 import load_raw_sessions_on_VAST
+from Home import init
+
+
+try:
+ st.set_page_config(layout="wide",
+ page_title='Foraging behavior browser',
+ page_icon=':mouse2:',
+ menu_items={
+ 'Report a bug': "https://github.com/hanhou/foraging-behavior-browser/issues",
+ 'About': "Github repo: https://github.com/hanhou/foraging-behavior-browser/"
+ }
+ )
+except:
+ pass
+
+st.markdown(
+"""
+""",
+unsafe_allow_html=True,
+)
+
+# Load QUERY_PRESET from json
+with open("data_inventory_QUERY_PRESET.json", "r") as f:
+ QUERY_PRESET = json.load(f)
+
+META_COLUMNS = [
+ "Han_temp_pipeline (bpod)",
+ "Han_temp_pipeline (bonsai)",
+ "VAST_raw_data_on_VAST",
+] + [query["alias"] for query in QUERY_PRESET]
+
+@st.cache_data(ttl=3600*12)
+def merge_queried_dfs(dfs, queries_to_merge):
+ # Combine queried dfs using df_unique_mouse_date (on index "subject_id", "session_date" only)
+ df_merged = dfs[queries_to_merge[0]["alias"]]["df_unique_mouse_date"]
+ for df in [dfs[query["alias"]]["df_unique_mouse_date"] for query in queries_to_merge[1:]]:
+ df_merged = df_merged.combine_first(df) # Combine nwb_names
+
+ # Recover the column order of QUERY_PRESET
+ query_cols = [query["alias"] for query in queries_to_merge]
+ df_merged = df_merged.reindex(
+ columns=[
+ other_col for other_col in df_merged.columns if other_col not in query_cols
+ ]
+ + query_cols
+ )
+ return df_merged
+
+
+def _filter_df_by_patch_ids(df, patch_ids):
+ """ Filter df by patch_ids [110, 001] etc"""
+ # Turn NAN to False
+ df = df.fillna(False)
+ conditions = []
+ for patch_id in patch_ids:
+ # Convert patch_id string to boolean conditions
+ condition = True
+ for i, col in enumerate(df.columns):
+ if patch_id[i] == "1": # Include rows where the column is True
+ condition &= df[col]
+ elif patch_id[i] == "0": # Include rows where the column is not True
+ condition &= ~df[col]
+ conditions.append(condition)
+
+ # Combine all conditions with OR
+ final_condition = pd.concat(conditions, axis=1).any(axis=1)
+ return df[final_condition].index
+
+@st.cache_data(ttl=3600*24)
+def generate_venn(df, venn_preset):
+ """ Show venn diagram """
+ circle_settings = venn_preset["circle_settings"]
+
+ fig, ax = plt.subplots()
+ if len(circle_settings) == 2:
+ v_func = venn2
+ c_func = venn2_circles
+ elif len(circle_settings) == 3:
+ v_func = venn3
+ c_func = venn3_circles
+ else:
+ st.warning("Number of columns to venn should be 2 or 3.")
+ return None, None
+
+ v = v_func(
+ [set(df.index[df[c_s["column"]]==True]) for c_s in circle_settings],
+ set_labels=[c_s["column"] for c_s in circle_settings],
+ )
+ c = c_func(
+ [set(df.index[df[c_s["column"]]==True]) for c_s in circle_settings],
+ )
+
+ # Set edge color and style
+ for i, c_s in enumerate(circle_settings):
+ edge_color = c_s.get("edge_color", "black")
+ edge_style = c_s.get("edge_style", "solid")
+
+ c[i].set_edgecolor(edge_color)
+ c[i].set_linestyle(edge_style)
+ v.get_label_by_id(["A", "B", "C"][i]).set_color(edge_color)
+
+ # Clear all patch color
+ for patch in v.patches:
+ if patch: # Some patches might be None
+ patch.set_facecolor('none')
+
+ notes = []
+ for patch_setting in venn_preset["patch_settings"]:
+ # Set color
+ for patch_id in patch_setting["patch_ids"]:
+ if v.get_patch_by_id(patch_id):
+ v.get_patch_by_id(patch_id).set_color(patch_setting["color"])
+ # Add notes
+ notes.append(f"#### :{patch_setting['emoji']}: :{patch_setting['color']}[{patch_setting['notes']}]")
+
+ return fig, notes
+
+def _show_records_on_sidebar(dfs, file_name_prefix, source_str="docDB"):
+ df = dfs["df"]
+ df_multi_sessions_per_day = dfs["df_multi_sessions_per_day"]
+ df_unique_mouse_date = dfs["df_unique_mouse_date"]
+
+ with st.expander(f"{len(df)} records from {source_str}"):
+ download_df(df, label="Download as CSV", file_name=f"{file_name_prefix}.csv")
+ st.write(df)
+
+ st.markdown(":heavy_exclamation_mark: :red[Multiple sessions per day should be resolved!]")
+ with st.expander(f"{len(df_multi_sessions_per_day)} have multiple sessions per day"):
+ download_df(
+ df_multi_sessions_per_day,
+ label="Download as CSV",
+ file_name=f"{file_name_prefix}_multi_sessions_per_day.csv",
+ )
+ st.write(df_multi_sessions_per_day)
+
+ with st.expander(f"{len(df_unique_mouse_date)} unique mouse-date pairs"):
+ download_df(
+ df_unique_mouse_date,
+ label="Download as CSV",
+ file_name=f"{file_name_prefix}_unique_mouse_date.csv",
+ )
+ st.write(df_unique_mouse_date)
+
+ # if len(df_unique_mouse_date) != df_merged[query["alias"]].sum():
+ # st.warning('''len(df_unique_mouse_date) != df_merged[query["alias"]].sum()!''')
+
+def add_sidebar(df_merged, dfs_docDB, df_Han_pipeline, dfs_raw_on_VAST, docDB_retrieve_time):
+ # Sidebar
+ with st.sidebar:
+ for query in QUERY_PRESET:
+ with st.expander(f"### {query['alias']}"):
+
+ # Turn query to json with indent=4
+ # with st.expander("Show docDB query"):
+ query_json = json.dumps(query["filter"], indent=4)
+ st.code(query_json)
+
+ # Show records
+ _show_records_on_sidebar(dfs_docDB[query["alias"]], file_name_prefix=query["alias"], source_str="docDB")
+ st.markdown('#### See [how to use above queries](https://aind-data-access-api.readthedocs.io/en/latest/UserGuide.html#document-database-docdb) in your own code.')
+
+ st.markdown('''## 2. From Han's temporary pipeline (the "Home" page)''')
+ hardwares = ["bonsai", "bpod"]
+ for hardware in hardwares:
+ df_this_hardware = df_Han_pipeline[
+ df_Han_pipeline[f"Han_temp_pipeline ({hardware})"].notnull()
+ ]
+ with st.expander(
+ f"### {len(df_this_hardware)} {hardware} sessions"
+ + (" (old data, not growing)" if hardware == "bpod" else "")
+ ):
+ download_df(
+ df_this_hardware,
+ label="Download as CSV",
+ file_name=f"Han_temp_pipeline_{hardware}.csv",
+ )
+ st.write(df_this_hardware)
+
+ st.markdown('''## 3. From VAST /scratch: existing raw data''')
+ _show_records_on_sidebar(dfs_raw_on_VAST, file_name_prefix="raw_on_VAST", source_str="VAST /scratch")
+
+ add_footnote()
+
+@st.cache_data(ttl=3600*24)
+def plot_histogram_over_time(df, venn_preset, time_period="Daily", if_sync_y_limits=True, if_separate_plots=False):
+ """Generate histogram over time for the columns and patches in preset
+ """
+ df["Daily"] = df["session_date"]
+ df["Weekly"] = df["session_date"].dt.to_period("W").dt.start_time
+ df["Monthly"] = df["session_date"].dt.to_period("M").dt.start_time
+ df["Quarterly"] = df["session_date"].dt.to_period("Q").dt.start_time
+
+ # Function to count "True" values for a given column over a specific time period
+ def count_true_values(df, time_period, column):
+ return df.groupby(time_period)[column].apply(lambda x: (x == True).sum())
+
+ # Preparing subplots for each circle/patch in venn
+ columns = [c_s["column"] for c_s in venn_preset["circle_settings"]] + [
+ str(p_s["patch_ids"]) for p_s in venn_preset.get("patch_settings", [])
+ if not p_s.get("skip_timeline", False)
+ ]
+ colors = [c_s["edge_color"] for c_s in venn_preset["circle_settings"]] + [
+ p_s["color"] for p_s in venn_preset["patch_settings"]
+ ]
+ if if_separate_plots:
+ fig = make_subplots(
+ rows=len(columns),
+ cols=1,
+ shared_xaxes=True,
+ vertical_spacing=0.05,
+ subplot_titles=columns,
+ )
+
+ # Adding traces for each column
+ max_counts = 0
+ for i, column in enumerate(columns):
+ counts = count_true_values(df, time_period, column)
+ max_counts = max(max_counts, counts.max())
+ fig.add_trace(
+ go.Bar(
+ x=counts.index,
+ y=counts.values,
+ name=column,
+ marker=dict(color=colors[i]),
+ ),
+ row=i + 1,
+ col=1,
+ )
+
+ # Sync y limits
+ if if_sync_y_limits:
+ fig.update_yaxes(range=[0, max_counts * 1.1])
+
+ # Updating layout
+ fig.update_layout(
+ height=200 * len(columns),
+ showlegend=False,
+ title=f"{time_period} counts",
+ )
+ else: # side-by-side histograms in the same plot
+ fig = go.Figure()
+ for i, column in enumerate(columns):
+ fig.add_trace(go.Histogram(
+ x=df[df[column]==True]["session_date"],
+ xbins=dict(size="M1"), # Only monthly bins look good
+ name=column,
+ marker_color=colors[i],
+ opacity=0.75
+ ))
+
+ # Update layout for grouped histogram
+ fig.update_layout(
+ height=500,
+ bargap=0.05, # Gap between bars of adjacent locations
+ bargroupgap=0.1, # Gap between bars of the same location
+ barmode='group', # Grouped style
+ showlegend=True,
+ legend=dict(
+ orientation="h", # Horizontal legend
+ y=-0.2, # Position below the plot
+ x=0.5, # Center the legend
+ xanchor="center", # Anchor the legend's x position
+ yanchor="top" # Anchor the legend's y position
+ ),
+ title="Monthly counts"
+ )
+
+ return fig
+
+def app():
+ # --- 1. Generate combined dataframe from docDB queries ---
+ with st.sidebar:
+ st.markdown('# Metadata sources:')
+ st.markdown('## 1. From docDB queries')
+
+ with st.expander("MetadataDbClient settings"):
+ with st.form("MetadataDbClient settings"):
+ parallel = st.checkbox("Parallel fetching", value=False)
+ pagination = st.checkbox("Pagination", value=False)
+ paginate_batch_size = st.number_input("Pagination batch size", value=5000, disabled=not pagination)
+ st.form_submit_button("OK")
+
+ cols = st.columns([1.5, 1])
+ with cols[0]:
+ start_time = time.time()
+ fetch_fun = fetch_queries_from_docDB_parallel if parallel else fetch_queries_from_docDB
+ dfs_docDB = fetch_fun(
+ queries_to_merge=QUERY_PRESET,
+ pagination=pagination,
+ paginate_batch_size=paginate_batch_size,
+ )
+ df_merged = merge_queried_dfs(dfs_docDB, QUERY_PRESET)
+
+ docDB_retrieve_time = time.time() - start_time
+ st.markdown(f"Finished in {docDB_retrieve_time:.3f} secs.")
+
+ with cols[1]:
+ if st.button('Re-fetch docDB queries'):
+ st.cache_data.clear()
+ st.rerun()
+
+ if df_merged is None:
+ st.cache_data.clear() # Fetch failed, re-fetch
+ return
+
+ # --- 2. Merge in the master df in the Home page (Han's temporary pipeline) ---
+ # Data from Home.init (all sessions from Janelia bpod + AIND bpod + AIND bonsai)
+ df_from_Home = st.session_state.df["sessions_bonsai"]
+ # Only keep AIND sessions
+ df_from_Home = df_from_Home.query("institute == 'AIND'")
+ df_from_Home.loc[df_from_Home.hardware == "bpod", "Han_temp_pipeline (bpod)"] = True
+ df_from_Home.loc[df_from_Home.hardware == "bonsai", "Han_temp_pipeline (bonsai)"] = True
+
+ # Only keep subject_id and session_date as index
+ df_Han_pipeline = (
+ df_from_Home[
+ [
+ "subject_id",
+ "session_date",
+ "Han_temp_pipeline (bpod)",
+ "Han_temp_pipeline (bonsai)",
+ ]
+ ]
+ .set_index(["subject_id", "session_date"])
+ .sort_index(
+ level=["session_date", "subject_id"],
+ ascending=[False, False],
+ )
+ )
+
+ # Merged with df_merged
+ df_merged = df_merged.combine_first(df_Han_pipeline)
+
+ # --- 3. Get raw data on VAST ---
+ raw_sessions_on_VAST = load_raw_sessions_on_VAST()
+
+ # Example entry of raw_sessions_on_VAST:
+ # Z:\svc_aind_behavior_transfer\447-3-D\751153\behavior_751153_2024-10-20_17-13-34\behavior
+ # Z:\svc_aind_behavior_transfer\2023late_DataNoMeta_Reorganized\687553_2023-11-20_09-48-24\behavior
+ # Z:\svc_aind_behavior_transfer\2023late_DataNoMeta_Reorganized\687553_2023-11-13_11-09-55\687553_2023-12-01_09-41-43\TrainingFolder
+ # Let's find the strings between two \\s that precede "behavior" or "TrainingFolder"
+
+ # Parse "name" from full path on VAST
+ re_pattern = R"\\([^\\]*)\\(?:behavior|TrainingFolder)$"
+ session_names = [re.findall(re_pattern, path)[0] for path in raw_sessions_on_VAST]
+ df_raw_sessions_on_VAST = pd.DataFrame(raw_sessions_on_VAST, columns=["full_path"])
+ df_raw_sessions_on_VAST["name"] = session_names
+ df_raw_sessions_on_VAST["raw_data_on_VAST"] = True
+
+ # Formatting metadata dataframe
+ (
+ df_raw_sessions_on_VAST,
+ df_raw_sessions_on_VAST_unique_mouse_date,
+ df_raw_sessions_on_VAST_multi_sessions_per_day
+ ) = formatting_metadata_df(df_raw_sessions_on_VAST, source_prefix="VAST")
+
+ dfs_raw_on_VAST = {
+ "df": df_raw_sessions_on_VAST,
+ "df_unique_mouse_date": df_raw_sessions_on_VAST_unique_mouse_date,
+ "df_multi_sessions_per_day": df_raw_sessions_on_VAST_multi_sessions_per_day,
+ }
+
+ # Merging with df_merged (using the unique mouse-date dataframe)
+ df_merged = df_merged.combine_first(df_raw_sessions_on_VAST_unique_mouse_date)
+ df_merged.sort_index(level=["session_date", "subject_id"], ascending=[False, False], inplace=True)
+
+ # --- Add sidebar ---
+ add_sidebar(df_merged, dfs_docDB, df_Han_pipeline, dfs_raw_on_VAST, docDB_retrieve_time)
+
+ # --- Main contents ---
+ st.markdown(f"# Data inventory for dynamic foraging")
+ st.markdown(f"### Merged metadata (n = {len(df_merged)}, see the sidebar for details)")
+ download_df(df_merged, label="Download merged df as CSV", file_name="df_docDB_queries.csv")
+
+ aggrid_interactive_table_basic(
+ df_merged.reset_index(),
+ height=400,
+ configure_columns=[
+ dict(
+ field="session_date",
+ type=["customDateTimeFormat"],
+ custom_format_string="yyyy-MM-dd",
+ )
+ ],
+ )
+
+ # --- Venn diagram from presets ---
+ with open("data_inventory_VENN_PRESET.json", "r") as f:
+ VENN_PRESET = json.load(f)
+
+ if VENN_PRESET:
+
+ cols = st.columns([2, 1])
+ cols[0].markdown("## Venn diagrams from presets")
+ with cols[1].expander("Time view settings", expanded=True):
+ cols_1 = st.columns([1, 1])
+ if_separate_plots = cols_1[0].checkbox("Separate in subplots", value=True)
+ if_sync_y_limits = cols_1[0].checkbox(
+ "Sync Y limits", value=True, disabled=not if_separate_plots
+ )
+ time_period = cols_1[1].selectbox(
+ "Bin size",
+ ["Daily", "Weekly", "Monthly", "Quarterly"],
+ index=1,
+ disabled=not if_separate_plots,
+ )
+
+ for i_venn, venn_preset in enumerate(VENN_PRESET):
+ # -- Venn diagrams --
+ st.markdown(f"### ({i_venn+1}). {venn_preset['name']}")
+ fig, notes = generate_venn(
+ df_merged,
+ venn_preset
+ )
+ for note in notes:
+ st.markdown(note)
+
+ cols = st.columns([1, 1])
+ with cols[0]:
+ st.pyplot(fig, use_container_width=True)
+
+ # -- Show and download df for this Venn --
+ circle_columns = [c_s["column"] for c_s in venn_preset["circle_settings"]]
+ # Show histogram over time for the columns and patches in preset
+ df_this_preset = df_merged[circle_columns]
+ # Filter out rows that have at least one True in this Venn
+ df_this_preset = df_this_preset[df_this_preset.any(axis=1)]
+
+ # Create a new column to indicate sessions in patches specified by patch_ids like ["100", "101", "110", "111"]
+ for patch_setting in venn_preset.get("patch_settings", []):
+ idx = _filter_df_by_patch_ids(
+ df_this_preset[circle_columns],
+ patch_setting["patch_ids"]
+ )
+ df_this_preset.loc[idx, str(patch_setting["patch_ids"])] = True
+
+ # Join in other extra columns
+ df_this_preset = df_this_preset.join(
+ df_merged[[col for col in df_merged.columns if col not in META_COLUMNS]], how="left"
+ )
+
+ with cols[0]:
+ download_df(
+ df_this_preset,
+ label="Download as CSV for this Venn diagram",
+ file_name=f"df_{venn_preset['name']}.csv",
+ )
+ with st.expander(f"Show dataframe, n = {len(df_this_preset)}"):
+ st.write(df_this_preset)
+
+ with cols[1]:
+ # -- Show histogram over time --
+ fig = plot_histogram_over_time(
+ df=df_this_preset.reset_index(),
+ venn_preset=venn_preset,
+ time_period=time_period,
+ if_sync_y_limits=if_sync_y_limits,
+ if_separate_plots=if_separate_plots,
+ )
+ plotly_events(
+ fig,
+ click_event=False,
+ hover_event=False,
+ select_event=False,
+ override_height=fig.layout.height * 1.1,
+ override_width=fig.layout.width,
+ )
+
+ st.markdown("---")
+
+ # --- User-defined Venn diagram ---
+ # Multiselect for selecting queries up to three
+ # st.markdown('---')
+ # st.markdown("## Venn diagram from user-selected queries")
+ # selected_queries = st.multiselect(
+ # "Select queries to filter sessions",
+ # meta_columns,
+ # default=meta_columns[:3],
+ # key="selected_queries",
+ # )
+
+ # columns_to_venn = selected_queries
+ # fig = generate_venn(df_merged, columns_to_venn)
+ # st.columns([1, 1])[0].pyplot(fig, use_container_width=True)
+
+
+if __name__ == "__main__":
+
+ # Share the same master df as the Home page
+ if "df" not in st.session_state or "sessions_bonsai" not in st.session_state.df.keys() or not st.session_state.bpod_loaded:
+ st.spinner("Loading data from Han temp pipeline...")
+ init(if_load_docDB_override=False, if_load_bpod_data_override=True)
+
+ app()
diff --git a/code/pages/3_AIND data access playground.py b/code/pages/3_AIND data access playground.py
deleted file mode 100644
index 03324c7..0000000
--- a/code/pages/3_AIND data access playground.py
+++ /dev/null
@@ -1,30 +0,0 @@
-'''Migrated from David's toy app https://codeocean.allenneuraldynamics.org/capsule/9532498/tree
-'''
-
-import logging
-
-import streamlit as st
-from streamlit_dynamic_filters import DynamicFilters
-from util.fetch_data_docDB import load_data_from_docDB
-
-try:
- st.set_page_config(layout="wide",
- page_title='Foraging behavior browser',
- page_icon=':mouse2:',
- menu_items={
- 'Report a bug': "https://github.com/hanhou/foraging-behavior-browser/issues",
- 'About': "Github repo: https://github.com/hanhou/foraging-behavior-browser/"
- }
- )
-except:
- pass
-
-df = load_data_from_docDB()
-
-st.markdown(f'### Note: the dataframe showing here has been merged in to the master table on the Home page!')
-
-dynamic_filters = DynamicFilters(
- df=df,
- filters=['subject_id', 'subject_genotype'])
-dynamic_filters.display_filters()
-dynamic_filters.display_df()
diff --git a/code/util/aws_s3.py b/code/util/aws_s3.py
index 17f258b..300dffc 100644
--- a/code/util/aws_s3.py
+++ b/code/util/aws_s3.py
@@ -37,6 +37,12 @@ def load_data(tables=['sessions'], data_source = 'bonsai'):
width=500)
return df
+@st.cache_data(ttl=12*3600)
+def load_raw_sessions_on_VAST():
+ file_name = s3_nwb_folder['bonsai'] + 'raw_sessions_on_VAST.json'
+ with fs.open(file_name) as f:
+ raw_sessions_on_VAST = json.load(f)
+ return raw_sessions_on_VAST
def draw_session_plots_quick_preview(df_to_draw_session):
diff --git a/code/util/fetch_data_docDB.py b/code/util/fetch_data_docDB.py
index d063ed4..9b867be 100644
--- a/code/util/fetch_data_docDB.py
+++ b/code/util/fetch_data_docDB.py
@@ -3,6 +3,8 @@
import logging
import time
+from concurrent.futures import ThreadPoolExecutor, as_completed
+import threading
import pandas as pd
import semver
@@ -13,6 +15,7 @@
from aind_data_access_api.document_db import MetadataDbClient
+from util.reformat import formatting_metadata_df
@st.cache_data(ttl=3600*12) # Cache the df_docDB up to 12 hours
def load_data_from_docDB():
@@ -221,3 +224,95 @@ def find_result(x, lookup):
if result_name.startswith(x):
return result
return {}
+
+
+# --------- Helper functions for Data Inventory (Han) ----------
+@st.cache_data(ttl=3600*24)
+def fetch_single_query(query, pagination, paginate_batch_size):
+ """ Fetch a query from docDB and process the result into dataframe
+
+ Return:
+ - df: dataframe of the original returned records
+ - df_multi_sessions_per_day: the dataframe with multiple sessions per day
+ - df_unique_mouse_date: the dataframe with multiple sessions per day combined
+ """
+
+ # --- Fetch data from docDB ---
+ client = load_client()
+ results = client.retrieve_docdb_records(
+ filter_query=query["filter"],
+ projection={
+ "_id": 0,
+ "name": 1,
+ "rig.rig_id": 1,
+ "session.experimenter_full_name": 1,
+ },
+ paginate=pagination,
+ paginate_batch_size=paginate_batch_size,
+ )
+ print(f"Done querying {query['alias']}!")
+
+ # --- Process data into dataframe ---
+ df = pd.json_normalize(results)
+
+ # Formatting dataframe
+ df, df_unique_mouse_date, df_multi_sessions_per_day = formatting_metadata_df(df)
+ df_unique_mouse_date[query["alias"]] = True # Add a column to prepare for merging
+
+ return df, df_unique_mouse_date, df_multi_sessions_per_day
+
+# Don't cache this function otherwise the progress bar won't work
+def fetch_queries_from_docDB(queries_to_merge, pagination=False, paginate_batch_size=5000):
+ """ Get merged queries from selected queries """
+
+ dfs = {}
+
+ # Fetch data in serial
+ p_bar = st.progress(0, text="Querying docDB in serial...")
+ for i, query in enumerate(queries_to_merge):
+ df, df_unique_mouse_date, df_multi_sessions_per_day = fetch_single_query(
+ query, pagination=pagination, paginate_batch_size=paginate_batch_size
+ )
+ dfs[query["alias"]] = {
+ "df": df,
+ "df_unique_mouse_date": df_unique_mouse_date,
+ "df_multi_sessions_per_day": df_multi_sessions_per_day,
+ }
+ p_bar.progress((i+1) / len(queries_to_merge), text=f"Querying docDB... ({i+1}/{len(queries_to_merge)})")
+
+ if not dfs:
+ st.warning("Querying docDB error! Try \"Pagination\" in MetadataDbClient settings or ask Han.")
+ return None
+
+ return dfs
+
+@st.cache_data(ttl=3600*12)
+def fetch_queries_from_docDB_parallel(queries_to_merge, pagination=False, paginate_batch_size=5000):
+ """ Get merged queries from selected queries """
+
+ dfs = {}
+
+ # Fetch data in parallel
+ with ThreadPoolExecutor(max_workers=len(queries_to_merge)) as executor:
+ future_to_query = {
+ executor.submit(
+ fetch_single_query,
+ key,
+ pagination=pagination,
+ paginate_batch_size=paginate_batch_size,
+ ): key
+ for key in queries_to_merge
+ }
+ for i, future in enumerate(as_completed(future_to_query), 1):
+ key = future_to_query[future]
+ try:
+ df, df_unique_mouse_date, df_multi_sessions_per_day = future.result()
+ dfs[key["alias"]] = {
+ "df": df,
+ "df_unique_mouse_date": df_unique_mouse_date,
+ "df_multi_sessions_per_day": df_multi_sessions_per_day,
+ }
+ except Exception as e:
+ print(f"Error querying {key}: {e}")
+
+ return dfs
\ No newline at end of file
diff --git a/code/util/reformat.py b/code/util/reformat.py
new file mode 100644
index 0000000..dd3da68
--- /dev/null
+++ b/code/util/reformat.py
@@ -0,0 +1,98 @@
+
+""" Helper functions to reformat the data
+"""
+import re
+import pandas as pd
+
+# Function to split the `nwb_name` column
+def split_nwb_name(nwb_name):
+ """Turn the nwb_name into subject_id, session_date, nwb_suffix in order to be merged to
+ the main df.
+
+ Parameters
+ ----------
+ nwb_name : str. The name of the nwb file. This function can handle the following formats:
+ "721403_2024-08-09_08-39-12.nwb"
+ "685641_2023-10-04.nwb",
+ "behavior_754280_2024-11-14_11-06-24.nwb",
+ "behavior_1_2024-08-05_15-48-54",
+ "
+ ...
+
+ Returns
+ -------
+ subject_id : str. The subject ID
+ session_date : str. The session date
+ nwb_suffix : int. The nwb suffix (converted from session time if available, otherwise 0)
+ """
+
+ pattern = R"(?:\w+_)?(?P\d+)_(?P\d{4}-\d{2}-\d{2})(?:_(?P