-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Updates multi-agent-collaboration to be mostly compliant
@Stefan will get the rest of this
- Loading branch information
1 parent
18d7c25
commit f408f6b
Showing
7 changed files
with
242 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes
Empty file.
76 changes: 76 additions & 0 deletions
76
examples/other-examples/hamilton-multi-modal/application.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from typing import List | ||
|
||
import dag | ||
from hamilton import driver | ||
|
||
from burr.core import ApplicationBuilder, default, when | ||
from burr.integrations.hamilton import Hamilton, append_state, from_state, update_state | ||
from burr.lifecycle import LifecycleAdapter | ||
|
||
|
||
def application(hooks: List[LifecycleAdapter], app_id: str, storage_dir: str, project_id: str): | ||
dr = driver.Driver({"provider": "openai"}, dag) # TODO -- add modules | ||
Hamilton.set_driver(dr) | ||
application = ( | ||
ApplicationBuilder() | ||
.with_state(chat_history=[], prompt="Draw an image of a turtle saying 'hello, world'") | ||
.with_entrypoint("prompt") | ||
.with_state(chat_history=[]) | ||
.with_actions( | ||
prompt=Hamilton( | ||
inputs={"prompt": from_state("prompt")}, | ||
outputs={"processed_prompt": append_state("chat_history")}, | ||
), | ||
check_safety=Hamilton( | ||
inputs={"prompt": from_state("prompt")}, | ||
outputs={"safe": update_state("safe")}, | ||
), | ||
decide_mode=Hamilton( | ||
inputs={"prompt": from_state("prompt")}, | ||
outputs={"mode": update_state("mode")}, | ||
), | ||
generate_image=Hamilton( | ||
inputs={"prompt": from_state("prompt")}, | ||
outputs={"generated_image": update_state("response")}, | ||
), | ||
generate_code=Hamilton( # TODO -- implement | ||
inputs={"chat_history": from_state("chat_history")}, | ||
outputs={"generated_code": update_state("response")}, | ||
), | ||
answer_question=Hamilton( # TODO -- implement | ||
inputs={"chat_history": from_state("chat_history")}, | ||
outputs={"answered_question": update_state("response")}, | ||
), | ||
prompt_for_more=Hamilton( | ||
inputs={}, | ||
outputs={"prompt_for_more": update_state("response")}, | ||
), | ||
response=Hamilton( | ||
inputs={ | ||
"response": from_state("response"), | ||
"safe": from_state("safe"), | ||
"mode": from_state("mode"), | ||
}, | ||
outputs={"processed_response": append_state("chat_history")}, | ||
), | ||
) | ||
.with_transitions( | ||
("prompt", "check_safety", default), | ||
("check_safety", "decide_mode", when(safe=True)), | ||
("check_safety", "response", default), | ||
("decide_mode", "generate_image", when(mode="generate_image")), | ||
("decide_mode", "generate_code", when(mode="generate_code")), | ||
("decide_mode", "answer_question", when(mode="answer_question")), | ||
("decide_mode", "prompt_for_more", default), | ||
( | ||
["generate_image", "answer_question", "generate_code", "prompt_for_more"], | ||
"response", | ||
), | ||
("response", "prompt", default), | ||
) | ||
.with_hooks(*hooks) | ||
.with_identifiers(app_id=app_id) | ||
.with_tracker("local", project=project_id, params={"storage_dir": storage_dir}) | ||
.build() | ||
) | ||
return application |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
import copy | ||
from typing import Dict, Optional, TypedDict | ||
|
||
import openai | ||
from hamilton.function_modifiers import config | ||
|
||
ChatContents = TypedDict("ChatContents", {"role": str, "content": str, "type": str}) | ||
|
||
|
||
def processed_prompt(prompt: str) -> dict: | ||
return {"role": "user", "content": prompt, "type": "text"} | ||
|
||
|
||
@config.when(provider="openai") | ||
def client() -> openai.Client: | ||
return openai.Client() | ||
|
||
|
||
def text_model() -> str: | ||
return "gpt-3.5-turbo" | ||
|
||
|
||
def image_model() -> str: | ||
return "dall-e-2" | ||
|
||
|
||
def safe(prompt: str) -> bool: | ||
if "unsafe" in prompt: | ||
return False | ||
return True | ||
|
||
|
||
def modes() -> Dict[str, str]: | ||
return { | ||
"answer_question": "text", | ||
"generate_image": "image", | ||
"generate_code": "code", | ||
"unknown": "text", | ||
} | ||
|
||
|
||
def find_mode_prompt(prompt: str, modes: Dict[str, str]) -> str: | ||
return ( | ||
f"You are a chatbot. You've been prompted this: {prompt}. " | ||
f"You have the capability of responding in the following modes: {', '.join(modes)}. " | ||
"Please respond with *only* a single word representing the mode that most accurately" | ||
" corresponds to the prompt. Fr instance, if the prompt is 'draw a picture of a cat', " | ||
"the mode would be 'generate_image'. If the prompt is 'what is the capital of France', the mode would be 'answer_question'." | ||
"If none of these modes apply, please respond with 'unknown'." | ||
) | ||
|
||
|
||
def suggested_mode(find_mode_prompt: str, client: openai.Client, text_model: str) -> str: | ||
result = client.chat.completions.create( | ||
model=text_model, | ||
messages=[ | ||
{"role": "system", "content": "You are a helpful assistant"}, | ||
{"role": "user", "content": find_mode_prompt}, | ||
], | ||
) | ||
content = result.choices[0].message.content | ||
return content | ||
|
||
|
||
def mode(suggested_mode: str, modes: Dict[str, str]) -> str: | ||
# TODO -- use instructor! | ||
print(f"Mode: {suggested_mode}") | ||
lowercase = suggested_mode.lower() | ||
if lowercase not in modes: | ||
return "unknown" # default to unknown | ||
return lowercase | ||
|
||
|
||
def generated_text(chat_history: list[dict], text_model: str, client: openai.Client) -> str: | ||
chat_history_api_format = [ | ||
{ | ||
"role": chat["role"], | ||
"content": chat["content"], | ||
} | ||
for i, chat in enumerate(chat_history) | ||
] | ||
result = client.chat.completions.create( | ||
model=text_model, | ||
messages=chat_history_api_format, | ||
) | ||
return result.choices[0].message.content | ||
|
||
|
||
# def generated_text_error(generated_text) -> dict: | ||
# ... | ||
|
||
|
||
def generated_image(prompt: str, image_model: str, client: openai.Client) -> str: | ||
result = client.images.generate( | ||
model=image_model, prompt=prompt, size="1024x1024", quality="standard", n=1 | ||
) | ||
return result.data[0].url | ||
|
||
|
||
def answered_question(chat_history: list[dict], text_model: str, client: openai.Client) -> str: | ||
chat_history = copy.deepcopy(chat_history) | ||
chat_history[-1][ | ||
"content" | ||
] = f"Please answer the following question: {chat_history[-1]['content']}" | ||
|
||
chat_history_api_format = [ | ||
{ | ||
"role": chat["role"], | ||
"content": chat["content"], | ||
} | ||
for i, chat in enumerate(chat_history) | ||
] | ||
response = client.chat.completions.create( | ||
model=text_model, | ||
messages=chat_history_api_format, | ||
) | ||
return response.choices[0].message.content | ||
|
||
|
||
def generated_code(chat_history: list[dict], text_model: str, client: openai.Client) -> str: | ||
chat_history = copy.deepcopy(chat_history) | ||
chat_history[-1][ | ||
"content" | ||
] = f"Please respond to the following with *only* code: {chat_history[-1]['content']}" | ||
|
||
chat_history_api_format = [ | ||
{ | ||
"role": chat["role"], | ||
"content": chat["content"], | ||
} | ||
for i, chat in enumerate(chat_history) | ||
] | ||
response = client.chat.completions.create( | ||
model=text_model, | ||
messages=chat_history_api_format, | ||
) | ||
return response.choices[0].message.content | ||
|
||
|
||
def prompt_for_more(modes: Dict[str, str]) -> str: | ||
return ( | ||
f"I can't find a mode that applies to your input. Can you" | ||
f" please clarify? I support: {', '.join(modes)}." | ||
) | ||
|
||
|
||
def processed_response( | ||
response: Optional[str], mode: str, modes: Dict[str, str], safe: bool | ||
) -> ChatContents: | ||
if not safe: | ||
return {"role": "assistant", "content": "I'm sorry, I can't do that.", "type": "text"} | ||
return {"role": "assistant", "type": modes[mode], "content": response} | ||
|
||
|
||
def processed_error(error: str) -> dict: | ||
return {"role": "assistant", "error": error, "type": "text"} |