diff --git a/app.py b/app.py index 187f2a3..b7d880d 100644 --- a/app.py +++ b/app.py @@ -1,6 +1,11 @@ # Written by Dr Daniel Buscombe, Marda Science LLC # for the USGS Coastal Change Hazards Program -# +########################################################################## +# Note: +# Modfiied by Jin Ikeda, LSU Center for Computation and Technology on 2024-10-30 +# Jin Ikeda added the ability to upload JSON files for editing for labeling work on doodler +########################################################################## + # MIT License # # Copyright (c) 2020-2022, Marda Science LLC @@ -23,19 +28,20 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -#======================================================== -## ``````````````````````````` local imports +# ======================================================== +# ``````````````````````````` local imports # allows loading of functions from the src directory -import sys,os +import sys +import os # sys.path.insert(1, 'app_files'+os.sep+'src') # from annotations_to_segmentations import * from doodler_engine.annotations_to_segmentations import * -#======================================================== -## ``````````````````````````` imports -##======================================================== +# ======================================================== +# ``````````````````````````` imports +# ======================================================== -## dash/plotly/flask +# dash/plotly/flask import plotly.express as px import plotly.graph_objects as go # import skimage.util @@ -46,33 +52,42 @@ from dash.dependencies import Input, Output, State try: from dash import html -except: +except BaseException: import dash_html_components as html try: from dash import dcc -except: +except BaseException: import dash_core_components as dcc from flask import Flask from flask_caching import Cache -#others -import base64, PIL.Image, json, shutil, time, logging, psutil +# others +import base64 +import PIL.Image +import json +import shutil +import time +import logging +import psutil from datetime import datetime from doodler_engine.app_funcs import uploaded_files, get_asset_files -def get_unlabeled_files()-> list: + +def get_unlabeled_files() -> list: """Returns a sorted list of unlabeled files - + uses the environment variables UPLOAD_DIRECTORY and LABELED_DIRECTORY to determine - which files were are unlabeled and labeled + which files were are unlabeled and labeled Returns: list: sorted list of unlabeled files - """ - #this file must exist - it contains a list of images labeled in this session + """ + # this file must exist - it contains a list of images labeled in this + # session filelist = 'files_done.txt' - files, labeled_files = uploaded_files(filelist,UPLOAD_DIRECTORY,LABELED_DIRECTORY) + files, labeled_files = uploaded_files( + filelist, UPLOAD_DIRECTORY, LABELED_DIRECTORY) logging.info('File list written to %s' % (filelist)) files = [f.split('assets/')[-1] for f in files] labeled_files = [f.split('labeled/')[-1] for f in labeled_files] @@ -83,19 +98,22 @@ def get_unlabeled_files()-> list: logging.info(f"Unlabeled files: {unlabeled_files}") return unlabeled_files -##======================================================== +# ======================================================== + + def make_and_return_default_figure( - images,#=[DEFAULT_IMAGE_PATH], - stroke_color,#=convert_integer_class_to_color(class_label_colormap,DEFAULT_LABEL_CLASS), - pen_width,#=DEFAULT_PEN_WIDTH, - shapes#=[], + images, # =[DEFAULT_IMAGE_PATH], + # =convert_integer_class_to_color(class_label_colormap,DEFAULT_LABEL_CLASS), + stroke_color, + pen_width, # =DEFAULT_PEN_WIDTH, + shapes # =[], ): """ create and return the default Dash/plotly figure object """ - fig = dummy_fig() #plot_utils. + fig = dummy_fig() # plot_utils. - add_layout_images_to_fig(fig, images) #plot_utils. + add_layout_images_to_fig(fig, images) # plot_utils. fig.update_layout( { @@ -111,7 +129,7 @@ def make_and_return_default_figure( return fig -##======================================================== +# ======================================================== def dummy_fig(): """ create a dummy figure to be later modified """ fig = go.Figure(go.Scatter(x=[], y=[])) @@ -122,13 +140,15 @@ def dummy_fig(): ) return fig -##======================================================== +# ======================================================== + + def pil2uri(img): """ conevrts PIL image to uri""" return ImageUriValidator.pil_image_to_uri(img) -##======================================================== +# ======================================================== def parse_contents(contents, filename, date): return html.Div([ html.H5(filename), @@ -145,41 +165,45 @@ def parse_contents(contents, filename, date): }) ]) -#======================================================== -## defaults -#======================================================== +# ======================================================== +# defaults +# ======================================================== + -DEFAULT_IMAGE_PATH = "assets"+os.sep+"logos"+os.sep+"dash-default.jpg" +DEFAULT_IMAGE_PATH = "assets" + os.sep + "logos" + os.sep + "dash-default.jpg" try: from my_defaults import * print('Hyperparameters imported from my_defaults.py') -except: +except BaseException: from doodler_engine.defaults import * print('Default hyperparameters imported from src/my_defaults.py') -#======================================================== -## logs -#======================================================== +# ======================================================== +# logs +# ======================================================== -logging.basicConfig(filename=os.getcwd()+os.sep+'app_files'+os.sep+'logs'+ - os.sep+datetime.now().strftime("%Y-%m-%d-%H-%M")+'.log', +logging.basicConfig(filename=os.getcwd() + os.sep + 'app_files' + os.sep + 'logs' + + os.sep + datetime.now().strftime("%Y-%m-%d-%H-%M") + '.log', level=logging.INFO) -logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p') +logging.basicConfig( + format='%(asctime)s %(message)s', + datefmt='%m/%d/%Y %I:%M:%S %p') -#======================================================== -## folders -#======================================================== +# ======================================================== +# folders +# ======================================================== -UPLOAD_DIRECTORY = os.getcwd()+os.sep+"assets" -LABELED_DIRECTORY = os.getcwd()+os.sep+"labeled" -results_folder = 'results'+os.sep+'results'+datetime.now().strftime("%Y-%m-%d-%H-%M") +UPLOAD_DIRECTORY = os.getcwd() + os.sep + "assets" +LABELED_DIRECTORY = os.getcwd() + os.sep + "labeled" +results_folder = 'results' + os.sep + 'results' + \ + datetime.now().strftime("%Y-%m-%d-%H-%M") try: os.mkdir(results_folder) logging.info(datetime.now().strftime("%Y-%m-%d-%H-%M-%S")) logging.info("Folder created: %s" % (results_folder)) -except: +except BaseException: pass logging.info(datetime.now().strftime("%Y-%m-%d-%H-%M-%S")) @@ -188,12 +212,12 @@ def parse_contents(contents, filename, date): if not os.path.exists(UPLOAD_DIRECTORY): os.makedirs(UPLOAD_DIRECTORY) logging.info(datetime.now().strftime("%Y-%m-%d-%H-%M-%S")) - logging.info('Made the directory '+UPLOAD_DIRECTORY) + logging.info('Made the directory ' + UPLOAD_DIRECTORY) -##======================================================== -## classes -#======================================================== +# ======================================================== +# classes +# ======================================================== # the number of different classes for labels DEFAULT_LABEL_CLASS = 0 @@ -201,20 +225,20 @@ def parse_contents(contents, filename, date): try: with open('classes.txt') as f: classes = f.readlines() -except: #in case classes.txt does not exist +except BaseException: # in case classes.txt does not exist print("classes.txt not found or badly formatted. \ - Exit the program and fix the classes.txt file ... \ otherwise, will continue using default classes. ") + Exit the program and fix the classes.txt file ... \\ otherwise, will continue using default classes. ") classes = ['water', 'land'] class_label_names = [c.strip() for c in classes] NUM_LABEL_CLASSES = len(class_label_names) -#======================================================== -## colormap -#======================================================== +# ======================================================== +# colormap +# ======================================================== -if NUM_LABEL_CLASSES<=10: +if NUM_LABEL_CLASSES <= 10: class_label_colormap = px.colors.qualitative.G10 else: class_label_colormap = px.colors.qualitative.Light24 @@ -229,9 +253,9 @@ def parse_contents(contents, filename, date): for f in class_label_names: logging.info(f) -#======================================================== -## image asset files -#======================================================== +# ======================================================== +# image asset files +# ======================================================== # files = get_asset_files() # list of unlabeled files in assets directory files = get_unlabeled_files() @@ -241,18 +265,18 @@ def parse_contents(contents, filename, date): for f in files: logging.info(f) -##======================================================== +# ======================================================== # app, server, and cache -#======================================================== +# ======================================================== # Normally, Dash creates its own Flask server internally. By creating our own, # we can create a route for downloading files directly: server = Flask(__name__) app = dash.Dash(server=server) -#app = dash.Dash(__name__) -#server = app.server -app.config.suppress_callback_exceptions=True +# app = dash.Dash(__name__) +# server = app.server +app.config.suppress_callback_exceptions = True # app = dash.Dash(__name__) # app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) # server = app.server @@ -260,258 +284,304 @@ def parse_contents(contents, filename, date): cache = Cache(app.server, config={ 'CACHE_TYPE': 'filesystem', - 'CACHE_DIR': 'app_files'+os.sep+'cache-directory' + 'CACHE_DIR': 'app_files' + os.sep + 'cache-directory' }) -##======================================================== -## app layout -##======================================================== +# ======================================================== +# app layout +# ======================================================== app.layout = html.Div( id="app-container", children=[ - #======================================================== - ## tab 1 - #======================================================== + + # Store to keep track of the last clicked classification button + dcc.Store(id="selected-label-class", data=class_label_names[0]), + # Store to keep track of all drawn shapes + dcc.Store(id="drawn-shapes", data=[]), + + # ======================================================== + # tab 1 + # ======================================================== html.Div( id="banner", children=[ html.H1( - "Doodler: Interactive Image Segmentation", - id="title", - className="seven columns", - ), + "Doodler: Interactive Image Segmentation", + id="title", + className="seven columns", + ), - html.Img(id="logo", src=app.get_asset_url("logos"+os.sep+"dash-logo-new.png")), - # html.Div(html.Img(src=app.get_asset_url('logos/dash-logo-new.png'), style={'height':'10%', 'width':'10%'})), #id="logo", - - html.H2(""), - dcc.Upload( - id="upload-data", - children=html.Div( - [" "] #(Label all classes that are present, in all regions of the image those classes occur) - ), - style={ - "width": "100%", - "height": "30px", - "lineHeight": "70px", - "borderWidth": "1px", - "borderStyle": "none", - "borderRadius": "1px", - "textAlign": "center", - "margin": "10px", - }, - multiple=True, - ), - html.H2(""), - html.Ul(id="file-list"), + html.Img( + id="logo", + src=app.get_asset_url( + "logos" + + os.sep + + "dash-logo-new.png")), + # html.Div(html.Img(src=app.get_asset_url('logos/dash-logo-new.png'), + # style={'height':'10%', 'width':'10%'})), #id="logo", + + html.H2(""), + dcc.Upload( + id="upload-data", + children=html.Div( + # (Label all classes that are present, in all regions of the image those classes occur) + [" "] + ), + style={ + "width": "100%", + "height": "30px", + "lineHeight": "70px", + "borderWidth": "1px", + "borderStyle": "none", + "borderRadius": "1px", + "textAlign": "center", + "margin": "10px", + }, + multiple=True, + ), + html.H2(""), + html.Ul(id="file-list"), - ], #children - ), #div banner id + ], # children + ), # div banner id - dcc.Tabs([ - dcc.Tab(label='Imagery and Controls', children=[ - - html.Div( - id="main-content", - children=[ + dcc.Tabs([ + dcc.Tab(label='Imagery and Controls', children=[ html.Div( - id="left-column", + id="main-content", children=[ - dcc.Loading( - id="segmentations-loading", - type="cube", + + html.Div( + id="left-column", children=[ - # Graph - dcc.Graph( - id="graph", - figure=make_and_return_default_figure( - images=[DEFAULT_IMAGE_PATH], - stroke_color=convert_integer_class_to_color(class_label_colormap,DEFAULT_LABEL_CLASS), - pen_width=DEFAULT_PEN_WIDTH, - shapes=[], - ), - config={ - 'displayModeBar': 'hover', - "displaylogo": False, - "modeBarButtonsToRemove": [ - "toImage", - "hoverClosestCartesian", - "hoverCompareCartesian", - "toggleSpikelines", - ], - "modeBarButtonsToAdd": [ - "drawopenpath", - "eraseshape", - ] - }, + dcc.Loading( + id="segmentations-loading", + type="cube", + children=[ + # Graph + dcc.Graph( + id="graph", + figure=make_and_return_default_figure( + images=[DEFAULT_IMAGE_PATH], + stroke_color=convert_integer_class_to_color( + class_label_colormap, DEFAULT_LABEL_CLASS), + pen_width=DEFAULT_PEN_WIDTH, + shapes=[], + ), + config={ + 'displayModeBar': 'hover', + "displaylogo": False, + "modeBarButtonsToRemove": [ + "hoverClosestCartesian", + "hoverCompareCartesian", + "toggleSpikelines", + ], + "modeBarButtonsToAdd": [ + "drawopenpath", + "eraseshape", + "toImage", # Able to take a photo + ] + }, + ), + ], ), + ], + className="ten columns app-background", ), - ], - className="ten columns app-background", - ), + html.Div( + id="right-column", + children=[ - html.Div( - id="right-column", - children=[ + html.H6("Label class"), + # Label class chosen with buttons + html.Div( + id="label-class-buttons", + children=[ + html.Button( + # "%2d" % (n,), + "%s" % (class_label_names[n],), + id={"type": "label-class-button", + "index": n}, + style={ + "background-color": convert_integer_class_to_color(class_label_colormap, c)}, + ) + for n, c in enumerate(class_labels) + ], + ), - html.H6("Label class"), - # Label class chosen with buttons - html.Div( - id="label-class-buttons", - children=[ - html.Button( - #"%2d" % (n,), - "%s" % (class_label_names[n],), - id={"type": "label-class-button", "index": n}, - style={"background-color": convert_integer_class_to_color(class_label_colormap,c)}, - ) - for n, c in enumerate(class_labels) - ], - ), + html.H6(id="pen-width-display"), + # Slider for specifying pen width + dcc.Slider( + id="pen-width", + min=0, + max=5, + step=1, + value=DEFAULT_PEN_WIDTH, + ), - html.H6(id="pen-width-display"), - # Slider for specifying pen width - dcc.Slider( - id="pen-width", - min=0, - max=5, - step=1, - value=DEFAULT_PEN_WIDTH, - ), + # Indicate showing most recently computed + # segmentation + dcc.Checklist( + id="crf-show-segmentation", + options=[ + { + "label": "Compute/Show segmentation", + "value": "Show segmentation", + } + ], + value=[], + ), - # Indicate showing most recently computed segmentation - dcc.Checklist( - id="crf-show-segmentation", - options=[ - { - "label": "Compute/Show segmentation", - "value": "Show segmentation", - } - ], - value=[], - ), + dcc.Markdown( + ">Post-processing settings" + ), - dcc.Markdown( - ">Post-processing settings" - ), + html.H6(id="theta-display"), + # Slider for specifying pen width + dcc.Slider( + id="crf-theta-slider", + min=1, + max=20, + step=1, + value=DEFAULT_CRF_THETA, + ), - html.H6(id="theta-display"), - # Slider for specifying pen width - dcc.Slider( - id="crf-theta-slider", - min=1, - max=20, - step=1, - value=DEFAULT_CRF_THETA, - ), + html.H6(id="mu-display"), + # Slider for specifying pen width + dcc.Slider( + id="crf-mu-slider", + min=1, + max=20, + step=1, + value=DEFAULT_CRF_MU, + ), - html.H6(id="mu-display"), - # Slider for specifying pen width - dcc.Slider( - id="crf-mu-slider", - min=1, - max=20, - step=1, - value=DEFAULT_CRF_MU, - ), + html.H6(id="crf-downsample-display"), + # Slider for specifying pen width + dcc.Slider( + id="crf-downsample-slider", + min=1, + max=6, + step=1, + value=DEFAULT_CRF_DOWNSAMPLE, + ), - html.H6(id="crf-downsample-display"), - # Slider for specifying pen width - dcc.Slider( - id="crf-downsample-slider", - min=1, - max=6, - step=1, - value=DEFAULT_CRF_DOWNSAMPLE, - ), + # html.H6(id="crf-gtprob-display"), + # # Slider for specifying pen width + # dcc.Slider( + # id="crf-gtprob-slider", + # min=0.5, + # max=0.95, + # step=0.05, + # value=DEFAULT_CRF_GTPROB, + # ), + + dcc.Markdown( + ">Classifier settings" + ), - # html.H6(id="crf-gtprob-display"), - # # Slider for specifying pen width - # dcc.Slider( - # id="crf-gtprob-slider", - # min=0.5, - # max=0.95, - # step=0.05, - # value=DEFAULT_CRF_GTPROB, - # ), - - dcc.Markdown( - ">Classifier settings" - ), + html.H6(id="rf-downsample-display"), + # Slider for specifying pen width + dcc.Slider( + id="rf-downsample-slider", + min=1, + max=20, + step=1, + value=DEFAULT_RF_DOWNSAMPLE, + ), - html.H6(id="rf-downsample-display"), - # Slider for specifying pen width - dcc.Slider( - id="rf-downsample-slider", - min=1, - max=20, - step=1, - value=DEFAULT_RF_DOWNSAMPLE, - ), + html.H6(id="numscales-display"), + # Slider for specifying pen width + dcc.Slider( + id="numscales-slider", + min=2, + max=6, + step=1, + value=DEFAULT_NUMSCALES, + ), - html.H6(id="numscales-display"), - # Slider for specifying pen width - dcc.Slider( - id="numscales-slider", - min=2, - max=6, - step=1, - value=DEFAULT_NUMSCALES, - ), + # File upload for drawn (saved) JSON + dcc.Upload( + id="upload-json-file", + children=html.Button( + "Upload JSON File", id="upload-button"), + accept=".json" + ), + # Display upload status + html.Div(id="JSON-upload"), + + # Save button to store edited polylines + # Input field for filename + dcc.Input( + id="filename-input", + type="text", + placeholder="Enter filename (optional)"), + html.Button( + "Save Polylines", id="save-button", n_clicks=0), # Save button + # Output message for save confirmation + html.Div(id="save-output"), + + # Output message for save confirmation, placed + # directly below the button + dcc.Interval( + id="shape-check-interval", + interval=3000, # Runs every 3 seconds; adjust as needed + n_intervals=0 + ) + ], + className="three columns app-background", + ), ], - className="three columns app-background", - ), - ], - className="ten columns", - ), #main content Div + className="ten columns", + ), # main content Div - #======================================================== - ## tab 2 - #======================================================== + # ======================================================== + # tab 2 + # ======================================================== - ]), - dcc.Tab(label='File List and Instructions', children=[ + ]), + dcc.Tab(label='File List and Instructions', children=[ - html.H4(children="Doodler"), - dcc.Markdown( - "> A user-interactive tool for fast segmentation of imagery (designed for natural environments), using a Multilayer Perceptron classifier and Conditional Random Field (CRF) refinement. \ + html.H4(children="Doodler"), + dcc.Markdown( + "> A user-interactive tool for fast segmentation of imagery (designed for natural environments), using a Multilayer Perceptron classifier and Conditional Random Field (CRF) refinement. \ Doodles are used to make a classifier model, which maps image features to unary potentials to create an initial image segmentation. The segmentation is then refined using a CRF model." - ), + ), - dcc.Input(id='my-id', value='Enter-user-ID', type="text"), - html.Button('Submit', id='button'), - html.Div(id='my-div'), - - html.H3("Select Image"), - dcc.Dropdown( - id="select-image", - optionHeight=15, - style={'fontSize': 13}, - options = [ - {'label': image.split('assets/')[-1], 'value': image } \ - for image in files - ], - - value='assets/logos/dash-default.jpg', # - multi=False, - ), - html.Div([html.Div(id='live-update-text'), - dcc.Interval(id='interval-component', interval=500, n_intervals=0)]), - - - html.P(children="This image/Copy"), - dcc.Textarea(id="thisimage_output", cols=80), - html.Br(), - - dcc.Markdown( - """ + dcc.Input(id='my-id', value='Enter-user-ID', type="text"), + html.Button('Submit', id='button'), + html.Div(id='my-div'), + + html.H3("Select Image"), + dcc.Dropdown( + id="select-image", + optionHeight=15, + style={'fontSize': 13}, + options=[ + {'label': image.split('assets/')[-1], 'value': image} \ + for image in files + ], + + value='assets/logos/dash-default.jpg', + multi=False, + ), + html.Div([html.Div(id='live-update-text'), + dcc.Interval(id='interval-component', interval=500, n_intervals=0)]), + + + html.P(children="This image/Copy"), + dcc.Textarea(id="thisimage_output", cols=80), + html.Br(), + + dcc.Markdown( + """ **Instructions:** * Before you begin, make a new 'classes.txt' file that contains a list of the classes you'd like to label * Optionally, you can copy the images you wish to label into the 'assets' folder (just jpg, JPG or jpeg extension, or mixtures of those, for now) @@ -526,22 +596,22 @@ def parse_contents(contents, filename, date): * As you go, the program only lists files that are yet to be labeled. It does this irrespective of your opinion of the segmentation, so you get 'one shot' before you select another image (i.e. you cant go back to redo) * [Code on GitHub](https://github.com/dbuscombe-usgs/dash_doodler). """ - ), - dcc.Markdown( - """ + ), + dcc.Markdown( + """ **Tips:** 1) Works best for small imagery, typically much smaller than 3000 x 3000 px images. This prevents out-of-memory errors, and also helps you identify small features\ 2) Less is usually more! It is often best to use small pen width and relatively few annotations. Don't be tempted to spend too long doodling; extra doodles can be strategically added to correct segmentations \ 3) Make doodles of every class present in the image, and also every region of the image (i.e. avoid label clusters) \ 4) If things get weird, hit the refresh button on your browser and it should reset the application. Don't worry, all your previous work is saved!\ 5) Remember to uncheck 'Show/compute segmentation' before you change parameter values or change image\ """ - ), + ), - ]),]), + ]),]), - #======================================================== - ## components that are not displayed, used for storing data in localhost - #======================================================== + # ======================================================== + # components that are not displayed, used for storing data in localhost + # ======================================================== html.Div( id="no-display", @@ -558,66 +628,69 @@ def parse_contents(contents, filename, date): dcc.Store(id="segmentation", data={}), dcc.Store(id="classified-image-store", data=""), ], - ), #nos-display div + ), # nos-display div - ], #children -) #app layout + ], # children +) # app layout # ##======================================================== -##======================================================== -## app callbacks -##======================================================== +# ======================================================== +# app callbacks +# ======================================================== @app.callback( [ - Output("select-image","options"), - Output("graph", "figure"), - Output("image-list-store", "data"), - Output("masks", "data"), - Output('my-div', 'children'), - Output("segmentation", "data"), - Output('thisimage_output', 'value'), - Output("pen-width-display", "children"), - Output("theta-display", "children"), - Output("mu-display", "children"), - Output("crf-downsample-display", "children"), - # Output("crf-gtprob-display", "children"), - Output("rf-downsample-display", "children"), - Output("numscales-display", "children"), - Output("classified-image-store", "data"), + Output("select-image", "options"), + Output("graph", "figure"), + Output("image-list-store", "data"), + Output("masks", "data"), + Output('my-div', 'children'), + Output("segmentation", "data"), + Output('thisimage_output', 'value'), + Output("pen-width-display", "children"), + Output("theta-display", "children"), + Output("mu-display", "children"), + Output("crf-downsample-display", "children"), + # Output("crf-gtprob-display", "children"), + Output("rf-downsample-display", "children"), + Output("numscales-display", "children"), + Output("classified-image-store", "data"), + Output("JSON-upload", "children"), # Display upload status ], [ - Input("upload-data", "filename"), - Input("upload-data", "contents"), - Input("graph", "relayoutData"), - Input( - {"type": "label-class-button", "index": dash.dependencies.ALL}, - "n_clicks_timestamp", - ), - Input("crf-theta-slider", "value"), - Input('crf-mu-slider', "value"), - Input("pen-width", "value"), - Input("crf-show-segmentation", "value"), - Input("crf-downsample-slider", "value"), - # Input("crf-gtprob-slider", "value"), - Input("rf-downsample-slider", "value"), - Input("numscales-slider", "value"), - Input("select-image", "value"), - Input('interval-component', 'n_intervals'), + Input("upload-data", "filename"), + Input("upload-data", "contents"), + Input("graph", "relayoutData"), + Input( + {"type": "label-class-button", "index": dash.dependencies.ALL}, + "n_clicks_timestamp", + ), + Input("crf-theta-slider", "value"), + Input('crf-mu-slider', "value"), + Input("pen-width", "value"), + Input("crf-show-segmentation", "value"), + Input("crf-downsample-slider", "value"), + # Input("crf-gtprob-slider", "value"), + Input("rf-downsample-slider", "value"), + Input("numscales-slider", "value"), + Input("select-image", "value"), + Input('interval-component', 'n_intervals'), + # Contents of the upload xxx.json file + Input("upload-json-file", "contents"), ], [ - State("image-list-store", "data"), - State('my-id', 'value'), - State("masks", "data"), - State("segmentation", "data"), - State("classified-image-store", "data"), + State("image-list-store", "data"), + State('my-id', 'value'), + State("masks", "data"), + State("segmentation", "data"), + State("classified-image-store", "data"), + State("upload-json-file", "filename"), # Display imported filename ], ) - # ##======================================================== -##======================================================== -## app callback function -##======================================================== +# ======================================================== +# app callback function +# ======================================================== def update_output( uploaded_filenames, uploaded_file_contents, @@ -632,12 +705,14 @@ def update_output( n_sigmas, select_image_value, n_intervals, + json_contents, image_list_data, my_id_value, masks_data, segmentation_data, segmentation_store_data, - ): + json_filename +): """ This is where all the action happens, and is called any time a button is pressed This function is automatically called, and the inputs and outputs match, in order, @@ -646,8 +721,9 @@ def update_output( The callback context is first defined, which dictates what the function does """ - callback_context = [p["prop_id"] for p in dash.callback_context.triggered][0] - #print(callback_context) + callback_context = [p["prop_id"] + for p in dash.callback_context.triggered][0] + # print("callback_context:",callback_context) multichannel = True intensity = True @@ -659,19 +735,22 @@ def update_output( files = '' options = [] - # Remove any "_" from my_id_value and if the my_id_value is empty replace with TEMPID - my_id_value = my_id_value.replace("_","") - if(len(my_id_value) == 0): - my_id_value='TEMPID' + # Remove any "_" from my_id_value and if the my_id_value is empty replace + # with TEMPID + my_id_value = my_id_value.replace("_", "") + if (len(my_id_value) == 0): + my_id_value = 'TEMPID' - if callback_context=='interval-component.n_intervals': + if callback_context == 'interval-component.n_intervals': unlabeled_files = get_unlabeled_files() - options = [{'label': image, 'value': image } for image in unlabeled_files] - logging.info('Checked assets and labeled lists and revised list of images yet to label') + options = [{'label': image, 'value': image} + for image in unlabeled_files] + logging.info( + 'Checked assets and labeled lists and revised list of images yet to label') if select_image_value is not None: if 'assets' not in select_image_value: - select_image_value = 'assets'+os.sep+select_image_value + select_image_value = 'assets' + os.sep + select_image_value if callback_context == "graph.relayoutData": try: @@ -679,14 +758,14 @@ def update_output( masks_data["shapes"] = graph_relayoutData["shapes"] else: return dash.no_update - except: + except BaseException: return dash.no_update elif callback_context == "select-image.value": - masks_data={"shapes": []} - segmentation_data={} + masks_data = {"shapes": []} + segmentation_data = {} - logging.info('New image selected') + logging.info('New image selected') pen_width = pen_width_value @@ -700,16 +779,55 @@ def update_output( )[0] fig = make_and_return_default_figure( - images = [select_image_value], - stroke_color=convert_integer_class_to_color(class_label_colormap,label_class_value), + images=[select_image_value], + stroke_color=convert_integer_class_to_color( + class_label_colormap, label_class_value), pen_width=pen_width, shapes=masks_data["shapes"], ) logging.info('Main figure window updated with new image') + ########################################################################## + # Loading a JSON file part + ########################################################################## + json_processed_Flag = False # Flag to control return after JSON processing + + # JSON file handling + if callback_context == "upload-json-file.contents" and json_contents is not None: + # print("json_contents available:", json_contents) # for Debugging + print("json_filename:", json_filename) + + try: + # print("Attempting to add polylines to the current image") # for + # Debugging + # Read polylines coordinates + shapes = parse_jason(json_contents, json_filename) + # Update loaded data as drawn polylines + masks_data["shapes"] = shapes + # print("shapes", shapes) # for Debugging + print("JSON file loaded successfully") + + fig = make_and_return_default_figure( + images=[select_image_value or DEFAULT_IMAGE_PATH], + stroke_color=convert_integer_class_to_color( + class_label_colormap, DEFAULT_LABEL_CLASS), + pen_width=pen_width_value, + shapes=shapes + ) + json_processed_Flag = True # Set flag indicating JSON was processed + except Exception as e: + error_message = f"There was an error processing this file: {str(e)}" + print(error_message) + return [ + dash.no_update, dash.no_update, dash.no_update, dash.no_update, + dash.no_update, dash.no_update, dash.no_update, dash.no_update, + dash.no_update, dash.no_update, dash.no_update, dash.no_update, + dash.no_update, dash.no_update, error_message + ] + if ("Show segmentation" in show_segmentation_value) and ( - len(masks_data["shapes"]) > 0): + len(masks_data["shapes"]) > 0): # to store segmentation data in the store, we need to base64 encode the # PIL.Image and hash the set of shapes to use this as the key # to retrieve the segmentation data, we need to base64 decode to a PIL.Image @@ -717,94 +835,187 @@ def update_output( sh = shapes_to_key( [ masks_data["shapes"], - '', #segmentation_features_value, - '', #sigma_range_slider_value, + '', # segmentation_features_value, + '', # sigma_range_slider_value, ] ) segimgpng = None # start timer - if os.name=='posix': # true if linux/mac or cygwin on windows + if os.name == 'posix': # true if linux/mac or cygwin on windows start = time.time() - else: # windows + else: # windows try: start = time.clock() - except: + except BaseException: start = time.perf_counter() - # this is the function that computes and updates the segmentation whenever the checkbox is checked - segimgpng, seg, img, color_doodles, doodles = show_segmentation( + # this is the function that computes and updates the segmentation + # whenever the checkbox is checked + segimgpng, seg, img, color_doodles, doodles = show_segmentation( [select_image_value], masks_data["shapes"], callback_context, - crf_theta_slider_value, crf_mu_slider_value, results_folder, rf_downsample_value, crf_downsample_value, 1.0, my_id_value, - n_sigmas, multichannel, intensity, edges, texture,class_label_colormap + crf_theta_slider_value, crf_mu_slider_value, results_folder, rf_downsample_value, crf_downsample_value, 1.0, my_id_value, + n_sigmas, multichannel, intensity, edges, texture, class_label_colormap ) logging.info('... showing segmentation on screen') logging.info('percent RAM usage: %f' % (psutil.virtual_memory()[2])) - if os.name=='posix': # true if linux/mac - elapsed = (time.time() - start)#/60 - else: # windows - #elapsed = (time.clock() - start)#/60 + if os.name == 'posix': # true if linux/mac + elapsed = (time.time() - start) # /60 + else: # windows + # elapsed = (time.clock() - start)#/60 try: - elapsed = (time.clock() - start)#/60 - except: - elapsed = (time.perf_counter() - start)#/60 + elapsed = (time.clock() - start) # /60 + except BaseException: + elapsed = (time.perf_counter() - start) # /60 logging.info('Processing took %s seconds' % (str(elapsed))) - lstack = (np.arange(seg.max()) == seg[...,None]-1).astype(int) #one-hot encode the 2D label into 3D stack of IxJxN classes + # one-hot encode the 2D label into 3D stack of IxJxN classes + lstack = (np.arange(seg.max()) == seg[..., None] - 1).astype(int) logging.info('One-hot encoded label stack created') - - if type(select_image_value) is list: + if isinstance(select_image_value, list): if 'jpg' in select_image_value[0]: - colfile = select_image_value[0].replace('assets',results_folder).replace('.jpg','_label'+datetime.now().strftime("%Y-%m-%d-%H-%M")+'_'+my_id_value+'.png') + colfile = select_image_value[0].replace( + 'assets', + results_folder).replace( + '.jpg', + '_label' + + datetime.now().strftime("%Y-%m-%d-%H-%M") + + '_' + + my_id_value + + '.png') if 'JPG' in select_image_value[0]: - colfile = select_image_value[0].replace('assets',results_folder).replace('.JPG','_label'+datetime.now().strftime("%Y-%m-%d-%H-%M")+'_'+my_id_value+'.png') + colfile = select_image_value[0].replace( + 'assets', + results_folder).replace( + '.JPG', + '_label' + + datetime.now().strftime("%Y-%m-%d-%H-%M") + + '_' + + my_id_value + + '.png') if 'jpeg' in select_image_value[0]: - colfile = select_image_value[0].replace('assets',results_folder).replace('.jpeg','_label'+datetime.now().strftime("%Y-%m-%d-%H-%M")+'_'+my_id_value+'.png') - - if np.ndim(img)==3: - imsave(colfile,label_to_colors(seg-1, img[:,:,0]==0, alpha=128, colormap=class_label_colormap, color_class_offset=0, do_alpha=False)) + colfile = select_image_value[0].replace( + 'assets', + results_folder).replace( + '.jpeg', + '_label' + + datetime.now().strftime("%Y-%m-%d-%H-%M") + + '_' + + my_id_value + + '.png') + + if np.ndim(img) == 3: + imsave(colfile, + label_to_colors(seg - 1, + img[:, + :, + 0] == 0, + alpha=128, + colormap=class_label_colormap, + color_class_offset=0, + do_alpha=False)) else: - imsave(colfile,label_to_colors(seg-1, img==0, alpha=128, colormap=class_label_colormap, color_class_offset=0, do_alpha=False)) + imsave( + colfile, + label_to_colors( + seg - 1, + img == 0, + alpha=128, + colormap=class_label_colormap, + color_class_offset=0, + do_alpha=False)) orig_image = imread(select_image_value[0]) - if np.ndim(orig_image)>3: - orig_image = orig_image[:,:,:3] + if np.ndim(orig_image) > 3: + orig_image = orig_image[:, :, :3] else: if 'jpg' in select_image_value: - colfile = select_image_value.replace('assets',results_folder).replace('.jpg','_label'+datetime.now().strftime("%Y-%m-%d-%H-%M")+'_'+my_id_value+'.png') + colfile = select_image_value.replace( + 'assets', + results_folder).replace( + '.jpg', + '_label' + + datetime.now().strftime("%Y-%m-%d-%H-%M") + + '_' + + my_id_value + + '.png') if 'JPG' in select_image_value: - colfile = select_image_value.replace('assets',results_folder).replace('.JPG','_label'+datetime.now().strftime("%Y-%m-%d-%H-%M")+'_'+my_id_value+'.png') + colfile = select_image_value.replace( + 'assets', + results_folder).replace( + '.JPG', + '_label' + + datetime.now().strftime("%Y-%m-%d-%H-%M") + + '_' + + my_id_value + + '.png') if 'jpeg' in select_image_value: - colfile = select_image_value.replace('assets',results_folder).replace('.jpeg','_label'+datetime.now().strftime("%Y-%m-%d-%H-%M")+'_'+my_id_value+'.png') - - if np.ndim(img)==3: - imsave(colfile,label_to_colors(seg-1, img[:,:,0]==0, alpha=128, colormap=class_label_colormap, color_class_offset=0, do_alpha=False)) + colfile = select_image_value.replace( + 'assets', + results_folder).replace( + '.jpeg', + '_label' + + datetime.now().strftime("%Y-%m-%d-%H-%M") + + '_' + + my_id_value + + '.png') + + if np.ndim(img) == 3: + imsave(colfile, + label_to_colors(seg - 1, + img[:, + :, + 0] == 0, + alpha=128, + colormap=class_label_colormap, + color_class_offset=0, + do_alpha=False)) else: - imsave(colfile,label_to_colors(seg-1, img==0, alpha=128, colormap=class_label_colormap, color_class_offset=0, do_alpha=False)) + imsave( + colfile, + label_to_colors( + seg - 1, + img == 0, + alpha=128, + colormap=class_label_colormap, + color_class_offset=0, + do_alpha=False)) orig_image = imread(select_image_value) - if np.ndim(orig_image)>3: - orig_image = orig_image[:,:,:3] + if np.ndim(orig_image) > 3: + orig_image = orig_image[:, :, :3] # orig_image = img_to_ubyte_array(select_image_value) logging.info('RGB label image saved to %s' % (colfile)) - settings_dict = np.array([pen_width, crf_downsample_value, rf_downsample_value, crf_theta_slider_value, crf_mu_slider_value, 1.0, n_sigmas]) + settings_dict = np.array([pen_width, + crf_downsample_value, + rf_downsample_value, + crf_theta_slider_value, + crf_mu_slider_value, + 1.0, + n_sigmas]) - if type(select_image_value) is list: + if isinstance(select_image_value, list): if 'jpg' in select_image_value[0]: - numpyfile = select_image_value[0].replace('assets',results_folder).replace('.jpg','_'+my_id_value+'.npz') + numpyfile = select_image_value[0].replace( + 'assets', results_folder).replace( + '.jpg', '_' + my_id_value + '.npz') if 'JPG' in select_image_value[0]: - numpyfile = select_image_value[0].replace('assets',results_folder).replace('.JPG','_'+my_id_value+'.npz') + numpyfile = select_image_value[0].replace( + 'assets', results_folder).replace( + '.JPG', '_' + my_id_value + '.npz') if 'jpeg' in select_image_value[0]: - numpyfile = select_image_value[0].replace('assets',results_folder).replace('.jpeg','_'+my_id_value+'.npz') + numpyfile = select_image_value[0].replace( + 'assets', results_folder).replace( + '.jpeg', '_' + my_id_value + '.npz') if os.path.exists(numpyfile): saved_data = np.load(numpyfile) @@ -812,7 +1023,7 @@ def update_output( for k in saved_data.keys(): tmp = saved_data[k] name = str(k) - savez_dict['0'+name] = tmp + savez_dict['0' + name] = tmp del tmp savez_dict['image'] = img.astype(np.uint8) @@ -822,7 +1033,7 @@ def update_output( savez_dict['doodles'] = doodles.astype(np.uint8) savez_dict['settings'] = settings_dict savez_dict['classes'] = class_label_names - np.savez_compressed(numpyfile, **savez_dict ) + np.savez_compressed(numpyfile, **savez_dict) else: savez_dict = dict() @@ -833,15 +1044,22 @@ def update_output( savez_dict['doodles'] = doodles.astype(np.uint8) savez_dict['settings'] = settings_dict savez_dict['classes'] = class_label_names - np.savez_compressed(numpyfile, **savez_dict ) #save settings too + np.savez_compressed( + numpyfile, **savez_dict) # save settings too else: if 'jpg' in select_image_value: - numpyfile = select_image_value.replace('assets',results_folder).replace('.jpg','_'+my_id_value+'.npz') + numpyfile = select_image_value.replace( + 'assets', results_folder).replace( + '.jpg', '_' + my_id_value + '.npz') if 'JPG' in select_image_value: - numpyfile = select_image_value.replace('assets',results_folder).replace('.JPG','_'+my_id_value+'.npz') + numpyfile = select_image_value.replace( + 'assets', results_folder).replace( + '.JPG', '_' + my_id_value + '.npz') if 'jpeg' in select_image_value: - numpyfile = select_image_value.replace('assets',results_folder).replace('.jpeg','_'+my_id_value+'.npz') + numpyfile = select_image_value.replace( + 'assets', results_folder).replace( + '.jpeg', '_' + my_id_value + '.npz') if os.path.exists(numpyfile): saved_data = np.load(numpyfile) @@ -849,7 +1067,7 @@ def update_output( for k in saved_data.keys(): tmp = saved_data[k] name = str(k) - savez_dict['0'+name] = tmp + savez_dict['0' + name] = tmp del tmp savez_dict['image'] = img.astype(np.uint8) @@ -859,7 +1077,8 @@ def update_output( savez_dict['doodles'] = doodles.astype(np.uint8) savez_dict['settings'] = settings_dict savez_dict['classes'] = class_label_names - np.savez_compressed(numpyfile, **savez_dict )#save settings too + np.savez_compressed( + numpyfile, **savez_dict) # save settings too else: savez_dict = dict() @@ -870,7 +1089,8 @@ def update_output( savez_dict['doodles'] = doodles.astype(np.uint8) savez_dict['settings'] = settings_dict savez_dict['classes'] = class_label_names - np.savez_compressed(numpyfile, **savez_dict )#save settings too + np.savez_compressed( + numpyfile, **savez_dict) # save settings too logging.info('percent RAM usage: %f' % (psutil.virtual_memory()[2])) @@ -886,39 +1106,51 @@ def update_output( segmentation_store_data = pil2uri( seg_pil( select_image_value, segimgpng, do_alpha=True - ) #plot_utils. + ) # plot_utils. ) - shutil.copyfile(select_image_value, select_image_value.replace('assets', 'labeled')) #move - except: + shutil.copyfile( + select_image_value, + select_image_value.replace( + 'assets', + 'labeled')) # move + except BaseException: segmentation_store_data = pil2uri( seg_pil( PIL.Image.open(select_image_value), segimgpng, do_alpha=True - ) #plot_utils. + ) # plot_utils. ) - shutil.copyfile(select_image_value, select_image_value.replace('assets', 'labeled')) #move + shutil.copyfile( + select_image_value, + select_image_value.replace( + 'assets', + 'labeled')) # move - logging.info('%s moved to labeled folder' % (select_image_value.replace('assets', 'labeled'))) + logging.info('%s moved to labeled folder' % + (select_image_value.replace('assets', 'labeled'))) images_to_draw = [] if segimgpng is not None: images_to_draw = [segimgpng] - fig = add_layout_images_to_fig(fig, images_to_draw) #plot_utils. + fig = add_layout_images_to_fig(fig, images_to_draw) # plot_utils. show_segmentation_value = [] image_list_data.append(select_image_value) try: - os.remove('my_defaults.py') - except: - pass + os.remove('my_defaults.py') + except BaseException: + pass - #write defaults back out to file + # write defaults back out to file with open('my_defaults.py', 'a') as the_file: the_file.write('DEFAULT_PEN_WIDTH = {}\n'.format(pen_width)) - the_file.write('DEFAULT_CRF_DOWNSAMPLE = {}\n'.format(crf_downsample_value)) - the_file.write('DEFAULT_RF_DOWNSAMPLE = {}\n'.format(rf_downsample_value)) - the_file.write('DEFAULT_CRF_THETA = {}\n'.format(crf_theta_slider_value)) + the_file.write( + 'DEFAULT_CRF_DOWNSAMPLE = {}\n'.format(crf_downsample_value)) + the_file.write( + 'DEFAULT_RF_DOWNSAMPLE = {}\n'.format(rf_downsample_value)) + the_file.write( + 'DEFAULT_CRF_THETA = {}\n'.format(crf_theta_slider_value)) the_file.write('DEFAULT_CRF_MU = {}\n'.format(crf_mu_slider_value)) # the_file.write('DEFAULT_CRF_GTPROB = {}\n'.format(gt_prob)) the_file.write('DEFAULT_NUMSCALES = {}\n'.format(n_sigmas)) @@ -931,39 +1163,207 @@ def update_output( # logging.info(datetime.now().strftime("%Y-%m-%d-%H-%M-%S")) logging.info('percent RAM usage: %f' % (psutil.virtual_memory()[2])) - if len(files) == 0: + if json_processed_Flag: return [ - options, - fig, - image_list_data, - masks_data, - segmentation_data, - 'User ID: "{}"'.format(my_id_value) , - select_image_value, - "Pen width (default: %d): %d" % (DEFAULT_PEN_WIDTH,pen_width), - "Blur factor (default: %d): %d" % (DEFAULT_CRF_THETA, crf_theta_slider_value), - "Model independence factor (default: %d): %d" % (DEFAULT_CRF_MU,crf_mu_slider_value), - "CRF downsample factor (default: %d): %d" % (DEFAULT_CRF_DOWNSAMPLE,crf_downsample_value), - # "User-defined quality score (1=perfect. default: %f): %f" % (DEFAULT_CRF_GTPROB,gt_prob), - "Classifier downsample factor (default: %d): %d" % (DEFAULT_RF_DOWNSAMPLE,rf_downsample_value), - "Number of scales (default: %d): %d" % (DEFAULT_NUMSCALES,n_sigmas), - segmentation_store_data, + options, + fig, + image_list_data, + masks_data, + segmentation_data, + 'User ID: "{}"'.format(my_id_value), + select_image_value, + "Pen width (default: %d): %d" % (DEFAULT_PEN_WIDTH, pen_width), + "Blur factor (default: %d): %d" % ( + DEFAULT_CRF_THETA, crf_theta_slider_value), + "Model independence factor (default: %d): %d" % ( + DEFAULT_CRF_MU, crf_mu_slider_value), + "CRF downsample factor (default: %d): %d" % ( + DEFAULT_CRF_DOWNSAMPLE, crf_downsample_value), + # "User-defined quality score (1=perfect. default: %f): %f" % (DEFAULT_CRF_GTPROB,gt_prob), + "Classifier downsample factor (default: %d): %d" % ( + DEFAULT_RF_DOWNSAMPLE, rf_downsample_value), + "Number of scales (default: %d): %d" % ( + DEFAULT_NUMSCALES, n_sigmas), + segmentation_store_data, + "JSON file loaded successfully." ] + + elif (len(files) == 0): + return [ + options, + fig, + image_list_data, + masks_data, + segmentation_data, + 'User ID: "{}"'.format(my_id_value), + select_image_value, + "Pen width (default: %d): %d" % (DEFAULT_PEN_WIDTH, pen_width), + "Blur factor (default: %d): %d" % ( + DEFAULT_CRF_THETA, crf_theta_slider_value), + "Model independence factor (default: %d): %d" % ( + DEFAULT_CRF_MU, crf_mu_slider_value), + "CRF downsample factor (default: %d): %d" % ( + DEFAULT_CRF_DOWNSAMPLE, crf_downsample_value), + # "User-defined quality score (1=perfect. default: %f): %f" % (DEFAULT_CRF_GTPROB,gt_prob), + "Classifier downsample factor (default: %d): %d" % ( + DEFAULT_RF_DOWNSAMPLE, rf_downsample_value), + "Number of scales (default: %d): %d" % ( + DEFAULT_NUMSCALES, n_sigmas), + segmentation_store_data, + "Upload a JSON file to load polylines." + ] + else: return [ - options, - fig, - image_list_data, - masks_data, - segmentation_data, - 'User ID: "{}"'.format(my_id_value) , - select_image_value, - "Pen width (default: %d): %d" % (DEFAULT_PEN_WIDTH,pen_width), - "Blur factor (default: %d): %d" % (DEFAULT_CRF_THETA, crf_theta_slider_value), - "Model independence factor (default: %d): %d" % (DEFAULT_CRF_MU,crf_mu_slider_value), - "CRF downsample factor (default: %d): %d" % (DEFAULT_CRF_DOWNSAMPLE,crf_downsample_value), - # "User-defined quality score (1=perfect. default: %f): %f" % (DEFAULT_CRF_GTPROB,gt_prob), - "Classifier downsample factor (default: %d): %d" % (DEFAULT_RF_DOWNSAMPLE,rf_downsample_value), - "Number of scales (default: %d): %d" % (DEFAULT_NUMSCALES,n_sigmas), - segmentation_store_data, + options, + fig, + image_list_data, + masks_data, + segmentation_data, + 'User ID: "{}"'.format(my_id_value), + select_image_value, + "Pen width (default: %d): %d" % (DEFAULT_PEN_WIDTH, pen_width), + "Blur factor (default: %d): %d" % ( + DEFAULT_CRF_THETA, crf_theta_slider_value), + "Model independence factor (default: %d): %d" % ( + DEFAULT_CRF_MU, crf_mu_slider_value), + "CRF downsample factor (default: %d): %d" % ( + DEFAULT_CRF_DOWNSAMPLE, crf_downsample_value), + # "User-defined quality score (1=perfect. default: %f): %f" % (DEFAULT_CRF_GTPROB,gt_prob), + "Classifier downsample factor (default: %d): %d" % ( + DEFAULT_RF_DOWNSAMPLE, rf_downsample_value), + "Number of scales (default: %d): %d" % ( + DEFAULT_NUMSCALES, n_sigmas), + segmentation_store_data, + "Upload a JSON file to load polylines." ] + +# Function for loading JSON file + + +def parse_jason(contents, filename): + content_type, content_string = contents.split(',') + decoded = base64.b64decode(content_string) + + if filename.endswith('.json'): + data = json.loads(decoded) + if data: + shapes = [] + for shape in data: + # Remove `label_class` if it exists and is a string or null + if "label_class" in shape and ( + shape["label_class"] is None or isinstance(shape["label_class"], str)): + del shape["label_class"] + shapes.append(shape) + + return shapes + else: + raise ValueError("No shapes found in JSON data.") + else: + raise ValueError("Unsupported file type.") + + +@app.callback( + Output("selected-label-class", "data"), + [Input({"type": "label-class-button", + "index": dash.dependencies.ALL}, + "n_clicks_timestamp")], +) +def update_selected_label_class(n_clicks_timestamp): + if n_clicks_timestamp and any( + n_clicks_timestamp): # Check if any button was clicked + # Identify the button with the most recent click + clicked_index = n_clicks_timestamp.index(max(n_clicks_timestamp)) + # Retrieve the corresponding label class + selected_class = class_label_names[clicked_index] + # Not only for Debugging: Display selected class in the console (keep) + print("Selected label class:", selected_class) + return selected_class # Store the label class in dcc.Store + return None + + +@app.callback( + Output("drawn-shapes", "data"), + [Input("shape-check-interval", "n_intervals")], + [State("graph", "figure"), State("drawn-shapes", "data"), + State("selected-label-class", "data")] +) +def check_for_current_shapes( + n_intervals, figure_data, current_shapes, label_class): + if current_shapes is None: + current_shapes = [] # Initialize if empty + + # Extract paths and shape details currently visible in the graph layout + layout_shapes = { + shape["path"]: shape for shape in figure_data.get( + "layout", {}).get( + "shapes", [])} + layout_paths = set(layout_shapes.keys()) + # print("Paths in layout:", layout_paths) # Debugging output + + # Extract paths of shapes stored in current_shapes + stored_paths = {shape["path"] for shape in current_shapes} + # print("Paths in current_shapes:", stored_paths) # Debugging output + + # Detect erased shapes (those in stored_paths but not in layout_paths) + erased_paths = stored_paths - layout_paths + if erased_paths: + # print("Erased shapes detected:", erased_paths) # Debugging output + erased_paths = set() # Clear the erased_paths set to delete from memory + + # Remove erased shapes from current_shapes + updated_shapes = [ + shape for shape in current_shapes if shape["path"] in layout_paths] + + # Detect and update modified shapes + for shape in updated_shapes: + path = shape["path"] + if path in layout_shapes and shape != layout_shapes[path]: + # print("Modified shape detected:", shape) # Debugging output + # Update shape's vertices and other details + shape.update(layout_shapes[path]) + + # Add any new shapes from layout_paths not in stored_paths + new_shapes = [ + # Add label_class to each new shape + {**shape, "label_class": label_class} + for shape in layout_shapes.values() + if shape["path"] not in stored_paths + ] + if new_shapes: + updated_shapes.extend(new_shapes) + print( + "New shapes added to current_shapes with label class:", + new_shapes) # Debugging output + + return updated_shapes + + +@app.callback( + Output("save-output", "children"), + Input("save-button", "n_clicks"), + State("drawn-shapes", "data"), # Retrieve all drawn shapes + State("filename-input", "value") # Get the filename from input +) +def save_polylines(n_clicks, all_shapes, filename): + if n_clicks > 0: + if not all_shapes: + return "No shapes to save." + + # Define the save path based on filename input + if not filename or filename.strip() == "": + filename = f"polylines_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + else: + filename = filename.strip() + ".json" + + save_path = os.path.join("results", filename) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + # Save all shapes to JSON + with open(save_path, 'w') as file: + json.dump(all_shapes, file) + + print("Polylines saved successfully.") # Confirm save in terminal + return f"Polylines saved to {save_path}" + + return "Click 'Save Polylines' to save your drawing."