Skip to content

Commit

Permalink
fix fatal bugs in data_utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
caradryanl committed May 16, 2024
1 parent 50e8f15 commit 4d6e2a9
Show file tree
Hide file tree
Showing 18 changed files with 321 additions and 44 deletions.
279 changes: 279 additions & 0 deletions diffusers/dev/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
import torch
from torchvision import transforms
from transformers import CLIPTokenizer

from sklearn import metrics
import torch
import numpy as np
import json
from PIL import Image
import os
from typing import Callable, Optional, Any, Tuple, List

def test(member_scores, nonmember_scores, experiment, output_path, threshold_path):

with open(threshold_path + experiment + '_result.json', 'w') as file:
result = json.dump(file, indent=4)

best_threshold_at_1_FPR = result['best_threshold_at_1_FPR']
best_threshold_at_01_FPR = result['best_threshold_at_01_FPR']

min_score = min(member_scores.min(), nonmember_scores.min())
max_score = max(member_scores.max(), nonmember_scores.max())

TPR_list = []
FPR_list = []
threshold_list = []
output = {}

total = member_scores.size(0) + nonmember_scores.size(0)
for threshold in torch.arange(min_score, max_score, (max_score - min_score) / 10000):
acc = ((member_scores <= threshold).sum() + (nonmember_scores > threshold).sum()) / total

TP = (member_scores <= threshold).sum()
TN = (nonmember_scores > threshold).sum()
FP = (nonmember_scores <= threshold).sum()
FN = (member_scores > threshold).sum()

TPR = TP / (TP + FN)
FPR = FP / (FP + TN)

TPR_list.append(TPR.item())
FPR_list.append(FPR.item())
threshold_list.append(threshold.item())

TP = (member_scores <= best_threshold_at_1_FPR).sum()
TN = (nonmember_scores > best_threshold_at_1_FPR).sum()
FP = (nonmember_scores <= best_threshold_at_1_FPR).sum()
FN = (member_scores > best_threshold_at_1_FPR).sum()
TPR_at_1_threshold = TP / (TP + FN)
FPR_at_1_threshold = FP / (FP + TN)

TP = (member_scores <= best_threshold_at_01_FPR).sum()
TN = (nonmember_scores > best_threshold_at_01_FPR).sum()
FP = (nonmember_scores <= best_threshold_at_01_FPR).sum()
FN = (member_scores > best_threshold_at_01_FPR).sum()
TPR_at_01_threshold = TP / (TP + FN)
FPR_at_01_threshold = FP / (FP + TN)

# print(f'Score threshold = {threshold:.16f} \t ASR: {acc:.8f} \t TPR: {TPR:.8f} \t FPR: {FPR:.8f}')
auc = metrics.auc(np.asarray(FPR_list), np.asarray(TPR_list))
print(f'AUROC: {auc}')
print(f'TPR_at_1_threshold: {TPR_at_1_threshold}, FPR_at_1_threshold: {FPR_at_1_threshold}')
print(f'TPR_at_01_threshold: {TPR_at_01_threshold}, FPR_at_01_threshold: {FPR_at_01_threshold}')

output['TPR_at_1_threshold'] = TPR_at_1_threshold
output['FPR_at_1_threshold'] = FPR_at_1_threshold
output['TPR_at_01_threshold'] = TPR_at_01_threshold
output['FPR_at_1_threshold'] = FPR_at_01_threshold
output['AUROC'] = auc
output['TPR'] = TPR_list
output['FPR'] = FPR_list
output['threshold'] = threshold_list


with open(output_path + experiment + '_result_test.json', 'w') as file:
json.dump(output, file, indent=4)

def benchmark(member_scores, nonmember_scores, experiment, output_path):

min_score = min(member_scores.min(), nonmember_scores.min())
max_score = max(member_scores.max(), nonmember_scores.max())

TPR_list = []
FPR_list = []
threshold_list = []
output = {}

best_TPR_at_1_FPR, best_TPR_at_01_FPR = 0.0, 0.0
best_threshold_at_1_FPR, best_threshold_at_01_FPR = 0.0, 0.0

total = member_scores.size(0) + nonmember_scores.size(0)
for threshold in torch.arange(min_score, max_score, (max_score - min_score) / 10000):
acc = ((member_scores <= threshold).sum() + (nonmember_scores > threshold).sum()) / total

