Skip to content

Commit

Permalink
unbundle rebundle
Browse files Browse the repository at this point in the history
  • Loading branch information
agahkarakuzu committed Sep 3, 2024
1 parent 81ebcb3 commit 34d1d7d
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
16 changes: 9 additions & 7 deletions api/neurolibre_celery_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,22 @@ def get_time():
"""

class BaseNeuroLibreTask:
def __init__(self, celery_task, task_title, screening=None, payload=None):
def __init__(self, celery_task, screening=None, payload=None):
self.celery_task = celery_task
self.task_title = task_title
self.payload = payload
self.task_id = celery_task.request.id
if screening:
if not isinstance(screening, ScreeningClient):
raise TypeError("The 'screening' parameter must be an instance of ScreeningClient")
self.screening = screening
# If passed here, must be JSON serialization of ScreeningClient object.
# We need to unpack these to pass to ScreeningClient to initialize it as an object.
standard_attrs = ['task_name', 'issue_id', 'target_repo_url', 'task_id', 'comment_id', 'commit_hash']
standard_dict = {key: screening.pop(key) for key in standard_attrs if key in screening}
extra_payload = screening
self.screening = ScreeningClient(**standard_dict, **extra_payload)
self.owner_name, self.repo_name, self.provider_name = get_owner_repo_provider(screening.target_repo_url, provider_full_name=True)
elif payload:
# This will be probably deprecated soon.
# This will be probably deprecated soon. For now, reserve for backward compatibility.
self.screening = ScreeningClient(
self.task_title,
payload['task_name'],
payload['issue_id'],
payload['repo_url'],
self.task_id,
Expand Down
19 changes: 18 additions & 1 deletion api/screening_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self, task_name, issue_id, target_repo_url = None, task_id="0000000
self.target_repo_url = target_repo_url
self.commit_hash = commit_hash
self.comment_id = comment_id
self.__extra_payload = extra_payload

for key, value in extra_payload.items():
setattr(self, key, value)
Expand All @@ -37,12 +38,28 @@ def __init__(self, task_name, issue_id, target_repo_url = None, task_id="0000000
else:
self.repo_object = None

# If no comment ID is provided, create a new comment with a pending status
if self.comment_id is None:
self.comment_id = self.respond().PENDING("Awaiting task assignment...")

def to_dict(self):
# Convert the object to a dictionary to pass to Celery
result = {
'task_name': self.task_name,
'issue_id': self.issue_id,
'target_repo_url': self.target_repo_url,
'task_id': self.task_id,
'comment_id': self.comment_id,
'commit_hash': self.commit_hash,
}
result.update(self.__extra_payload)
return result

def start_celery_task(self, celery_task_func):

task_result = celery_task_func.apply_async(args=[self])
# This trick is needed to pass the ScreeningClient object to the Celery task.
# This is because the ScreeningClient object cannot be serialized into JSON, which is required by Redis.
task_result = celery_task_func.apply_async(args=[self.to_dict()])

if task_result.task_id is not None:
self.task_id = task_result.task_id
Expand Down

0 comments on commit 34d1d7d

Please sign in to comment.