Skip to content

Commit

Permalink
Minor changes to make the UX better given how long the reextraction t…
Browse files Browse the repository at this point in the history
…akes.
  • Loading branch information
cmccully committed Oct 31, 2024
1 parent ed54990 commit 47df98f
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 39 deletions.
53 changes: 22 additions & 31 deletions banzai_floyds_ui/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from banzai_floyds_ui.gui.utils.file_utils import fetch_all, get_related_frame
from banzai_floyds_ui.gui.plots import make_1d_sci_plot, make_2d_sci_plot, make_arc_2d_plot, make_arc_line_plots
from banzai_floyds_ui.gui.plots import make_profile_plot
from banzai_floyds_ui.gui.utils.file_utils import get_filename, cache_fits, get_cached_fits
from banzai_floyds_ui.gui.utils import file_utils
from banzai_floyds_ui.gui.utils.plot_utils import extraction_region_traces
from dash.exceptions import PreventUpdate
from banzai_floyds_ui.gui.utils.plot_utils import json_to_polynomial
Expand All @@ -22,12 +22,13 @@
import banzai.main
import io
from banzai.logs import get_logger
from django.core.cache import cache


logger = get_logger()

dashboard_name = 'banzai-floyds'
app = DjangoDash(name=dashboard_name)
app = DjangoDash(name=dashboard_name, csrf_token_name='csrftoken')

# set up the context object for banzai

Expand All @@ -51,29 +52,23 @@ def layout():
dbc.Modal([
dbc.ModalHeader(dbc.ModalTitle("Error"), className='bg-danger text-white'),
dbc.ModalBody("You must be logged in to save an extraction."),
dbc.ModalFooter(
dbc.Button("Close", id="logged-in-close", className="ms-auto", n_clicks=0)
),
],
id="error-logged-in-modal",
is_open=False,
),
dbc.Modal([
dbc.ModalHeader(dbc.ModalTitle("Error"), className='bg-danger text-white'),
dbc.ModalBody("Error extracting spectrum. Plots may not reflect extraction paramters."),
dbc.ModalFooter(
dbc.Button("Close", id="extract-fail-close", className="ms-auto", n_clicks=0)
),
],
id="error-extract-failed-modal",
is_open=False,
),
dbc.Modal([
dbc.ModalHeader(dbc.ModalTitle("Error"), className='bg-danger text-white'),
dbc.ModalBody("Error extracting spectrum. Spectrum will not be saved."),
dbc.ModalFooter(
dbc.Button("Close", id="save-fail-close", className="ms-auto", n_clicks=0)
),
dbc.ModalBody("""
Error saving spectrum.
Note you need to have clicked the re-extract button at least before saving.
"""),
],
id="error-extract-failed-on-save-modal",
is_open=False,
Expand Down Expand Up @@ -123,8 +118,6 @@ def layout():
dcc.Store(id='initial-extraction-info'),
dcc.Store(id='file-list-metadata'),
dcc.Store(id='extraction-positions'),
dcc.Store(id='extraction-traces'),
dcc.Store(id='extractions'),
dcc.Loading(
id='loading-arc-2d-plot-container',
type='default',
Expand Down Expand Up @@ -162,17 +155,21 @@ def layout():
config={'edits': {'shapePosition': True}}),
]
),
html.Div('Extraction Type:'),
html.Div(['Extraction Type:',
dcc.RadioItems(['Optimal', 'Unweighted'], 'Optimal', inline=True, id='extraction-type',
style={"margin-right": "10px"})],
dbc.Button('Re-Extract', id='extract-button')),
dcc.Loading(
id='loading-extraction-plot-container',
type='default',
children=[
dcc.Store(id='extraction-traces'),
dcc.Store(id='extractions'),
dcc.Graph(id='extraction-plot',
style={'display': 'inline-block',
'width': '100%', 'height': '550px;'}),
]
),
dcc.RadioItems(['Optimal', 'Unweighted'], 'Optimal', inline=True, id='extraction-type'),
dbc.Button('Save Extraction', id='extract-button'),
]
)
Expand Down Expand Up @@ -215,7 +212,7 @@ def callback_dropdown_files(*args, **kwargs):
logger.error(f"Failed to fetch data from archive: {e}. {response.content}")
return
data = response.json()['results']
results += [{'label': row['filename'], 'value': row['id']} for row in data]
results += [{'label': f'{row["filename"]} {row["OBJECT"]} {row["PROPID"]}', 'value': row['id']} for row in data]
results.sort(key=lambda x: x['label'])
return results, results

Expand Down Expand Up @@ -333,7 +330,8 @@ def callback_make_plots(*args, **kwargs):
sci_2d_frame, sci_2d_filename = get_related_frame(frame_id, archive_header, 'L1ID2D')
sci_2d_plot, extraction_data = make_2d_sci_plot(sci_2d_frame, sci_2d_filename)

