Skip to content

Commit

Permalink
add pfami and drc for ldm
Browse files Browse the repository at this point in the history
  • Loading branch information
caradryanl committed May 14, 2024
1 parent 8b286d8 commit 349737b
Show file tree
Hide file tree
Showing 7 changed files with 506 additions and 59 deletions.
7 changes: 5 additions & 2 deletions diffusers/scripts/exp_ldm.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
python scripts/train_secmi.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --batch-size 32
python scripts/train_pia.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --batch-size 32
# python scripts/train_secmi.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --batch-size 32
# python scripts/train_pia.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --batch-size 32

python scripts/train_drc.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --batch-size 32
python scripts/train_pfami.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --batch-size 32
5 changes: 4 additions & 1 deletion diffusers/scripts/exp_ldm_demo.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
# python scripts/train_secmi.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --demo True
python scripts/train_pia.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --demo True
# python scripts/train_pia.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --demo True

python scripts/train_drc.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --demo True
python scripts/train_pfami.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --demo True
86 changes: 53 additions & 33 deletions diffusers/scripts/train_drc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from PIL import Image


from stable_copyright import benchmark, collate_fn, Dataset, DRCStableDiffusionInpaintPipeline
from stable_copyright import benchmark, collate_fn, Dataset, DRCStableDiffusionInpaintPipeline, DRCLatentDiffusionPipeline
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from transformers import CLIPModel, CLIPImageProcessor, CLIPTokenizer

Expand All @@ -24,19 +24,29 @@
std=[0.26862954, 0.26130258, 0.27577711]),
])

