From 8b286d865c5ab2fc50be09bfe0dc09ebbbf4e675 Mon Sep 17 00:00:00 2001 From: CaradryanLiang Date: Mon, 13 May 2024 15:11:45 -0700 Subject: [PATCH] correct ldm setups --- diffusers/stable_copyright/data_utils.py | 28 ++++++++++++++++-------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/diffusers/stable_copyright/data_utils.py b/diffusers/stable_copyright/data_utils.py index dfb1317..5b98bec 100644 --- a/diffusers/stable_copyright/data_utils.py +++ b/diffusers/stable_copyright/data_utils.py @@ -177,20 +177,30 @@ def __getitem__(self, index: int): def load_dataset(dataset_root, ckpt_path, dataset: str='laion-aesthetic-2-5k', batch_size: int=6, model_type='sd'): - resolution = 512 - transform = transforms.Compose( - [ - transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(resolution), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) + if model_type != 'ldm': + resolution = 512 + transform = transforms.Compose( + [ + transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(resolution), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) tokenizer = CLIPTokenizer.from_pretrained( ckpt_path, subfolder="tokenizer", revision=None ) else: + resolution = 256 + transform = transforms.Compose( + [ + transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR), + # transforms.CenterCrop(resolution), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) tokenizer = None train_dataset = Dataset(