TP = (member_scores <= threshold).sum()
TN = (nonmember_scores > threshold).sum()
FP = (nonmember_scores <= threshold).sum()
FN = (member_scores > threshold).sum()

TPR = TP / (TP + FN)
FPR = FP / (FP + TN)

TPR_list.append(TPR.item())
FPR_list.append(FPR.item())
threshold_list.append(threshold.item())

if FPR <= 0.01 and TPR > best_TPR_at_1_FPR:
best_TPR_at_1_FPR = TPR.item()
best_threshold_at_1_FPR = threshold.item()
if FPR <= 0.001 and TPR > best_TPR_at_01_FPR:
best_TPR_at_01_FPR = TPR.item()
best_threshold_at_01_FPR = threshold.item()

# print(f'Score threshold = {threshold:.16f} \t ASR: {acc:.8f} \t TPR: {TPR:.8f} \t FPR: {FPR:.8f}')
auc = metrics.auc(np.asarray(FPR_list), np.asarray(TPR_list))
print(f'AUROC: {auc}')
print(f'best_TPR_at_1_FPR: {best_TPR_at_1_FPR}, best_threshold_at_1_FPR: {best_threshold_at_1_FPR}')
print(f'best_TPR_at_01_FPR: {best_TPR_at_01_FPR}, best_threshold_at_01_FPR: {best_threshold_at_01_FPR}')

output['AUROC'] = auc
output['best_TPR_at_1_FPR'] = best_TPR_at_1_FPR
output['best_threshold_at_1_FPR'] = best_threshold_at_1_FPR
output['best_TPR_at_01_FPR'] = best_TPR_at_01_FPR
output['best_threshold_at_01_FPR'] = best_threshold_at_01_FPR
output['TPR'] = TPR_list
output['FPR'] = FPR_list
output['threshold'] = threshold_list

with open(output_path + experiment + '_result.json', 'w') as file:
json.dump(output, file, indent=4)

def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
if examples[0]["input_ids"] == None:
input_ids = None
else:
input_ids = torch.stack([example["input_ids"] for example in examples])
path = [example["path"] for example in examples]
mask = [example["mask"] for example in examples]
return {"pixel_values": pixel_values, "input_ids": input_ids, "path": path, "mask": mask}

class StandardTransform:
def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None:
self.transform = transform
self.target_transform = target_transform

def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]:
if self.transform is not None:
input = self.transform(input)
if self.target_transform is not None:
target = self.target_transform(target)
return input, target

def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
lines = transform.__repr__().splitlines()
return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]

def __repr__(self) -> str:
body = [self.__class__.__name__]
if self.transform is not None:
body += self._format_transform_repr(self.transform, "Transform: ")
if self.target_transform is not None:
body += self._format_transform_repr(self.target_transform, "Target transform: ")

return "\n".join(body)


class Dataset(torch.utils.data.Dataset):

def __init__(
self,
dataset: str,
img_root: str,
transforms: Optional[Callable] = None,
tokenizer=None,
) -> None:
self.dataset = dataset
self.img_root = img_root
self.tokenizer = tokenizer
self.transforms = transforms
caption_path = os.path.join(img_root, dataset, 'caption.json')
# load list file
self.img_info = []
with open(caption_path, 'r') as json_file:
img_info = json.load(json_file)
for value in img_info.values():
self.img_info.append(value)

self._init_tokenize_captions()


def __len__(self):
return len(self.img_info)


def _init_tokenize_captions(self):
captions = []
flag = False
for metadata in self.img_info:
if len(metadata['caption'])==0:
captions.append(None)
else:
caption = metadata['caption'][0]
captions.append(caption)
flag=True

if flag == False:
self.input_ids = None
else:
inputs = self.tokenizer(
captions, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True,
return_tensors="pt"
)
self.input_ids = inputs.input_ids

def _load_input_id(self, id: int):
return self.input_ids[id] if self.input_ids is not None else None

def __getitem__(self, index: int):
img_name = self.img_info[index]['path']

# image
img_path = os.path.join(self.img_root, self.dataset, 'images', img_name)
image = Image.open(img_path).convert("RGB")

# mask
mask_path = os.path.join(self.img_root, self.dataset, 'masks', img_name[:-4]+'.npy')
mask = np.load(mask_path)

input_id = self._load_input_id(index)
if len(self.img_info[index]['caption']) == 0:
caption = None
else:
caption = self.img_info[index]['caption'][0]

