Skip to content

Commit

Permalink
update experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
caradryanl committed May 15, 2024
1 parent 4016a42 commit cba88f6
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 0 deletions.
65 changes: 65 additions & 0 deletions diffusers/stable_copyright/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,71 @@
import os
from typing import Callable, Optional, Any, Tuple, List

def test(member_scores, nonmember_scores, experiment, output_path, threshold_path):

with open(threshold_path + experiment + '_result.json', 'w') as file:
result = json.dump(file, indent=4)

best_threshold_at_1_FPR = result['best_threshold_at_1_FPR']
best_threshold_at_01_FPR = result['best_threshold_at_01_FPR']

min_score = min(member_scores.min(), nonmember_scores.min())
max_score = max(member_scores.max(), nonmember_scores.max())

TPR_list = []
FPR_list = []
threshold_list = []
output = {}

total = member_scores.size(0) + nonmember_scores.size(0)
for threshold in torch.arange(min_score, max_score, (max_score - min_score) / 10000):
acc = ((member_scores <= threshold).sum() + (nonmember_scores > threshold).sum()) / total

TP = (member_scores <= threshold).sum()
TN = (nonmember_scores > threshold).sum()
FP = (nonmember_scores <= threshold).sum()
FN = (member_scores > threshold).sum()

TPR = TP / (TP + FN)
FPR = FP / (FP + TN)

TPR_list.append(TPR.item())
FPR_list.append(FPR.item())
threshold_list.append(threshold.item())

TP = (member_scores <= best_threshold_at_1_FPR).sum()
TN = (nonmember_scores > best_threshold_at_1_FPR).sum()
FP = (nonmember_scores <= best_threshold_at_1_FPR).sum()
FN = (member_scores > best_threshold_at_1_FPR).sum()
TPR_at_1_threshold = TP / (TP + FN)
FPR_at_1_threshold = FP / (FP + TN)

TP = (member_scores <= best_threshold_at_01_FPR).sum()
TN = (nonmember_scores > best_threshold_at_01_FPR).sum()
FP = (nonmember_scores <= best_threshold_at_01_FPR).sum()
FN = (member_scores > best_threshold_at_01_FPR).sum()
TPR_at_01_threshold = TP / (TP + FN)
FPR_at_01_threshold = FP / (FP + TN)

# print(f'Score threshold = {threshold:.16f} \t ASR: {acc:.8f} \t TPR: {TPR:.8f} \t FPR: {FPR:.8f}')
auc = metrics.auc(np.asarray(FPR_list), np.asarray(TPR_list))
print(f'AUROC: {auc}')
print(f'TPR_at_1_threshold: {TPR_at_1_threshold}, FPR_at_1_threshold: {FPR_at_1_threshold}')
print(f'TPR_at_01_threshold: {TPR_at_01_threshold}, FPR_at_01_threshold: {FPR_at_01_threshold}')

output['TPR_at_1_threshold'] = TPR_at_1_threshold
output['FPR_at_1_threshold'] = FPR_at_1_threshold
output['TPR_at_01_threshold'] = TPR_at_01_threshold
output['FPR_at_1_threshold'] = FPR_at_01_threshold
output['AUROC'] = auc
output['TPR'] = TPR_list
output['FPR'] = FPR_list
output['threshold'] = threshold_list


with open(output_path + experiment + '_result_test.json', 'w') as file:
json.dump(output, file, indent=4)

def benchmark(member_scores, nonmember_scores, experiment, output_path):

min_score = min(member_scores.min(), nonmember_scores.min())
Expand Down

0 comments on commit cba88f6

Please sign in to comment.