diff --git a/api/neurolibre_celery_tasks.py b/api/neurolibre_celery_tasks.py index 6babdd5..d483e0e 100644 --- a/api/neurolibre_celery_tasks.py +++ b/api/neurolibre_celery_tasks.py @@ -72,19 +72,27 @@ def get_time(): """ class BaseNeuroLibreTask: - def __init__(self, celery_task, task_title, payload, screening=None): + def __init__(self, celery_task, task_title, screening=None, payload=None): self.celery_task = celery_task self.task_title = task_title self.payload = payload self.task_id = celery_task.request.id - self.screening = screening if screening else ScreeningClient( - self.task_title, - payload['issue_id'], - payload['repo_url'], - self.task_id, - payload['comment_id'] - ) - self.owner, self.repo, self.provider = get_owner_repo_provider(payload['repo_url'], provider_full_name=True) + if screening: + if not isinstance(screening, ScreeningClient): + raise TypeError("The 'screening' parameter must be an instance of ScreeningClient") + self.screening = screening + self.owner_name, self.repo_name, self.provider_name = get_owner_repo_provider(screening.target_repo_url, provider_full_name=True) + elif payload: + # This will be probably deprecated soon. + self.screening = ScreeningClient( + self.task_title, + payload['issue_id'], + payload['repo_url'], + self.task_id, + payload['comment_id']) + self.owner_name, self.repo_name, self.provider_name = get_owner_repo_provider(payload['repo_url'], provider_full_name=True) + else: + raise ValueError("Either screening or payload must be provided.") def start(self, message=""): self.screening.respond().STARTED(message) @@ -138,21 +146,21 @@ def sleep_task(self, seconds): return 'done sleeping for {} seconds'.format(seconds) @celery_app.task(bind=True) -def preview_download_data(self, payload): +def preview_download_data(self, screening): """ Downloading data to the preview server. """ - task = BaseNeuroLibreTask(self, "DATA DOWNLOAD (REPO2DATA)", payload) + task = BaseNeuroLibreTask(self, "DATA DOWNLOAD (REPO2DATA)", screening) task.start("Started downloading the data.") try: - contents = task.screening.repo.get_contents("binder/data_requirement.json") + contents = task.screening.repo_object.get_contents("binder/data_requirement.json") logging.debug(contents.decoded_content) data_manifest = json.loads(contents.decoded_content) # Create a temporary directory to store the data manifest - os.makedirs(task.join_data_root_path("tmp_repo2data",task.owner,task.repo),exist_ok=True) + os.makedirs(task.join_data_root_path("tmp_repo2data",task.owner_name,task.repo_name),exist_ok=True) # Write the data manifest to the temporary directory - json_path = task.join_data_root_path("tmp_repo2data",task.owner,task.repo,"data_requirement.json") + json_path = task.join_data_root_path("tmp_repo2data",task.owner_name,task.repo_name,"data_requirement.json") with open(json_path,"w") as f: json.dump(data_manifest,f) if not data_manifest: @@ -160,13 +168,13 @@ def preview_download_data(self, payload): project_name = data_manifest['projectName'] except Exception as e: message = f"Data download has failed: {str(e)}" - if payload['email']: - send_email(payload['email'], f"{JOURNAL_NAME}: Data download request", message) + if screening.email: + send_email(screening.email, f"{JOURNAL_NAME}: Data download request", message) else: task.fail(f"Data exists for {project_name}; not overwriting by default! Please set overwrite=True.") data_path = task.join_data_root_path(project_name) - if os.path.exists(data_path) and not payload['overwrite']: + if os.path.exists(data_path) and not screening.is_overwrite: task.fail(f"Data exists for {project_name} already downloaded to {data_path}; \ not overwriting by default! Please set overwrite=True.") return @@ -178,8 +186,8 @@ def preview_download_data(self, payload): message = f"Downloaded data in {downloaded_data_path}." # Update status - if payload['email']: - send_email(payload['email'], f"{JOURNAL_NAME}: Data download request", message) + if screening.email: + send_email(screening.email, f"{JOURNAL_NAME}: Data download request", message) self.update_state(state=states.SUCCESS, meta={'message': message}) else: task.succeed(message) diff --git a/api/screening_client.py b/api/screening_client.py index b75d4c9..7a920ee 100644 --- a/api/screening_client.py +++ b/api/screening_client.py @@ -33,9 +33,10 @@ def __init__(self, task_name, issue_id, target_repo_url = None, task_id=None, co self.__gh_bot_token = os.getenv('GH_BOT') self.github_client = Github(self.__gh_bot_token) if self.target_repo_url: - self.repo = self.github_client.get_repo(self.gh_filter(self.target_repo_url)) + self.repo_object = self.github_client.get_repo(self.gh_filter(self.target_repo_url)) else: - self.repo = None + self.repo_object = None + if self.comment_id is None: self.comment_id = self.respond().PENDING("Awaiting task assignment...") @@ -71,7 +72,7 @@ def gh_response_template(self, message="", collapse=True): message = "" return { - "PENDING": f"⚫ **{self.task_name}**\n----------------------------\n**Status:** Waiting for task assignment\n**Last updated:** {cur_time}\n{message}\n:recycle: [Refresh](https://github.com/neurolibre/neurolibre-reviews/issues/{self.issue_id}#issuecomment-{self.comment_id})", + "PENDING": f"⚫ **{self.task_name}**\n----------------------------\n**Status:** Waiting for task assignment\n**Last updated:** {cur_time}\n{message}", "RECEIVED": f"⚪ **{self.task_name}**\n----------------------------\n**Status:** Assigned to task `{self.task_id[0:8]}`\n**Last updated:** {cur_time}\n{message}\n:recycle: [Refresh](https://github.com/neurolibre/neurolibre-reviews/issues/{self.issue_id}#issuecomment-{self.comment_id})", "STARTED": f"🟠 **{self.task_name}**\n----------------------------\n**Status:** In progress `{self.task_id[0:8]}`\n**Last updated:** {cur_time}\n{message}\n:recycle: [Refresh](https://github.com/neurolibre/neurolibre-reviews/issues/{self.issue_id}#issuecomment-{self.comment_id})", "SUCCESS": f"🟢 **{self.task_name}**\n----------------------------\n**Status:** Success `{self.task_id[0:8]}`\n**Last updated:** {cur_time}\n{message}",