Skip to content

Commit

Permalink
Updates telephone example
Browse files Browse the repository at this point in the history
To make it a top-level example.

Simplifies it a little and splits out the two implementations
to make the code simpler to look at.

Adds integrations to be skipped for validation for now.
Proper fix would be to mark directories as example ones
and only validate those...

Adds an example hook to save images locally.
  • Loading branch information
skrawcz committed Apr 9, 2024
1 parent 54683e0 commit c4de7a4
Show file tree
Hide file tree
Showing 16 changed files with 831 additions and 180 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,7 @@ burr/tracking/server/build
.idea
# added by macOS
.DS_store

# statemachine images
examples/*/statemachine
examples/*/*/statemachine
39 changes: 39 additions & 0 deletions examples/image-telephone/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Image Telephone

This example demonstrates how to play telephone with DALL-E and ChatGPT. See some examples of the outputs
in this [streamlit app](https://image-telephone.streamlit.app).

It is a fun example of how to use Burr! In this example you'll see a simple way to define an application
that talks to itself to do something fun. The game is simple:

1. You provide an initial image to ChatGPT, which then generates a caption. The caption is saved to state.
2. That caption is then provided to DALL-E, which generates an image based on the caption, which is saved to state.
3. The loop repeats -- and you have encoded the game of telephone!

Specifically, each action here in Burr is delegated to the [Hamilton](https://github.com/dagworks-inc/hamilton) micro-framework to run.
Hamilton is a great replacement for tools like LCEL, because it's built to provide a great SDLC experience, in addition
to being lightweight, extensible and more general
purpose (e.g. it's great for expressing things data processing, ML, and web-request logic). We're using
off-the-shelf dataflows from the [Hamilton hub](https://hub.dagworks.io) to do the work of captioning and generating images.

Right now the terminal state is set to 4, so the game will end after 4 images are captioned:

![Telephone](statemachine.png)

## Modifying the telephone game
There are two levels you can modify:

1. The high-level orchestration and state management
2. What each action actually does.

For the high-level orchestration you can add more nodes, modify the actions (e.g. to save the images),
change conditions, etc.

For the low-level actions, you can change the prompt, the template, etc. too. To do so see the
documentation for the Hamilton dataflows that are used: [captioning](https://hub.dagworks.io/docs/Users/elijahbenizzy/caption_images/) and
[generating image](https://hub.dagworks.io/docs/Users/elijahbenizzy/generate_images/). You can easily modify the prompt and
template by overriding values, or by copying the code and modifying it yourself in 2 minutes - see instructions on the [hub](https://hub.dagworks.io/).

## Hamilton code
For more details on the [Hamilton](https://github.com/dagworks-inc/hamilton) code and
this [streamlit app](https://image-telephone.streamlit.app) see [this example in the Hamilton repo.](https://github.com/DAGWorks-Inc/hamilton/tree/main/examples/LLM_Workflows/image_telephone)
Empty file.
201 changes: 201 additions & 0 deletions examples/image-telephone/application.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
"""
This module demonstrates a telephone application
using Burr that:
- captions an image
- creates caption embeddings (for analysis)
- creates a new image based on the created caption
We use pre-defined Hamilton DAGs to perform the
image captioning and image generation tasks. Unlike other frameworks
Hamilton doesn't hide the contents of the defined DAG from the user.
You can easily introspect, and modify the code as needed.
The Hamilton DAGs used in this example can be found here:
- https://hub.dagworks.io/docs/Users/elijahbenizzy/caption_images/
- https://hub.dagworks.io/docs/Users/elijahbenizzy/generate_images/
"""
import os
import uuid

import requests
from hamilton import dataflows, driver

from burr.core import Action, ApplicationBuilder, State, default, expr
from burr.core.action import action
from burr.lifecycle import PostRunStepHook

# import hamilton modules that define the DAGs for image captioning and image generation
caption_images = dataflows.import_module("caption_images", "elijahbenizzy")
generate_images = dataflows.import_module("generate_images", "elijahbenizzy")


class ImageSaverHook(PostRunStepHook):
"""Class to save images to a directory.
This is an example of a custom way to interact indirectly.
This is one way you could save to S3 by writing something like this.
"""

def __init__(self, save_dir: str = "saved_images"):
self.save_dir = save_dir
self.run_id = str(uuid.uuid4())[0:8]
self.path = os.path.join(self.save_dir, self.run_id)
if not os.path.exists(self.path):
os.makedirs(self.path)

def post_run_step(self, *, state: "State", action: "Action", **future_kwargs):
"""Pulls the image URL from the state and saves it to the save directory."""
if action.name == "generate":
image_url = state["current_image_location"]
image_name = "image_" + str(state["__SEQUENCE_ID"]) + ".png"
with open(os.path.join(self.path, image_name), "wb") as f:
f.write(requests.get(image_url).content)
print(f"Saved image to {self.path}/{image_name}")


@action(
reads=["current_image_location"],
writes=["current_image_caption", "image_location_history"],
)
def image_caption(state: State, caption_image_driver: driver.Driver) -> tuple[dict, State]:
"""Action to caption an image.
This delegates to the Hamilton DAG for image captioning.
For more details go here: https://hub.dagworks.io/docs/Users/elijahbenizzy/caption_images/.
"""
current_image = state["current_image_location"]
result = caption_image_driver.execute(
["generated_caption"], inputs={"image_url": current_image}
)
updates = {
"current_image_caption": result["generated_caption"],
}
# You could save to S3 here
return result, state.update(**updates).append(image_location_history=current_image)


@action(
reads=["current_image_caption"],
writes=["caption_analysis"],
)
def caption_embeddings(state: State, caption_image_driver: driver.Driver) -> tuple[dict, State]:
"""Action to analyze the caption and create embeddings for analysis.
This delegates to the Hamilton DAG for getting embeddings for the caption.
For more details go here: https://hub.dagworks.io/docs/Users/elijahbenizzy/caption_images/.
This uses the overrides functionality to use the result of the prior Hamilton DAG run
to avoid re-computation.
"""
result = caption_image_driver.execute(
["metadata"],
inputs={"image_url": state["current_image_location"]},
overrides={"generated_caption": state["current_image_caption"]},
)
# You could save to S3 here
return result, state.append(caption_analysis=result["metadata"])


@action(
reads=["current_image_caption"],
writes=["current_image_location", "image_caption_history"],
)
def image_generation(state: State, generate_image_driver: driver.Driver) -> tuple[dict, State]:
"""Action to create an image.
This delegates to the Hamilton DAG for image generation.
For more details go here: https://hub.dagworks.io/docs/Users/elijahbenizzy/generate_images/.
"""
current_caption = state["current_image_caption"]
result = generate_image_driver.execute(
["generated_image"], inputs={"image_generation_prompt": current_caption}
)
updates = {
"current_image_location": result["generated_image"],
}
# You could save to S3 here
return result, state.update(**updates).append(image_caption_history=current_caption)


@action(reads=["image_location_history", "image_caption_history", "caption_analysis"], writes=[])
def terminal_step(state: State) -> tuple[dict, State]:
"""This is a terminal step. We can do any final processing here."""
result = {
"image_location_history": state["image_location_history"],
"image_caption_history": state["image_caption_history"],
"caption_analysis": state["caption_analysis"],
}
# Could save everything to S3 here.
return result, state


def build_application(
starting_image: str = "statemachine.png", number_of_images_to_caption: int = 4
):
"""This builds the Burr application and returns it.
:param starting_image: the starting image to use
:param number_of_images_to_caption: the number of iterations to go through
:return: the built application
"""
# instantiate hamilton drivers and then bind them to the actions.
caption_image_driver = (
driver.Builder()
.with_config({"include_embeddings": True})
.with_modules(caption_images)
.build()
)
generate_image_driver = driver.Builder().with_config({}).with_modules(generate_images).build()
app = (
ApplicationBuilder()
.with_state(
current_image_location=starting_image,
current_image_caption="",
image_location_history=[],
image_caption_history=[],
caption_analysis=[],
)
.with_actions(
caption=image_caption.bind(caption_image_driver=caption_image_driver),
analyze=caption_embeddings.bind(caption_image_driver=caption_image_driver),
generate=image_generation.bind(generate_image_driver=generate_image_driver),
terminal=terminal_step,
)
.with_transitions(
("caption", "analyze", default),
(
"analyze",
"terminal",
expr(f"len(image_caption_history) == {number_of_images_to_caption}"),
),
("analyze", "generate", default),
("generate", "caption", default),
)
.with_entrypoint("caption")
.with_hooks(ImageSaverHook())
.with_tracker(project="image-telephone")
.build()
)
return app


if __name__ == "__main__":
import random

coin_flip = random.choice([True, False])
# app = build_application("path/to/my/image.png")
app = build_application()
app.visualize(output_file_path="statemachine", include_conditions=True, view=True, format="png")
# if coin_flip:
# _last_action, _result, _state = app.run(halt_after=["terminal"])
# # save to S3 / download images etc.
# else:
# # alternate way to run:
# while True:
# _action, _result, _state = app.step()
# print("action=====\n", _action)
# print("result=====\n", _result)
# # you could save to S3 / download images etc. here.
# if _action.name == "terminal":
# break
# print(_state)
Loading

0 comments on commit c4de7a4

Please sign in to comment.