-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
191 lines (170 loc) · 6.22 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import sys
import h5py
import model_viz.config as config
import model_viz.utils as utils
import model_viz.hdf_ops as hdf_ops
import model_viz.plotting as plotting
import model_viz.component_factory as component_factory
import dash
from dash import html, dcc, Input, Output, State, MATCH
import plotly.graph_objects as go
import plotly.io as pio
import functools
import dash_bootstrap_components as dbc
from typing import List
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.PULSE])
app.title = "Model Viz"
pio.templates.default = config.Plotter.theme
GRAPH_TYPES = {
"Histogram 2D": plotting.Histogram2D,
"Boxplot over Time": plotting.BoxPlotOverTime,
}
def generate_plots(
groups: list[h5py.Group], graph_type: str
) -> List[plotting.BasePlotter]:
"""Generate plots from HDF5 file and return list of plots
Args:
groups: List of HDF5 groups
graph_type (str): Type of graph to generate
Raises:
NotImplementedError: If graph type is not implemented
Returns:
list: List of plots where each plot is a Plotly figure
"""
if graph_type not in GRAPH_TYPES:
raise NotImplementedError(f"Graph type {graph_type} not implemented")
plots = []
plotter = GRAPH_TYPES[graph_type]
for plot_item in groups:
title = plot_item.name.split("/")[-1]
data = plot_item["data"][:]
overlay_data = (
plot_item["overlay_data"][:].flatten()
if "overlay_data" in plot_item
else None
)
plot = plotter(data=data, overlay_data=overlay_data)
plot.create_plot(title=title)
plots.append(plot)
return plots
def main(argv):
reader = hdf_ops.HDFReader("reader", argv[0])
groups_to_plot = reader.get_all_groups()
group_tabs = {
group: component_factory.DashTab(label=group, component_id=group)
for group in groups_to_plot.keys()
}
dash_tabs = component_factory.DashTabs(
component_id="dash_tabs", tabs=list(group_tabs.values())
).generate_component()
app.layout = dbc.Container(
[
html.Div(
[
html.Br(),
html.H1("Model-Viz"),
html.Hr(),
dcc.Dropdown(
options=list(GRAPH_TYPES.keys()),
placeholder="Select Graph Type",
id="graph_type",
),
html.Br(),
dbc.Spinner(
dbc.Button(
"Export All Plots",
color="primary",
id="export",
class_name="me-1",
)
),
dcc.Download(id="download"),
html.Br(),
html.Div(dash_tabs),
html.Div(
id="tab_content_1",
style={"width": "75%", "display": "inline-block"},
),
html.Div(
id="tab_content_2",
style={"width": "25%", "display": "inline-block"},
),
]
)
],
fluid=True,
)
@app.callback(
Output("tab_content_1", "children"),
Output("tab_content_2", "children"),
Input("graph_type", "value"),
Input("dash_tabs", "active_tab"),
prevent_initial_call=True,
)
@functools.lru_cache(maxsize=32)
def update_graph1(graph_type, active_tab):
if graph_type is not None:
if active_tab in group_tabs:
dash_graphs_1 = []
dash_graphs_2 = []
for plot in generate_plots(groups_to_plot[active_tab], graph_type):
dash_graphs_1.append(
dcc.Graph(
id={"type": "dcc_go_1", "index": plot.title},
figure=plot.fig,
style=config.Plotter.graph_div_style,
)
)
dash_graphs_2.append(
dcc.Graph(
id={"type": "dcc_go_2", "index": plot.title},
figure=go.Figure(),
style=config.Plotter.graph_div_style,
)
)
return dash_graphs_1, dash_graphs_2
else:
raise NotImplementedError(f"Active tab {active_tab} not implemented")
return dash.no_update
@app.callback(
Output({"type": "dcc_go_2", "index": MATCH}, "figure"),
Input({"type": "dcc_go_1", "index": MATCH}, "hoverData"),
State({"type": "dcc_go_1", "index": MATCH}, "id"),
State("dash_tabs", "active_tab"),
prevent_initial_call=True,
)
def update_graph2(hover_data, id, active_tab):
if hover_data is not None:
graph_title = id["index"]
x = int(hover_data["points"][0]["x"])
group = reader.get_group(active_tab, [graph_title])
data = group["data"]
overlay_data = (
int(group["overlay_data"][:, x]) if "overlay_data" in group else None
)
fig = plotting.Histogram(
data=data[:, x], overlay_data=overlay_data
).create_plot()
return fig
return dash.no_update
@app.callback(
Output("export", "n_clicks"),
Output("download", "data"),
Input("graph_type", "value"),
Input("export", "n_clicks"),
prevent_initial_call=True,
)
def export_plots(graph_type, n_clicks):
if graph_type is not None and n_clicks is not None:
files = []
for group in groups_to_plot.values():
files += [
plot.export_plot() for plot in generate_plots(group, graph_type)
]
utils.merge_pdf_files(files, config.output_filename)
utils.delete_files(files, delete_dir=True)
return None, dcc.send_file(config.output_filename)
return dash.no_update
app.run_server(debug=True)
if __name__ == "__main__":
main(sys.argv[1:])