Skip to content

Commit

Permalink
works + work on infusion
Browse files Browse the repository at this point in the history
  • Loading branch information
RolandBERTINJOHANNET committed Sep 5, 2024
1 parent d150119 commit ce888d7
Show file tree
Hide file tree
Showing 10 changed files with 713 additions and 192 deletions.
21 changes: 6 additions & 15 deletions examples/imgnet_example/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,16 @@

# Paths to data
IMAGE_LATENTS_PATH_TRAIN = (
"/home/rbertin/pyt_scripts/full_imgnet/full_size/"
"vae_full_withbigger_disc/384_combined_standardized_embeddings.npy"
)
"/home/rbertin/pyt_scripts/full_imgnet/full_size/vae_full_withbigger_disc/384_combined_standardized_embeddings.npy"
)
IMAGE_LATENTS_PATH_VAL = (
"/home/rbertin/pyt_scripts/full_imgnet/full_size/"
"vae_full_withbigger_disc/384_val_combined_standardized_embeddings.npy"
"/home/rbertin/pyt_scripts/full_imgnet/full_size/vae_full_withbigger_disc/384_val_combined_standardized_embeddings.npy"
)
CAPTION_EMBEDDINGS_PATH_TRAIN = (
"/home/rbertin/cleaned/git_synced/shimmer/examples/imgnet_example/"
"bge_fullsized_captions_norm.npy"
"/home/rbertin/pyt_scripts/BLIP_TEST/gemma/gemma_norm_bge_captions_train.npy"
)
CAPTION_EMBEDDINGS_PATH_VAL = (
"/home/rbertin/cleaned/git_synced/shimmer/examples/imgnet_example/"
"fullsized_captions_norm_val.npy"
"/home/rbertin/pyt_scripts/BLIP_TEST/gemma/gemma_norm_bge_captions_val.npy"
)


Expand Down Expand Up @@ -111,7 +107,7 @@ def __init__(

# Set up the ImageNet dataset transformations
transform = transforms.Compose([
transforms.Resize(256),
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
])
Expand Down Expand Up @@ -143,11 +139,6 @@ def train_dataloader(self):
def val_dataloader(self):
val_loaders = self.setup_dataloaders(self.val_datasets)
return CombinedLoader(val_loaders, mode="sequential")






def get_samples(self, split: Literal["train", "val"], amount: int) -> dict[frozenset, dict[str, torch.Tensor]]:
"""Fetches a specified number of samples from the specified split ('train' or 'val')."""
Expand Down
56 changes: 51 additions & 5 deletions examples/imgnet_example/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from shimmer import DomainModule, LossOutput

from diffusers.models import AutoencoderKL


class ImageDomain(DomainModule):
def __init__(self, latent_dim: int):
Expand All @@ -26,15 +28,11 @@ def __init__(self, latent_dim: int):


def encode(self, x: torch.Tensor) -> torch.Tensor:
# Add random noise to x between 0 and 0.03
noise_level = torch.rand(1).item() * 0.03
return x #+ noise
return x

def decode(self, z: torch.Tensor) -> torch.Tensor:
self.eval()

print("stats before decoding : ",z.mean(), z.std())

# Decode using VAE model
val = self.vae_model.decode(z)
return val
Expand All @@ -61,6 +59,54 @@ def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
return LossOutput(loss=F.mse_loss(pred, target))



class SDImageDomain(DomainModule):
def __init__(self, latent_dim: int):
super().__init__(latent_dim)
if latent_dim != 1024:
raise ValueError("vision latent_dim must be 1024")

# load the model parameters
self.vae_model = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")

self.vae_model.eval()
print(
"nb params in the vae model : ",
sum(p.numel() for p in self.vae_model.parameters())
)


def encode(self, x: torch.Tensor) -> torch.Tensor:
return x

def decode(self, z: torch.Tensor) -> torch.Tensor:
self.eval()
z = z.reshape(z.shape[0],4,16,16)
val = self.vae_model.decode(z).sample
return val

def training_step(self, batch: torch.Tensor, batch_idx: int):
(domain,) = batch
decoded = self.decode(self.encode(domain))
loss = F.mse_loss(domain, decoded)
self.log("train_loss", loss)
return loss

def validation_step(self, batch: torch.Tensor, batch_idx: int):
(domain,) = batch
decoded = self.decode(self.encode(domain))
loss = F.mse_loss(domain, decoded)
self.log("val_loss", loss)
return loss

def configure_optimizers(self):
return AdamW(self.parameters(), lr=1e-3, weight_decay=1e-6)

def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
# Computes an illustrative loss, can be tailored for specific use cases
return LossOutput(loss=F.mse_loss(pred, target))


class TextDomain(DomainModule):
def __init__(self, latent_dim: int):
super().__init__(latent_dim)
Expand Down
88 changes: 88 additions & 0 deletions examples/imgnet_example/get_sd_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import os
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt
from diffusers.models import AutoencoderKL
from tqdm import tqdm

# Set the device
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("device:", device)

# Load the VAE model
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device)
vae.eval()

