forked from DAGWorks-Inc/burr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dag.py
156 lines (118 loc) · 4.58 KB
/
dag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import copy
from typing import Dict, Optional, TypedDict
import openai
from hamilton.function_modifiers import config
ChatContents = TypedDict("ChatContents", {"role": str, "content": str, "type": str})
def processed_prompt(prompt: str) -> dict:
return {"role": "user", "content": prompt, "type": "text"}
@config.when(provider="openai")
def client() -> openai.Client:
return openai.Client()
def text_model() -> str:
return "gpt-3.5-turbo"
def image_model() -> str:
return "dall-e-2"
def safe(prompt: str) -> bool:
if "unsafe" in prompt:
return False
return True
def modes() -> Dict[str, str]:
return {
"answer_question": "text",
"generate_image": "image",
"generate_code": "code",
"unknown": "text",
}
def find_mode_prompt(prompt: str, modes: Dict[str, str]) -> str:
return (
f"You are a chatbot. You've been prompted this: {prompt}. "
f"You have the capability of responding in the following modes: {', '.join(modes)}. "
"Please respond with *only* a single word representing the mode that most accurately"
" corresponds to the prompt. Fr instance, if the prompt is 'draw a picture of a cat', "
"the mode would be 'generate_image'. If the prompt is 'what is the capital of France', the mode would be 'answer_question'."
"If none of these modes apply, please respond with 'unknown'."
)
def suggested_mode(find_mode_prompt: str, client: openai.Client, text_model: str) -> str:
result = client.chat.completions.create(
model=text_model,
messages=[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": find_mode_prompt},
],
)
content = result.choices[0].message.content
return content
def mode(suggested_mode: str, modes: Dict[str, str]) -> str:
# TODO -- use instructor!
print(f"Mode: {suggested_mode}")
lowercase = suggested_mode.lower()
if lowercase not in modes:
return "unknown" # default to unknown
return lowercase
def generated_text(chat_history: list[dict], text_model: str, client: openai.Client) -> str:
chat_history_api_format = [
{
"role": chat["role"],
"content": chat["content"],
}
for i, chat in enumerate(chat_history)
]
result = client.chat.completions.create(
model=text_model,
messages=chat_history_api_format,
)
return result.choices[0].message.content
# def generated_text_error(generated_text) -> dict:
# ...
def generated_image(prompt: str, image_model: str, client: openai.Client) -> str:
result = client.images.generate(
model=image_model, prompt=prompt, size="1024x1024", quality="standard", n=1
)
return result.data[0].url
def answered_question(chat_history: list[dict], text_model: str, client: openai.Client) -> str:
chat_history = copy.deepcopy(chat_history)
chat_history[-1][
"content"
] = f"Please answer the following question: {chat_history[-1]['content']}"
chat_history_api_format = [
{
"role": chat["role"],
"content": chat["content"],
}
for i, chat in enumerate(chat_history)
]
response = client.chat.completions.create(
model=text_model,
messages=chat_history_api_format,
)
return response.choices[0].message.content
def generated_code(chat_history: list[dict], text_model: str, client: openai.Client) -> str:
chat_history = copy.deepcopy(chat_history)
chat_history[-1][
"content"
] = f"Please respond to the following with *only* code: {chat_history[-1]['content']}"
chat_history_api_format = [
{
"role": chat["role"],
"content": chat["content"],
}
for i, chat in enumerate(chat_history)
]
response = client.chat.completions.create(
model=text_model,
messages=chat_history_api_format,
)
return response.choices[0].message.content
def prompt_for_more(modes: Dict[str, str]) -> str:
return (
f"I can't find a mode that applies to your input. Can you"
f" please clarify? I support: {', '.join(modes)}."
)
def processed_response(
response: Optional[str], mode: str, modes: Dict[str, str], safe: bool
) -> ChatContents:
if not safe:
return {"role": "assistant", "content": "I'm sorry, I can't do that.", "type": "text"}
return {"role": "assistant", "type": modes[mode], "content": response}
def processed_error(error: str) -> dict:
return {"role": "assistant", "error": error, "type": "text"}