if self.transforms is not None:
image, input_id = StandardTransform(self.transforms, None)(image, input_id)

# return image, target
return {"pixel_values": image, "input_ids": input_id, 'caption': caption, 'path': img_name, 'mask': mask}


def load_dataset(dataset_root, ckpt_path, dataset: str='laion-aesthetic-2-5k', batch_size: int=6, model_type='sd'):

if model_type != 'ldm':
resolution = 512
transform = transforms.Compose(
[
transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(resolution),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
tokenizer = CLIPTokenizer.from_pretrained(
ckpt_path, subfolder="tokenizer", revision=None
)
else:
resolution = 256
transform = transforms.Compose(
[
transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
# transforms.CenterCrop(resolution),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
tokenizer = None

train_dataset = Dataset(
dataset=dataset,
img_root=dataset_root,
transforms=transform, tokenizer=tokenizer)

train_dataloader = torch.utils.data.DataLoader(
train_dataset, shuffle=False, collate_fn=collate_fn, batch_size=batch_size
)
return train_dataset, train_dataloader
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
10 changes: 3 additions & 7 deletions diffusers/scripts/exp_ldm.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
python scripts/train_drc.py --model-type sd --batch-size 20

# python scripts/train_secmi.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --batch-size 32
# python scripts/train_pia.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --batch-size 32

python scripts/train_drc.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --batch-size 32
# python scripts/train_pfami.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --batch-size 32
python scripts/train_secmi.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k
python scripts/train_pia.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k
python scripts/train_pfami.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k


8 changes: 3 additions & 5 deletions diffusers/scripts/exp_ldm_demo.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# python scripts/train_secmi.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --demo True
# python scripts/train_pia.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --demo True

python scripts/train_drc.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --demo True
# python scripts/train_pfami.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --demo True
python scripts/train_secmi.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --demo True
python scripts/train_pia.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --demo True
python scripts/train_pfami.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --demo True
6 changes: 6 additions & 0 deletions diffusers/scripts/exp_ldm_sd_eval.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
python scripts/train_secmi.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --eval True
python scripts/train_pia.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --eval True
python scripts/train_pfami.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --eval True
python scripts/train_secmi.py --model-type sd --ckpt-path ../models/diffusers/stable-diffusion-v1-5/ --eval True
python scripts/train_pia.py --model-type sd --ckpt-path ../models/diffusers/stable-diffusion-v1-5/ --eval True
python scripts/train_pfami.py --model-type sd --ckpt-path ../models/diffusers/stable-diffusion-v1-5/ --eval True
6 changes: 6 additions & 0 deletions diffusers/scripts/exp_ldm_sd_eval_demo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
python scripts/train_secmi.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --eval True --demo True
python scripts/train_pia.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --eval True --demo True
python scripts/train_pfami.py --model-type ldm --ckpt-path ../models/diffusers/ldm-celebahq-256/ --member-dataset celeba-hq-2-5k --holdout-dataset ffhq-2-5k --eval True --demo True
python scripts/train_secmi.py --model-type sd --ckpt-path ../models/diffusers/stable-diffusion-v1-5/ --eval True --demo True
python scripts/train_pia.py --model-type sd --ckpt-path ../models/diffusers/stable-diffusion-v1-5/ --eval True --demo True
python scripts/train_pfami.py --model-type sd --ckpt-path ../models/diffusers/stable-diffusion-v1-5/ --eval True --demo True
9 changes: 5 additions & 4 deletions diffusers/scripts/train_pfami.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,13 @@ def main(args):
os.mkdir(args.output)

member_scores_sum_step, member_scores_all_steps, member_path_log = get_reverse_denoise_results(pipe, member_loader, args.device, strengths, args.demo)
torch.save(member_scores_all_steps, args.output + f'pfami_{args.model_type}_member_scores_all_steps.pth')

nonmember_scores_sum_step, nonmember_scores_all_steps, nonmember_path_log = get_reverse_denoise_results(pipe, holdout_loader, args.device, strengths, args.demo)
torch.save(nonmember_scores_all_steps, args.output + f'pfami_{args.model_type}_nonmember_scores_all_steps.pth')

member_corr_scores, nonmember_corr_scores = compute_corr_score(member_scores_all_steps, nonmember_scores_all_steps)

if not args.eval:
torch.save(member_scores_all_steps, args.output + f'pfami_{args.model_type}_member_scores_all_steps.pth')
torch.save(nonmember_scores_all_steps, args.output + f'pfami_{args.model_type}_nonmember_scores_all_steps.pth')

benchmark(member_scores_sum_step, nonmember_scores_sum_step, f'pfami_{args.model_type}_sum_score', args.output)
benchmark(member_corr_scores, nonmember_corr_scores, f'pfami_{args.model_type}_corr_score', args.output)

Expand All @@ -162,6 +161,8 @@ def main(args):
with open(args.output + f'pfami_{args.model_type}_running_time.json', 'w') as file:
json.dump(running_time, file, indent=4)
else:
torch.save(member_scores_all_steps, args.output + f'pfami_{args.model_type}_member_scores_all_steps_test.pth')
torch.save(nonmember_scores_all_steps, args.output + f'pfami_{args.model_type}_nonmember_scores_all_steps_test.pth')
threshold_path = args.threshold_root + f'{args.model_type}/pfami/'

test(member_scores_sum_step, member_scores_sum_step, f'pfami_{args.model_type}_sum_score', args.output, threshold_path)
Expand Down
14 changes: 7 additions & 7 deletions diffusers/scripts/train_pia.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,14 @@ def main(args):
os.mkdir(args.output)

member_scores_sum_step, member_scores_all_steps, member_path_log = get_reverse_denoise_results(pipe, member_loader, args.device, args.normalized, args.demo)
torch.save(member_scores_all_steps, args.output + f'{pia_or_pian}_{args.model_type}_member_scores_all_steps.pth')

nonmember_scores_sum_step, nonmember_scores_all_steps, nonmember_path_log = get_reverse_denoise_results(pipe, holdout_loader, args.device, args.normalized, args.demo)
torch.save(nonmember_scores_all_steps, args.output + f'{pia_or_pian}_{args.model_type}_nonmember_scores_all_steps.pth')

member_corr_scores, nonmember_corr_scores = compute_corr_score(member_scores_all_steps, nonmember_scores_all_steps)


if not args.eval:
torch.save(member_scores_all_steps, args.output + f'{pia_or_pian}_{args.model_type}_member_scores_all_steps.pth')
torch.save(nonmember_scores_all_steps, args.output + f'{pia_or_pian}_{args.model_type}_nonmember_scores_all_steps.pth')

benchmark(member_scores_sum_step, nonmember_scores_sum_step, f'{pia_or_pian}_{args.model_type}_sum_score', args.output)
benchmark(member_corr_scores, nonmember_corr_scores, f'{pia_or_pian}_{args.model_type}_corr_score', args.output)

Expand All @@ -124,8 +123,9 @@ def main(args):
with open(args.output + f'{pia_or_pian}_{args.model_type}_running_time.json', 'w') as file:
json.dump(running_time, file, indent=4)
else:
torch.save(member_scores_all_steps, args.output + f'{pia_or_pian}_{args.model_type}_member_scores_all_steps_test.pth')
torch.save(nonmember_scores_all_steps, args.output + f'{pia_or_pian}_{args.model_type}_nonmember_scores_all_steps_test.pth')
threshold_path = args.threshold_root + f'{args.model_type}/{pia_or_pian}/'

test(member_scores_sum_step, member_scores_sum_step, f'{pia_or_pian}_{args.model_type}_sum_score', args.output, threshold_path)
test(member_corr_scores, nonmember_corr_scores, f'{pia_or_pian}_{args.model_type}_corr_score', args.output, threshold_path)

Expand Down Expand Up @@ -154,8 +154,8 @@ def fix_seed(seed):
parser.add_argument('--holdout-dataset', default='coco2017-val-2-5k')
parser.add_argument('--dataset-root', default='datasets/', type=str)
parser.add_argument('--seed', type=int, default=10)
# parser.add_argument('--ckpt-path', type=str, default='../models/diffusers/stable-diffusion-v1-5/')
parser.add_argument('--ckpt-path', type=str, default='../models/diffusers/ldm-celebahq-256/')
parser.add_argument('--ckpt-path', type=str, default='../models/diffusers/stable-diffusion-v1-5/')
# parser.add_argument('--ckpt-path', type=str, default='../models/diffusers/ldm-celebahq-256/')
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--output', type=str, default='outputs/')
parser.add_argument('--batch-size', type=int, default=10)
Expand Down
Loading

0 comments on commit 4d6e2a9

Please sign in to comment.