This repository has been archived by the owner on Apr 2, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 5
/
app_interruptible.py
73 lines (62 loc) · 2.63 KB
/
app_interruptible.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# !pip freeze
# !pip install 'git+https://github.com/Lightning-AI/LAI-API-Access-UI-Component.git'
# !pip install streamlit pandas
# !curl https://raw.githubusercontent.com/Lightning-AI/stablediffusion/lit/configs/stable-diffusion/v1-inference.yaml -o v1-inference.yaml
import lightning as L
import os, base64, io, torch
from diffusion_with_autoscaler import AutoScaler, BatchText, BatchImage, Text, Image, IntervalReplacement
PROXY_URL = "https://ulhcn-01gd3c9epmk5xj2y9a9jrrvgt8.litng-ai-03.litng.ai/api/predict"
class FlashAttentionBuildConfig(L.BuildConfig):
def build_commands(self):
return ["pip install 'git+https://github.com/Lightning-AI/stablediffusion.git@lit'"]
class DiffusionServer(L.app.components.PythonServer):
def __init__(self, *args, **kwargs):
super().__init__(
input_type=BatchText,
output_type=BatchImage,
cloud_build_config=FlashAttentionBuildConfig(),
*args,
**kwargs,
)
def setup(self):
import ldm
if not os.path.exists("v1-5-pruned-emaonly.ckpt"):
cmd = "curl -C - https://pl-public-data.s3.amazonaws.com/dream_stable_diffusion/v1-5-pruned-emaonly.ckpt -o v1-5-pruned-emaonly.ckpt"
os.system(cmd)
device = "cuda" if torch.cuda.is_available() else "cpu"
self._model = ldm.lightning.LightningStableDiffusion(
config_path="v1-inference.yaml",
checkpoint_path="v1-5-pruned-emaonly.ckpt",
device=device,
deepspeed=True, # Supported on Ampere and RTX, skipped otherwise.
context="no_grad",
flash_attention="triton",
steps=30,
)
def predict(self, requests):
texts = [request.text for request in requests.inputs]
images = self._model.predict_step(prompts=texts, batch_idx=0)
results = []
for image in images:
buffer = io.BytesIO()
image.save(buffer, format="PNG")
image_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
results.append(image_str)
return BatchImage(outputs=[{"image": image_str} for image_str in results])
component = AutoScaler(
DiffusionServer, # The component to scale
cloud_compute=L.CloudCompute("gpu-rtx", interruptible=True, disk_size=80),
strategy=IntervalReplacement(interval=30 * 60),
enable_dashboard=True,
# autoscaler args
min_replicas=1,
max_replicas=4,
endpoint="/predict",
scale_out_interval=0,
scale_in_interval=5 * 60,
max_batch_size=6,
timeout_batching=0.3,
input_type=Text,
output_type=Image,
)
app = L.LightningApp(component)