Skip to content

Commit

Permalink
Merge pull request #8 from LCOGT/reextract-backend
Browse files Browse the repository at this point in the history
Front end changes for re-extraction
  • Loading branch information
cmccully authored Nov 5, 2024
2 parents 63dc148 + b46e95f commit 11102a0
Show file tree
Hide file tree
Showing 8 changed files with 331 additions and 73 deletions.
7 changes: 7 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
0.2.0 (2024-11-05)
------------------
- Added the ability to re-extract via the GUI

0.1.0 (2024-06-06)
------------------
- Initial Release
217 changes: 207 additions & 10 deletions banzai_floyds_ui/gui/app.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,43 @@
import dash
from dash import dcc, html
import dash_bootstrap_components as dbc
from django_plotly_dash import DjangoDash
import logging
import datetime
import requests
import asyncio
from banzai_floyds_ui.gui.utils.file_utils import fetch_all, get_related_frame
from banzai_floyds_ui.gui.plots import make_1d_sci_plot, make_2d_sci_plot, make_arc_2d_plot, make_arc_line_plots
from banzai_floyds_ui.gui.plots import make_profile_plot
from banzai_floyds_ui.gui.utils import file_utils
from banzai_floyds_ui.gui.utils.plot_utils import extraction_region_traces
from dash.exceptions import PreventUpdate
from banzai_floyds_ui.gui.utils.plot_utils import json_to_polynomial
from banzai_floyds_ui.gui.utils.plot_utils import EXTRACTION_REGION_LINE_ORDER
from banzai.utils import import_utils
from banzai.utils.stage_utils import get_stages_for_individual_frame
from banzai_floyds.frames import FLOYDSFrameFactory
from banzai_floyds import settings
from banzai_floyds.utils.profile_utils import profile_fits_to_data
import os
import banzai.main
import io
from banzai.logs import get_logger
from django.core.cache import cache


logger = logging.getLogger(__name__)
logger = get_logger()

dashboard_name = 'banzai-floyds'
app = DjangoDash(name=dashboard_name)
app = DjangoDash(name=dashboard_name, csrf_token_name='csrftoken')

# set up the context object for banzai

settings.fpack = True
settings.post_to_open_search = bool(os.environ.get('POST_TO_OPENSEARCH', False))
settings.post_to_archive = bool(os.environ.get('POST_TO_ARCHIVE', False))
settings.no_file_cache = True
settings.db_address = os.environ['DB_ADDRESS']
RUNTIME_CONTEXT = banzai.main.parse_args(settings, parse_system_args=False)


def layout():
Expand All @@ -29,6 +49,30 @@ def layout():
html.Div(
id='options-container',
children=[
dbc.Modal([
dbc.ModalHeader(dbc.ModalTitle("Error"), className='bg-danger text-white'),
dbc.ModalBody("You must be logged in to save an extraction."),
],
id="error-logged-in-modal",
is_open=False,
),
dbc.Modal([
dbc.ModalHeader(dbc.ModalTitle("Error"), className='bg-danger text-white'),
dbc.ModalBody("Error extracting spectrum. Plots may not reflect extraction paramters."),
],
id="error-extract-failed-modal",
is_open=False,
),
dbc.Modal([
dbc.ModalHeader(dbc.ModalTitle("Error"), className='bg-danger text-white'),
dbc.ModalBody("""
Error saving spectrum.
Note you need to have clicked the re-extract button at least before saving.
"""),
],
id="error-extract-failed-on-save-modal",
is_open=False,
),
html.Div(
children=[
dcc.DatePickerRange(
Expand Down Expand Up @@ -72,8 +116,8 @@ def layout():
id='plot-container',
children=[
dcc.Store(id='initial-extraction-info'),
dcc.Store(id='file-list-metadata'),
dcc.Store(id='extraction-positions'),
dcc.Store(id='extraction-traces'),
dcc.Loading(
id='loading-arc-2d-plot-container',
type='default',
Expand Down Expand Up @@ -111,15 +155,22 @@ def layout():
config={'edits': {'shapePosition': True}}),
]
),
html.Div(['Extraction Type:',
dcc.RadioItems(['Optimal', 'Unweighted'], 'Optimal', inline=True, id='extraction-type',
style={"margin-right": "10px"})],
dbc.Button('Re-Extract', id='extract-button')),
dcc.Loading(
id='loading-extraction-plot-container',
type='default',
children=[
dcc.Store(id='extraction-traces'),
dcc.Store(id='extractions'),
dcc.Graph(id='extraction-plot',
style={'display': 'inline-block',
'width': '100%', 'height': '550px;'}),
]
)
),
dbc.Button('Save Extraction', id='extract-button'),
]
)
]
Expand All @@ -129,7 +180,8 @@ def layout():
app.layout = layout


