-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
192 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Object Detection on Beam | ||
|
||
This guide demonstrates how to use Beam to perform object detection on a base64-encoded image string. We will cover how to set up the API endpoint, load the pre-trained model, and visualize the results. | ||
|
||
## Overview | ||
|
||
We'll create an endpoint that: | ||
|
||
- Accepts a base64-encoded image string | ||
- Decodes the image | ||
- Performs object detection using a pre-trained Faster R-CNN model | ||
- Visualizes the detection results | ||
- Returns the resulting image and bounding box coordinates | ||
|
||
## Pre-requisites | ||
|
||
- An active Beam account | ||
- Basic knowledge of Python | ||
- An image file to test the endpoint | ||
|
||
## Test the endpoint | ||
|
||
``` | ||
beam serve app.py:predict | ||
``` | ||
|
||
This will print a URL in your shell. Be sure to update `request.py` with your unique URL and auth token: | ||
|
||
``` | ||
url = 'https://app.beam.cloud/endpoint/id/[ENDPOINT-ID]' | ||
headers = { | ||
'Connection': 'keep-alive', | ||
'Content-Type': 'application/json', | ||
'Authorization': 'Bearer [AUTH_TOKEN]' | ||
} | ||
``` | ||
|
||
You can run `python request.py` to send a request to the API. | ||
|
||
It returns a pre-signed URL with the bounding boxes added to the image: | ||
|
||
<img src="https://app.beam.cloud/output/id/95ea6071-2c4a-4618-9397-117345f3e8f2" alt="beam image"/> | ||
|
||
## Deploy the endpoint | ||
|
||
To deploy a persistent endpoint for production use, run this command: | ||
|
||
``` | ||
beam deploy app.py:predict | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from beam import Image, endpoint, Output, env | ||
import io | ||
import base64 | ||
|
||
# Since these packages are only installed remotely on Beam, this block ensures the interpreter doesn't try to import them locally | ||
if env.is_remote(): | ||
from torchvision import models, transforms | ||
from PIL import Image as PILImage | ||
import matplotlib.pyplot as plt | ||
import matplotlib.patches as patches | ||
import torch | ||
from torchvision import models, transforms | ||
|
||
# Define the image transforms | ||
transform = transforms.Compose( | ||
[ | ||
transforms.ToTensor(), | ||
] | ||
) | ||
|
||
|
||
# The beam container image that this code will run on | ||
image = Image( | ||
python_version="python3.9", | ||
python_packages=[ | ||
"torch", | ||
"torchvision", | ||
"pillow", | ||
"matplotlib", | ||
], | ||
) | ||
|
||
|
||
# Pre-load models onto the container | ||
def load_model(): | ||
model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True) | ||
model.eval() | ||
return model | ||
|
||
|
||
def detect_objects(model, image): | ||
image_tensor = transform(image).unsqueeze(0) | ||
|
||
# Perform object detection | ||
with torch.no_grad(): | ||
predictions = model(image_tensor) | ||
|
||
# Extract the bounding boxes and labels | ||
boxes = predictions[0]["boxes"].cpu().numpy() | ||
scores = predictions[0]["scores"].cpu().numpy() | ||
labels = predictions[0]["labels"].cpu().numpy() | ||
|
||
# Filter out low-confidence detections | ||
threshold = 0.5 | ||
boxes = boxes[scores >= threshold] | ||
labels = labels[scores >= threshold] | ||
|
||
return boxes, labels | ||
|
||
|
||
def visualize_detection(image, boxes): | ||
fig, ax = plt.subplots(1) | ||
ax.imshow(image) | ||
|
||
# Draw the bounding boxes | ||
for box in boxes: | ||
rect = patches.Rectangle( | ||
(box[0], box[1]), | ||
box[2] - box[0], | ||
box[3] - box[1], | ||
linewidth=2, | ||
edgecolor="r", | ||
facecolor="none", | ||
) | ||
ax.add_patch(rect) | ||
|
||
buf = io.BytesIO() | ||
plt.savefig(buf, format="png") | ||
buf.seek(0) | ||
return PILImage.open(buf) | ||
|
||
|
||
@endpoint( | ||
image=image, | ||
on_start=load_model, | ||
keep_warm_seconds=60, | ||
cpu=2, | ||
gpu="A10G", | ||
memory="16Gi", | ||
) | ||
def predict(context, image_base64: str): | ||
# Retrieve pre-loaded model from loader | ||
model = context.on_start_value | ||
|
||
# Decode the base64 image | ||
image_data = base64.b64decode(image_base64) | ||
image = PILImage.open(io.BytesIO(image_data)) | ||
|
||
# Perform object detection | ||
boxes, labels = detect_objects(model, image) | ||
|
||
# Visualize the results | ||
result_image = visualize_detection(image, boxes) | ||
|
||
# Save image file | ||
output = Output.from_pil_image(result_image).save() | ||
|
||
# Retrieve pre-signed URL for output file | ||
url = output.public_url() | ||
|
||
return {"image": url, "boxes": boxes.tolist()} |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import requests | ||
import base64 | ||
|
||
# Beam API details -- make sure to replace with your own credentials | ||
url = 'https://app.beam.cloud/endpoint/id/[ENDPOINT-ID]' | ||
headers = { | ||
'Connection': 'keep-alive', | ||
'Content-Type': 'application/json', | ||
'Authorization': 'Bearer [YOUR-AUTH-TOKEN]' | ||
} | ||
|
||
# Load image and encode it to base64 | ||
def load_image_as_base64(image_path): | ||
with open(image_path, "rb") as image_file: | ||
base64_string = base64.b64encode(image_file.read()).decode('utf-8') | ||
return base64_string | ||
|
||
# Send a POST request to the Beam endpoint | ||
def call_beam_api(image_base64): | ||
data = { | ||
"image_base64": image_base64 | ||
} | ||
response = requests.post(url, headers=headers, json=data) | ||
return response.json() | ||
|
||
|
||
if __name__ == "__main__": | ||
image_path = "example.jpg" | ||
image_base64 = load_image_as_base64(image_path) | ||
result = call_beam_api(image_base64) | ||
print(result) |