def load_dataset_drc(dataset_root, ckpt_path, dataset: str='laion-aesthetic-2-5k', batch_size: int=6):
resolution = 512
transform = transforms.Compose(
[
transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(resolution),
transforms.ToTensor(),
# transforms.Normalize([0.5], [0.5]), Do not need to normalize for inpainting
]
)
tokenizer = CLIPTokenizer.from_pretrained(
ckpt_path, subfolder="tokenizer", revision=None
)
def load_dataset_drc(dataset_root, ckpt_path, dataset: str='laion-aesthetic-2-5k', batch_size: int=6, model_type: str='sd'):
if model_type != 'ldm':
resolution = 512
transform = transforms.Compose(
[
transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(resolution),
transforms.ToTensor(),
]
)
tokenizer = CLIPTokenizer.from_pretrained(
ckpt_path, subfolder="tokenizer", revision=None
)
else:
resolution = 256
transform = transforms.Compose(
[
transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
# transforms.CenterCrop(resolution),
transforms.ToTensor(),
]
)
tokenizer = None
train_dataset = Dataset(
dataset=dataset,
img_root=dataset_root,
Expand All @@ -47,13 +57,21 @@ def load_dataset_drc(dataset_root, ckpt_path, dataset: str='laion-aesthetic-2-5k
)
return train_dataset, train_dataloader

def load_pipeline(ckpt_path, device='cuda:0'):
pipe = DRCStableDiffusionInpaintPipeline.from_pretrained(ckpt_path, torch_dtype=torch.float32)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(device)
def load_pipeline(ckpt_path, device='cuda:0', model_type='sd'):
if model_type == 'sd':
pipe = DRCStableDiffusionInpaintPipeline.from_pretrained(ckpt_path, torch_dtype=torch.float32)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(device)
elif model_type == 'ldm':
pipe = DRCLatentDiffusionPipeline.from_pretrained(ckpt_path, torch_dtype=torch.float32)
# pipe.scheduler = SecMIDDIMScheduler.from_config(pipe.scheduler.config)
elif model_type == 'sdxl':
raise NotImplementedError('SDXL not implemented yet')
else:
raise NotImplementedError(f'Unrecognized model type {model_type}')
return pipe

def get_reverse_denoise_results(pipe, dataloader, device, output_path, mem_or_nonmem):
def get_reverse_denoise_results(pipe, dataloader, device, output_path, mem_or_nonmem, demo):
model_id = "../models/diffusers/clip-vit-large-patch14"
model = CLIPModel.from_pretrained(model_id).to(device)

Expand Down Expand Up @@ -101,8 +119,8 @@ def get_reverse_denoise_results(pipe, dataloader, device, output_path, mem_or_no
mean_l2 += scores[-1]
print(f'[{batch_idx}/{len(dataloader)}] mean l2-sum: {mean_l2 / (batch_idx + 1):.8f}')

# if batch_idx > 8:
# break
if demo and batch_idx > 0:
break

return torch.stack(scores, dim=0), path_log

Expand Down Expand Up @@ -130,25 +148,25 @@ def compute_corr_score(member_scores, nonmember_scores):
def main(args):
start_time = time.time()

_, holdout_loader = load_dataset_drc(args.dataset_root, args.ckpt_path, args.holdout_dataset, args.batch_size)
_, member_loader = load_dataset_drc(args.dataset_root, args.ckpt_path, args.member_dataset, args.batch_size)
_, holdout_loader = load_dataset_drc(args.dataset_root, args.ckpt_path, args.holdout_dataset, args.batch_size, args.model_type)
_, member_loader = load_dataset_drc(args.dataset_root, args.ckpt_path, args.member_dataset, args.batch_size, args.model_type)

pipe = load_pipeline(args.ckpt_path, args.device)
pipe = load_pipeline(args.ckpt_path, args.device, args.model_type)

if not args.use_ddp:

if not os.path.exists(args.output):
os.mkdir(args.output)

member_scores, member_path_log = get_reverse_denoise_results(pipe, member_loader, args.device, args.output, 'member')
torch.save(member_scores, args.output + 'member_scores.pth')
member_scores, member_path_log = get_reverse_denoise_results(pipe, member_loader, args.device, args.output, 'member', args.demo)
torch.save(member_scores, args.output + f'drc_{args.model_type}_member_scores.pth')

nonmember_scores, nonmember_path_log = get_reverse_denoise_results(pipe, holdout_loader, args.device, args.output, 'nonmember')
torch.save(nonmember_scores, args.output + 'nonmember_scores.pth')
nonmember_scores, nonmember_path_log = get_reverse_denoise_results(pipe, holdout_loader, args.device, args.output, 'nonmember', args.demo)
torch.save(nonmember_scores, args.output + f'drc_{args.model_type}_nonmember_scores.pth')

benchmark(member_scores, nonmember_scores, 'drc_score', args.output)
benchmark(member_scores, nonmember_scores, f'drc_{args.model_type}_score', args.output)

with open(args.output + 'drc_image_log.json', 'w') as file:
with open(args.output + f'drc_{args.model_type}_image_log.json', 'w') as file:
json.dump(dict(member=member_path_log, nonmember=nonmember_path_log), file, indent=4)

else:
Expand All @@ -158,7 +176,7 @@ def main(args):
elapsed_time = end_time - start_time
running_time = dict(running_time=elapsed_time)

with open(args.output + 'drc_running_time.json', 'w') as file:
with open(args.output + f'drc_{args.model_type}_running_time.json', 'w') as file:
json.dump(running_time, file, indent=4)


Expand All @@ -174,15 +192,17 @@ def fix_seed(seed):

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--member-dataset', default='laion-aesthetic-2-5k', choices=['laion-aesthetic-2-5k'])
parser.add_argument('--holdout-dataset', default='coco2017-val-2-5k', choices=['coco2017-val-2-5k'])
parser.add_argument('--member-dataset', default='laion-aesthetic-2-5k')
parser.add_argument('--holdout-dataset', default='coco2017-val-2-5k')
parser.add_argument('--dataset-root', default='datasets/', type=str)
parser.add_argument('--seed', type=int, default=10)
parser.add_argument('--ckpt-path', type=str, default='../models/diffusers/stable-diffusion-v1-5/')
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--output', type=str, default='outputs/')
parser.add_argument('--batch-size', type=int, default=5)
parser.add_argument('--use-ddp', type=bool, default=False)
parser.add_argument('--model-type', type=str, choices=['sd', 'sdxl', 'ldm'], default='sd')
parser.add_argument('--demo', type=bool, default=False)
args = parser.parse_args()

fix_seed(args.seed)
Expand Down
52 changes: 31 additions & 21 deletions diffusers/scripts/train_pfami.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from copy import deepcopy
import time, json

from stable_copyright import PFAMIStableDiffusionPipeline, SecMIDDIMScheduler
from stable_copyright import PFAMIStableDiffusionPipeline, SecMIDDIMScheduler, PFAMILatentDiffusionPipeline
from stable_copyright import load_dataset, benchmark

def image_perturbation(image, strength, image_size=512):
Expand All @@ -22,14 +22,22 @@ def image_perturbation(image, strength, image_size=512):
])
return perturbation(image)

