Skip to content

Commit

Permalink
add prompts for standardized framework
Browse files Browse the repository at this point in the history
  • Loading branch information
caradryanl committed May 23, 2024
1 parent 53e5988 commit b8bc643
Show file tree
Hide file tree
Showing 17 changed files with 103 additions and 186 deletions.
1 change: 1 addition & 0 deletions diffusers/scripts/exp_sdxl_demo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python scripts/train_secmi.py --model-type sdxl --ckpt-path ../models/diffusers/Kohaku-XL-Epsilon/ --member-dataset hakubooru-2-5k-member --holdout-dataset hakubooru-2-5k-nonmember --batch-size 3 --demo True
10 changes: 10 additions & 0 deletions diffusers/scripts/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import numpy as np
from sklearn import preprocessing

member_features = np.array([np.nan, np.inf, -np.inf, 1000, 0, 4, 30000, -899])
membermax, membermin = member_features[~np.isposinf(member_features)].max(), member_features[~np.isneginf(member_features)].min()
member_features = np.nan_to_num(member_features, nan=0, posinf=membermax, neginf=membermin)

x = preprocessing.scale(member_features)
x = np.nan_to_num(member_features, nan=0)
print(x, member_features)
4 changes: 2 additions & 2 deletions diffusers/scripts/train_gsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def get_reverse_denoise_results(pipe, dataloader, device, gsa_mode, demo):
features, path_log = [], []
for batch_idx, batch in enumerate(tqdm.tqdm(dataloader)):
path_log.extend(batch['path'])
latents, encoder_hidden_states = pipe.prepare_inputs(batch, weight_dtype, device)
latents, encoder_hidden_states, prompts = pipe.prepare_inputs(batch, weight_dtype, device)
out = pipe(\
accelerator=accelerator, optimizer=optimizer, prompt=None, latents=latents, \
accelerator=accelerator, optimizer=optimizer, prompt=prompts, latents=latents, \
prompt_embeds=encoder_hidden_states, guidance_scale=1.0, num_inference_steps=20, gsa_mode=gsa_mode)
gsa_features = out.gsa_features # # [bsz x Tensor(num_p)]
# print(f"gsa: {gsa_features}")
Expand Down
4 changes: 2 additions & 2 deletions diffusers/scripts/train_pfami.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def get_reverse_denoise_results(pipe, dataloader, device, strengths, demo):
for strength in strengths:
input_batch = deepcopy(batch)
input_batch["pixel_values"] = image_perturbation(input_batch["pixel_values"], strength)
latents, encoder_hidden_states = pipe.prepare_inputs(input_batch, weight_dtype, device)
out = pipe(prompt=None, latents=latents, prompt_embeds=encoder_hidden_states, \
latents, encoder_hidden_states, prompts = pipe.prepare_inputs(input_batch, weight_dtype, device)
out = pipe(prompt=prompts, latents=latents, prompt_embeds=encoder_hidden_states, \
guidance_scale=1.0, num_inference_steps=100)
_, posterior_results, denoising_results = out.images, out.posterior_results, out.denoising_results
# [len(attack_timesteps) x [B, 4, 64, 64]]
Expand Down
4 changes: 2 additions & 2 deletions diffusers/scripts/train_pia.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def get_reverse_denoise_results(pipe, dataloader, device, normalized, demo):
scores_sum, scores_all_steps, path_log = [], [], []
for batch_idx, batch in enumerate(tqdm.tqdm(dataloader)):
path_log.extend(batch['path'])
latents, encoder_hidden_states = pipe.prepare_inputs(batch, weight_dtype, device)
latents, encoder_hidden_states, prompts = pipe.prepare_inputs(batch, weight_dtype, device)
out = pipe(\
prompt=None, latents=latents, prompt_embeds=encoder_hidden_states, \
prompt=prompts, latents=latents, prompt_embeds=encoder_hidden_states, \
guidance_scale=1.0, num_inference_steps=100, normalized=normalized, strength=0.5)
_, posterior_results, denoising_results = out.images, out.posterior_results, out.denoising_results

Expand Down
10 changes: 6 additions & 4 deletions diffusers/scripts/train_secmi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import argparse
import json,time

from stable_copyright import SecMILatentDiffusionPipeline, SecMIStableDiffusionPipeline, SecMIDDIMScheduler
from stable_copyright import SecMILatentDiffusionPipeline, SecMIStableDiffusionPipeline, SecMIDDIMScheduler, SecMIStableDiffusionXLPipeline
from stable_copyright import load_dataset, benchmark, test


Expand All @@ -22,7 +22,9 @@ def load_pipeline(ckpt_path, device='cuda:0', model_type='sd'):
pipe = SecMILatentDiffusionPipeline.from_pretrained(ckpt_path, torch_dtype=torch.float32)
# pipe.scheduler = SecMIDDIMScheduler.from_config(pipe.scheduler.config)
elif model_type == 'sdxl':
raise NotImplementedError('SDXL not implemented yet')
pipe = SecMIStableDiffusionXLPipeline.from_pretrained(ckpt_path, torch_dtype=torch.float32)
pipe.scheduler = SecMIDDIMScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(device)
else:
raise NotImplementedError(f'Unrecognized model type {model_type}')
return pipe
Expand All @@ -34,8 +36,8 @@ def get_reverse_denoise_results(pipe, dataloader, device, demo=False):
scores_50_step, scores_all_steps, path_log = [], [], []
for batch_idx, batch in enumerate(tqdm.tqdm(dataloader)):
path_log.extend(batch['path'])
latents, encoder_hidden_states = pipe.prepare_inputs(batch, weight_dtype, device)
out = pipe(prompt=None, latents=latents, prompt_embeds=encoder_hidden_states, guidance_scale=1.0, num_inference_steps=100)
latents, encoder_hidden_states, prompts = pipe.prepare_inputs(batch, weight_dtype, device)
out = pipe(prompt=prompts, latents=latents, prompt_embeds=encoder_hidden_states, guidance_scale=1.0, num_inference_steps=100)
_, posterior_results, denoising_results = out.images, out.posterior_results, out.denoising_results

# print(f'posterior {posterior_results[0].shape}')
Expand Down
1 change: 1 addition & 0 deletions diffusers/stable_copyright/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .secmi_pipeline_stable_diffusion import SecMIStableDiffusionPipeline
from .secmi_scheduling_ddim import SecMIDDIMScheduler
from .secmi_pipeline_latent_diffusion import SecMILatentDiffusionPipeline
from .secmi_pipeline_sdxl import SecMIStableDiffusionXLPipeline

from .pia_pipeline_stable_diffusion import PIAStableDiffusionPipeline
from .pia_pipeline_latent_diffusion import PIALatentDiffusionPipeline
Expand Down
5 changes: 3 additions & 2 deletions diffusers/stable_copyright/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ def collate_fn(examples):
else:
input_ids = torch.stack([example["input_ids"] for example in examples])
path = [example["path"] for example in examples]
return {"pixel_values": pixel_values, "input_ids": input_ids, "path": path}
prompts = [example["prompt"] for example in examples]
return {"pixel_values": pixel_values, "input_ids": input_ids, "path": path, "prompts": prompts}

class StandardTransform:
def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None:
Expand Down Expand Up @@ -234,7 +235,7 @@ def __getitem__(self, index: int):
image, input_id = StandardTransform(self.transforms, None)(image, input_id)

# return image, target
return {"pixel_values": image, "input_ids": input_id, 'caption': caption, 'path': img_name}
return {"pixel_values": image, "input_ids": input_id, 'prompt': caption, 'path': img_name}


def load_dataset(dataset_root, ckpt_path, dataset: str='laion-aesthetic-2-5k', batch_size: int=6, model_type='sd'):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def prepare_inputs(self, batch, weight_dtype, device):
for param in self.unet.parameters():
param.requires_grad = True

return latents, encoder_hidden_states
return latents, encoder_hidden_states, None


# borrow from Image2Image
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def prepare_inputs(self, batch, weight_dtype, device):
for param in self.unet.parameters():
param.requires_grad = True

return latents, encoder_hidden_states
return latents, encoder_hidden_states, None


# borrow from Image2Image
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def prepare_inputs(self, batch, weight_dtype, device):
latents = self.vae.encode(pixel_values)[0]
encoder_hidden_states = None

return latents, encoder_hidden_states
return latents, encoder_hidden_states, None


# borrow from Image2Image
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def prepare_inputs(self, batch, weight_dtype, device):
latents = latents * 0.18215
encoder_hidden_states = self.text_encoder(input_ids)[0]

return latents, encoder_hidden_states
return latents, encoder_hidden_states, None


# borrow from Image2Image
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def prepare_inputs(self, batch, weight_dtype, device):
latents = self.vae.encode(pixel_values)[0]
encoder_hidden_states = None

return latents, encoder_hidden_states
return latents, encoder_hidden_states, None


# borrow from Image2Image
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def prepare_inputs(self, batch, weight_dtype, device):
latents = latents * 0.18215
encoder_hidden_states = self.text_encoder(input_ids)[0]

return latents, encoder_hidden_states
return latents, encoder_hidden_states, None


# borrow from Image2Image
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def prepare_inputs(self, batch, weight_dtype, device):
latents = self.vae.encode(pixel_values)[0]
encoder_hidden_states = None

return latents, encoder_hidden_states
return latents, encoder_hidden_states, None


# borrow from Image2Image
Expand Down
Loading

0 comments on commit b8bc643

Please sign in to comment.