From 3f49d57e370a82955a443a81a9d3e36f6a2161f1 Mon Sep 17 00:00:00 2001 From: AnyISalIn Date: Tue, 25 Jul 2023 23:17:56 +0800 Subject: [PATCH] release 0.1.5 (#23) * feature: support load cloud controlnet models * fix: optimize hijack * feature: add controlnet and inpaint support --------- Signed-off-by: AnyISalIn --- extension/api.py | 481 ++++++++++++++------------------ extension/version.py | 2 +- scripts/{proxy.py => hijack.py} | 263 +++++++++++++---- scripts/main_ui.py | 232 +++++++++------ 4 files changed, 570 insertions(+), 408 deletions(-) rename scripts/{proxy.py => hijack.py} (50%) diff --git a/extension/api.py b/extension/api.py index 6c48ef4..9e744b6 100644 --- a/extension/api.py +++ b/extension/api.py @@ -16,6 +16,8 @@ OMNIINFER_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.omniinfer.json') +OMNIINFER_API_ENDPOINT = "https://api.omniinfer.io" + def _user_agent(model_name=None): if model_name: @@ -71,14 +73,9 @@ def append_child(self, child): @property def display_name(self): # format -> [] [] - kind = self.kind - if self.kind == 'checkpoint': - kind = 'ckpt' - - n = "[{}] ".format(kind) - - if self.tags is not None and len(self.tags) != 0: - n += "[{}] ".format(self.tags[0]) + n = "" + if len(self.tags) > 0: + n = "[{}] ".format(",".join(self.tags)) return n + os.path.splitext(self.name)[0] def to_json(self): @@ -118,7 +115,6 @@ def __init__(self, class OmniinferAPI(BaseAPI): def __init__(self, token=None): - self._endpoint = 'https://api.omniinfer.io' self._token = None if self._token is not None: self._token = token @@ -197,7 +193,7 @@ def update_key_to_config(cls, key): def test_connection(cls, token): if token == "": raise Exception("Token is empty") - res = requests.get('https://api.omniinfer.io/v2/progress', + res = requests.get('{}/v2/progress'.format(OMNIINFER_API_ENDPOINT), params={'key': token}) if res.status_code >= 400: raise Exception("Request failed: {}".format(res.text)) @@ -206,128 +202,6 @@ def test_connection(cls, token): return "Omniinfer Ready... now you can inference on cloud" - def _txt2img(self, model_name, prompts, neg_prompts, sampler_name, - batch_size, steps, n_iter, cfg_scale, seed, height, width, - controlnet_args): - - if self._token is None: - raise Exception( - "Please configure your omniinfer key in the `Cloud Inference` Tab" - ) - - # TODO: workaround - if isinstance(sampler_name, int): - sampler_name = sd_samplers[sampler_name] - payload = { - "prompt": prompts, - "negative_prompt": neg_prompts, - "sampler_name": sampler_name or "Euler a", - "batch_size": batch_size or 1, - "n_iter": n_iter or 1, - "steps": steps or 30, - "cfg_scale": cfg_scale or 7.5, - "seed": int(seed) or -1, - "height": height or 512, - "width": width or 512, - "model_name": model_name, - "controlnet_units": controlnet_args - } - - # print( - # '[cloud-inference] call api txt2img: payload: {}'.format({ - # key: value - # for key, value in payload.items() if key != "controlnet_units" - # }), ) - - headers = { - "accept": "application/json", - "content-type": "application/json", - 'Accept-Encoding': 'gzip, deflate', - "X-OmniInfer-Source": _user_agent(model_name), - "User-Agent": _user_agent(model_name) - } - - try: - res = self._session.post("http://api.omniinfer.io/v2/txt2img", - json=payload, - headers=headers, - params={"key": self._token}) - except Exception as exp: - raise Exception("Request failed: {}, res: {}".format( - exp, res.text if res is not None else "")) - - json_data = res.json() - - if json_data['code'] != 0: - raise Exception("Request failed: {}".format(res.text)) - - return json_data['data']['task_id'] - - def _img2img(self, model_name, prompts, neg_prompts, sampler_name, - batch_size, steps, n_iter, cfg_scale, seed, height, width, - restore_faces, denoising_strength, mask_blur_x, mask_blur_y, - inpainting_fill, inpaint_full_res, inpaint_full_res_padding, - inpainting_mask_invert, initial_noise_multiplier, init_images, - controlnet_units): - - if self._token is None: - raise Exception( - "Please configure your omniinfer key in the `Cloud Inference` Tab" - ) - - if isinstance(sampler_name, int): - sampler_name = sd_samplers[sampler_name] - - payload = { - "prompt": prompts, - "negative_prompt": neg_prompts, - "sampler_name": sampler_name or "Euler a", - "batch_size": batch_size or 1, - "n_iter": n_iter or 1, - "steps": steps or 30, - "cfg_scale": cfg_scale or 7.5, - "seed": int(seed) or -1, - "height": height or 512, - "width": width or 512, - "model_name": model_name, - "restore_faces": restore_faces, - "denoising_strength": denoising_strength, - "mask_blur_x": mask_blur_x, - "mask_blur_y": mask_blur_y, - "inpainting_fill": inpainting_fill, - "inpaint_full_res": inpaint_full_res, - "inpaint_full_res_padding": inpaint_full_res_padding, - "inpainting_mask_invert": inpainting_mask_invert, - "initial_noise_multiplier": initial_noise_multiplier, - "init_images": init_images, - "controlnet_units": controlnet_units - } - headers = { - "accept": "application/json", - "content-type": "application/json", - 'Accept-Encoding': 'gzip, deflate', - "X-OmniInfer-Source": _user_agent(model_name), - "User-Agent": _user_agent(model_name) - } - - # print( - # '[cloud-inference] call api txt2img: payload: {}'.format({ - # key: value - # for key, value in payload.items() if key != "controlnet_units" - # }), ) - - res = requests.post("http://api.omniinfer.io/v2/img2img", - json=payload, - headers=headers, - params={"key": self._token}) - - try: - json_data = res.json() - except Exception: - raise Exception("Request failed: {}".format(res.text)) - - return json_data['data']['task_id'] - def _wait_task_completed(self, task_id): STATUS_CODE_PENDING = 0 STATUS_CODE_PROGRESSING = 1 @@ -344,7 +218,7 @@ def _wait_task_completed(self, task_id): raise Exception("Interrupted") task_res = self._session.get( - "http://api.omniinfer.io/v2/progress", + "{}/v2/progress".format(OMNIINFER_API_ENDPOINT), params={ "key": self._token, "task_id": task_id, @@ -404,105 +278,139 @@ def img2img( buffered = io.BytesIO() i.save(buffered, format=live_previews_image_format, **save_kwargs) - base64_image = base64.b64encode( - buffered.getvalue()).decode('ascii') + base64_image = base64.b64encode(buffered.getvalue()).decode('ascii') images_base64.append(base64_image) - img_urls = [] + def _req(p: processing.StableDiffusionProcessingImg2Img, controlnet_units): + req = { + "model_name": p._cloud_inference_settings['sd_checkpoint'], + "init_images": [image_to_base64(_) for _ in p.init_images], + "mask": image_to_base64(p.image_mask) if p.image_mask else None, + "resize_mode": p.resize_mode, + "denoising_strength": p.denoising_strength, + "cfg_scale": p.image_cfg_scale, + "mask_blur": p.mask_blur_x, + "inpainting_fill": p.inpainting_fill, + "inpaint_full_res": p.inpaint_full_res, + "inpaint_full_res_padding": p.inpaint_full_res_padding, + "inpainting_mask_invert": p.inpainting_mask_invert, + "initial_noise_multiplier": p.initial_noise_multiplier, + "prompt": p.prompt, + "seed": int(p.seed) or -1, + "negative_prompt": p.negative_prompt, + "batch_size": p.batch_size, + "n_iter": p.n_iter, + "steps": p.steps, + "width": p.width, + "height": p.height, + "restore_faces": p.restore_faces, + "clip_skip": opts.CLIP_stop_at_last_layers, + } + if 'CLIP_stop_at_last_layers' in p.override_settings: + req['clip_skip'] = p.override_settings['CLIP_stop_at_last_layers'] + + if 'sd_vae' in p._cloud_inference_settings: + req['sd_vae'] = p._cloud_inference_settings['sd_vae'] + + if len(controlnet_units) > 0: + req['controlnet_units'] = controlnet_units + + headers = { + "accept": "application/json", + "content-type": "application/json", + 'Accept-Encoding': 'gzip, deflate', + "X-OmniInfer-Source": _user_agent(p._cloud_inference_settings['sd_checkpoint']), + "X-OmniInfer-Key": self._token, + "User-Agent": _user_agent(p._cloud_inference_settings['sd_checkpoint']) + } + + res = self._session.post("{}/v2/img2img".format(OMNIINFER_API_ENDPOINT), + json=req, + headers=headers, + params={"key": self._token}) + + try: + json_data = res.json() + except Exception: + raise Exception("Request failed: {}".format(res.text)) + + if json_data['code'] != 0: + raise Exception("Request failed: {}".format(res.text)) + + return self._wait_task_completed(json_data['data']['task_id']) + + controlnet_batchs = get_controlnet_arg(p) + + imgs = [] if len(controlnet_batchs) > 0: for c in controlnet_batchs: - img_urls.extend( - self._wait_task_completed( - self._img2img( - model_name=p._remote_model_name, - prompts=p.prompt, - neg_prompts=p.negative_prompt, - sampler_name=p.sampler_name, - batch_size=p.batch_size, - steps=p.steps, - n_iter=p.n_iter, - cfg_scale=p.cfg_scale, - seed=p.seed, - height=p.height, - width=p.width, - restore_faces=p.restore_faces, - denoising_strength=p.denoising_strength, - mask_blur_x=p.mask_blur_x, - mask_blur_y=p.mask_blur_y, - inpaint_full_res=bool(p.inpaint_full_res), - inpaint_full_res_padding=p. - inpaint_full_res_padding, - inpainting_fill=p.inpainting_fill, - inpainting_mask_invert=p.inpainting_mask_invert, - initial_noise_multiplier=p.initial_noise_multiplier, - init_images=images_base64, - controlnet_units=c))) + imgs.extend(_req(p, c)) else: - img_urls.extend( - self._wait_task_completed( - self._img2img( - model_name=p._remote_model_name, - prompts=p.prompt, - neg_prompts=p.negative_prompt, - sampler_name=p.sampler_name, - batch_size=p.batch_size, - steps=p.steps, - n_iter=p.n_iter, - cfg_scale=p.cfg_scale, - seed=p.seed, - height=p.height, - width=p.width, - restore_faces=p.restore_faces, - denoising_strength=p.denoising_strength, - mask_blur_x=p.mask_blur_x, - mask_blur_y=p.mask_blur_y, - inpaint_full_res=bool(p.inpaint_full_res), - inpaint_full_res_padding=p.inpaint_full_res_padding, - inpainting_fill=p.inpainting_fill, - inpainting_mask_invert=p.inpainting_mask_invert, - initial_noise_multiplier=p.initial_noise_multiplier, - init_images=images_base64, - controlnet_units=[]))) - return retrieve_images(img_urls) + imgs.extend(_req(p, [])) + return retrieve_images(imgs) def txt2img(self, p: processing.StableDiffusionProcessingTxt2Img): controlnet_batchs = get_controlnet_arg(p) - img_urls = [] + def _req(p: processing.StableDiffusionProcessingTxt2Img, controlnet_units): + req = { + "model_name": p._cloud_inference_settings['sd_checkpoint'], + "prompt": p.prompt, + "negative_prompt": p.negative_prompt, + "sampler_name": p.sampler_name or "Euler a", + "batch_size": p.batch_size or 1, + "n_iter": p.n_iter or 1, + "steps": p.steps or 30, + "cfg_scale": p.cfg_scale or 7.5, + "seed": int(p.seed) or -1, + "height": p.height or 512, + "width": p.width or 512, + "restore_faces": p.restore_faces, + "clip_skip": opts.CLIP_stop_at_last_layers, + } + + if 'CLIP_stop_at_last_layers' in p.override_settings: + req['clip_skip'] = p.override_settings['CLIP_stop_at_last_layers'] + + if 'sd_vae' in p._cloud_inference_settings: + req['sd_vae'] = p._cloud_inference_settings['sd_vae'] + + if len(controlnet_units) > 0: + req['controlnet_units'] = controlnet_units + + headers = { + "accept": "application/json", + "content-type": "application/json", + 'Accept-Encoding': 'gzip, deflate', + "X-OmniInfer-Source": _user_agent(p._cloud_inference_settings['sd_checkpoint']), + "X-OmniInfer-Key": self._token, + "User-Agent": _user_agent(p._cloud_inference_settings['sd_checkpoint']) + } + + res = self._session.post("{}/v2/txt2img".format(OMNIINFER_API_ENDPOINT), + json=req, + headers=headers, + params={"key": self._token}) + try: + json_data = res.json() + except Exception: + raise Exception("Request failed: {}".format(res.text)) + + if json_data['code'] != 0: + raise Exception("Request failed: {}".format(res.text)) + + return self._wait_task_completed(json_data['data']['task_id']) + + imgs = [] if len(controlnet_batchs) > 0: for c in controlnet_batchs: - img_urls.extend( - self._wait_task_completed( - self._txt2img(model_name=p._remote_model_name, - prompts=p.prompt, - neg_prompts=p.negative_prompt, - sampler_name=p.sampler_name, - batch_size=p.batch_size, - steps=p.steps, - n_iter=p.n_iter, - cfg_scale=p.cfg_scale, - seed=p.seed, - height=p.height, - width=p.width, - controlnet_args=c))) + imgs.extend(_req(p, c)) else: - img_urls.extend( - self._wait_task_completed( - self._txt2img(model_name=p._remote_model_name, - prompts=p.prompt, - neg_prompts=p.negative_prompt, - sampler_name=p.sampler_name, - batch_size=p.batch_size, - steps=p.steps, - n_iter=p.n_iter, - cfg_scale=p.cfg_scale, - seed=p.seed, - height=p.height, - width=p.width, - controlnet_args=[]))) + imgs.extend(_req(p, [])) state.textinfo = "downloading images..." - return retrieve_images(img_urls) + + return retrieve_images(imgs) def list_models(self): if self._models is None or len(self._models) == 0: @@ -510,51 +418,73 @@ def list_models(self): return sorted(self._models, key=lambda x: x.rating, reverse=True) def refresh_models(self): - url = "http://api.omniinfer.io/v2/models" - headers = { - "accept": "application/json", - 'Accept-Encoding': 'gzip, deflate', - "X-OmniInfer-Source": _user_agent(), - "User-Agent": _user_agent() - } - print("[cloud-inference] refreshing models...") - sd_models = [] + def get_models(kind): + url = "{}/v2/models".format(OMNIINFER_API_ENDPOINT) + headers = { + "accept": "application/json", + 'Accept-Encoding': 'gzip, deflate', + "X-OmniInfer-Source": _user_agent(), + "User-Agent": _user_agent() + } + + res = requests.get(url, headers=headers, params={"type": kind}) + if res.status_code >= 400: + return [] + + models = [] + if res.json()["data"]["models"] is not None: + models = res.json()["data"]["models"] + + for item in models: + model = StableDiffusionModel(kind=item["type"], + name=item["sd_name"]) + model.rating = item.get("civitai_download_count", 0) + civitai_tags = item["civitai_tags"].split(",") if item.get( + "civitai_tags", None) is not None else [] + + if model.tags is None: + model.tags = [] + + if len(civitai_tags) > 0: + model.tags.append(civitai_tags[0]) + + if item.get('civitai_nsfw', False): + model.tags.append("nsfw") + + if len(item.get('civitai_images', + [])) > 0 and item['civitai_images'][0]['meta'].get( + 'prompt') is not None: + first_image = item['civitai_images'][0] + first_image_meta = item['civitai_images'][0]['meta'] + model.example = StableDiffusionModelExample( + prompts=first_image_meta['prompt'], + neg_prompt=first_image_meta.get('negative_prompt', None), + width=first_image_meta.get('width', None), + height=first_image_meta.get('height', None), + sampler_name=first_image_meta.get('sampler_name', None), + cfg_scale=first_image_meta.get('cfg_scale', None), + seed=first_image_meta.get('seed', None), + preview=first_image.get('url', None) + ) - res = requests.get(url, headers=headers) - if res.status_code >= 400: - return [] - for item in res.json()["data"]["models"]: - model = StableDiffusionModel(kind=item["type"], - name=item["sd_name"]) - model.rating = item.get("civitai_download_count", 0) - model.tags = item["civitai_tags"].split(",") if item.get( - "civitai_tags", None) is not None else [] - - if len(item.get('civitai_images', - [])) > 0 and item['civitai_images'][0]['meta'].get( - 'prompt') is not None: - first_image = item['civitai_images'][0] - first_image_meta = item['civitai_images'][0]['meta'] - model.example = StableDiffusionModelExample( - prompts=first_image_meta['prompt'], - neg_prompt=first_image_meta.get('negative_prompt', None), - width=first_image_meta.get('width', None), - height=first_image_meta.get('height', None), - sampler_name=first_image_meta.get('sampler_name', None), - cfg_scale=first_image_meta.get('cfg_scale', None), - seed=first_image_meta.get('seed', None), - preview=first_image.get('url', None) - ) - - if item['type'] == 'lora': - civitai_dependency_model_name = item.get( - 'civitai_dependency_model_name', None) - if civitai_dependency_model_name is not None: - model.dependency_model_name = civitai_dependency_model_name - sd_models.append(model) + if item['type'] == 'lora': + civitai_dependency_model_name = item.get( + 'civitai_dependency_model_name', None) + if civitai_dependency_model_name is not None: + model.dependency_model_name = civitai_dependency_model_name + sd_models.append(model) m = {} + sd_models = [] + print("[cloud-inference] refreshing models...") + + get_models("checkpoint") + get_models("lora") + get_models("controlnet") + get_models("vae") + + # build lora and checkpoint relationship for model in sd_models: m[model.name] = model @@ -563,6 +493,7 @@ def refresh_models(self): if m.get(model.dependency_model_name) is not None: m[model.dependency_model_name].append_child(model.name) + self.__class__.update_models_to_config(sd_models) self._models = sd_models return sd_models @@ -601,8 +532,27 @@ def get_controlnet_arg(p: processing.StableDiffusionProcessing): controlnet_arg = {} controlnet_arg['weight'] = c.weight - controlnet_arg['model'] = "control_v11f1e_sd15_tile" # TODO + controlnet_arg['model'] = c.model.strip("[cloud] ") controlnet_arg['module'] = c.module + if c.resize_mode == "Just Resize": + controlnet_arg['resize_mode'] = 0 + elif c.resize_mode == "Resize and Crop": + controlnet_arg['resize_mode'] = 1 + elif c.resize_mode == "Envelope (Outer Fit)": + controlnet_arg['resize_code'] = 2 + + if 'processor_res' in c.__dict__: + if c.processor_res > 0: + controlnet_arg['processor_res'] = c.processor_res + + if 'threshold_a' in c.__dict__: + controlnet_arg['threshold_a'] = int(c.threshold_a) + if 'threshold_b' in c.__dict__: + controlnet_arg['threshold_b'] = int(c.threshold_b) + if 'guidance_start' in c.__dict__: + controlnet_arg['guidance_start'] = c.guidance_start + if 'guidance_end' in c.__dict__: + controlnet_arg['guidance_end'] = c.guidance_end if c.control_mode == "Balanced": controlnet_arg['control_mode'] = 0 @@ -614,15 +564,13 @@ def get_controlnet_arg(p: processing.StableDiffusionProcessing): return if getattr(c.input_mode, 'value', '') == "simple": - base64_str = "" - if controlnet_units[0].image: - if "mask" in controlnet_units[0].image: - mask = Image.fromarray( - controlnet_units[0].image["mask"]) + if c.image: + if "mask" in c.image: + mask = Image.fromarray(c.image["mask"]) controlnet_arg['mask'] = image_to_base64(mask) controlnet_arg['input_image'] = image_to_base64( - Image.fromarray(controlnet_units[0].image["image"])) + Image.fromarray(c.image["image"])) if len(controlnet_batchs) <= 1: controlnet_batchs.append([]) @@ -667,3 +615,4 @@ def _download(img_url): applied.append(pool.apply_async(_download, (img_url, ))) ret = [r.get() for r in applied] return [_ for _ in ret if _ is not None] + diff --git a/extension/version.py b/extension/version.py index 51e0a06..de49d1f 100644 --- a/extension/version.py +++ b/extension/version.py @@ -1 +1 @@ -__version__ = "0.1.4" \ No newline at end of file +__version__ = "0.1.5" \ No newline at end of file diff --git a/scripts/proxy.py b/scripts/hijack.py similarity index 50% rename from scripts/proxy.py rename to scripts/hijack.py index 9a11954..71fb58b 100644 --- a/scripts/proxy.py +++ b/scripts/hijack.py @@ -1,5 +1,5 @@ import modules.scripts as scripts -import html +import os import sys import gradio as gr import importlib @@ -7,22 +7,53 @@ from modules import images, script_callbacks, errors, processing, ui, shared from modules.processing import Processed, StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, StableDiffusionProcessing from modules.shared import opts, state, prompt_styles +from collections import OrderedDict from extension import api -from inspect import getmembers, isfunction +from inspect import getmembers, isfunction, ismodule import random -import traceback -import os -class _Proxy(object): +class _HijackManager(object): - def __init__(self, fn): - self._fn = fn - self._patched = False + def __init__(self): + self._hijacked = False + self._xyz_hijacked = False self._binding = None + self.hijack_map = {} + + def hijack_one(self, name, new_fn): + tmp = name.rsplit('.', 1) + if len(tmp) < 2: + raise Exception('invalid module.func name: {}'.format(name)) + + module_name, func_name = tmp + old_fn = _hijack_func(module_name, func_name, new_fn) + if old_fn is None: + print('[cloud-inference] hijack failed: {}'.format(name)) + return False + + self.hijack_map[name] = { + 'old': old_fn, + 'new': new_fn, + } + + print('[cloud-inference] hijack {}, old: <{}>, new: <{}>'.format( + name, old_fn.__module__ + '.' + old_fn.__name__, new_fn.__module__ + '.' + new_fn.__name__)) + + def hijack_all(self, *args, **kwargs): + if self._hijacked: + return + + self.hijack_one('modules.processing.process_images', + self._hijack_process_images) + self.hijack_one('extensions.sd-webui-controlnet.scripts.global_state.update_cn_models', self._hijack_update_cn_models) + print('[cloud-inference] hijack finished') + def _apply_xyz(self): + if self._xyz_hijacked: + return def find_module(module_names): if isinstance(module_names, str): @@ -35,62 +66,70 @@ def find_module(module_names): xyz_grid = find_module("xyz_grid.py, xy_grid.py") if xyz_grid: + def xyz_checkpoint_apply(p: StableDiffusionProcessing, opt, v): + if '_cloud_inference_settings' not in p.__dict__: + p._cloud_inference_settings = {} - def xyz_model_apply(p: StableDiffusionProcessing, opt, v): m = self._binding.choice_to_model(opt) - if m.kind == 'lora': - p._remote_model_name = m.dependency_model_name - p.prompt = self._binding._update_lora_in_prompt( - p.prompt, m.name) - else: - p._remote_model_name = m.name - - def xyz_model_confirm(p, opt): + # if m.kind == 'lora': + # p._cloud_inference_settings['sd_checkpoint'] = m.dependency_model_name + # p.prompt = self._binding._update_lora_in_prompt( + # p.prompt, m.name) + # else: + p._cloud_inference_settings['sd_checkpoint'] = m.name + + def xyz_checkpoint_confirm(p, opt): return - def xyz_model_format(p, opt, v): + def xyz_checkpoint_format(p, opt, v): return self._binding.choice_to_model(v).name.rsplit(".", 1)[0] - xyz_grid.axis_options.append( - xyz_grid.AxisOption('[Cloud Inference] Model Name', - str, - apply=xyz_model_apply, - confirm=xyz_model_confirm, - format_value=xyz_model_format, - choices=self._binding.get_model_ckpt_choices)) + def xyz_vae_apply(p: StableDiffusionProcessing, opt, v): + if '_cloud_inference_settings' not in p.__dict__: + p._cloud_inference_settings = {} - def monkey_patch(self): - if self._patched: - return + p._cloud_inference_settings['sd_vae'] = opt - processing.process_images = self + def xyz_vae_confirm(p, opt): + return - keys = list(sys.modules.keys()) - for name in keys: - # if (name.startswith('modules') - # or name.startswith('scripts')) and name != 'modules.processing': - if name.startswith('modules') and name != 'modules.processing': - if 'process_images' in dict( - getmembers(sys.modules[name], isfunction)).keys(): - print('[cloud-inference] reloading', name) - importlib.reload(sys.modules[name]) + def xyz_vae_format(p, opt, v): + return v + print('[cloud-inference] hijack xyz_grid') + xyz_grid.axis_options.append( + xyz_grid.AxisOption('[Cloud Inference] Checkpoint', + str, + apply=xyz_checkpoint_apply, + confirm=xyz_checkpoint_confirm, + format_value=xyz_checkpoint_format, + choices=lambda: [_.display_name for _ in self._binding.remote_model_checkpoints])) + xyz_grid.axis_options.append( + xyz_grid.AxisOption('[Cloud Inference] VAE', + str, + apply=xyz_vae_apply, + confirm=xyz_vae_confirm, + format_value=xyz_vae_format, + choices=lambda: ["Automatic", "None"] + [_.name for _ in self._binding.remote_model_vaes])) + self._xyz_hijacked = True + + def _hijack_update_cn_models(self): from modules.scripts import scripts_data for script in scripts_data: - if hasattr(script.module, 'process_images'): - script.module.process_images = self - if hasattr( - script.module, 'processing' - ) and script.module.processing.__name__ == 'modules.processing': - script.module.processing.process_images = self - - self._apply_xyz() - print('[cloud-inference] monkey patched') - - self._patched = True - - def __call__(self, *args, **kwargs) -> Processed: + if script.module.__name__ == 'controlnet.py': + if self._binding.remote_inference_enabled: + script.module.global_state.cn_models.clear() + cn_models_keys = ["None"] + [_.name for _ in self._binding.remote_model_controlnet] + cn_models_dict = {k: None for k in cn_models_keys} + + script.module.global_state.cn_models.update(cn_models_dict) + script.module.global_state.cn_models_names.clear() + script.module.global_state.cn_models_names.update(cn_models_dict) + break + else: + self.hijack_map['extensions.sd-webui-controlnet.scripts.global_state.update_cn_models']['old']() + def _hijack_process_images(self, *args, **kwargs) -> Processed: if len(args) > 0 and isinstance(args[0], processing.StableDiffusionProcessing): p = args[0] @@ -98,9 +137,11 @@ def __call__(self, *args, **kwargs) -> Processed: raise Exception( 'process_images: first argument must be a processing object') - remote_inference_enabled, selected_model_index = get_visible_extension_args(p, 'cloud inference') + remote_inference_enabled, selected_checkpoint_index, selected_vae_name = get_visible_extension_args( + p, 'cloud inference') + if not remote_inference_enabled: - return self._fn(*args, **kwargs) + return self.hijack_map['modules.processing.process_images']['old'](*args, **kwargs) # random seed locally if not specified if p.seed == -1: @@ -112,8 +153,14 @@ def __call__(self, *args, **kwargs) -> Processed: state.textinfo = "remote inferencing ({})".format( api.get_instance().__class__.__name__) - if not getattr(p, '_remote_model_name', None): # xyz_grid - p._remote_model_name = self._binding.remote_sd_models[selected_model_index].name + + if '_cloud_inference_settings' not in p.__dict__: + p._cloud_inference_settings = {} + + if 'sd_checkpoint' not in p._cloud_inference_settings: + p._cloud_inference_settings['sd_checkpoint'] = self._binding.remote_model_checkpoints[selected_checkpoint_index].name + if 'sd_vae' not in p._cloud_inference_settings: + p._cloud_inference_settings['sd_vae'] = selected_vae_name if isinstance(p, StableDiffusionProcessingTxt2Img): generated_images = api.get_instance().txt2img(p) @@ -251,7 +298,7 @@ def create_infotext(p, "Seed": all_seeds[index], "Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Size": f"{p.width}x{p.height}", - "Model": (None if not opts.add_model_name_to_info or not p._remote_model_name else p._remote_model_name.replace(',', '').replace(':', '')), + "Model": (None if not opts.add_model_name_to_info or not p._cloud_inference_settings['sd_checkpoint'] else p._cloud_inference_settings['sd_checkpoint'].replace(',', '').replace(':', '')), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), @@ -300,4 +347,106 @@ def get_visible_extension_args(p: processing.StableDiffusionProcessing, name): return [] -_proxy = _Proxy(processing.process_images) +def _hijack_func(module_name, func_name, new_func): + old_func = None + extension_mode = False + extension_prefix = "" + if module_name.startswith('extensions.'): + extension_mode = True + + # from modules.processing import process_images + search_names = [module_name] + search_names.append(func_name) + tmp = module_name.split(".") + if len(tmp) >= 2: + search_names.append(".".join(tmp[-2:])) # import modules.processing + search_names.append(tmp[-1]) # from modules import processing + # from modules import processing.process_images + search_names.append("{}.{}".format(tmp[-1], func_name)) + + if not extension_mode: + # hajiack for normal module + + # case 1: import module, replace function, return old function + module = importlib.import_module(module_name) + old_func = getattr(module, func_name) + setattr(module, func_name, new_func) + + # case 2: from module import func_name + keys = list(sys.modules.keys()) + for name in keys: + # if (name.startswith('modules') + # or name.startswith('scripts')) and name != 'modules.processing': + if name.startswith('modules') and name != module_name: + members = getmembers(sys.modules[name], isfunction) + if func_name in dict(members): + # func_fullname = '{}.{}'.format(members[func_name].__module__, members[func_name].__name__) + # print(func_fullname, '{}.{}'.format(module_name, func_name)) + # if func_fullname == '{}.{}'.format(module_name, func_name): + print('[cloud-inference] reloading', name) + importlib.reload(sys.modules[name]) + + from modules.scripts import scripts_data + for script in scripts_data: + for name in search_names: + if name in script.module.__dict__: + obj = script.module.__dict__[name] + replace = False + if ismodule(obj) and obj.__file__ == module.__file__: + replace = True + elif isfunction(obj) and obj.__module__ == module_name: # ?? + replace = True + + if replace: + if name == func_name: + print( + '[cloud-inference] reloading {} - {}'.format(script.module.__name__, func_name)) + setattr(script.module, name, new_func) + else: + print( + '[cloud-inference] reloading {} - {}'.format(script.module.__name__, name)) + t = getattr(script.module, name) + setattr(t, func_name, new_func) + # setattr(script.module, name, t) # ? + return old_func + else: + # hijack for extension module + + from modules.scripts import scripts_data + tmp1, tmp2, extension_suffix = module_name.split( + '.', 2) # scripts internal module name + extension_prefix = "{}.{}".format(tmp1, tmp2) + module_name = "modules.{}".format(extension_suffix) + + for script in scripts_data: + if extension_mode and os.path.join(*extension_prefix.split('.')) not in script.basedir: + continue + for name in search_names: + if name in script.module.__dict__: + obj = script.module.__dict__[name] + replace = False + if ismodule(obj) and module_name.endswith(obj.__name__): + replace = True + elif isfunction(obj) and obj.__module__ == module_name: # ?? + replace = True + + if replace: + # import pkgutil + # s = [_ for _ in pkgutil.iter_modules([os.path.dirname(script.module.__file__)])] + if name == func_name: + print( + '[cloud-inference] reloading {} - {}'.format(script.module.__name__, func_name)) + old_func = getattr(script.module, name) + setattr(script.module, name, new_func) + else: + print( + '[cloud-inference] reloading {} - {}'.format(script.module.__name__, name)) + t = getattr(script.module, name) + old_func = getattr(t, func_name) + setattr(t, func_name, new_func) + setattr(script.module, name, t) + # print(importlib.reload(importlib.import_module('extensions.sd-webui-controlnet.scripts.xyz_grid_support'))) + return old_func + + +_hijack_manager = _HijackManager() diff --git a/scripts/main_ui.py b/scripts/main_ui.py index 88ed9e0..7a22e2b 100644 --- a/scripts/main_ui.py +++ b/scripts/main_ui.py @@ -1,23 +1,16 @@ import modules.scripts as scripts -import html -import sys import gradio as gr -from modules import images, script_callbacks, errors, processing, ui, shared -from modules.processing import Processed, StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, StableDiffusionProcessing -from modules.shared import opts, state, prompt_styles +from modules import script_callbacks, shared from extension import api -from inspect import getmembers, isfunction import random -import traceback import os -DEMO_MODE = os.getenv("CLOUD_INFERENCE_DEMO_MODE") - refresh_symbol = '\U0001f504' # 🔄 +favorite_symbol = '\U0001f49e' # 💞 class FormComponent: @@ -54,31 +47,30 @@ def __init__(self): self.img2img_cloud_inference_checkbox = None self.txt2img_cloud_inference_checkbox = None + self.txt2img_cloud_inference_vae_dropdown = None + self.img2img_cloud_inference_vae_dropdown = None + self.txt2img_cloud_inference_suggest_prompts_checkbox = None self.img2img_cloud_inference_suggest_prompts_checkbox = None - # self.cloud_api_dropdown = None + self.remote_inference_enabled = False + + self.remote_models = None + self.remote_model_checkpoints = None + self.remote_model_loras = None + self.remote_model_controlnet = None + self.remote_model_vaes = None + + # third component + self.txt2img_controlnet_model_dropdown = None + self.img2img_controlnet_model_dropdown = None - self.remote_sd_models = None self.default_remote_model = None self.initialized = False - def update_selected_model(self, name_index: int, selected_loras: list[str], suggest_prompts_enabled, prompt: str, neg_prompt: str): - selected: api.StableDiffusionModel = self.remote_sd_models[name_index] - selected_checkpoint: api.StableDiffusionModel = None - selected_checkpoint_index: int = 0 - - # if selected model is lora, then we need to get base model of it and set selected model to base model - if selected.kind == 'lora': - for idx, model in enumerate(self.remote_sd_models): - if model.name == selected.dependency_model_name: - selected_checkpoint = model - # selected_checkpoint_index = idx - selected_loras = [selected.name] - break - else: - selected_checkpoint = selected - # selected_checkpoint_index = name_index + def on_selected_model(self, name_index: int, suggest_prompts_enabled, prompt: str, neg_prompt: str): + selected: api.StableDiffusionModel = self.remote_model_checkpoints[name_index] + selected_checkpoint = selected # name = self.remote_sd_models[name_index].name prompt = prompt @@ -88,19 +80,28 @@ def update_selected_model(self, name_index: int, selected_loras: list[str], sugg if selected.example.prompts is not None and suggest_prompts_enabled: prompt = selected.example.prompts prompt = prompt.replace("\n", "") - if len(selected_loras) > 0: - prompt = self._update_lora_in_prompt( - selected.example.prompts, selected_loras) - prompt = prompt.replace("\n", "") if selected.example.neg_prompt is not None and suggest_prompts_enabled: neg_prompt = selected.example.neg_prompt return gr.Dropdown.update( - choices=[_.display_name for _ in self.remote_sd_models], - value=selected_checkpoint.display_name), gr.update(value=selected_loras), gr.update(value=prompt), gr.update(value=neg_prompt) + choices=[_.display_name for _ in self.remote_model_checkpoints], + value=selected_checkpoint.display_name), gr.update(value=prompt), gr.update(value=neg_prompt) + + def update_models(self): + _binding.remote_model_loras = _get_kind_from_remote_models(_binding.remote_models, "lora") + _binding.remote_model_checkpoints = _get_kind_from_remote_models(_binding.remote_models, "checkpoint") + _binding.remote_model_vaes = _get_kind_from_remote_models(_binding.remote_models, "vae") + _binding.remote_model_controlnet = _get_kind_from_remote_models(_binding.remote_models, "controlnet") + + @staticmethod - def _update_lora_in_prompt(prompt, lora_names, weight=1): + def _update_lora_in_prompt(prompt, _lora_names, weight=1): + lora_names = [] + for lora_name in _lora_names: + lora_names.append( + _binding.find_model_by_display_name(lora_name).name) + prompt = prompt add_lora_prompts = [] @@ -123,36 +124,23 @@ def _update_lora_in_prompt(prompt, lora_names, weight=1): return ", ".join(prompt_split) def update_selected_lora(self, lora_names, prompt): - print("[cloud-inference] set_selected_lora", lora_names) return gr.update(value=self._update_lora_in_prompt(prompt, lora_names)) def update_cloud_api(self, v): - # TODO: support multiple cloud api provider - print("[cloud-inference] set_cloud_api", v) self.cloud_api = v - def get_selected_model_loras(self): - ret = [] - for ckpt in self.remote_sd_models: - if ckpt.name == self.selected_checkpoint.name: - for lora_name in ckpt.child: - ret.append(lora_name) - return ret - - def get_model_loras_cohices(self, base=None): - ret = [] - for model in self.remote_sd_models: - if model.kind == 'lora': - ret.append(model.name) - return ret - - def choice_to_model(self, choice): # display_name -> sd_name - for model in self.remote_sd_models: + def find_model_by_display_name(self, choice): # display_name -> sd_name + for model in self.remote_models: if model.display_name == choice: return model - def get_model_ckpt_choices(self): - return [_.display_name for _ in self.remote_sd_models] + +def _get_kind_from_remote_models(models, kind): + t = [] + for model in models: + if model.kind == kind: + t.append(model) + return t class CloudInferenceScript(scripts.Script): @@ -169,21 +157,26 @@ def ui(self, is_img2img): tabname = "img2img" # data initialize, TODO: move - if _binding.remote_sd_models is None or len(_binding.remote_sd_models) == 0: - _binding.remote_sd_models = api.get_instance().list_models() + if _binding.remote_models is None or len(_binding.remote_models) == 0: + _binding.remote_models = api.get_instance().list_models() + _binding.update_models() - top_n = min(len(_binding.remote_sd_models), 50) + top_n = min(len(_binding.remote_model_checkpoints), 50) if _binding.default_remote_model is None: _binding.default_remote_model = random.choice( - _binding.remote_sd_models[:top_n]).display_name if len(_binding.remote_sd_models) > 0 else None + _binding.remote_model_checkpoints[:top_n]).display_name if len(_binding.remote_model_checkpoints) > 0 else None + + default_enabled = shared.opts.data.get( + "cloud_inference_default_enabled", False) + if default_enabled: + _binding.remote_inference_enabled = True # define ui layouts with gr.Accordion('Cloud Inference', open=True): with gr.Row(): cloud_inference_checkbox = gr.Checkbox( label="Enable Cloud Inference", - value=lambda: shared.opts.data.get( - "cloud_inference_default_enabled", False), + value=lambda: default_enabled, visible=not shared.opts.data.get( "cloud_inference_checkbox_hidden", False), elem_id="{}_cloud_inference_checkbox".format(tabname)) @@ -202,28 +195,49 @@ def ui(self, is_img2img): ) cloud_inference_model_dropdown = gr.Dropdown( - label="Quick Select (Checkpoint/Lora)", - choices=_binding.get_model_ckpt_choices(), + label="Checkpoint", + choices=[ + _.display_name for _ in _binding.remote_model_checkpoints], value=lambda: _binding.default_remote_model, type="index", elem_id="{}_cloud_inference_model_dropdown".format(tabname)) refresh_button = ToolButton( value=refresh_symbol, elem_id="{}_cloud_inference_refersh_button".format(tabname)) + # favorite_button = ToolButton( + # value=favorite_symbol, elem_id="{}_cloud_inference_favorite_button".format(tabname)) with gr.Row(): cloud_inference_lora_dropdown = gr.Dropdown( - choices=_binding.get_model_loras_cohices(), + choices=[_.display_name for _ in _binding.remote_model_loras], label="Lora", - elem_id="{}_cloud_inference_lora_dropdown", multiselect=True) + elem_id="{}_cloud_inference_lora_dropdown", multiselect=True, scale=4) + + cloud_inference_extra_checkbox = gr.Checkbox( + label="Extra", + value=False, + elem_id="{}_cloud_inference_extra_subseed_show", + scale=1 + ) + + with gr.Row(visible=False) as extra_row: + cloud_inference_vae_dropdown = gr.Dropdown( + choices=["Automatic", "None"] + [ + _.name for _ in _binding.remote_model_vaes], + value="Automatic", + label="VAE", + elme_id="{}_cloud_inference_vae_dropdown".format(tabname), + ) + + cloud_inference_extra_checkbox.change(lambda x: gr.update(visible=x), inputs=[ + cloud_inference_extra_checkbox], outputs=[extra_row]) # define events of components. # auto fill prompt after select model cloud_inference_model_dropdown.select( - fn=_binding.update_selected_model, + fn=_binding.on_selected_model, inputs=[ cloud_inference_model_dropdown, - cloud_inference_lora_dropdown, cloud_inference_suggest_prompts_checkbox, getattr(_binding, "{}_prompt".format(tabname)), getattr(_binding, "{}_neg_prompt".format(tabname)) @@ -231,7 +245,6 @@ def ui(self, is_img2img): ], outputs=[ cloud_inference_model_dropdown, - cloud_inference_lora_dropdown, getattr(_binding, "{}_prompt".format(tabname)), getattr(_binding, "{}_neg_prompt".format(tabname)) ]) @@ -249,24 +262,29 @@ def ui(self, is_img2img): def _model_refresh(): api.get_instance().refresh_models() - _binding.remote_sd_models = api.get_instance().list_models() - return gr.update(choices=[_.display_name for _ in _binding.remote_sd_models]), gr.update(choices=[_.name for _ in _binding.remote_sd_models if _.kind == 'lora']) + # TODO: fix name_index out of range + _binding.remote_models = api.get_instance().list_models() + _binding.update_models() + + return gr.update(choices=[_.display_name for _ in _binding.remote_model_checkpoints]), gr.update(choices=[_.display_name for _ in _binding.remote_model_loras]), gr.update(choices=["Automatic", "None"] + [_.name for _ in _binding.remote_model_vaes]) refresh_button.click( fn=_model_refresh, inputs=[], outputs=[cloud_inference_model_dropdown, - cloud_inference_lora_dropdown]) + cloud_inference_lora_dropdown, + cloud_inference_vae_dropdown + ]) - return [cloud_inference_checkbox, cloud_inference_model_dropdown] + return [cloud_inference_checkbox, cloud_inference_model_dropdown, cloud_inference_vae_dropdown] _binding = None if _binding is None: _binding = DataBinding() - from scripts.proxy import _proxy - _proxy._binding = _binding - _proxy.monkey_patch() + from scripts.hijack import _hijack_manager + _hijack_manager._binding = _binding + _hijack_manager._apply_xyz() # TOOD print('Loading extension: sd-webui-cloud-inference') @@ -302,6 +320,16 @@ def on_after_component_callback(component, **_kwargs): if type(component) is gr.Dropdown and getattr(component, 'elem_id', None) == 'img2img_cloud_inference_model_dropdown': _binding.img2img_cloud_inference_model_dropdown = component + if type(component) is gr.Dropdown and getattr(component, 'elem_id', None) == 'txt2img_controlnet_ControlNet-0_controlnet_model_dropdown': + _binding.txt2img_controlnet_model_dropdown = component + if type(component) is gr.Dropdown and getattr(component, 'elem_id', None) == 'img2img_controlnet_ControlNet-0_controlnet_model_dropdown': + _binding.img2img_controlnet_model_dropdown = component + + # if type(component) is gr.Dropdown and getattr(component, 'elem_id', None) == 'txt2img_cloud_inference_vae_dropdown': + # _binding.txt2img_cloud_inference_vae_dropdown = component + # if type(component) is gr.Dropdown and getattr(component, 'elem_id', None) == 'img2img_cloud_inference_vae_dropdown': + # _binding.img2img_cloud_inference_vae_dropdown = component + if _binding.txt2img_cloud_inference_checkbox and \ _binding.img2img_cloud_inference_checkbox and \ _binding.txt2img_cloud_inference_model_dropdown and \ @@ -310,6 +338,8 @@ def on_after_component_callback(component, **_kwargs): _binding.img2img_cloud_inference_suggest_prompts_checkbox and \ _binding.txt2img_generate and \ _binding.img2img_generate and \ + _binding.txt2img_controlnet_model_dropdown and \ + _binding.img2img_controlnet_model_dropdown and \ not _binding.initialized: sync_cloud_model(_binding.txt2img_cloud_inference_model_dropdown, @@ -317,9 +347,13 @@ def on_after_component_callback(component, **_kwargs): sync_two_component(_binding.txt2img_cloud_inference_suggest_prompts_checkbox, _binding.img2img_cloud_inference_suggest_prompts_checkbox, 'change') + # sync_two_component(_binding.txt2img_cloud_inference_vae_dropdown, + # _binding.img2img_cloud_inference_vae_dropdown, + # 'select' + # ) sync_cloud_inference_checkbox(_binding.txt2img_cloud_inference_checkbox, - _binding.img2img_cloud_inference_checkbox, _binding.txt2img_generate, _binding.img2img_generate) + _binding.img2img_cloud_inference_checkbox, _binding.txt2img_generate, _binding.img2img_generate, _binding.txt2img_controlnet_model_dropdown, _binding.img2img_controlnet_model_dropdown) _binding.initialized = True @@ -338,22 +372,22 @@ def mirror(a, b): if a != b: b = a - target_model = _binding.remote_sd_models[b] + target_model = _binding.remote_model_checkpoints[b] # TODO if target_model.kind == 'lora' and target_model.dependency_model_name != None: - for model in _binding.remote_sd_models: + for model in _binding.remote_models: if model.name == target_model.dependency_model_name: b = model.display_name break elif target_model.kind == 'checkpoint': b = target_model.display_name - return _binding.remote_sd_models[a].display_name, b + return _binding.remote_model_checkpoints[a].display_name, b getattr(a, "select")(fn=mirror, inputs=[a, b], outputs=[a, b]) getattr(b, "select")(fn=mirror, inputs=[b, a], outputs=[b, a]) -def sync_cloud_inference_checkbox(txt2img_checkbox, img2img_checkbox, txt2img_generate_button, img2img_generate_button): +def sync_cloud_inference_checkbox(txt2img_checkbox, img2img_checkbox, txt2img_generate_button, img2img_generate_button, txt2img_controlnet_model_dropdown, img2img_controlnet_model_dropdown): def mirror(source, target): enabled = source @@ -362,13 +396,42 @@ def mirror(source, target): button_text = "Generate" if enabled: + _binding.remote_inference_enabled = True button_text = "Generate (cloud)" - return source, target, button_text, button_text + else: + _binding.remote_inference_enabled = False + + # TODO + # controlnet_models = [ + # "None", + # "[cloud] control_v11e_sd15_ip2p", + # "[cloud] control_v11e_sd15_shuffle", + # "[cloud] control_v11f1e_sd15_tile", + # "[cloud] control_v11f1p_sd15_depth", + # "[cloud] control_v11p_sd15_canny", + # "[cloud] control_v11p_sd15_inpaint", + # "[cloud] control_v11p_sd15_lineart", + # "[cloud] control_v11p_sd15_mlsd", + # "[cloud] control_v11p_sd15_normalbae", + # "[cloud] control_v11p_sd15_openpose", + # "[cloud] control_v11p_sd15_scribble", + # "[cloud] control_v11p_sd15_seg", + # "[cloud] control_v11p_sd15_softedge", + # "[cloud] control_v11p_sd15s2_lineart_anime", + # ] + + controlnet_models = ["None"] + \ + [_.name for _ in _binding.remote_model_controlnet] + + if not enabled: + return source, target, button_text, button_text, None, None + + return source, target, button_text, button_text, gr.update(value=controlnet_models[0], choices=controlnet_models), gr.update(value=controlnet_models[0], choices=controlnet_models) txt2img_checkbox.change(fn=mirror, inputs=[txt2img_checkbox, img2img_checkbox], outputs=[ - txt2img_checkbox, img2img_checkbox, txt2img_generate_button, img2img_generate_button]) + txt2img_checkbox, img2img_checkbox, txt2img_generate_button, img2img_generate_button, txt2img_controlnet_model_dropdown, img2img_controlnet_model_dropdown]) img2img_checkbox.change(fn=mirror, inputs=[img2img_checkbox, txt2img_checkbox], outputs=[ - img2img_checkbox, txt2img_checkbox, txt2img_generate_button, img2img_generate_button]) + img2img_checkbox, txt2img_checkbox, txt2img_generate_button, img2img_generate_button, txt2img_controlnet_model_dropdown, img2img_controlnet_model_dropdown]) def on_ui_settings(): @@ -381,3 +444,4 @@ def on_ui_settings(): script_callbacks.on_after_component(on_after_component_callback) script_callbacks.on_ui_settings(on_ui_settings) +script_callbacks.on_app_started(_hijack_manager.hijack_all)