Skip to content

Commit

Permalink
Merge pull request #92 from seung-lab/openai
Browse files Browse the repository at this point in the history
Call openai compatible services to explain failed tasks
  • Loading branch information
ranlu authored Apr 28, 2024
2 parents 9a9b943 + 913eb8e commit c8953cf
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 9 deletions.
55 changes: 48 additions & 7 deletions dags/slack_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,61 @@ def task_retry_alert(context):
).format(**locals())
slack_alert(f":exclamation: Task up for retry {last_try} times already, <{log_url}|check the latest error log>", context)


def interpret_error_message(error_message):
from langchain_openai import ChatOpenAI
from airflow.hooks.base_hook import BaseHook

try:
conn = BaseHook.get_connection("LLMServer")
base_url = conn.host
api_key = conn.password
extra_args = conn.extra_dejson
except Exception:
return None

model = ChatOpenAI(base_url=base_url, api_key=api_key, model=extra_args.get("model", "gpt-3.5-turbo"))
messages = [
("system", "You are a helpful assistant, identify and explain the error message in a few words"),
("human", error_message),
]
try:
msg = model.invoke(messages)
return msg.content
except Exception:
return None


def task_failure_alert(context):
from airflow.models import Variable
import urllib.parse
from sqlalchemy import select
from airflow.models import Variable
from airflow.utils.log.log_reader import TaskLogReader
ti = context.get('task_instance')
last_try = ti.try_number - 1
iso = urllib.parse.quote(ti.execution_date.isoformat())
webui_ip = Variable.get("webui_ip", default_var="localhost")
log_url = "https://"+webui_ip + (
"/airflow/log"
"?dag_id={ti.dag_id}"
"&task_id={ti.task_id}"
"&execution_date={iso}"
).format(**locals())
log_url = f"https://{webui_ip}/airflow/log?dag_id={ti.dag_id}&task_id={ti.task_id}&execution_date={iso}"
slack_alert(f":exclamation: Task failed, <{log_url}|check the latest error log>", context)

task_log_reader = TaskLogReader()

if ti.queue == "manager":
metadata = {}
error_message = []
for text in task_log_reader.read_log_stream(ti, last_try, metadata):
lines = text.split("\n")
targets = ['error', 'traceback', 'exception']
for i, l in enumerate(lines):
if any(x in l.lower() for x in targets):
error_message = "\n".join(lines[i:i+10])
break
parsed_msg = interpret_error_message(error_message)
if parsed_msg:
slack_message(interpret_error_message(error_message))
else:
slack_message(f"Failed to use the LLM server to interpret the error message ```{error_message}```")


def task_done_alert(context):
return slack_alert(":heavy_check_mark: Task Finished", context)
4 changes: 2 additions & 2 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ RUN savedAptMark="$(apt-mark showmanual)" \
&& CONSTRAINT_URL="https://raw.githubusercontent.com/apache/airflow/constraints-${AIRFLOW_VERSION}/constraints-${PYTHON_VERSION}.txt" \
&& pip install --no-cache-dir -U pip \
&& pip install --no-cache-dir --compile --global-option=build git+https://github.com/seung-lab/chunk_iterator#egg=chunk-iterator \
&& pip install --no-cache-dir igneous-pipeline onnx \
&& pip install --no-cache-dir "apache-airflow[celery,postgres,rabbitmq,docker,slack,google,statsd]==${AIRFLOW_VERSION}" --constraint "${CONSTRAINT_URL}" \
&& pip install --no-cache-dir "apache-airflow[celery,postgres,rabbitmq,docker,slack,google,statsd,openai]==${AIRFLOW_VERSION}" --constraint "${CONSTRAINT_URL}" \
&& pip install --no-cache-dir -U igneous-pipeline onnx langchain-openai \
&& mkdir -p ${AIRFLOW_HOME}/version \
&& groupadd -r docker \
&& groupadd -r ${AIRFLOW_USER} \
Expand Down
7 changes: 7 additions & 0 deletions pipeline/init_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ def parse_metadata():
conn_id='Slack', conn_type='http',
host='localhost', extra=json.dumps({"notification_channel": os.environ.get("SLACK_NOTIFICATION_CHANNEL", "seuron-alerts")})))

db_utils.merge_conn(
models.Connection(
conn_id='LLMServer', conn_type='openai',
host=os.environ.get('LLM_HOST', 'https://api.openai.com/v1'),
password="",
extra=json.dumps({"model": os.environ.get("LLM_MODEL", "gpt-3.5-turbo")})))

if os.environ.get("VENDOR", None) == "Google":
deployment = os.environ.get("DEPLOYMENT", None)
zone = os.environ.get("ZONE", None)
Expand Down

0 comments on commit c8953cf

Please sign in to comment.