Skip to content

Commit

Permalink
object detection example (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
mernit authored Aug 3, 2024
1 parent 09af5e6 commit 82af9e7
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 0 deletions.
50 changes: 50 additions & 0 deletions 09_image_generation/object-detection/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Object Detection on Beam

This guide demonstrates how to use Beam to perform object detection on a base64-encoded image string. We will cover how to set up the API endpoint, load the pre-trained model, and visualize the results.

## Overview

We'll create an endpoint that:

- Accepts a base64-encoded image string
- Decodes the image
- Performs object detection using a pre-trained Faster R-CNN model
- Visualizes the detection results
- Returns the resulting image and bounding box coordinates

## Pre-requisites

- An active Beam account
- Basic knowledge of Python
- An image file to test the endpoint

## Test the endpoint

```
beam serve app.py:predict
```

This will print a URL in your shell. Be sure to update `request.py` with your unique URL and auth token:

```
url = 'https://app.beam.cloud/endpoint/id/[ENDPOINT-ID]'
headers = {
'Connection': 'keep-alive',
'Content-Type': 'application/json',
'Authorization': 'Bearer [AUTH_TOKEN]'
}
```

You can run `python request.py` to send a request to the API.

It returns a pre-signed URL with the bounding boxes added to the image:

<img src="https://app.beam.cloud/output/id/95ea6071-2c4a-4618-9397-117345f3e8f2" alt="beam image"/>

## Deploy the endpoint

To deploy a persistent endpoint for production use, run this command:

```
beam deploy app.py:predict
```
111 changes: 111 additions & 0 deletions 09_image_generation/object-detection/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from beam import Image, endpoint, Output, env
import io
import base64

# Since these packages are only installed remotely on Beam, this block ensures the interpreter doesn't try to import them locally
if env.is_remote():
from torchvision import models, transforms
from PIL import Image as PILImage
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torch
from torchvision import models, transforms

# Define the image transforms
transform = transforms.Compose(
[
transforms.ToTensor(),
]
)


# The beam container image that this code will run on
image = Image(
python_version="python3.9",
python_packages=[
"torch",
"torchvision",
"pillow",
"matplotlib",
],
)


# Pre-load models onto the container
def load_model():
model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()
return model


def detect_objects(model, image):
image_tensor = transform(image).unsqueeze(0)

# Perform object detection
with torch.no_grad():
predictions = model(image_tensor)

# Extract the bounding boxes and labels
boxes = predictions[0]["boxes"].cpu().numpy()
scores = predictions[0]["scores"].cpu().numpy()
labels = predictions[0]["labels"].cpu().numpy()

# Filter out low-confidence detections
threshold = 0.5
boxes = boxes[scores >= threshold]
labels = labels[scores >= threshold]

return boxes, labels


def visualize_detection(image, boxes):
fig, ax = plt.subplots(1)
ax.imshow(image)

# Draw the bounding boxes
for box in boxes:
rect = patches.Rectangle(
(box[0], box[1]),
box[2] - box[0],
box[3] - box[1],
linewidth=2,
edgecolor="r",
facecolor="none",
)
ax.add_patch(rect)

buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
return PILImage.open(buf)


@endpoint(
image=image,
on_start=load_model,
keep_warm_seconds=60,
cpu=2,
gpu="A10G",
memory="16Gi",
)
def predict(context, image_base64: str):
# Retrieve pre-loaded model from loader
model = context.on_start_value

# Decode the base64 image
image_data = base64.b64decode(image_base64)
image = PILImage.open(io.BytesIO(image_data))

# Perform object detection
boxes, labels = detect_objects(model, image)

# Visualize the results
result_image = visualize_detection(image, boxes)

# Save image file
output = Output.from_pil_image(result_image).save()

# Retrieve pre-signed URL for output file
url = output.public_url()

return {"image": url, "boxes": boxes.tolist()}
Binary file added 09_image_generation/object-detection/example.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
31 changes: 31 additions & 0 deletions 09_image_generation/object-detection/request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import requests
import base64

# Beam API details -- make sure to replace with your own credentials
url = 'https://app.beam.cloud/endpoint/id/[ENDPOINT-ID]'
headers = {
'Connection': 'keep-alive',
'Content-Type': 'application/json',
'Authorization': 'Bearer [YOUR-AUTH-TOKEN]'
}

# Load image and encode it to base64
def load_image_as_base64(image_path):
with open(image_path, "rb") as image_file:
base64_string = base64.b64encode(image_file.read()).decode('utf-8')
return base64_string

# Send a POST request to the Beam endpoint
def call_beam_api(image_base64):
data = {
"image_base64": image_base64
}
response = requests.post(url, headers=headers, json=data)
return response.json()


if __name__ == "__main__":
image_path = "example.jpg"
image_base64 = load_image_as_base64(image_path)
result = call_beam_api(image_base64)
print(result)

0 comments on commit 82af9e7

Please sign in to comment.