Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates telephone example #145

Merged
merged 1 commit into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,7 @@ burr/tracking/server/build
.idea
# added by macOS
.DS_store

# statemachine images
examples/*/statemachine
examples/*/*/statemachine
56 changes: 56 additions & 0 deletions examples/image-telephone/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Image Telephone

This example demonstrates how to play telephone with DALL-E and ChatGPT. See some examples of the outputs
in this [streamlit app](https://image-telephone.streamlit.app).

It is a fun example of how to use Burr! In this example you'll see a simple way to define an application
that talks to itself to do something fun. The game is simple:

1. You provide an initial image to ChatGPT, which then generates a caption. The caption is saved to state.
2. That caption is then provided to DALL-E, which generates an image based on the caption, which is saved to state.
3. The loop repeats -- and you have encoded the game of telephone!

Specifically, each action here in Burr is delegated to the [Hamilton](https://github.com/dagworks-inc/hamilton) micro-framework to run.
Hamilton is a great replacement for tools like LCEL, because it's built to provide a great SDLC experience, in addition
to being lightweight, extensible and more general
purpose (e.g. it's great for expressing things data processing, ML, and web-request logic). We're using
off-the-shelf dataflows from the [Hamilton hub](https://hub.dagworks.io) to do the work of captioning and generating images.

Right now the terminal state is set to four iterations, so the game will end after 4 images are captioned:

![Telephone](statemachine.png)

## Running the Example
We recommend starting with the notebook.

### notebook.ipynb
You can use [notebook.ipynb](./notebook.ipynb) to run things. Or
<a target="_blank" href="https://colab.research.google.com/github/DAGWorks-Inc/burr/blob/main/examples/image-telephone/notebook.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

### Running application.py

To run the basics do:
```bash
python application.py
```
To modify it for your purposes you'll need to adjust the code to point to your image that you want to start with.

## Modifying the telephone game
There are two levels you can modify:

1. The high-level orchestration and state management
2. What each action actually does.

For the high-level orchestration you can add more nodes, modify the actions (e.g. to save the images),
change conditions, etc.

For the low-level actions, you can change the prompt, the template, etc. too. To do so see the
documentation for the Hamilton dataflows that are used: [captioning](https://hub.dagworks.io/docs/Users/elijahbenizzy/caption_images/) and
[generating image](https://hub.dagworks.io/docs/Users/elijahbenizzy/generate_images/). You can easily modify the prompt and
template by overriding values, or by copying the code and modifying it yourself in 2 minutes - see instructions on the [hub](https://hub.dagworks.io/).

## Hamilton code
For more details on the [Hamilton](https://github.com/dagworks-inc/hamilton) code and
this [streamlit app](https://image-telephone.streamlit.app) see [this example in the Hamilton repo.](https://github.com/DAGWorks-Inc/hamilton/tree/main/examples/LLM_Workflows/image_telephone)
Empty file.
201 changes: 201 additions & 0 deletions examples/image-telephone/application.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
"""
This module demonstrates a telephone application
using Burr that:
- captions an image
- creates caption embeddings (for analysis)
- creates a new image based on the created caption

We use pre-defined Hamilton DAGs to perform the
image captioning and image generation tasks. Unlike other frameworks
Hamilton doesn't hide the contents of the defined DAG from the user.
You can easily introspect, and modify the code as needed.

The Hamilton DAGs used in this example can be found here:

- https://hub.dagworks.io/docs/Users/elijahbenizzy/caption_images/
- https://hub.dagworks.io/docs/Users/elijahbenizzy/generate_images/
"""
import os
import uuid

import requests
from hamilton import dataflows, driver

from burr.core import Action, ApplicationBuilder, State, default, expr
from burr.core.action import action
from burr.lifecycle import PostRunStepHook

# import hamilton modules that define the DAGs for image captioning and image generation
caption_images = dataflows.import_module("caption_images", "elijahbenizzy")
generate_images = dataflows.import_module("generate_images", "elijahbenizzy")


class ImageSaverHook(PostRunStepHook):
skrawcz marked this conversation as resolved.
Show resolved Hide resolved
"""Class to save images to a directory.
This is an example of a custom way to interact indirectly.

This is one way you could save to S3 by writing something like this.
"""

def __init__(self, save_dir: str = "saved_images"):
self.save_dir = save_dir
self.run_id = str(uuid.uuid4())[0:8]
self.path = os.path.join(self.save_dir, self.run_id)
if not os.path.exists(self.path):
os.makedirs(self.path)

def post_run_step(self, *, state: "State", action: "Action", **future_kwargs):
"""Pulls the image URL from the state and saves it to the save directory."""
if action.name == "generate":
image_url = state["current_image_location"]
image_name = "image_" + str(state["__SEQUENCE_ID"]) + ".png"
with open(os.path.join(self.path, image_name), "wb") as f:
f.write(requests.get(image_url).content)
print(f"Saved image to {self.path}/{image_name}")


@action(
reads=["current_image_location"],
writes=["current_image_caption", "image_location_history"],
)
def image_caption(state: State, caption_image_driver: driver.Driver) -> tuple[dict, State]:
skrawcz marked this conversation as resolved.
Show resolved Hide resolved
"""Action to caption an image.
This delegates to the Hamilton DAG for image captioning.
For more details go here: https://hub.dagworks.io/docs/Users/elijahbenizzy/caption_images/.
"""
current_image = state["current_image_location"]
result = caption_image_driver.execute(
["generated_caption"], inputs={"image_url": current_image}
)
updates = {
"current_image_caption": result["generated_caption"],
}
# You could save to S3 here
return result, state.update(**updates).append(image_location_history=current_image)


@action(
reads=["current_image_caption"],
writes=["caption_analysis"],
)
def caption_embeddings(state: State, caption_image_driver: driver.Driver) -> tuple[dict, State]:
"""Action to analyze the caption and create embeddings for analysis.

This delegates to the Hamilton DAG for getting embeddings for the caption.
For more details go here: https://hub.dagworks.io/docs/Users/elijahbenizzy/caption_images/.

This uses the overrides functionality to use the result of the prior Hamilton DAG run
to avoid re-computation.
"""
result = caption_image_driver.execute(
["metadata"],
inputs={"image_url": state["current_image_location"]},
overrides={"generated_caption": state["current_image_caption"]},
)
# You could save to S3 here
return result, state.append(caption_analysis=result["metadata"])


@action(
reads=["current_image_caption"],
writes=["current_image_location", "image_caption_history"],
)
def image_generation(state: State, generate_image_driver: driver.Driver) -> tuple[dict, State]:
"""Action to create an image.

This delegates to the Hamilton DAG for image generation.
For more details go here: https://hub.dagworks.io/docs/Users/elijahbenizzy/generate_images/.
"""
current_caption = state["current_image_caption"]
result = generate_image_driver.execute(
["generated_image"], inputs={"image_generation_prompt": current_caption}
)
updates = {
"current_image_location": result["generated_image"],
}
# You could save to S3 here
return result, state.update(**updates).append(image_caption_history=current_caption)


@action(reads=["image_location_history", "image_caption_history", "caption_analysis"], writes=[])
def terminal_step(state: State) -> tuple[dict, State]:
"""This is a terminal step. We can do any final processing here."""
result = {
"image_location_history": state["image_location_history"],
"image_caption_history": state["image_caption_history"],
"caption_analysis": state["caption_analysis"],
}
# Could save everything to S3 here.
return result, state


def build_application(
starting_image: str = "statemachine.png", number_of_images_to_caption: int = 4
):
"""This builds the Burr application and returns it.

:param starting_image: the starting image to use
:param number_of_images_to_caption: the number of iterations to go through
:return: the built application
"""
# instantiate hamilton drivers and then bind them to the actions.
caption_image_driver = (
driver.Builder()
.with_config({"include_embeddings": True})
.with_modules(caption_images)
.build()
)
generate_image_driver = driver.Builder().with_config({}).with_modules(generate_images).build()
app = (
ApplicationBuilder()
.with_state(
current_image_location=starting_image,
current_image_caption="",
image_location_history=[],
image_caption_history=[],
caption_analysis=[],
)
.with_actions(
caption=image_caption.bind(caption_image_driver=caption_image_driver),
analyze=caption_embeddings.bind(caption_image_driver=caption_image_driver),
generate=image_generation.bind(generate_image_driver=generate_image_driver),
terminal=terminal_step,
)
.with_transitions(
("caption", "analyze", default),
(
"analyze",
"terminal",
expr(f"len(image_caption_history) == {number_of_images_to_caption}"),
),
("analyze", "generate", default),
("generate", "caption", default),
)
.with_entrypoint("caption")
.with_hooks(ImageSaverHook())
.with_tracker(project="image-telephone")
.build()
)
return app


if __name__ == "__main__":
import random

coin_flip = random.choice([True, False])
# app = build_application("path/to/my/image.png")
app = build_application()
app.visualize(output_file_path="statemachine", include_conditions=True, view=True, format="png")
if coin_flip:
_last_action, _result, _state = app.run(halt_after=["terminal"])
# save to S3 / download images etc.
else:
# alternate way to run:
while True:
_action, _result, _state = app.step()
print("action=====\n", _action)
print("result=====\n", _result)
# you could save to S3 / download images etc. here.
if _action.name == "terminal":
break
print(_state)
Loading
Loading