cache_fits('science_2d_frame', sci_2d_frame)
file_utils.cache_fits('science_2d_frame', sci_2d_frame)
cache.set('filename', sci_2d_frame['SCI'].header['ORIGNAME'] + '.fits')

profile_plot, initial_extraction_info = make_profile_plot(sci_2d_frame)

Expand Down Expand Up @@ -467,18 +465,19 @@ def reextract(hdu, filename, extraction_positions, initial_extraction_info, runt
def trigger_reextract(extraction_positions, extraction_type, initial_extraction_info,
frame_id, frame_data, **kwargs):

science_frame = get_cached_fits('science_2d_frame')
science_frame = file_utils.get_cached_fits('science_2d_frame')

if science_frame is None:
raise PreventUpdate

filename = get_filename(frame_id, frame_data)
filename = cache.get('filename')
frame = reextract(science_frame, filename, extraction_positions, initial_extraction_info,
RUNTIME_CONTEXT, extraction_type=extraction_type.lower())

if frame is None:
return dash.no_update, True

file_utils.cache_frame('reextracted_frame', frame)
x = []
y = []
for order in [2, 1]:
Expand Down Expand Up @@ -512,29 +511,21 @@ def trigger_reextract(extraction_positions, extraction_type, initial_extraction_
@app.expanded_callback([dash.dependencies.Output("error-logged-in-modal", "is_open"),
dash.dependencies.Output('error-extract-failed-on-save-modal', 'is_open')],
dash.dependencies.Input('extract-button', 'n_clicks'),
[dash.dependencies.State('extraction-type', 'value'),
dash.dependencies.State('extraction-positions', 'data'),
dash.dependencies.State('initial-extraction-info', 'data')],
prevent_initial_call=True)
def save_extraction(n_clicks, extraction_type, extraction_positions, initial_extraction_info,
frame_id, frame_data, **kwargs):
def save_extraction(n_clicks, **kwargs):
if not n_clicks:
raise PreventUpdate
science_frame = get_cached_fits('science_2d_frame')
if science_frame is None:
raise PreventUpdate

# If not logged in, open a modal saying you can only save if you are.
username = kwargs['session_state'].get('username')
if username is None:
return True, dash.no_update

filename = get_filename(frame_id, frame_data)
# Run the reextraction
extracted_frame = reextract(science_frame, filename, extraction_positions, initial_extraction_info,
RUNTIME_CONTEXT, extraction_type=extraction_type.lower())
extracted_frame = file_utils.get_cached_frame('reextracted_frame')
if extracted_frame is None:
return dash.no_update, True

# Save the reducer into the header
extracted_frame.meta['REDUCER'] = username
# Push the results to the archive
Expand Down
25 changes: 17 additions & 8 deletions banzai_floyds_ui/gui/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import requests
from django.core.cache import cache
from io import BytesIO
import pickle


async def fetch(url, params, headers):
Expand Down Expand Up @@ -45,14 +46,6 @@ def get_related_frame(frame_id, archive_header, related_frame_key):
return download_frame(archive_header, params=params, list_endpoint=True), related_frame_filename


def get_filename(frame_id, frame_data):
for row in frame_data:
if row['value'] == frame_id:
filename = row['label']
break
return filename


def cache_fits(key_name, hdulist):
buffer = BytesIO()
hdulist.writeto(buffer)
Expand All @@ -67,3 +60,19 @@ def get_cached_fits(key_name):
buffer = BytesIO(cached_value)
buffer.seek(0)
return fits.open(buffer)


def get_cached_frame(key_name):
cached_value = cache.get(key_name)
if cached_value is None:
return None
buffer = BytesIO(cached_value)
buffer.seek(0)
return pickle.load(buffer)


def cache_frame(key_name, frame):
buffer = BytesIO()
pickle.dump(frame, buffer)
buffer.seek(0)
cache.set(key_name, buffer.read(), timeout=None)
3 changes: 3 additions & 0 deletions banzai_floyds_ui/gui/views.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from django.shortcuts import render
from django.views.decorators.http import require_http_methods
from banzai_floyds_ui.gui.forms import LoginForm
from django.views.decorators.csrf import csrf_protect
from django.conf import settings
import requests

Expand All @@ -16,6 +17,7 @@ def banzai_floyds_view(request, template_name="floyds.html", **kwargs):
return render(request, template_name=template_name, context=context)


@csrf_protect
@require_http_methods(["POST"])
def login_view(request):
form = LoginForm(request.POST)
Expand All @@ -35,6 +37,7 @@ def login_view(request):
return render(request, 'floyds.html')


@csrf_protect
@require_http_methods(["POST"])
def logout_view(request):
if 'auth_token' in request.session:
Expand Down

0 comments on commit 47df98f

Please sign in to comment.