diff --git a/store/neurostore/models/event_listeners.py b/store/neurostore/models/event_listeners.py index 27b169dee..0583eb0ba 100644 --- a/store/neurostore/models/event_listeners.py +++ b/store/neurostore/models/event_listeners.py @@ -1,7 +1,14 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy import event from .data import ( - AnnotationAnalysis, Annotation, Studyset, Study, Analysis, Point, Image, _check_type + AnnotationAnalysis, + Annotation, + Studyset, + Study, + Analysis, + Point, + Image, + _check_type, ) from ..database import db @@ -124,37 +131,37 @@ def add_annotation_analyses_study(study, analyses, collection_adapter): # Define an event listener to update Base-Study flags -@event.listens_for(Analysis.points, 'append') -@event.listens_for(Analysis.points, 'remove') -@event.listens_for(Analysis.images, 'append') -@event.listens_for(Analysis.images, 'remove') -def update_base_study_flags(target, value, initiator): - base_study = getattr(getattr(target, 'study', None), 'base_study', None) - updated = False - if base_study is not None: - base_study.has_coordinates = isinstance(value, Point) or any( - analysis.points for study in base_study.versions for analysis in study.analyses - ) - base_study.has_images = isinstance(value, Image) or any( - analysis.images for study in base_study.versions for analysis in study.analyses - ) - db.session.add(base_study) - updated = True - - return updated - - -@event.listens_for(db.session, 'after_flush') -def update_base_study_flags_item_delete(session, flush_context): - any_updates = False - for obj in session.deleted: - if isinstance(obj, (Point, Image)): - target = obj.analysis_id - value = obj - initiator = "DELETE" - res = update_base_study_flags(target, value, initiator) - if not any_updates and res: - any_updates = True - - if any_updates: - db.session.commit() +# @event.listens_for(Analysis.points, 'append') +# @event.listens_for(Analysis.points, 'remove') +# @event.listens_for(Analysis.images, 'append') +# @event.listens_for(Analysis.images, 'remove') +# def update_base_study_flags(target, value, initiator): +# base_study = getattr(getattr(target, 'study', None), 'base_study', None) +# updated = False +# if base_study is not None: +# base_study.has_coordinates = isinstance(value, Point) or any( +# analysis.points for study in base_study.versions for analysis in study.analyses +# ) +# base_study.has_images = isinstance(value, Image) or any( +# analysis.images for study in base_study.versions for analysis in study.analyses +# ) +# db.session.add(base_study) +# updated = True + +# return updated + + +# @event.listens_for(db.session, 'after_flush') +# def update_base_study_flags_item_delete(session, flush_context): +# any_updates = False +# for obj in session.deleted: +# if isinstance(obj, (Point, Image)): +# target = obj.analysis_id +# value = obj +# initiator = "DELETE" +# res = update_base_study_flags(target, value, initiator) +# if not any_updates and res: +# any_updates = True + +# if any_updates: +# db.session.commit() diff --git a/store/neurostore/resources/base.py b/store/neurostore/resources/base.py index 0ad3872d2..b09c04eac 100644 --- a/store/neurostore/resources/base.py +++ b/store/neurostore/resources/base.py @@ -37,6 +37,7 @@ def create_user(): from auth0.v3.authentication.users import Users + auth = request.headers.get("Authorization", None) token = auth.split()[1] profile_info = Users( @@ -46,8 +47,7 @@ def create_user(): # user signed up with auth0, but has not made any queries yet... # should have endpoint to "create user" after sign on with auth0 current_user = User( - external_id=connexion.context["user"], - name=profile_info.get("name", "Unknown") + external_id=connexion.context["user"], name=profile_info.get("name", "Unknown") ) return current_user @@ -179,6 +179,7 @@ def update_or_create(cls, data, id=None, commit=True): setattr(record, field, nested) + # add other custom update after the nested attributes are handled... if commit: db.session.add_all(to_commit) try: @@ -324,8 +325,8 @@ def delete(self, id): abort(403) else: db.session.delete(record) - - db.session.commit() + self.post_delete(record) + db.session.commit() # clear relevant caches clear_cache(self.__class__, record, request.path) @@ -335,6 +336,9 @@ def delete(self, id): def insert_data(self, id, data): return data + def post_delete(record): + pass + LIST_USER_ARGS = { "search": fields.String(missing=None), diff --git a/store/neurostore/resources/data.py b/store/neurostore/resources/data.py index 02c750f62..6fb1aa326 100644 --- a/store/neurostore/resources/data.py +++ b/store/neurostore/resources/data.py @@ -367,6 +367,23 @@ def custom_record_update(record): return record + def post_delete(record): + base_study = getattr(record, "base_study", None) + if base_study: + base_study.has_images = any( + a.images + for study in base_study.versions + if study is not record + for a in study.analyses + ) + base_study.has_coordinates = any( + a.points + for study in base_study.versions + if study is not record + for a in study.analyses + ) + db.session.add(base_study) + @view_maker class AnalysesView(ObjectView, ListView): @@ -385,6 +402,34 @@ class AnalysesView(ObjectView, ListView): } _search_fields = ("name", "description") + def custom_record_update(record): + # need to do this custom update after the nested attributes are set + pass + + def post_delete(record): + study = getattr(record, "study", None) + if not study: + return + + if len(study.analyses) > 1: + return + + base_study = getattr(study, "base_study", None) + if base_study: + base_study.has_images = any( + a.images + for study in base_study.versions + for a in study.analyses + if a is not record + ) + base_study.has_coordinates = any( + a.points + for study in base_study.versions + for a in study.analyses + if a is not record + ) + db.session.add(base_study) + @view_maker class ConditionsView(ObjectView, ListView): @@ -399,7 +444,35 @@ class ImagesView(ObjectView, ListView): _search_fields = ("filename", "space", "value_type", "analysis_name") def custom_record_update(record): - pass + base_study = getattr(getattr(record, "study", None), "base_study", None) + if base_study is not None and base_study.has_images is False: + base_study.has_images = True + db.session.add(base_study) + return record + + def post_delete(record): + analysis = getattr(record, "analysis", None) + if not analysis: + return + + study = getattr(analysis, "study", None) + if not study: + return + + if len(study.analyses) > 1: + return + + if set(analysis.images) == set([record]): + base_study = getattr(study, "base_study", None) + if base_study: + base_study.has_images = any( + a.images + for study in base_study.versions + for a in study.analyses + if a is not analysis + ) + db.session.add(base_study) + @view_maker class PointsView(ObjectView, ListView): @@ -412,6 +485,45 @@ class PointsView(ObjectView, ListView): } _search_fields = ("space", "analysis_name") + def custom_record_update(record): + base_study = getattr( + getattr(getattr(record, "analysis", None), "study", None), + "base_study", + None, + ) + if base_study is not None and base_study.has_coordinates is False: + base_study.has_coordinates = True + db.session.add(base_study) + return record + + def post_delete(record): + analysis = getattr(record, "analysis", None) + # nothing to update if point not connected to analysis + if not analysis: + return + + study = getattr(analysis, "study", None) + # nothing to update if analysis is not connected to study + if not study: + return + + # nothing to update if there is more than 1 analysis + if len(study.analyses) > 1: + return + + # only care if the point is the last point in the analysis + if set(analysis.points) == set([record]): + base_study = getattr(study, "base_study", None) + if base_study: + base_study.has_coordinates = any( + a.points + for study in base_study.versions + for a in study.analyses + if a is not analysis + ) + + db.session.add(base_study) + @view_maker class PointValuesView(ObjectView, ListView): diff --git a/store/neurostore/tests/api/test_base_studies.py b/store/neurostore/tests/api/test_base_studies.py index 61932af1a..275c2cee3 100644 --- a/store/neurostore/tests/api/test_base_studies.py +++ b/store/neurostore/tests/api/test_base_studies.py @@ -56,7 +56,8 @@ def test_has_coordinates_images(auth_client, session): # update analysis with points analysis_image = auth_client.put( - f"/api/analyses/{analysis_id}", data={"images": [{"filename": "my_fake_image.nii.gz"}]} + f"/api/analyses/{analysis_id}", + data={"images": [{"filename": "my_fake_image.nii.gz"}]}, ) assert analysis_image.status_code == 200 diff --git a/store/neurostore/tests/conftest.py b/store/neurostore/tests/conftest.py index 70f2864e9..96d3aca0a 100644 --- a/store/neurostore/tests/conftest.py +++ b/store/neurostore/tests/conftest.py @@ -138,7 +138,7 @@ def app(mock_auth): _app.config["SQLALCHEMY_ENGINE_OPTIONS"] = { "max_overflow": -1, "pool_timeout": 5, - "pool_size": 0 + "pool_size": 0, } cache.clear() # Establish an application context before running the tests. diff --git a/store/neurostore/tests/test_auth.py b/store/neurostore/tests/test_auth.py index 62a49c441..fe1010899 100644 --- a/store/neurostore/tests/test_auth.py +++ b/store/neurostore/tests/test_auth.py @@ -22,7 +22,7 @@ def test_creating_new_user_on_db(add_users): client = Client( token=token_info[user_name]["token"], - username=token_info[user_name]["external_id"] + username=token_info[user_name]["external_id"], ) client.post("/api/studies/", data={"name": "my study"})