# Set the root directory where ImageNet is located
DATASET_DIR = os.environ.get('DATASET_DIR', '.') # Get the environment variable, if not set, default to '.'
root_dir = DATASET_DIR

# Define your transformations
transform = transforms.Compose([
transforms.Resize(128),
transforms.CenterCrop(128),
transforms.ToTensor(),
])

# Create datasets
train_dataset = ImageFolder(root=os.path.join(root_dir, 'imagenet/train'), transform=transform)
val_dataset = ImageFolder(root=os.path.join(root_dir, 'imagenet/val'), transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=False, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=2)

print("Images loaded")

# Check inputs distribution
print("Checking out inputs distribution:\n")
data, _ = next(iter(train_loader))
print("Min:", data.min())
print("Max:", data.max())
print("Mean:", data.mean(dim=[0, 2, 3]))
print("Std:", data.std(dim=[0, 2, 3]))

# Loop over the validation set, get all the embeddings, and save them to .npy format.
all_embeddings = {}
all_indices = {}

for batch_idx, (inputs, _) in enumerate(tqdm(train_loader)):
inputs = inputs.to(device)
with torch.no_grad():
encoded_latents = vae.encode(inputs).latent_dist.sample()

# Get the shape of the current embeddings
shape = encoded_latents.shape
shape_str = str(shape) # Convert shape to a string for dictionary keys

# Check if this shape already exists in the dictionary
if shape_str not in all_embeddings:
print("Adding shape:", shape_str)
all_embeddings[shape_str] = []
all_indices[shape_str] = []

# Add embeddings and corresponding indices to the respective list
all_embeddings[shape_str].append(encoded_latents.detach().cpu().numpy())
all_indices[shape_str].append(batch_idx)

root_dir = "sd_image_embeddings"

# Save each list of embeddings to a separate .npy file
os.makedirs(root_dir, exist_ok=True)

for shape_str, embeddings in all_embeddings.items():
embeddings = np.concatenate(embeddings)
np.save(os.path.join(root_dir, f'image_embeddings_{shape_str}_sd.npy'), embeddings)
print(f"Saved embeddings of shape {shape_str} to", os.path.join(root_dir, f'image_embeddings_{shape_str}_sd.npy'))
print(f"all_embeddings shape for {shape_str}: ", embeddings.shape)

mean = np.mean(embeddings, axis=0)
std = np.std(embeddings, axis=0)
print(f"embeddings distribution for {shape_str}: mean: {mean}, std: {std}")

# Print the shapes and corresponding indices
for shape_str, indices in all_indices.items():
print(f"Shape: {shape_str}, Indices: {indices}")
2 changes: 1 addition & 1 deletion examples/imgnet_example/imnet_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def on_callback(
pl_module.eval()
prediction_demi_cycles = batch_demi_cycles(
pl_module.gw_mod, selection_mod, latent_groups
)
)
prediction_cycles = batch_cycles(
pl_module.gw_mod,
selection_mod,
Expand Down
Loading

0 comments on commit ce888d7

Please sign in to comment.