diff --git a/diffusers/scripts/exp_ldm.sh b/diffusers/scripts/exp_ldm.sh index 1a8d81b..7dcb896 100755 --- a/diffusers/scripts/exp_ldm.sh +++ b/diffusers/scripts/exp_ldm.sh @@ -1,2 +1,5 @@ -python scripts/train_secmi.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --batch-size 32 -python scripts/train_pia.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --batch-size 32 \ No newline at end of file +# python scripts/train_secmi.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --batch-size 32 +# python scripts/train_pia.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --batch-size 32 + +python scripts/train_drc.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --batch-size 32 +python scripts/train_pfami.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --batch-size 32 \ No newline at end of file diff --git a/diffusers/scripts/exp_ldm_demo.sh b/diffusers/scripts/exp_ldm_demo.sh index 53e9bb4..9ea80e1 100755 --- a/diffusers/scripts/exp_ldm_demo.sh +++ b/diffusers/scripts/exp_ldm_demo.sh @@ -1,2 +1,5 @@ # python scripts/train_secmi.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --demo True -python scripts/train_pia.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --demo True \ No newline at end of file +# python scripts/train_pia.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --demo True + +python scripts/train_drc.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --demo True +python scripts/train_pfami.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --demo True \ No newline at end of file diff --git a/diffusers/scripts/train_drc.py b/diffusers/scripts/train_drc.py index 2c47ab4..6f96b7c 100644 --- a/diffusers/scripts/train_drc.py +++ b/diffusers/scripts/train_drc.py @@ -12,7 +12,7 @@ from PIL import Image -from stable_copyright import benchmark, collate_fn, Dataset, DRCStableDiffusionInpaintPipeline +from stable_copyright import benchmark, collate_fn, Dataset, DRCStableDiffusionInpaintPipeline, DRCLatentDiffusionPipeline from diffusers.schedulers.scheduling_ddim import DDIMScheduler from transformers import CLIPModel, CLIPImageProcessor, CLIPTokenizer @@ -24,19 +24,29 @@ std=[0.26862954, 0.26130258, 0.27577711]), ]) -def load_dataset_drc(dataset_root, ckpt_path, dataset: str='laion-aesthetic-2-5k', batch_size: int=6): - resolution = 512 - transform = transforms.Compose( - [ - transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(resolution), - transforms.ToTensor(), - # transforms.Normalize([0.5], [0.5]), Do not need to normalize for inpainting - ] - ) - tokenizer = CLIPTokenizer.from_pretrained( - ckpt_path, subfolder="tokenizer", revision=None - ) +def load_dataset_drc(dataset_root, ckpt_path, dataset: str='laion-aesthetic-2-5k', batch_size: int=6, model_type: str='sd'): + if model_type != 'ldm': + resolution = 512 + transform = transforms.Compose( + [ + transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(resolution), + transforms.ToTensor(), + ] + ) + tokenizer = CLIPTokenizer.from_pretrained( + ckpt_path, subfolder="tokenizer", revision=None + ) + else: + resolution = 256 + transform = transforms.Compose( + [ + transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR), + # transforms.CenterCrop(resolution), + transforms.ToTensor(), + ] + ) + tokenizer = None train_dataset = Dataset( dataset=dataset, img_root=dataset_root, @@ -47,13 +57,21 @@ def load_dataset_drc(dataset_root, ckpt_path, dataset: str='laion-aesthetic-2-5k ) return train_dataset, train_dataloader -def load_pipeline(ckpt_path, device='cuda:0'): - pipe = DRCStableDiffusionInpaintPipeline.from_pretrained(ckpt_path, torch_dtype=torch.float32) - pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) - pipe = pipe.to(device) +def load_pipeline(ckpt_path, device='cuda:0', model_type='sd'): + if model_type == 'sd': + pipe = DRCStableDiffusionInpaintPipeline.from_pretrained(ckpt_path, torch_dtype=torch.float32) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe = pipe.to(device) + elif model_type == 'ldm': + pipe = DRCLatentDiffusionPipeline.from_pretrained(ckpt_path, torch_dtype=torch.float32) + # pipe.scheduler = SecMIDDIMScheduler.from_config(pipe.scheduler.config) + elif model_type == 'sdxl': + raise NotImplementedError('SDXL not implemented yet') + else: + raise NotImplementedError(f'Unrecognized model type {model_type}') return pipe -def get_reverse_denoise_results(pipe, dataloader, device, output_path, mem_or_nonmem): +def get_reverse_denoise_results(pipe, dataloader, device, output_path, mem_or_nonmem, demo): model_id = "../models/diffusers/clip-vit-large-patch14" model = CLIPModel.from_pretrained(model_id).to(device) @@ -101,8 +119,8 @@ def get_reverse_denoise_results(pipe, dataloader, device, output_path, mem_or_no mean_l2 += scores[-1] print(f'[{batch_idx}/{len(dataloader)}] mean l2-sum: {mean_l2 / (batch_idx + 1):.8f}') - # if batch_idx > 8: - # break + if demo and batch_idx > 0: + break return torch.stack(scores, dim=0), path_log @@ -130,25 +148,25 @@ def compute_corr_score(member_scores, nonmember_scores): def main(args): start_time = time.time() - _, holdout_loader = load_dataset_drc(args.dataset_root, args.ckpt_path, args.holdout_dataset, args.batch_size) - _, member_loader = load_dataset_drc(args.dataset_root, args.ckpt_path, args.member_dataset, args.batch_size) + _, holdout_loader = load_dataset_drc(args.dataset_root, args.ckpt_path, args.holdout_dataset, args.batch_size, args.model_type) + _, member_loader = load_dataset_drc(args.dataset_root, args.ckpt_path, args.member_dataset, args.batch_size, args.model_type) - pipe = load_pipeline(args.ckpt_path, args.device) + pipe = load_pipeline(args.ckpt_path, args.device, args.model_type) if not args.use_ddp: if not os.path.exists(args.output): os.mkdir(args.output) - member_scores, member_path_log = get_reverse_denoise_results(pipe, member_loader, args.device, args.output, 'member') - torch.save(member_scores, args.output + 'member_scores.pth') + member_scores, member_path_log = get_reverse_denoise_results(pipe, member_loader, args.device, args.output, 'member', args.demo) + torch.save(member_scores, args.output + f'drc_{args.model_type}_member_scores.pth') - nonmember_scores, nonmember_path_log = get_reverse_denoise_results(pipe, holdout_loader, args.device, args.output, 'nonmember') - torch.save(nonmember_scores, args.output + 'nonmember_scores.pth') + nonmember_scores, nonmember_path_log = get_reverse_denoise_results(pipe, holdout_loader, args.device, args.output, 'nonmember', args.demo) + torch.save(nonmember_scores, args.output + f'drc_{args.model_type}_nonmember_scores.pth') - benchmark(member_scores, nonmember_scores, 'drc_score', args.output) + benchmark(member_scores, nonmember_scores, f'drc_{args.model_type}_score', args.output) - with open(args.output + 'drc_image_log.json', 'w') as file: + with open(args.output + f'drc_{args.model_type}_image_log.json', 'w') as file: json.dump(dict(member=member_path_log, nonmember=nonmember_path_log), file, indent=4) else: @@ -158,7 +176,7 @@ def main(args): elapsed_time = end_time - start_time running_time = dict(running_time=elapsed_time) - with open(args.output + 'drc_running_time.json', 'w') as file: + with open(args.output + f'drc_{args.model_type}_running_time.json', 'w') as file: json.dump(running_time, file, indent=4) @@ -174,8 +192,8 @@ def fix_seed(seed): if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--member-dataset', default='laion-aesthetic-2-5k', choices=['laion-aesthetic-2-5k']) - parser.add_argument('--holdout-dataset', default='coco2017-val-2-5k', choices=['coco2017-val-2-5k']) + parser.add_argument('--member-dataset', default='laion-aesthetic-2-5k') + parser.add_argument('--holdout-dataset', default='coco2017-val-2-5k') parser.add_argument('--dataset-root', default='datasets/', type=str) parser.add_argument('--seed', type=int, default=10) parser.add_argument('--ckpt-path', type=str, default='../models/diffusers/stable-diffusion-v1-5/') @@ -183,6 +201,8 @@ def fix_seed(seed): parser.add_argument('--output', type=str, default='outputs/') parser.add_argument('--batch-size', type=int, default=5) parser.add_argument('--use-ddp', type=bool, default=False) + parser.add_argument('--model-type', type=str, choices=['sd', 'sdxl', 'ldm'], default='sd') + parser.add_argument('--demo', type=bool, default=False) args = parser.parse_args() fix_seed(args.seed) diff --git a/diffusers/scripts/train_pfami.py b/diffusers/scripts/train_pfami.py index 15c4608..7e33e22 100644 --- a/diffusers/scripts/train_pfami.py +++ b/diffusers/scripts/train_pfami.py @@ -12,7 +12,7 @@ from copy import deepcopy import time, json -from stable_copyright import PFAMIStableDiffusionPipeline, SecMIDDIMScheduler +from stable_copyright import PFAMIStableDiffusionPipeline, SecMIDDIMScheduler, PFAMILatentDiffusionPipeline from stable_copyright import load_dataset, benchmark def image_perturbation(image, strength, image_size=512): @@ -22,14 +22,22 @@ def image_perturbation(image, strength, image_size=512): ]) return perturbation(image) -def load_pipeline(ckpt_path, device='cuda:0'): - pipe = PFAMIStableDiffusionPipeline.from_pretrained(ckpt_path, torch_dtype=torch.float32) - pipe.scheduler = SecMIDDIMScheduler.from_config(pipe.scheduler.config) - pipe = pipe.to(device) +def load_pipeline(ckpt_path, device='cuda:0', model_type='sd'): + if model_type == 'sd': + pipe = PFAMIStableDiffusionPipeline.from_pretrained(ckpt_path, torch_dtype=torch.float32) + pipe.scheduler = SecMIDDIMScheduler.from_config(pipe.scheduler.config) + pipe = pipe.to(device) + elif model_type == 'ldm': + pipe = PFAMILatentDiffusionPipeline.from_pretrained(ckpt_path, torch_dtype=torch.float32) + # pipe.scheduler = SecMIDDIMScheduler.from_config(pipe.scheduler.config) + elif model_type == 'sdxl': + raise NotImplementedError('SDXL not implemented yet') + else: + raise NotImplementedError(f'Unrecognized model type {model_type}') return pipe # difference from secmi: we return the sum of intermediate differences here -def get_reverse_denoise_results(pipe, dataloader, device, strengths): +def get_reverse_denoise_results(pipe, dataloader, device, strengths, demo): weight_dtype = torch.float32 mean_l2 = 0 scores_sum, scores_all_steps, path_log, = [], [], [] @@ -92,8 +100,8 @@ def get_reverse_denoise_results(pipe, dataloader, device, strengths): mean_l2 += scores_sum[-1].item() print(f'[{batch_idx}/{len(dataloader)}] mean l2-sum: {mean_l2 / (batch_idx + 1):.8f}') - # if batch_idx > 0: - # break + if demo and batch_idx > 0: + break return torch.stack(scores_sum, dim=0), torch.stack(scores_all_steps, dim=0), path_log @@ -121,10 +129,10 @@ def compute_corr_score(member_scores, nonmember_scores): def main(args): start_time = time.time() - _, holdout_loader = load_dataset(args.dataset_root, args.ckpt_path, args.holdout_dataset, args.batch_size) - _, member_loader = load_dataset(args.dataset_root, args.ckpt_path, args.member_dataset, args.batch_size) + _, holdout_loader = load_dataset(args.dataset_root, args.ckpt_path, args.holdout_dataset, args.batch_size, args.model_type) + _, member_loader = load_dataset(args.dataset_root, args.ckpt_path, args.member_dataset, args.batch_size, args.model_type) - pipe = load_pipeline(args.ckpt_path, args.device) + pipe = load_pipeline(args.ckpt_path, args.device, args.model_type) strengths = np.linspace(args.start_strength, args.end_strength, args.perturbation_number) @@ -133,18 +141,18 @@ def main(args): if not os.path.exists(args.output): os.mkdir(args.output) - member_scores_sum_step, member_scores_all_steps, member_path_log = get_reverse_denoise_results(pipe, member_loader, args.device, strengths) - torch.save(member_scores_all_steps, args.output + 'pfami_member_scores_all_steps.pth') + member_scores_sum_step, member_scores_all_steps, member_path_log = get_reverse_denoise_results(pipe, member_loader, args.device, strengths, args.demo) + torch.save(member_scores_all_steps, args.output + f'pfami_{args.model_type}_member_scores_all_steps.pth') - nonmember_scores_sum_step, nonmember_scores_all_steps, nonmember_path_log = get_reverse_denoise_results(pipe, holdout_loader, args.device, strengths) - torch.save(nonmember_scores_all_steps, args.output + 'pfami_nonmember_scores_all_steps.pth') + nonmember_scores_sum_step, nonmember_scores_all_steps, nonmember_path_log = get_reverse_denoise_results(pipe, holdout_loader, args.device, strengths, args.demo) + torch.save(nonmember_scores_all_steps, args.output + f'pfami_{args.model_type}_nonmember_scores_all_steps.pth') member_corr_scores, nonmember_corr_scores = compute_corr_score(member_scores_all_steps, nonmember_scores_all_steps) - benchmark(member_scores_sum_step, nonmember_scores_sum_step, 'pfami_sum_score', args.output) - benchmark(member_corr_scores, nonmember_corr_scores, 'pfami_corr_score', args.output) + benchmark(member_scores_sum_step, nonmember_scores_sum_step, f'pfami_{args.model_type}_sum_score', args.output) + benchmark(member_corr_scores, nonmember_corr_scores, f'pfami_{args.model_type}_corr_score', args.output) - with open(args.output + 'pfami_image_log.json', 'w') as file: + with open(args.output + f'pfami_{args.model_type}_image_log.json', 'w') as file: json.dump(dict(member=member_path_log, nonmember=nonmember_path_log), file, indent=4) else: @@ -154,7 +162,7 @@ def main(args): elapsed_time = end_time - start_time running_time = dict(running_time=elapsed_time) - with open(args.output + 'pfami_running_time.json', 'w') as file: + with open(args.output + f'pfami_{args.model_type}_running_time.json', 'w') as file: json.dump(running_time, file, indent=4) @@ -171,8 +179,8 @@ def fix_seed(seed): if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--member-dataset', default='laion-aesthetic-2-5k', choices=['laion-aesthetic-2-5k']) - parser.add_argument('--holdout-dataset', default='coco2017-val-2-5k', choices=['coco2017-val-2-5k']) + parser.add_argument('--member-dataset', default='laion-aesthetic-2-5k') + parser.add_argument('--holdout-dataset', default='coco2017-val-2-5k') parser.add_argument('--dataset-root', default='datasets/', type=str) parser.add_argument('--seed', type=int, default=10) parser.add_argument('--ckpt-path', type=str, default='../models/diffusers/stable-diffusion-v1-5/') @@ -184,6 +192,8 @@ def fix_seed(seed): parser.add_argument('--perturbation-number', type=int, default=10) parser.add_argument('--start-strength', type=float, default=0.95) parser.add_argument('--end-strength', type=float, default=0.7) + parser.add_argument('--model-type', type=str, choices=['sd', 'sdxl', 'ldm'], default='sd') + parser.add_argument('--demo', type=bool, default=False) args = parser.parse_args() fix_seed(args.seed) diff --git a/diffusers/stable_copyright/__init__.py b/diffusers/stable_copyright/__init__.py index 2e1b57b..3724925 100644 --- a/diffusers/stable_copyright/__init__.py +++ b/diffusers/stable_copyright/__init__.py @@ -9,11 +9,11 @@ from .pia_pipeline_sdxl import * from .pfami_pipeline_stable_diffusion import PFAMIStableDiffusionPipeline -from .pfami_pipeline_latent_diffusion import * +from .pfami_pipeline_latent_diffusion import PFAMILatentDiffusionPipeline from .pfami_pipeline_sdxl import * from .drc_dino_utils import * from .drc_dino_vision_transformer import * from .drc_pipeline_stable_diffusion_inpaint import DRCStableDiffusionInpaintPipeline -from .drc_pipeline_latent_diffusion import * +from .drc_pipeline_latent_diffusion import DRCLatentDiffusionPipeline from .drc_pipeline_sdxl import * \ No newline at end of file diff --git a/diffusers/stable_copyright/drc_pipeline_latent_diffusion.py b/diffusers/stable_copyright/drc_pipeline_latent_diffusion.py index e69de29..b7e1ba1 100644 --- a/diffusers/stable_copyright/drc_pipeline_latent_diffusion.py +++ b/diffusers/stable_copyright/drc_pipeline_latent_diffusion.py @@ -0,0 +1,244 @@ +import os +import torch +import numpy as np +import PIL.Image + +from typing import Any, Callable, Dict, List, Optional, Union +from dataclasses import dataclass +from diffusers import UNet2DModel, DDIMScheduler, VQModel + +from diffusers.pipelines.stable_diffusion.pipeline_output import BaseOutput +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import \ + ( + StableDiffusionInpaintPipeline, + PipelineImageInput, + deprecate, + retrieve_timesteps, + randn_tensor, + VaeImageProcessor, + DiffusionPipeline + ) + + +from .secmi_pipeline_stable_diffusion import SecMIStableDiffusionPipelineOutput +from .secmi_scheduling_ddim import SecMIDDIMScheduler + +class DRCLatentDiffusionPipeline( + DiffusionPipeline +): + def __init__(self, unet, vae, scheduler, device, generator): + + self.unet = unet + self.vae = vae + self.scheduler = scheduler + self.generator = generator + self.execution_device = device + + self.vae_scale_factor = 2 ** (len(self.vae.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + @classmethod + def from_pretrained(self, + pretrained_model_name_or_path: Union[str, os.PathLike]="CompVis/ldm-celebahq-256", + torch_dtype: torch.dtype=torch.float32, + ): + unet = UNet2DModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", torch_dtype=torch_dtype) + vae = VQModel.from_pretrained(pretrained_model_name_or_path, subfolder="vqvae", torch_dtype=torch_dtype) + scheduler = SecMIDDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") + + # cuda and seed + device = "cuda" if torch.cuda.is_available() else "cpu" + seed = 2024 + unet.to(device) + vae.to(device) + + # generate gaussian noise to be decoded + generator = torch.manual_seed(seed) + scheduler.set_timesteps(num_inference_steps=50) + + return DRCLatentDiffusionPipeline(unet=unet, vae=vae, scheduler=scheduler, device=device, generator=generator) + + @torch.no_grad() + def prepare_inputs(self, batch, weight_dtype, device): + pixel_values = batch["pixel_values"].to(weight_dtype) + if device == 'cuda': + pixel_values = pixel_values.cuda() + latents = pixel_values + encoder_hidden_states = None + + masks = [] + for mask in batch["mask"]: + masks.append(torch.tensor(mask)) + masks = torch.stack(masks, dim=0).cuda() + + return latents, encoder_hidden_states, masks + + # borrow from Image2Image + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] # order=1 + # [601, 581, ..., 21, 1] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def prepare_latents( + self, + image, + dtype, + device, + generator, + timestep=None, + is_strength_max=True, + return_noise=False, + return_image_latents=False, + ): + if return_image_latents: + image = image.to(device=device, dtype=dtype) + image_latents = self.vae.encode(image)[0] + + noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=dtype) + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + + outputs = (latents,) + if return_noise: + outputs += (noise,) + if return_image_latents: + outputs += (image_latents,) + + return outputs + + def prepare_mask_latents( + self, mask, masked_image, height, width, dtype, device, generator + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = self.vae.encode(masked_image)[0] + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + + @torch.no_grad() + def __call__( + self, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attack_timesteps: List[int] = [0, 50, 100, 150, 200, 250, 300, 350, 400, 450], + normalized: bool=False, + prompt: Optional[Union[str, List[str]]] = None, + guidance_scale: float = 7.5, + prompt_embeds: Optional[torch.FloatTensor] = None, + **kwargs, + ): + + device = self.execution_device + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + attack_timesteps = [torch.tensor(attack_timestep).to(device=device) for attack_timestep in attack_timesteps] + + crops_coords = None + resize_mode = "default" + original_image = image.detach().clone() + init_image = self.image_processor.preprocess( + image, height=256, width=256, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 6. Prepare latent variables + latents_outputs = self.prepare_latents( + image, + torch.float32, + device, + generator, + is_strength_max=True, + return_noise=True, + return_image_latents=True, + ) + latents, noise, image_latents = latents_outputs + + # 7. Prepare mask latent variables + mask_condition = self.mask_processor.preprocess( + mask_image, height=256, width=256, resize_mode=resize_mode, crops_coords=crops_coords + ) + masked_image = init_image * (mask_condition < 0.5) + + + mask_condition = mask_condition.to(dtype=masked_image.dtype) + mask, _ = self.prepare_mask_latents( + mask_condition, + masked_image, + 256, + 256, + torch.float32, + device, + generator, + ) + + # 7. Denoising loop + denoising_results = [] + unit_t = timesteps[0] - timesteps[1] + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = latents + with torch.no_grad(): + noise_pred = self.unet( + latent_model_input, + t, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + denoising_results.append(noise_pred.detach().clone()) + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + init_latents_proper = image_latents + init_mask = mask + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + # print(f"{timesteps[i]} timestep denoising: {torch.sum(latents)}") + + if i == len(timesteps) - 1 or ((i + 1) > 0 and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + with torch.no_grad(): + image = self.vae.decode(latents)[0] + else: + image = latents + + if not return_dict: + return (image,) + + return SecMIStableDiffusionPipelineOutput(images=image, posterior_results=None, denoising_results=None) \ No newline at end of file diff --git a/diffusers/stable_copyright/pfami_pipeline_latent_diffusion.py b/diffusers/stable_copyright/pfami_pipeline_latent_diffusion.py index e69de29..4bf4bed 100644 --- a/diffusers/stable_copyright/pfami_pipeline_latent_diffusion.py +++ b/diffusers/stable_copyright/pfami_pipeline_latent_diffusion.py @@ -0,0 +1,167 @@ +import os +import torch +import numpy as np +import PIL.Image + +from typing import Any, Callable, Dict, List, Optional, Union +from dataclasses import dataclass +from diffusers import UNet2DModel, DDIMScheduler, VQModel + +from diffusers.pipelines.stable_diffusion.pipeline_output import BaseOutput +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import \ + ( + replace_example_docstring, + EXAMPLE_DOC_STRING, + PipelineImageInput, + deprecate, + retrieve_timesteps, + randn_tensor, + DiffusionPipeline + ) + +from .secmi_pipeline_stable_diffusion import SecMIStableDiffusionPipelineOutput +from .secmi_scheduling_ddim import SecMIDDIMScheduler + +class PFAMILatentDiffusionPipeline( + DiffusionPipeline +): + def __init__(self, unet, vae, scheduler, device, generator): + super().__init__() + + self.unet = unet + self.vae = vae + self.scheduler = scheduler + self.generator = generator + self.execution_device = device + + @classmethod + def from_pretrained(self, + pretrained_model_name_or_path: Union[str, os.PathLike]="CompVis/ldm-celebahq-256", + torch_dtype: torch.dtype=torch.float32, + ): + unet = UNet2DModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", torch_dtype=torch_dtype) + vae = VQModel.from_pretrained(pretrained_model_name_or_path, subfolder="vqvae", torch_dtype=torch_dtype) + scheduler = SecMIDDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") + + # cuda and seed + device = "cuda" if torch.cuda.is_available() else "cpu" + seed = 2024 + unet.to(device) + vae.to(device) + + # generate gaussian noise to be decoded + generator = torch.manual_seed(seed) + scheduler.set_timesteps(num_inference_steps=50) + + return PFAMILatentDiffusionPipeline(unet=unet, vae=vae, scheduler=scheduler, device=device, generator=generator) + + @torch.no_grad() + def prepare_inputs(self, batch, weight_dtype, device): + pixel_values = batch["pixel_values"].to(weight_dtype) + if device == 'cuda': + pixel_values = pixel_values.cuda() + latents = self.vae.encode(pixel_values)[0] + encoder_hidden_states = None + + return latents, encoder_hidden_states + + + # borrow from Image2Image + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] # order=1 + # [601, 581, ..., 21, 1] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + @torch.no_grad() + def __call__( + self, + latents: torch.FloatTensor, + num_inference_steps: int = 50, + timesteps: List[int] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attack_timesteps: List[int] = [0, 50, 100, 150, 200, 250, 300, 350, 400, 450], + normalized: bool=False, + prompt: Optional[Union[str, List[str]]] = None, + guidance_scale: float = 7.5, + prompt_embeds: Optional[torch.FloatTensor] = None, + **kwargs, + ): + + device = self.execution_device + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + attack_timesteps = [torch.tensor(attack_timestep).to(device=device) for attack_timestep in attack_timesteps] + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + # 6.1 Add image embeds for IP-Adapter + # 6.2 Optionally get Guidance Scale Embedding + + # get the intermediate at t in attack_timesteps [x_201, x_181, ..., x_1] + # print(timesteps) + original_latents = latents.detach().clone() + posterior_results = [] + for i, t in enumerate(attack_timesteps): # from t_max to t_min + noise = randn_tensor(original_latents.shape, generator=generator, device=device, dtype=original_latents.dtype) + posterior_results.append(noise) + # print(f"{t} timestep posterior: {torch.sum(posterior_latents)}") + + + # 7. Denoising loop + # num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + denoising_results = [] + unit_t = attack_timesteps[1] - attack_timesteps[0] + with self.progress_bar(total=len(attack_timesteps)) as progress_bar: + for i, t in enumerate(attack_timesteps): + noise = posterior_results[i] + t = t + unit_t + # expand the latents if we are doing classifier free guidance + # latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = original_latents.detach().clone() + latent_model_input = self.scheduler.add_noise(latent_model_input, noise, t) + + # predict the noise residual + with torch.no_grad(): + noise_pred = self.unet( + latent_model_input, + t, + return_dict=False, + )[0] + + # no classifier free guidance + # # perform guidance + # if self.do_classifier_free_guidance: + # noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + # noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + # noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + denoising_results.append(noise_pred.detach().clone()) + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + # print(f"{timesteps[i]} timestep denoising: {torch.sum(latents)}") + + if i == len(timesteps) - 1 or ((i + 1) > 0 and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + with torch.no_grad(): + image = self.vae.decode(latents)[0] + else: + image = latents + + if not return_dict: + return (image,) + + return SecMIStableDiffusionPipelineOutput(images=image, posterior_results=posterior_results, denoising_results=denoising_results) \ No newline at end of file