Skip to content

Commit

Permalink
Merge pull request #55 from NexaAI/zack/E2E-flow
Browse files Browse the repository at this point in the history
add retry logic for stable diffusion
  • Loading branch information
zhiyuan8 authored Aug 28, 2024
2 parents d4f3bcd + 838310a commit 13a45c7
Showing 1 changed file with 60 additions and 38 deletions.
98 changes: 60 additions & 38 deletions nexa/gguf/nexa_inference_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
RETRY_ATTEMPTS = (
3 # a temporary fix for the issue of segmentation fault for stable-diffusion-cpp
)


class NexaImageInference:
Expand All @@ -48,7 +51,6 @@ class NexaImageInference:
"""


def __init__(self, model_path, local_path=None, **kwargs):
self.model_path = model_path
self.downloaded_path = local_path
Expand Down Expand Up @@ -81,6 +83,7 @@ def __init__(self, model_path, local_path=None, **kwargs):
def _load_model(self, model_path: str):
with suppress_stdout_stderr():
from nexa.gguf.sd.stable_diffusion import StableDiffusion

self.model = StableDiffusion(
model_path=self.downloaded_path,
lora_model_dir=self.params.get("lora_dir", ""),
Expand All @@ -105,16 +108,28 @@ def _save_images(self, images):
image.save(file_path)
print(f"\nImage {i+1} saved to: {os.path.abspath(file_path)}")

def txt2img(self,
prompt,
negative_prompt="",
cfg_scale=7.5,
width=512,
height=512,
sample_steps=20,
seed=0,
control_cond="",
control_strength=0.9):
def _retry(self, func, *args, **kwargs):
for attempt in range(RETRY_ATTEMPTS):
try:
return func(*args, **kwargs)
except Exception as e:
logging.error(f"Attempt {attempt + 1} failed with error: {e}")
time.sleep(1)
logging.error("All retry attempts failed.")
return None

def txt2img(
self,
prompt,
negative_prompt="",
cfg_scale=7.5,
width=512,
height=512,
sample_steps=20,
seed=0,
control_cond="",
control_strength=0.9,
):
"""
Used for SDK. Generate images from text.
Expand All @@ -125,7 +140,8 @@ def txt2img(self,
Returns:
list: List of generated images.
"""
images = self.model.txt_to_img(
images = self._retry(
self.model.txt_to_img,
prompt=prompt,
negative_prompt=negative_prompt,
cfg_scale=cfg_scale,
Expand Down Expand Up @@ -157,25 +173,28 @@ def run_txt2img(self):
control_cond=self.params.get("control_image_path", ""),
control_strength=self.params.get("control_strength", 0.9),
)
self._save_images(images)
if images:
self._save_images(images)
except Exception as e:
logging.error(f"Error during text to image generation: {e}")
except KeyboardInterrupt:
print(EXIT_REMINDER)
except Exception as e:
logging.error(f"Error during generation: {e}", exc_info=True)

def img2img(self,
image_path,
prompt,
negative_prompt="",
cfg_scale=7.5,
width=512,
height=512,
sample_steps=20,
seed=0,
control_cond="",
control_strength=0.9):
def img2img(
self,
image_path,
prompt,
negative_prompt="",
cfg_scale=7.5,
width=512,
height=512,
sample_steps=20,
seed=0,
control_cond="",
control_strength=0.9,
):
"""
Used for SDK. Generate images from an image.
Expand All @@ -187,7 +206,8 @@ def img2img(self,
Returns:
list: List of generated images.
"""
images = self.model.img_to_img(
images = self._retry(
self.model.img_to_img,
image=image_path,
prompt=prompt,
negative_prompt=negative_prompt,
Expand All @@ -209,19 +229,21 @@ def run_img2img(self):
negative_prompt = nexa_prompt(
"Enter your negative prompt (press Enter to skip): "
)
images = self.img2img(image_path,
prompt,
negative_prompt,
cfg_scale=self.params["guidance_scale"],
width=self.params["width"],
height=self.params["height"],
sample_steps=self.params["num_inference_steps"],
seed=self.params["random_seed"],
control_cond=self.params.get("control_image_path", ""),
control_strength=self.params.get("control_strength", 0.9),
)
images = self.img2img(
image_path,
prompt,
negative_prompt,
cfg_scale=self.params["guidance_scale"],
width=self.params["width"],
height=self.params["height"],
sample_steps=self.params["num_inference_steps"],
seed=self.params["random_seed"],
control_cond=self.params.get("control_image_path", ""),
control_strength=self.params.get("control_strength", 0.9),
)

self._save_images(images)
if images:
self._save_images(images)
except KeyboardInterrupt:
print(EXIT_REMINDER)
except Exception as e:
Expand Down Expand Up @@ -309,4 +331,4 @@ def run_streamlit(self, model_path: str):
if args.img2img:
inference.run_img2img()
else:
inference.run_txt2img()
inference.run_txt2img()

0 comments on commit 13a45c7

Please sign in to comment.