def load_pipeline(ckpt_path, device='cuda:0'):
pipe = PFAMIStableDiffusionPipeline.from_pretrained(ckpt_path, torch_dtype=torch.float32)
pipe.scheduler = SecMIDDIMScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(device)
def load_pipeline(ckpt_path, device='cuda:0', model_type='sd'):
if model_type == 'sd':
pipe = PFAMIStableDiffusionPipeline.from_pretrained(ckpt_path, torch_dtype=torch.float32)
pipe.scheduler = SecMIDDIMScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(device)
elif model_type == 'ldm':
pipe = PFAMILatentDiffusionPipeline.from_pretrained(ckpt_path, torch_dtype=torch.float32)
# pipe.scheduler = SecMIDDIMScheduler.from_config(pipe.scheduler.config)
elif model_type == 'sdxl':
raise NotImplementedError('SDXL not implemented yet')
else:
raise NotImplementedError(f'Unrecognized model type {model_type}')
return pipe

# difference from secmi: we return the sum of intermediate differences here
def get_reverse_denoise_results(pipe, dataloader, device, strengths):
def get_reverse_denoise_results(pipe, dataloader, device, strengths, demo):
weight_dtype = torch.float32
mean_l2 = 0
scores_sum, scores_all_steps, path_log, = [], [], []
Expand Down Expand Up @@ -92,8 +100,8 @@ def get_reverse_denoise_results(pipe, dataloader, device, strengths):
mean_l2 += scores_sum[-1].item()
print(f'[{batch_idx}/{len(dataloader)}] mean l2-sum: {mean_l2 / (batch_idx + 1):.8f}')

# if batch_idx > 0:
# break
if demo and batch_idx > 0:
break

return torch.stack(scores_sum, dim=0), torch.stack(scores_all_steps, dim=0), path_log

Expand Down Expand Up @@ -121,10 +129,10 @@ def compute_corr_score(member_scores, nonmember_scores):
def main(args):
start_time = time.time()

_, holdout_loader = load_dataset(args.dataset_root, args.ckpt_path, args.holdout_dataset, args.batch_size)
_, member_loader = load_dataset(args.dataset_root, args.ckpt_path, args.member_dataset, args.batch_size)
_, holdout_loader = load_dataset(args.dataset_root, args.ckpt_path, args.holdout_dataset, args.batch_size, args.model_type)
_, member_loader = load_dataset(args.dataset_root, args.ckpt_path, args.member_dataset, args.batch_size, args.model_type)

pipe = load_pipeline(args.ckpt_path, args.device)
pipe = load_pipeline(args.ckpt_path, args.device, args.model_type)

strengths = np.linspace(args.start_strength, args.end_strength, args.perturbation_number)

Expand All @@ -133,18 +141,18 @@ def main(args):
if not os.path.exists(args.output):
os.mkdir(args.output)

member_scores_sum_step, member_scores_all_steps, member_path_log = get_reverse_denoise_results(pipe, member_loader, args.device, strengths)
torch.save(member_scores_all_steps, args.output + 'pfami_member_scores_all_steps.pth')
member_scores_sum_step, member_scores_all_steps, member_path_log = get_reverse_denoise_results(pipe, member_loader, args.device, strengths, args.demo)
torch.save(member_scores_all_steps, args.output + f'pfami_{args.model_type}_member_scores_all_steps.pth')

