Skip to content

Commit

Permalink
place strength back to gsa
Browse files Browse the repository at this point in the history
  • Loading branch information
caradryanl committed May 23, 2024
1 parent dac7213 commit 9ea5c0d
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 2 deletions.
6 changes: 4 additions & 2 deletions diffusers/scripts/train_gsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from sklearn.linear_model import SGDClassifier
from sklearn import preprocessing

from stable_copyright import GSALatentDiffusionPipeline, SecMIDDIMScheduler, GSAStableDiffusionPipeline
from stable_copyright import GSALatentDiffusionPipeline, SecMIDDIMScheduler, GSAStableDiffusionPipeline, GSAStableDiffusionXLPipeline
from stable_copyright import load_dataset, benchmark, test


Expand All @@ -28,7 +28,9 @@ def load_pipeline(ckpt_path, device='cuda:0', model_type='sd'):
pipe = GSALatentDiffusionPipeline.from_pretrained(ckpt_path, torch_dtype=torch.float16)
# pipe.scheduler = SecMIDDIMScheduler.from_config(pipe.scheduler.config)
elif model_type == 'sdxl':
raise NotImplementedError('SDXL not implemented yet')
pipe = GSAStableDiffusionXLPipeline.from_pretrained(ckpt_path, torch_dtype=torch.float16)
pipe.scheduler = SecMIDDIMScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(device)
else:
raise NotImplementedError(f'Unrecognized model type {model_type}')
return pipe
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __call__(

# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)

posterior_results = []
original_latents = latents.detach().clone()
Expand Down
2 changes: 2 additions & 0 deletions diffusers/stable_copyright/gsa_pipeline_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __call__(
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
strength: float=1.0,
gsa_mode: int = 1,
**kwargs,
):
Expand Down Expand Up @@ -169,6 +170,7 @@ def __call__(

# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)

# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
Expand Down
2 changes: 2 additions & 0 deletions diffusers/stable_copyright/gsa_pipeline_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __call__(
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
strength: float=1.0,
gsa_mode: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
Expand Down Expand Up @@ -239,6 +240,7 @@ def __call__(

# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)

# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
Expand Down

0 comments on commit 9ea5c0d

Please sign in to comment.