@app.expanded_callback(dash.dependencies.Output('file-list-dropdown', 'options'),
@app.expanded_callback([dash.dependencies.Output('file-list-dropdown', 'options'),
dash.dependencies.Output('file-list-metadata', 'data')],
[dash.dependencies.Input('date-range-picker', 'start_date'),
dash.dependencies.Input('date-range-picker', 'end_date'),
dash.dependencies.Input('site-picker', 'value')])
Expand Down Expand Up @@ -160,9 +212,9 @@ def callback_dropdown_files(*args, **kwargs):
logger.error(f"Failed to fetch data from archive: {e}. {response.content}")
return
data = response.json()['results']
results += [{'label': row['filename'], 'value': row['id']} for row in data]
results += [{'label': f'{row["filename"]} {row["OBJECT"]} {row["PROPID"]}', 'value': row['id']} for row in data]
results.sort(key=lambda x: x['label'])
return results
return results, results


# This callback snaps the lines back to the same y-position so they don't wander
Expand Down Expand Up @@ -240,7 +292,7 @@ def on_extraction_region_update(extraction_positions, initial_extraction_info):
for line in EXTRACTION_REGION_LINE_ORDER[1:]:
center = extraction_positions[str(order)]['center']
position = extraction_positions[str(order)][line]
positions_sigma[line] = abs(position - center)
positions_sigma[line] = position - center
positions_sigma[line] /= initial_extraction_info['refsigma'][str(order)]
center_delta = extraction_positions[str(order)]['center'] - initial_extraction_info['refcenter'][str(order)]
x, traces = extraction_region_traces(order_center_polynomial, center_polynomial, width_polynomial,
Expand All @@ -250,7 +302,7 @@ def on_extraction_region_update(extraction_positions, initial_extraction_info):
xs.append(x)
ys.append(trace)

return {'x': xs, 'y': ys}
return {'x': xs, 'y': ys}


@app.expanded_callback(
Expand Down Expand Up @@ -278,6 +330,9 @@ def callback_make_plots(*args, **kwargs):
sci_2d_frame, sci_2d_filename = get_related_frame(frame_id, archive_header, 'L1ID2D')
sci_2d_plot, extraction_data = make_2d_sci_plot(sci_2d_frame, sci_2d_filename)

file_utils.cache_fits('science_2d_frame', sci_2d_frame)
cache.set('filename', sci_2d_frame['SCI'].header['ORIGNAME'] + '.fits')

profile_plot, initial_extraction_info = make_profile_plot(sci_2d_frame)

for key in extraction_data:
Expand Down Expand Up @@ -335,3 +390,145 @@ def update_extraction_positions(initial_extraction_info, relayout_data, current_
line_id = EXTRACTION_REGION_LINE_ORDER[line_index % 7]
current_extraction_positions[str(order)][line_id] = relayout_data[key_with_update]
return current_extraction_positions


def reextract(hdu, filename, extraction_positions, initial_extraction_info, runtime_context, extraction_type='optimal'):
# Convert 2d science hdu to banzai-floyds frame object
factory = FLOYDSFrameFactory()
buffer = io.BytesIO()
hdu.writeto(buffer)
buffer.seek(0)
file_info = {'filename': filename, 'data_buffer': buffer}
frame = factory.open(file_info, runtime_context)
# reset the weights and the background region
centers, widths = frame.profile_fits
for order in [1, 2]:
center_delta = extraction_positions[str(order)]['center'] - \
initial_extraction_info['positions'][str(order)]['center']
centers[order-1].coef[0] += center_delta
frame.profile_fits = centers, widths, frame['PROFILEFITS'].data
frame.profile = profile_fits_to_data(frame.data.shape, centers, widths,
frame.orders, frame.wavelengths.data)
extraction_windows = []
for order in [1, 2]:
lower = extraction_positions[str(order)]['extract_lower'] - extraction_positions[str(order)]['center']
lower /= initial_extraction_info['refsigma'][str(order)]
upper = extraction_positions[str(order)]['extract_upper'] - extraction_positions[str(order)]['center']
upper /= initial_extraction_info['refsigma'][str(order)]
extraction_windows.append([lower, upper])

frame.extraction_windows = extraction_windows
# Override the default optimal extraction weights from the profile
if extraction_type == 'unweighted':
frame.binned_data['weights'] = 1.0

background_windows = []
for order in [1, 2]:
order_background = []
for region in ['left', 'right']:
inner = extraction_positions[str(order)][f'bkg_{region}_inner'] - extraction_positions[str(order)]['center']
inner /= initial_extraction_info['refsigma'][str(order)]
outer = extraction_positions[str(order)][f'bkg_{region}_outer'] - extraction_positions[str(order)]['center']
outer /= initial_extraction_info['refsigma'][str(order)]
this_background = [inner, outer]
this_background.sort()
order_background.append(this_background)
background_windows.append(order_background)
frame.background_windows = background_windows
stages_to_do = get_stages_for_individual_frame(runtime_context.ORDERED_STAGES,
last_stage=runtime_context.LAST_STAGE[frame.obstype.upper()],
extra_stages=runtime_context.EXTRA_STAGES[frame.obstype.upper()])

# Starting at the extraction weights stage
start_index = stages_to_do.index('banzai_floyds.extract.Extractor')
stages_to_do = stages_to_do[start_index:]

for stage_name in stages_to_do:
stage_constructor = import_utils.import_attribute(stage_name)
stage = stage_constructor(runtime_context)
frames = stage.run([frame])
if not frames:
logger.error('Reduction stopped', extra_tags={'filename': filename})
return
logger.info('Reduction complete', extra_tags={'filename': filename})
return frame


@app.expanded_callback([dash.dependencies.Output('extractions', 'data'),
dash.dependencies.Output('error-extract-failed-modal', 'is_open')],
[dash.dependencies.Input('extraction-positions', 'data'),
dash.dependencies.Input('extraction-type', 'value')],
[dash.dependencies.State('initial-extraction-info', 'data'),
dash.dependencies.State('file-list-dropdown', 'value'),
dash.dependencies.State('file-list-metadata', 'data')],
prevent_initial_call=True)
def trigger_reextract(extraction_positions, extraction_type, initial_extraction_info,
frame_id, frame_data, **kwargs):

science_frame = file_utils.get_cached_fits('science_2d_frame')

if science_frame is None:
raise PreventUpdate

filename = cache.get('filename')
frame = reextract(science_frame, filename, extraction_positions, initial_extraction_info,
RUNTIME_CONTEXT, extraction_type=extraction_type.lower())

if frame is None:
return dash.no_update, True

file_utils.cache_frame('reextracted_frame', frame)
x = []
y = []
for order in [2, 1]:
where_order = frame.extracted['order'] == order
for flux in ['flux', 'fluxraw', 'background']:
x.append(frame.extracted['wavelength'][where_order])
y.append(frame.extracted[flux][where_order])
return {'x': x, 'y': y}, False


app.clientside_callback(
"""
function(extraction_data) {
if (typeof extraction_data === "undefined") {
return window.dash_clientside.no_update;
}
var dccGraph = document.getElementById('extraction-plot');
var jsFigure = dccGraph.querySelector('.js-plotly-plot');
var trace_ids = [];
for (let i = 0; i < extraction_data.x.length; i++) {
trace_ids.push(i);
}
Plotly.restyle(jsFigure, extraction_data, trace_ids);
return window.dash_clientside.no_update;
}
""",
dash.dependencies.Input('extractions', 'data'),
prevent_initial_call=True)


@app.expanded_callback([dash.dependencies.Output("error-logged-in-modal", "is_open"),
dash.dependencies.Output('error-extract-failed-on-save-modal', 'is_open')],
dash.dependencies.Input('extract-button', 'n_clicks'),
prevent_initial_call=True)
def save_extraction(n_clicks, **kwargs):
if not n_clicks:
raise PreventUpdate

# If not logged in, open a modal saying you can only save if you are.
username = kwargs['session_state'].get('username')
if username is None:
return True, dash.no_update

# Run the reextraction
extracted_frame = file_utils.get_cached_frame('reextracted_frame')
if extracted_frame is None:
return dash.no_update, True

# Save the reducer into the header
extracted_frame.meta['REDUCER'] = username
# Push the results to the archive
extracted_frame.write(RUNTIME_CONTEXT)
# Return false to keep the error modal closed
return dash.no_update, dash.no_update
Loading

0 comments on commit 11102a0

Please sign in to comment.