nonmember_scores_sum_step, nonmember_scores_all_steps, nonmember_path_log = get_reverse_denoise_results(pipe, holdout_loader, args.device, strengths)
torch.save(nonmember_scores_all_steps, args.output + 'pfami_nonmember_scores_all_steps.pth')
nonmember_scores_sum_step, nonmember_scores_all_steps, nonmember_path_log = get_reverse_denoise_results(pipe, holdout_loader, args.device, strengths, args.demo)
torch.save(nonmember_scores_all_steps, args.output + f'pfami_{args.model_type}_nonmember_scores_all_steps.pth')

member_corr_scores, nonmember_corr_scores = compute_corr_score(member_scores_all_steps, nonmember_scores_all_steps)

benchmark(member_scores_sum_step, nonmember_scores_sum_step, 'pfami_sum_score', args.output)
benchmark(member_corr_scores, nonmember_corr_scores, 'pfami_corr_score', args.output)
benchmark(member_scores_sum_step, nonmember_scores_sum_step, f'pfami_{args.model_type}_sum_score', args.output)
benchmark(member_corr_scores, nonmember_corr_scores, f'pfami_{args.model_type}_corr_score', args.output)

with open(args.output + 'pfami_image_log.json', 'w') as file:
with open(args.output + f'pfami_{args.model_type}_image_log.json', 'w') as file:
json.dump(dict(member=member_path_log, nonmember=nonmember_path_log), file, indent=4)

else:
Expand All @@ -154,7 +162,7 @@ def main(args):
elapsed_time = end_time - start_time
running_time = dict(running_time=elapsed_time)

with open(args.output + 'pfami_running_time.json', 'w') as file:
with open(args.output + f'pfami_{args.model_type}_running_time.json', 'w') as file:
json.dump(running_time, file, indent=4)


Expand All @@ -171,8 +179,8 @@ def fix_seed(seed):

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--member-dataset', default='laion-aesthetic-2-5k', choices=['laion-aesthetic-2-5k'])
parser.add_argument('--holdout-dataset', default='coco2017-val-2-5k', choices=['coco2017-val-2-5k'])
parser.add_argument('--member-dataset', default='laion-aesthetic-2-5k')
parser.add_argument('--holdout-dataset', default='coco2017-val-2-5k')
parser.add_argument('--dataset-root', default='datasets/', type=str)
parser.add_argument('--seed', type=int, default=10)
parser.add_argument('--ckpt-path', type=str, default='../models/diffusers/stable-diffusion-v1-5/')
Expand All @@ -184,6 +192,8 @@ def fix_seed(seed):
parser.add_argument('--perturbation-number', type=int, default=10)
parser.add_argument('--start-strength', type=float, default=0.95)
parser.add_argument('--end-strength', type=float, default=0.7)
parser.add_argument('--model-type', type=str, choices=['sd', 'sdxl', 'ldm'], default='sd')
parser.add_argument('--demo', type=bool, default=False)
args = parser.parse_args()

fix_seed(args.seed)
Expand Down
4 changes: 2 additions & 2 deletions diffusers/stable_copyright/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from .pia_pipeline_sdxl import *

from .pfami_pipeline_stable_diffusion import PFAMIStableDiffusionPipeline
from .pfami_pipeline_latent_diffusion import *
from .pfami_pipeline_latent_diffusion import PFAMILatentDiffusionPipeline
from .pfami_pipeline_sdxl import *

from .drc_dino_utils import *
from .drc_dino_vision_transformer import *
from .drc_pipeline_stable_diffusion_inpaint import DRCStableDiffusionInpaintPipeline
from .drc_pipeline_latent_diffusion import *
from .drc_pipeline_latent_diffusion import DRCLatentDiffusionPipeline
from .drc_pipeline_sdxl import *
Loading

0 comments on commit 349737b

Please sign in to comment.