-
Notifications
You must be signed in to change notification settings - Fork 0
/
server.py
414 lines (359 loc) · 16.6 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
# mflux-server
# Server for image generation with mflux (https://github.com/filipstrand/mflux)
# (C) 2024 by @orbiter Michael Peter Christen
# This code is licensed under the Apache License, Version 2.0
import os
import io
import gc
import argparse
import time
import base64
import hashlib
import threading
import mlx.core.metal as metal
from PIL import Image
from flask import Flask, request, Response, jsonify
from flask_restx import Api, Resource, fields
from flask_cors import CORS
from mflux import Flux1, Config, ModelConfig, ImageUtil
from flask import send_file, redirect
import requests
# monkey pathing the Session to ignore SSL verification
old_request = requests.Session.request
def new_request(self, *args, **kwargs):
kwargs['verify'] = False
return old_request(self, *args, **kwargs)
requests.Session.request = new_request
app = Flask(__name__)
api = Api(app, version='1.0', title='MFLUX API Server',
description='An image generation server. Workflow: /generate -> /status -> /image',
doc='/swagger',
prefix='/api')
CORS(app, resources={r"/*": {"origins": "*"}})
tasklist = [] # list which holds the image computation tasks
flux = None # the flux object, initialized in main()
pixels = 1024 * 1024 # the number of pixels in all of the computed images (start value)
ctime = 80 # the total computation time for all images in seconds (start value)
metal_cache_limit = 0 # the cache limit for the metal library
apppath = os.path.dirname(__file__)
# we implement image generation as asynchronous task
# this will be executed in a separate thread
def compute_image_task():
global flux, tasklist, pixels, ctime
# we loop forever and in every iteration we check if there is a task to process
while True:
if flux == None or len(tasklist) == 0:
time.sleep(1)
continue
# loop through the tasklist and get the first task which has no image assigned
foundimage = False
for task in tasklist:
if 'image' in task: continue
# found a task without image
compute_time = time.time()
task['compute_time'] = compute_time
# generate the image
metal.set_cache_limit(metal_cache_limit)
init_image = task['init_image']
# make a temporary file path for the init_image
if init_image:
init_image_path = f"/tmp/init_image_{task['task_id']}.png"
init_image.save(init_image_path)
else:
init_image_path = None
generated_image = flux.generate_image(
seed=task['seed'],
prompt=task['prompt'],
config=Config(
num_inference_steps=task['steps'] or 4,
height=task['height'],
width=task['width'],
guidance=task['guidance'] or 3.5,
init_image_path=init_image_path,
init_image_strength=0.4
)
)
# remove the temporary init_image file
if init_image_path: os.remove(init_image_path)
# statistics
end_time = time.time()
ctime += end_time - compute_time
pixels += task['height'] * task['width']
# convert the image (we do not count this on the computation time on purpose)
# we do this here and not during retrieval to save memory in the tasklist
format = task.get('format', 'JPEG').upper()
if format not in ['PNG', 'JPEG']: format = 'JPEG'
if format == 'PNG':
png_image = io.BytesIO()
generated_image.image.save(png_image, format='PNG')
png_image.seek(0)
task['image'] = png_image
del png_image
else:
quality = task['quality']
jpeg_image = io.BytesIO()
generated_image.image.save(jpeg_image, format='JPEG', quality=quality)
jpeg_image.seek(0)
task['image'] = jpeg_image
del jpeg_image
# Free resources
del generated_image
metal.clear_cache()
gc.collect()
task['end_time'] = end_time # end time of the task
foundimage = True
break
# if we did not found any task without image, we sleep for 1 second
if not foundimage: time.sleep(1)
def str_to_bool(value):
return value.lower() in ['true', '1', 't', 'y', 'yes']
# generate image endpoint
task_model = api.model('TaskInput', {
'prompt': fields.String(description='The textual description of the image to generate.', default='A beautiful landscape', required=True),
'seed': fields.String(description='Entropy Seed', default=str(int(time.time())), required=False),
'height': fields.Integer(description='Image height', default=1024, required=False),
'width': fields.Integer(description='Image width', default=1024, required=False),
'steps': fields.Integer(description='Inference Steps', default=4, required=False),
'guidance': fields.Float(description='Guidance Scale', default=3.5, required=False),
'format': fields.String(description='The image format (JPEG or PNG), default is JPEG', default="JPEG", required=False),
'quality': fields.Integer(description='JPEG compression quality (1-100) if format is JPEG, default is 85', default=85, required=False),
'priority': fields.Boolean(description='Set to true to put this task to the head of the queue', default=False, required=False)
})
generate_response_model = api.model('GenerateResponse', {
'task_id': fields.String(description='ID of the image generation task'),
'task_length': fields.Integer(description='Length of the image generation task queue excluding this new one'),
'expected_time_seconds': fields.Float(description='Expected time in seconds for the image generation task to complete')
})
# function which counts number of pixels in images from the tasklist up to a certain index
def count_pixels(index):
global tasklist
pixels = 0
for i in range(index):
if i >= len(tasklist): break
task = tasklist[i]
if not 'image' in task:
pixels += task['width'] * task['height']
return pixels
@api.route('/generate')
class GenerateImage(Resource):
@api.expect(task_model, validate=True)
@api.response(200, 'Success', generate_response_model)
@api.response(404, 'Cannot append task')
def post(self):
"""
The /generate endpoint is used to generate an image as an asynchronous task.
This will put the task in the queue and return the task ID.
The task is either at the end of the queue or at the beginning if priority is set to true.
To save memory, the image is not stored in it's raw form but in the form demanded by the client.
Therefore the format has to be declared in the request at generation time in this endpoint.
"""
global tasklist, pixels, ctime
# Parse the JSON body into a dictionary
args = request.json
prompt = args.get('prompt', 'A beautiful landscape')
seed = args.get('seed', str(int(time.time())))
height = int(args.get('height', 1024))
width = int(args.get('width', 1024))
steps = int(args.get('steps', 4))
guidance = float(args.get('guidance', 3.5))
format = args.get('format', 'JPEG').upper()
quality = args.get('quality', 85)
priority = args.get('priority', False)
# Decode init_image if it is provided
init_image = None
if 'init_image' in args:
try:
init_image_data = base64.b64decode(args['init_image'])
init_image = Image.open(io.BytesIO(init_image_data))
# log properties of the init_image, width, height, mode
print("init_image", init_image.size, init_image.mode)
except Exception as e:
pass # ignore errors
start_time = time.time()
# taskid is a 8-digit hex hash to identify the image
md5 = hashlib.md5()
md5.update(str(start_time).encode())
task_id = md5.hexdigest()[:8]
task_metadata = {
'task_id': task_id,
'prompt': prompt,
'seed': seed,
'height': height,
'width': width,
'steps': steps,
'guidance': guidance,
'format': format,
'quality': quality,
'priority': priority,
'start_time': start_time,
'init_image': init_image
}
# compute waiting time based on the number of pixels in the queue
wait_for_pixels = width * height # include the current task
if priority and len(tasklist) > 1:
wait_for_pixels += count_pixels(1)
tasklist.insert(1, task_metadata)
else:
wait_for_pixels += count_pixels(len(tasklist))
tasklist.append(task_metadata)
expected_time_seconds = ctime * wait_for_pixels / pixels
return {
'task_id': task_id,
'task_length': len(tasklist) - 1,
'expected_time_seconds': expected_time_seconds
}, 200
# status endpoint
status_model = api.model('Status', {
'status': fields.String(description='Status of the image generation task'),
'pos': fields.Integer(description='Position in queue')
})
@api.route('/status')
class GetStatus(Resource):
@api.doc(params={'task_id': 'The ID of the image generation task'})
@api.response(200, 'Success', status_model)
@api.response(404, 'Task not found')
def get(self):
"""
The /status endpoint is used to check the image generation progress of a task.
The returned status can be i.e. when the task is not ready, position 3 in the queue, estimated time remaining 43 seconds:
{ "status": "waiting", "pos": 3, "wait_remaining": 43}
.. or when the task is done:
{ "status": "done"}
When the status is "done", the image can be retrieved with the /image endpoint.
If the task / the task_id is unknown, the endpoint returns a 404 status code.
"""
task_id = request.args.get('task_id', default='')
c = -1
for i, task in enumerate(tasklist):
if not 'image' in task: c += 1
if task['task_id'] == task_id:
if 'image' in task:
return jsonify({'status': 'done'})
else:
# compute the remaining time
wait_remaining = count_pixels(i + 1) * ctime / pixels
start_time = task.get('start_time', 0)
compute_time = task.get('compute_time', start_time)
wait_remaining = int(wait_remaining - (time.time() - compute_time))
if wait_remaining < 1: wait_remaining = 1
return jsonify({'status': 'waiting', 'pos': c, 'wait_remaining': wait_remaining})
return Response(status=404)
# image retrieval endpoint; image format was already defined in the generate endpoint
@api.route('/image')
class GetImage(Resource):
@api.doc(params={
'task_id': 'The ID of the image generation task',
'base64': 'Set to true to return the image as base64 encoded string, default false',
'delete': 'Set to true to delete the task after getting the image, default is true'
})
@api.response(200, 'Success')
@api.response(404, 'Task not found')
def get(self):
"""
The /image endpoint is used to get the produced image after a task has completed.
The image is already encoded in PNG or JPEG according to the formet given in the /generate endpoint.
The image can be returned as base64 encoded string or as binary data.
By default calling this endpoint will delete the task from the queue;
this means the image can only be retrieved once. To keep the task in the queue set delete to false.
If the image is not ready at the time of the request, the endpoint returns a 404 status code.
"""
task_id = request.args.get('task_id', default='')
for task in tasklist:
if task['task_id'] == task_id:
if 'image' in task:
image = task['image']
format = task['format']
base64p = str_to_bool(request.args.get('base64', default='false'))
deletep = str_to_bool(request.args.get('delete', default='true'))
if deletep:
tasklist.remove(task)
gc.collect()
if base64p:
return Response(base64.b64encode(image.getvalue()), mimetype='text/plain; charset=utf-8')
else:
return Response(image.getvalue(), mimetype='image/png' if format == 'PNG' else 'image/jpeg')
return Response(status=404)
# cancel task endpoint
@api.route('/cancel')
class CancelTask(Resource):
@api.doc(params={'task_id': 'The ID of the image generation task'})
@api.response(200, 'Success')
@api.response(404, 'Task not found')
def get(self):
"""
The /cancel endpoint is used to cancel a task.
"""
task_id = request.args.get('task_id', default='')
for task in tasklist:
if task['task_id'] == task_id:
tasklist.remove(task)
return Response(status=200)
return Response(status=404)
# tasks lising endpoint
task_output_model = api.inherit('TaskOutput', task_model, {
'task_id': fields.String(description='ID of the image generation task', default=None, required=False),
'start_time': fields.String(description='Time when the image generation task was submitted', default=None, required=False),
'compute_time': fields.String(description='Time when the image computation started', default=None, required=False),
'end_time': fields.String(description='Time when the image generation task ended', default=None, required=False)
})
tasks_model = api.model('Tasks', {
'tasks': fields.List(fields.Nested(task_output_model), description='List of tasks')
})
@api.route('/tasks')
class GetTasks(Resource):
@api.response(200, 'Success', tasks_model)
def get(self):
"""
The /tasks endpoint is used to list all tasks.
This can be used to implement a task manager.
"""
tasklist0 = []
for task in tasklist:
task0 = task.copy()
if 'image' in task0: del task0['image']
tasklist0.append(task0)
return jsonify(tasklist0)
# clear tasks endpoint
@api.route('/clear')
class ClearTasks(Resource):
@api.response(200, 'Success')
def get(self):
"""
The /clear endpoint is used to clear all tasks.
"""
tasklist.clear()
return Response(status=200)
# convenience file endpoints for testing
@app.route('/')
def redirect_to_index():
return redirect('/index.html')
@app.route('/index.html')
def serve_index():
return send_file(os.path.join(apppath, 'clients/web-ui/index.html'))
#@api.route('/index.html')
#class Root(Resource):
# @api.response(200, 'Success')
# def get(self):
# return send_file(os.path.join(apppath, 'clients/web-ui/index.html'))
def main():
parser = argparse.ArgumentParser(description='Start a server to generate images with mflux.')
parser.add_argument('--model', "-m", type=str, default="schnell", choices=["dev", "schnell"], help='The model to use ("schnell" or "dev").')
parser.add_argument('--quantize', "-q", type=int, choices=[4, 8], default=None, help='Quantize the model (4 or 8, Default is None)')
parser.add_argument('--path', type=str, default=None, help='Local path for loading a model from disk')
parser.add_argument('--host', type=str, default='127.0.0.1', help='The host to listen on')
parser.add_argument('--port', type=int, default=4030, help='The port to listen on')
parser.add_argument('--cache_limit', type=int, default=0, help='The metal cache limit in bytes')
args = parser.parse_args()
if args.path and args.model is None:
parser.error("--model must be specified when using --path")
global flux
flux = Flux1.from_alias(
alias=args.model,
quantize=args.quantize
)
metal_cache_limit = args.cache_limit
threading.Thread(target=compute_image_task).start()
print(f"Server started, view swagger API documentation at http://{args.host}:{args.port}/swagger")
app.run(host=args.host, port=args.port)
if __name__ == '__main__':
main()