-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
236 lines (188 loc) · 8 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import json
import os
from fastapi import HTTPException
from read_metrics import get_traces, extract_prompt_and_completion
from opentelemetry.sdk.trace.export import ConsoleSpanExporter, SimpleSpanProcessor
import openlit as openlit
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse
from pythia.ai_hallucination import ask_pythia, search_qids, \
entity_search, predicate_search
from pythia.validator import ValidatorPool
import logging
from starlette_prometheus import PrometheusMiddleware, metrics
from prometheus_client import Histogram
from opentelemetry import trace
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.exporter.jaeger.thrift import JaegerExporter
import threading
from opentelemetry.trace.status import Status, StatusCode
# Prometheus Histogram for each metric in the dictionary
accuracy_metric = Histogram('model_accuracy', 'Accuracy of the model')
entailment_metric = Histogram('model_entailment', 'Entailment score of the model')
contradiction_metric = Histogram('model_contradiction', 'Contradiction score of the model')
neutral_metric = Histogram('model_neutral', 'Neutral score of the model')
app = FastAPI()
app.add_middleware(PrometheusMiddleware)
app.add_route("/metrics", metrics)
logging.basicConfig(level=logging.INFO)
@app.post('/ask-pythia')
async def ask_pythia_api(request: Request):
data = await request.json()
question = None
input_reference = None
input_response = None
validators = None
try:
input_response = data["response"]
except Exception as e:
raise HTTPException(status_code=400, detail="Invalid JSON passed, must include a 'response' key {}".format(e))
try:
input_reference = data["reference"]
except Exception as e:
raise HTTPException(status_code=400, detail="Invalid JSON passed, must include a 'reference' key {}".format(e))
try:
question = data["question"]
except Exception as e:
print("No question set")
try:
validators = data["validators"]
except KeyError:
print("No validators set")
clam_checker_data = ask_pythia(input_reference=input_reference,
input_response=input_response,
question=question,
validators_list=validators)
return clam_checker_data
@app.post('/search-qids')
async def search_qids_api(request: Request):
data = await request.json()
try:
question = data["question"]
except KeyError:
raise HTTPException(status_code=400, detail="Invalid JSON passed, must include a 'question' key")
try:
qids = search_qids(question)
return qids
except Exception as e:
raise HTTPException(status_code=500, detail="Invalid JSON passed")
@app.post('/search_entity')
async def search_entity_api(request: Request):
data = await request.json()
try:
name = data["name"]
except KeyError:
raise HTTPException(status_code=400, detail="Invalid JSON passed, must include a 'name' key")
ignore_case = data.get("ignore_case", False)
matching_strategy = data.get("matching_strategy", "FUZZY")
limit = data.get("limit", 10)
try:
search_results = entity_search(
name=name,
ignore_case=ignore_case,
matching_strategy=matching_strategy,
limit=limit
)
return search_results, 200
except Exception as e:
raise HTTPException(status_code=500, detail=f'Error processing request: {e}')
@app.post('/search_predicate')
async def search_predicate_api(request: Request):
data = await request.json()
try:
name = data["name"]
except KeyError:
raise HTTPException(status_code=400, detail='Invalid JSON passed, must include a "name" key')
ignore_case = data.get("ignore_case", True)
matching_strategy = data.get("matching_strategy", "CONTAINS")
limit = data.get("limit", 10)
try:
search_results = predicate_search(
name=name,
ignore_case=ignore_case,
matching_strategy=matching_strategy,
limit=limit
)
return search_results, 200
except Exception as e:
raise HTTPException(status_code=500, detail=f'Error processing request: {e}')
@app.get("/", response_class=HTMLResponse)
async def orpheus_pythia():
return '''
<h1> Orpheus Pythia Application </h1>
<h4> V1.0.0 </h4>
'''
def get_model_metrics():
service_name = os.getenv("JAEGER_SERVICE_NAME")
traces = get_traces(service_name)
for trace_obj in traces:
try:
print("Process Trace with id: {}".format(trace_obj["traceID"]))
system_message, user_prompt, completion = extract_prompt_and_completion(trace_obj)
if system_message is None or user_prompt is None or completion is None:
return None
validators = ValidatorPool().enabled_validators
claim = ask_pythia(input_reference=system_message,
input_response=completion,
question=user_prompt,
validators_list=validators)
if claim is not None:
trace_pythia_response(claim)
else:
print("Pythia result are none trace {} was not process".format(trace_obj["traceID"]))
except Exception as e:
print("Error Processing Trace {}".format(e))
# Update Prometheus metrics with the values from the dictionary
READ_INTERVAL = os.getenv("READ_INTERVAL", "10")
def update_metrics_job():
try:
print("Read Metrics ....")
data = get_model_metrics()
if data is not None:
# Set the Prometheus metrics with values from the dictionary
accuracy_metric.observe(data['accuracy'])
entailment_metric.observe(data['entailment'])
contradiction_metric.observe(data['contradiction'])
neutral_metric.observe(data['neutral'])
print("Metrics updated")
except Exception as e:
print("Fail to update Metrics {}".format(e))
threading.Timer(int(READ_INTERVAL), update_metrics_job).start()
JAEGER_HOST = os.getenv("JAEGER_HOST", "jaeger")
JAEGER_PORT = os.getenv("JAEGER_PORT", "6831")
def trace_pythia_response(pythia_response):
# 1. Create Jaeger exporter
jaeger_exporter = JaegerExporter(
agent_host_name=JAEGER_HOST, # Jaeger agent host (default localhost)
agent_port=int(JAEGER_PORT) # Jaeger agent port (default 6831)
)
tracer_provider = TracerProvider(
resource=Resource.create({"service.name": "pythia-service"})
)
# 2. Set the TracerProvider with resource information
trace.set_tracer_provider(tracer_provider)
span_processor = BatchSpanProcessor(jaeger_exporter)
tracer_provider.add_span_processor(span_processor)
# 4. Get a tracer instance
tracer = trace.get_tracer(__name__)
# 5. Create a new trace/span and add dictionary values as attributes
with tracer.start_as_current_span("ask-pythia") as span:
print("Create new Trace ...")
for key, value in pythia_response.get("metrics").items():
span.set_attribute(key, value)
for validator in pythia_response.get("validatorsResults"):
try:
span.set_attribute("{}.isValid".format(validator["validator"]["name"]),
validator.get("isValid"))
span.set_attribute("{}.riskScore".format(validator["validator"]["name"]),
validator.get("riskScore"))
except Exception as e:
pass
span.set_status(Status(status_code=StatusCode.OK))
print("Tracing complete.")
update_metrics_job()
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)