From 9ea5c0d599ecee09dce4a2bae68b14b570927913 Mon Sep 17 00:00:00 2001 From: CaradryanLiang Date: Thu, 23 May 2024 08:04:52 -0700 Subject: [PATCH] place strength back to gsa --- diffusers/scripts/train_gsa.py | 6 ++++-- diffusers/stable_copyright/gsa_pipeline_latent_diffusion.py | 1 + diffusers/stable_copyright/gsa_pipeline_sdxl.py | 2 ++ diffusers/stable_copyright/gsa_pipeline_stable_diffusion.py | 2 ++ 4 files changed, 9 insertions(+), 2 deletions(-) diff --git a/diffusers/scripts/train_gsa.py b/diffusers/scripts/train_gsa.py index ee37b2d..daf245d 100644 --- a/diffusers/scripts/train_gsa.py +++ b/diffusers/scripts/train_gsa.py @@ -15,7 +15,7 @@ from sklearn.linear_model import SGDClassifier from sklearn import preprocessing -from stable_copyright import GSALatentDiffusionPipeline, SecMIDDIMScheduler, GSAStableDiffusionPipeline +from stable_copyright import GSALatentDiffusionPipeline, SecMIDDIMScheduler, GSAStableDiffusionPipeline, GSAStableDiffusionXLPipeline from stable_copyright import load_dataset, benchmark, test @@ -28,7 +28,9 @@ def load_pipeline(ckpt_path, device='cuda:0', model_type='sd'): pipe = GSALatentDiffusionPipeline.from_pretrained(ckpt_path, torch_dtype=torch.float16) # pipe.scheduler = SecMIDDIMScheduler.from_config(pipe.scheduler.config) elif model_type == 'sdxl': - raise NotImplementedError('SDXL not implemented yet') + pipe = GSAStableDiffusionXLPipeline.from_pretrained(ckpt_path, torch_dtype=torch.float16) + pipe.scheduler = SecMIDDIMScheduler.from_config(pipe.scheduler.config) + pipe = pipe.to(device) else: raise NotImplementedError(f'Unrecognized model type {model_type}') return pipe diff --git a/diffusers/stable_copyright/gsa_pipeline_latent_diffusion.py b/diffusers/stable_copyright/gsa_pipeline_latent_diffusion.py index 887003d..2fd68df 100644 --- a/diffusers/stable_copyright/gsa_pipeline_latent_diffusion.py +++ b/diffusers/stable_copyright/gsa_pipeline_latent_diffusion.py @@ -123,6 +123,7 @@ def __call__( # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) posterior_results = [] original_latents = latents.detach().clone() diff --git a/diffusers/stable_copyright/gsa_pipeline_sdxl.py b/diffusers/stable_copyright/gsa_pipeline_sdxl.py index 9eec747..115824d 100644 --- a/diffusers/stable_copyright/gsa_pipeline_sdxl.py +++ b/diffusers/stable_copyright/gsa_pipeline_sdxl.py @@ -79,6 +79,7 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], + strength: float=1.0, gsa_mode: int = 1, **kwargs, ): @@ -169,6 +170,7 @@ def __call__( # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels diff --git a/diffusers/stable_copyright/gsa_pipeline_stable_diffusion.py b/diffusers/stable_copyright/gsa_pipeline_stable_diffusion.py index 2f8510a..b68be34 100644 --- a/diffusers/stable_copyright/gsa_pipeline_stable_diffusion.py +++ b/diffusers/stable_copyright/gsa_pipeline_stable_diffusion.py @@ -63,6 +63,7 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, + strength: float=1.0, gsa_mode: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, @@ -239,6 +240,7 @@ def __call__( # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels