From 9259df9b508cd77fb0680d9639ba8992bf5460a2 Mon Sep 17 00:00:00 2001 From: jacek-oet Date: Tue, 3 Sep 2024 18:12:37 +0200 Subject: [PATCH] feat: update data table --- website/components/benchmark_table.py | 156 ++++++++++++++++++++------ website/raw-results.py | 1 - 2 files changed, 124 insertions(+), 33 deletions(-) diff --git a/website/components/benchmark_table.py b/website/components/benchmark_table.py index baf533c..7950544 100644 --- a/website/components/benchmark_table.py +++ b/website/components/benchmark_table.py @@ -1,39 +1,131 @@ import pandas as pd -from st_aggrid import AgGrid, GridOptionsBuilder -from st_aggrid.shared import GridUpdateMode - +import streamlit as st +import streamlit.components.v1 as components +from pandas.api.types import ( + is_categorical_dtype, + is_datetime64_any_dtype, + is_numeric_dtype, + is_object_dtype, +) def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame: - gb = GridOptionsBuilder.from_dataframe(df) - gb.configure_grid_options( - paginationPageSizeSelector=[10, 20, 50, 100], - ) - # Add styling options - gb.configure_pagination(paginationAutoPageSize=False, paginationPageSize=10) - gb.configure_side_bar() - gb.configure_default_column( - resizable=True, - sortable=True, - filter=True, - floatingFilter=True, - ) - - grid_options = gb.build() - - grid_response = AgGrid( - df, - height=570, - gridOptions=grid_options, - update_mode=GridUpdateMode.SELECTION_CHANGED | GridUpdateMode.FILTERING_CHANGED, - fit_columns_on_grid_load=True, - enable_enterprise_modules=True, - theme="alpine", - ) - - filtered_df = pd.DataFrame(grid_response["data"]) - return filtered_df + modify = st.checkbox("Add filters") + + if modify: + df = df.copy() + # Try to convert datetimes into a standard format (datetime, no timezone) + for col in df.columns: + if is_object_dtype(df[col]): + try: + df[col] = pd.to_datetime(df[col], format='%Y-%m-%d') # Update the format as needed + except Exception: + pass + + if is_datetime64_any_dtype(df[col]): + df[col] = df[col].dt.tz_localize(None) + + filter_menu = st.columns(2) + with filter_menu[0]: + to_filter_columns = st.multiselect("Filter dataframe on", df.columns) + + with filter_menu[1]: + modification_container = st.container() + + with modification_container: + for column in to_filter_columns: + left, right = st.columns((1, 20)) + # Treat columns with < 10 unique values as categorical + if is_categorical_dtype(df[column]) or df[column].nunique() < 10: + user_cat_input = right.multiselect( + f"Values for {column}", + df[column].unique(), + default=list(df[column].unique()), + ) + df = df[df[column].isin(user_cat_input)] + elif is_numeric_dtype(df[column]): + _min = float(df[column].min()) + _max = float(df[column].max()) + step = (_max - _min) / 100 + user_num_input = right.slider( + f"Values for {column}", + min_value=_min, + max_value=_max, + value=(_min, _max), + step=step, + ) + df = df[df[column].between(*user_num_input)] + elif is_datetime64_any_dtype(df[column]): + user_date_input = right.date_input( + f"Values for {column}", + value=( + df[column].min(), + df[column].max(), + ), + ) + if len(user_date_input) == 2: + user_date_input = tuple(map(pd.to_datetime, user_date_input)) + start_date, end_date = user_date_input + df = df.loc[df[column].between(start_date, end_date)] + else: + user_text_input = right.text_input( + f"Substring or regex in {column}", + ) + if user_text_input: + df = df[df[column].astype(str).str.contains(user_text_input)] + + return df + + +def sort_dataframe(df: pd.DataFrame) -> pd.DataFrame: + top_menu = st.columns(4) + with top_menu[0]: + sort = st.checkbox("Add Sort") + if sort: + with top_menu[1]: + sort_field = st.selectbox("Sort By", options=df.columns) + with top_menu[2]: + sort_direction = st.radio( + "Direction", options=["⬆️", "⬇️"], horizontal=True + ) + df = df.sort_values( + by=sort_field, ascending=sort_direction == "⬆️", ignore_index=True + ) + + # Add a text input for global search + with top_menu[3]: + search = st.text_input("Search") + if search: + df = df[df.apply(lambda row: row.astype(str).str.contains(search, case=False).any(), axis=1)] + + return df + + +def split_frame(input_df, rows) -> list: + df = [input_df[i:i + rows] for i in range(0, len(input_df), rows)] + return df + + +def paginate_dataframe(df: pd.DataFrame): + pagination = st.container() + bottom_menu = st.columns((6, 1, 1)) + with bottom_menu[2]: + batch_size = st.selectbox("Page Size", options=[10, 25, 50, 100]) + with bottom_menu[1]: + total_pages = ( + int(len(df) / batch_size) if int(len(df) / batch_size) > 0 else 1 + ) + current_page = st.number_input( + "Page", min_value=1, max_value=total_pages, step=1 + ) + with bottom_menu[0]: + st.markdown(f"Page **{current_page}** of **{total_pages}** ") + + pages = split_frame(df, batch_size) + pagination.dataframe(data=(pages[current_page - 1] if int(len(pages)) > 0 else []), use_container_width=True) def display_table(df: pd.DataFrame): filtered_df = filter_dataframe(df) - return filtered_df + sorted_df = sort_dataframe(filtered_df) + paginate_dataframe(sorted_df) + diff --git a/website/raw-results.py b/website/raw-results.py index 738e523..2d33b10 100644 --- a/website/raw-results.py +++ b/website/raw-results.py @@ -20,5 +20,4 @@ df = pd.read_csv(data_url) -# Remove the st.dataframe call and use AgGrid directly filtered_df = display_table(df)