This repository has been archived by the owner on Apr 2, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 5
/
app_looping.py
115 lines (98 loc) · 4.17 KB
/
app_looping.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# !pip install 'git+https://github.com/Lightning-AI/LAI-API-Access-UI-Component.git'
# !curl https://raw.githubusercontent.com/Lightning-AI/stablediffusion/lit/configs/stable-diffusion/v1-inference.yaml -o v1-inference.yaml
# !pip install 'git+https://github.com/Lightning-AI/lightning.git'
# !pip install 'git+https://github.com/Lightning-AI/stablediffusion.git@lit'
import lightning as L
import os, base64, io, torch, traceback, asyncio, uuid
from diffusion_with_autoscaler import AutoScaler, BatchText, BatchImage, Text, Image, IntervalReplacement
class DiffusionServer(L.app.components.PythonServer):
def __init__(self, *args, **kwargs):
super().__init__(
input_type=BatchText,
output_type=BatchImage,
cloud_build_config=L.BuildConfig(image="ghcr.io/gridai/lightning-stable-diffusion:v0.4"),
*args,
**kwargs,
)
self._requests = {}
self._predictor_task = None
self._lock = None
def setup(self):
if not os.path.exists("v1-5-pruned-emaonly.ckpt"):
cmd = "curl -C - https://pl-public-data.s3.amazonaws.com/dream_stable_diffusion/v1-5-pruned-emaonly.ckpt -o v1-5-pruned-emaonly.ckpt"
os.system(cmd)
import ldm
device = "cuda" if torch.cuda.is_available() else "cpu"
self._model = ldm.lightning.LightningStableDiffusion(
config_path="v1-inference.yaml",
checkpoint_path="v1-5-pruned-emaonly.ckpt",
device=device,
fp16=True, # Supported on GPU, skipped otherwise.
deepspeed=True, # Supported on Ampere and RTX, skipped otherwise.
context="no_grad",
flash_attention="hazy",
steps=30,
)
def apply_model(self, requests):
return self._model.in_loop_predict_step(requests)
def sanetize_data(self, request):
if "state" in request:
return request["state"]
return request["data"].text
def sanetize_results(self, result):
image = result
buffer = io.BytesIO()
image.save(buffer, format="PNG")
image_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
return BatchImage(outputs=[{"image": image_str}])
async def predict_fn(self):
try:
while True:
async with self._lock:
keys = list(self._requests)
if len(keys) == 0:
await asyncio.sleep(0.0001)
continue
inputs = {key: self.sanetize_data(self._requests[key]) for key in keys}
results = self.apply_model(inputs)
for key, state in inputs.items():
if key == "global_state":
self._requests["global_state"] = {"state": state}
else:
self._requests[key]["state"] = state
if results:
for key in results:
self._requests[key]["response"].set_result(self.sanetize_results(results[key]))
del self._requests[key]
await asyncio.sleep(0.0001)
except Exception:
print(traceback.print_exc())
async def predict(self, request: BatchText):
if self._lock is None:
self._lock = asyncio.Lock()
if self._predictor_task is None:
self._predictor_task = asyncio.create_task(self.predict_fn())
assert len(request.inputs) == 1
future = asyncio.Future()
async with self._lock:
self._requests[uuid.uuid4().hex] = {"data": request.inputs[0], "response": future}
result = await future
return result
component = AutoScaler(
DiffusionServer, # The component to scale
cloud_compute=L.CloudCompute("gpu-rtx", interruptible=True, disk_size=80),
strategy=IntervalReplacement(interval=30 * 60),
batching="streamed",
# autoscaler args
min_replicas=1,
max_replicas=1,
endpoint="/predict",
scale_out_interval=0,
scale_in_interval=600,
max_batch_size=6,
timeout_batching=0,
input_type=Text,
output_type=Image,
)
# component = DiffusionServer()
app = L.LightningApp(component)