diff --git a/examples/imgnet_example/dataset.py b/examples/imgnet_example/dataset.py index 53e70f14..1a9fc31e 100644 --- a/examples/imgnet_example/dataset.py +++ b/examples/imgnet_example/dataset.py @@ -16,20 +16,16 @@ # Paths to data IMAGE_LATENTS_PATH_TRAIN = ( - "/home/rbertin/pyt_scripts/full_imgnet/full_size/" - "vae_full_withbigger_disc/384_combined_standardized_embeddings.npy" - ) + "/home/rbertin/pyt_scripts/full_imgnet/full_size/vae_full_withbigger_disc/384_combined_standardized_embeddings.npy" +) IMAGE_LATENTS_PATH_VAL = ( - "/home/rbertin/pyt_scripts/full_imgnet/full_size/" - "vae_full_withbigger_disc/384_val_combined_standardized_embeddings.npy" + "/home/rbertin/pyt_scripts/full_imgnet/full_size/vae_full_withbigger_disc/384_val_combined_standardized_embeddings.npy" ) CAPTION_EMBEDDINGS_PATH_TRAIN = ( - "/home/rbertin/cleaned/git_synced/shimmer/examples/imgnet_example/" - "bge_fullsized_captions_norm.npy" + "/home/rbertin/pyt_scripts/BLIP_TEST/gemma/gemma_norm_bge_captions_train.npy" ) CAPTION_EMBEDDINGS_PATH_VAL = ( - "/home/rbertin/cleaned/git_synced/shimmer/examples/imgnet_example/" - "fullsized_captions_norm_val.npy" + "/home/rbertin/pyt_scripts/BLIP_TEST/gemma/gemma_norm_bge_captions_val.npy" ) @@ -111,7 +107,7 @@ def __init__( # Set up the ImageNet dataset transformations transform = transforms.Compose([ - transforms.Resize(256), + transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), ]) @@ -143,11 +139,6 @@ def train_dataloader(self): def val_dataloader(self): val_loaders = self.setup_dataloaders(self.val_datasets) return CombinedLoader(val_loaders, mode="sequential") - - - - - def get_samples(self, split: Literal["train", "val"], amount: int) -> dict[frozenset, dict[str, torch.Tensor]]: """Fetches a specified number of samples from the specified split ('train' or 'val').""" diff --git a/examples/imgnet_example/domains.py b/examples/imgnet_example/domains.py index 9eb44143..4e41f09c 100644 --- a/examples/imgnet_example/domains.py +++ b/examples/imgnet_example/domains.py @@ -8,6 +8,8 @@ from shimmer import DomainModule, LossOutput +from diffusers.models import AutoencoderKL + class ImageDomain(DomainModule): def __init__(self, latent_dim: int): @@ -26,15 +28,11 @@ def __init__(self, latent_dim: int): def encode(self, x: torch.Tensor) -> torch.Tensor: - # Add random noise to x between 0 and 0.03 - noise_level = torch.rand(1).item() * 0.03 - return x #+ noise + return x def decode(self, z: torch.Tensor) -> torch.Tensor: self.eval() - print("stats before decoding : ",z.mean(), z.std()) - # Decode using VAE model val = self.vae_model.decode(z) return val @@ -61,6 +59,54 @@ def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: return LossOutput(loss=F.mse_loss(pred, target)) + +class SDImageDomain(DomainModule): + def __init__(self, latent_dim: int): + super().__init__(latent_dim) + if latent_dim != 1024: + raise ValueError("vision latent_dim must be 1024") + + # load the model parameters + self.vae_model = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae") + + self.vae_model.eval() + print( + "nb params in the vae model : ", + sum(p.numel() for p in self.vae_model.parameters()) + ) + + + def encode(self, x: torch.Tensor) -> torch.Tensor: + return x + + def decode(self, z: torch.Tensor) -> torch.Tensor: + self.eval() + z = z.reshape(z.shape[0],4,16,16) + val = self.vae_model.decode(z).sample + return val + + def training_step(self, batch: torch.Tensor, batch_idx: int): + (domain,) = batch + decoded = self.decode(self.encode(domain)) + loss = F.mse_loss(domain, decoded) + self.log("train_loss", loss) + return loss + + def validation_step(self, batch: torch.Tensor, batch_idx: int): + (domain,) = batch + decoded = self.decode(self.encode(domain)) + loss = F.mse_loss(domain, decoded) + self.log("val_loss", loss) + return loss + + def configure_optimizers(self): + return AdamW(self.parameters(), lr=1e-3, weight_decay=1e-6) + + def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: + # Computes an illustrative loss, can be tailored for specific use cases + return LossOutput(loss=F.mse_loss(pred, target)) + + class TextDomain(DomainModule): def __init__(self, latent_dim: int): super().__init__(latent_dim) diff --git a/examples/imgnet_example/get_sd_embed.py b/examples/imgnet_example/get_sd_embed.py new file mode 100644 index 00000000..09a3168c --- /dev/null +++ b/examples/imgnet_example/get_sd_embed.py @@ -0,0 +1,88 @@ +import os +import torch +import numpy as np +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +from torchvision.datasets import ImageFolder +import matplotlib.pyplot as plt +from diffusers.models import AutoencoderKL +from tqdm import tqdm + +# Set the device +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +print("device:", device) + +# Load the VAE model +vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device) +vae.eval() + +# Set the root directory where ImageNet is located +DATASET_DIR = os.environ.get('DATASET_DIR', '.') # Get the environment variable, if not set, default to '.' +root_dir = DATASET_DIR + +# Define your transformations +transform = transforms.Compose([ + transforms.Resize(128), + transforms.CenterCrop(128), + transforms.ToTensor(), +]) + +# Create datasets +train_dataset = ImageFolder(root=os.path.join(root_dir, 'imagenet/train'), transform=transform) +val_dataset = ImageFolder(root=os.path.join(root_dir, 'imagenet/val'), transform=transform) + +# Create data loaders +train_loader = DataLoader(train_dataset, batch_size=256, shuffle=False, num_workers=2) +val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=2) + +print("Images loaded") + +# Check inputs distribution +print("Checking out inputs distribution:\n") +data, _ = next(iter(train_loader)) +print("Min:", data.min()) +print("Max:", data.max()) +print("Mean:", data.mean(dim=[0, 2, 3])) +print("Std:", data.std(dim=[0, 2, 3])) + +# Loop over the validation set, get all the embeddings, and save them to .npy format. +all_embeddings = {} +all_indices = {} + +for batch_idx, (inputs, _) in enumerate(tqdm(train_loader)): + inputs = inputs.to(device) + with torch.no_grad(): + encoded_latents = vae.encode(inputs).latent_dist.sample() + + # Get the shape of the current embeddings + shape = encoded_latents.shape + shape_str = str(shape) # Convert shape to a string for dictionary keys + + # Check if this shape already exists in the dictionary + if shape_str not in all_embeddings: + print("Adding shape:", shape_str) + all_embeddings[shape_str] = [] + all_indices[shape_str] = [] + + # Add embeddings and corresponding indices to the respective list + all_embeddings[shape_str].append(encoded_latents.detach().cpu().numpy()) + all_indices[shape_str].append(batch_idx) + +root_dir = "sd_image_embeddings" + +# Save each list of embeddings to a separate .npy file +os.makedirs(root_dir, exist_ok=True) + +for shape_str, embeddings in all_embeddings.items(): + embeddings = np.concatenate(embeddings) + np.save(os.path.join(root_dir, f'image_embeddings_{shape_str}_sd.npy'), embeddings) + print(f"Saved embeddings of shape {shape_str} to", os.path.join(root_dir, f'image_embeddings_{shape_str}_sd.npy')) + print(f"all_embeddings shape for {shape_str}: ", embeddings.shape) + + mean = np.mean(embeddings, axis=0) + std = np.std(embeddings, axis=0) + print(f"embeddings distribution for {shape_str}: mean: {mean}, std: {std}") + +# Print the shapes and corresponding indices +for shape_str, indices in all_indices.items(): + print(f"Shape: {shape_str}, Indices: {indices}") diff --git a/examples/imgnet_example/imnet_logging.py b/examples/imgnet_example/imnet_logging.py index a1e8fb88..74605e25 100644 --- a/examples/imgnet_example/imnet_logging.py +++ b/examples/imgnet_example/imnet_logging.py @@ -116,7 +116,7 @@ def on_callback( pl_module.eval() prediction_demi_cycles = batch_demi_cycles( pl_module.gw_mod, selection_mod, latent_groups - ) + ) prediction_cycles = batch_cycles( pl_module.gw_mod, selection_mod, diff --git a/examples/imgnet_example/inference_infusion.py b/examples/imgnet_example/inference_infusion.py new file mode 100644 index 00000000..61adcacd --- /dev/null +++ b/examples/imgnet_example/inference_infusion.py @@ -0,0 +1,337 @@ +import pandas as pd +import numpy as np +import os + +from tqdm import tqdm + +from sentence_transformers import SentenceTransformer +import matplotlib.pyplot as plt +import torch +from imnet_logging import LogGWImagesCallback +from dataset import make_datamodule +from domains import ImageDomain, TextDomain, SDImageDomain +from lightning.pytorch import Trainer, Callback +from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor +from shimmer import GWDecoder, GWEncoder, BroadcastLossCoefs +from shimmer.modules.global_workspace import GlobalWorkspace, SchedulerArgs +from lightning.pytorch.loggers.wandb import WandbLogger +import torch.nn as nn +import numpy as np +from torchvision import transforms +from torchvision.datasets import ImageFolder +from torch.utils.data import DataLoader, Subset +import random + +import matplotlib.pyplot as plt +from torchvision.utils import save_image +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas + +# Define dropout layers +def dropout_get_n_layers(n_layers: int, hidden_dim: int, dropout_rate: float = 0.5) -> list[nn.Module]: + layers: list[nn.Module] = [] + for _ in range(n_layers): + layers.extend([ + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Dropout(dropout_rate) + ]) + return layers + +# Define decoder with dropout layers +class dropout_GWDecoder(nn.Sequential): + def __init__(self, in_dim: int, hidden_dim: int, out_dim: int, n_layers: int, dropout_rate: float = 0.5): + super().__init__( + nn.Linear(in_dim, hidden_dim), + nn.ReLU(), + *dropout_get_n_layers(n_layers, hidden_dim, dropout_rate), + nn.Linear(hidden_dim, out_dim), + ) + +# Define encoder with dropout layers +class dropout_GWEncoder(dropout_GWDecoder): + def __init__(self, in_dim: int, hidden_dim: int, out_dim: int, n_layers: int, dropout_rate: float = 0.5): + super().__init__(in_dim, hidden_dim, out_dim, n_layers, dropout_rate) +def text_to_image(text, width=256, height=128, font_size=8): + """Convert a string to an image.""" + from PIL import Image, ImageDraw, ImageFont + + def wrap_text(text, num_words=6): + words = text.split() + lines = [' '.join(words[i:i+num_words]) for i in range(0, len(words), num_words)] + return '\n'.join(lines) + + image = Image.new('RGB', (width, height), color=(255, 255, 255)) + draw = ImageDraw.Draw(image) + try: + font = ImageFont.truetype("arial.ttf", font_size) + except IOError: + font = ImageFont.load_default() + + wrapped_text = wrap_text(text) + draw.text((10, 10), wrapped_text, font=font, fill=(0, 0, 0)) + return image + +def plot_and_save_images(random_latent_domains, decoded_images, iteration_titles, captions, vae_decoder, folder, final_image_name): + """Plot and save a single image with all results given a list of tensors and titles.""" + if not os.path.exists(folder): + os.makedirs(folder) + + num_titles = len(iteration_titles) + num_columns = 5 + num_rows = 1 + num_titles # One row for captions and one row per iteration + + fig, axes = plt.subplots(num_rows, num_columns, figsize=(15, num_rows * 3)) + canvas = FigureCanvas(fig) + + # Add a row for captions + for j in range(min(len(captions), num_columns)): # Ensure not to exceed the number of columns + caption_img = text_to_image(captions[j], font_size=6) + axes[0, j].imshow(caption_img) + axes[0, j].axis('off') + if j == 0: + axes[0, j].set_title('Captions', fontsize=12) + + # Loop through decoded images for each iteration + for i, (predicted_image_latents, title) in enumerate(zip(decoded_images, iteration_titles)): + # Decode the predicted image latents + decoded_imgs = vae_decoder(predicted_image_latents).cpu().detach() + + for j in range(min(decoded_imgs.size(0), num_columns)): # Ensure not to exceed the number of columns + img = decoded_imgs[j] + axes[i + 1, j].imshow(img.permute(1, 2, 0).clamp(0, 1)) # Convert from CHW to HWC and clamp values to [0, 1] + axes[i + 1, j].axis('off') + if j == 0: + axes[i + 1, j].set_title(title, fontsize=12) + + # Hide any unused subplots + for ax in axes.flat: + if not ax.has_data(): + ax.axis('off') + + # Save the final image + save_path = os.path.join(folder, f"{final_image_name}.png") + canvas.print_figure(save_path) + +def run_inference_and_plot(global_workspace, random_latent_domains, captions, final_image_name, num_iterations=100): + # Initialize device + device = 'cuda' if torch.cuda.is_available() else 'cpu' + global_workspace.to(device) + + # Initialize the VAE model + vae_model = global_workspace.domain_mods["image_latents"] + + # Prepare the input domains for the first iteration: Caption embeddings only + current_domains = {frozenset(["caption_embeddings"]): random_latent_domains[frozenset(["caption_embeddings"])]} + + # List to store the decoded images + decoded_images = [] + iteration_titles = [] + + plot_interval = max(1, num_iterations // 10) # Calculate the interval for plotting results + + for i in range(num_iterations): + # Determine the domains to encode based on the iteration + if i == 0: + # Encode only caption embeddings on the first iteration + latent_group = current_domains[frozenset(["caption_embeddings"])] + else: + # On subsequent iterations, encode both caption and image latents + latent_group = { + **current_domains[frozenset(["caption_embeddings"])], + **current_domains[frozenset(["image_latents"])] + } + print("image side: ", current_domains[frozenset(["image_latents"])]["image_latents"].shape) + + # 1. Encode and Fuse the latent representations manually + encoded_latents = global_workspace.gw_mod.encode(latent_group) + + # Manually set specific selection weights for text and image + batch_size = list(encoded_latents.values())[0].shape[0] + selection_weights = {} + for domain in encoded_latents: + if domain == "caption_embeddings": + selection_weights[domain] = torch.full((batch_size,), .95, device=device) + elif domain == "image_latents": + selection_weights[domain] = torch.full((batch_size,), 0.05, device=device) + + # Fuse the representations + fused_state = global_workspace.gw_mod.fuse(encoded_latents, selection_weights) + print("fused state shape: ", fused_state.shape) + + # 2. Manually perform the broadcast + all_domains = list(global_workspace.gw_mod.domain_mods.keys()) + predictions = {} + for domain in all_domains: + # Decode the fused state back to each domain + predictions[domain] = global_workspace.gw_mod.decode(fused_state, domains=[domain])[domain] + if domain == "image_latents": + print("\n\ngot here \n\n") + print("predictions shape: ", predictions[domain].shape) + + # 4. Construct the manual output structure using a dictionary + output = { + 'states': fused_state, + 'broadcasts': {frozenset(latent_group.keys()): predictions}, + } + + # Extract the predicted image latents + predicted_image_latents = output['broadcasts'][frozenset(latent_group.keys())]['image_latents'] + + # Decode and store the predicted image latents only at specified intervals + if i % plot_interval == 0 or i == num_iterations - 1: # Also ensure the last iteration is included + decoded_images.append(predicted_image_latents) + iteration_titles.append(f"Iteration {i+1}") + + # Prepare the input for the next iteration: Predicted image latents + original caption embeddings + current_domains = { + frozenset(["caption_embeddings"]): random_latent_domains[frozenset(["caption_embeddings"])], + frozenset(["image_latents"]): {'image_latents': predicted_image_latents} + } + + # Plot and save the results + folder = "inference_results" + plot_and_save_images(random_latent_domains, decoded_images, iteration_titles, captions, vae_model.decode, folder, final_image_name) + + +def inference_gw(): + batch_size = 2056 + + # Prepare data modules from the dataset script + data = make_datamodule(batch_size) + + # Initialize the domain modules with the specific latent dimensions and model parameters + image_domain = ImageDomain(latent_dim=384) + text_domain = TextDomain(latent_dim=384) + + domain_mods = { + "image_latents": image_domain, + "caption_embeddings": text_domain, + } + + workspace_dim = 512 # Define the dimension of the global workspace + + # Define modality encoders and decoders + gw_encoders = {} + gw_decoders = {} + for name, mod in domain_mods.items(): + gw_encoders[name] = dropout_GWEncoder( + in_dim=mod.latent_dim, + hidden_dim=1024, + out_dim=workspace_dim, + n_layers=4, + dropout_rate=0.0 # Example dropout rate + ) + gw_decoders[name] = dropout_GWDecoder( + in_dim=workspace_dim, + hidden_dim=1024, + out_dim=mod.latent_dim, + n_layers=4, + dropout_rate=0.0 # Example dropout rate + ) + + # Loss coefficients setup + loss_coefs: BroadcastLossCoefs = { + "translations": 2.0, + "demi_cycles": 1.0, + "cycles": 1.0, + "contrastives": .05, + "fused": 1.0 + } + + # Initialize device + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + n_epochs = 2000 # Number of training epochs + + global_workspace = GlobalWorkspace( + domain_mods, + gw_encoders, + gw_decoders, + workspace_dim, + loss_coefs, + scheduler_args=SchedulerArgs( + max_lr=0.0002, + total_steps=n_epochs * len(iter(data.train_dataloader())) + ), + ).to(device) + + # Set precision for matrix multiplication + torch.set_float32_matmul_precision("medium") + + # Initialize the BGE model + bge_model = SentenceTransformer("BAAI/bge-small-en-v1.5") + + # Load the model checkpoint + CHECKPOINT_PATH = "/home/rbertin/cleaned/git_synced/shimmer/examples/imgnet_example/wandb_output_weight_decay/simple_shapes_fusion/bqumes0b/checkpoints/epoch=299-step=187200.ckpt" + checkpoint = torch.load(CHECKPOINT_PATH) + + # Assuming the model has a method to load from checkpoint state_dict + def load_state_dict_without_vae(global_workspace, checkpoint): + # Create a new state dictionary excluding keys with "vae" + filtered_state_dict = {k: v for k, v in checkpoint['state_dict'].items() if 'vae' not in k} + + # Print the keys that are being loaded + print("Loading the following keys:") + for key in filtered_state_dict.keys(): + print(key) + + # Load the filtered state dictionary + global_workspace.load_state_dict(filtered_state_dict, strict=False) + + # Example usage + load_state_dict_without_vae(global_workspace, checkpoint) + + print("got checkpoint :)") + csv_file = "/home/rbertin/pyt_scripts/BLIP_TEST/gemma/captions_gemma_valimgnet.csv" + df = pd.read_csv(csv_file) + + # Set up the ImageNet dataset transformations + transform = transforms.Compose([ + transforms.Resize(128), + transforms.CenterCrop(128), + transforms.ToTensor(), + ]) + + val_dir = os.environ.get('DATASET_DIR', '.') + '/imagenet/train' + imagenet_val_dataset = ImageFolder(root=val_dir, transform=transform) + + # Process the textual train samples through the model + fig, axes = plt.subplots(2, 5, figsize=(30, 18)) + + # Load image latents + IMAGE_LATENTS_PATH_VAL = "/home/rbertin/pyt_scripts/full_imgnet/full_size/vae_full_withbigger_disc/384_val_combined_standardized_embeddings.npy" + image_latents_val = np.load(IMAGE_LATENTS_PATH_VAL) + image_latents_val = torch.tensor(image_latents_val).to('cuda') + + CAPTION_EMBEDDINGS_PATH_VAL = "/home/rbertin/pyt_scripts/BLIP_TEST/gemma/gemma_norm_bge_captions_val.npy" + caption_embeddings_val = np.load(CAPTION_EMBEDDINGS_PATH_VAL) + caption_embeddings_val = torch.tensor(caption_embeddings_val).to('cuda') + + # Set a fixed seed for reproducibility + np.random.seed(42) + + # Generate multiple figures with different random indices + num_figures = 10 + for fig_idx in range(num_figures): + random_indices = np.random.choice(image_latents_val.shape[0], 5, replace=False) + print("Selected indices for figure", fig_idx, ":", random_indices) + + print("shape for image latents : ", image_latents_val[random_indices].shape) + + random_latent_domains = { + frozenset(["image_latents"]): { + "image_latents": image_latents_val[random_indices] + }, + frozenset(["caption_embeddings"]): { + "caption_embeddings": caption_embeddings_val[random_indices] + } + } + + captions = df.iloc[random_indices]['Caption'].tolist() + print("captions : ", captions) + + # Call the updated run_inference_and_plot for each set of random indices + run_inference_and_plot(global_workspace, random_latent_domains, captions, final_image_name=f"final_output_image_{fig_idx}", num_iterations=10) + +if __name__ == "__main__": + inference_gw() diff --git a/examples/imgnet_example/inference_translation.py b/examples/imgnet_example/inference_translation.py index 19b6d43b..c81e3cfe 100644 --- a/examples/imgnet_example/inference_translation.py +++ b/examples/imgnet_example/inference_translation.py @@ -9,7 +9,7 @@ import torch from imnet_logging import LogGWImagesCallback from dataset import make_datamodule -from domains import ImageDomain, TextDomain +from domains import ImageDomain, TextDomain, SDImageDomain from lightning.pytorch import Trainer, Callback from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor from shimmer import GWDecoder, GWEncoder, BroadcastLossCoefs @@ -22,6 +22,117 @@ from torch.utils.data import DataLoader, Subset import random +import matplotlib.pyplot as plt +from torchvision.utils import save_image +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas + +def text_to_image(text, width=256, height=128, font_size=8): + """Convert a string to an image.""" + from PIL import Image, ImageDraw, ImageFont + + def wrap_text(text, num_words=6): + words = text.split() + lines = [' '.join(words[i:i+num_words]) for i in range(0, len(words), num_words)] + return '\n'.join(lines) + + image = Image.new('RGB', (width, height), color=(255, 255, 255)) + draw = ImageDraw.Draw(image) + try: + font = ImageFont.truetype("arial.ttf", font_size) + except IOError: + font = ImageFont.load_default() + + wrapped_text = wrap_text(text) + draw.text((10, 10), wrapped_text, font=font, fill=(0, 0, 0)) + return image + +def plot_and_save_images(input_tensors, tensors, titles, captions, vae_decoder, folder, final_image_name): + """Plot and save a single image with all results given a list of tensors and titles.""" + if not os.path.exists(folder): + os.makedirs(folder) + + num_titles = len(titles) + fig, axes = plt.subplots(num_titles + 2, 5, figsize=(15, (num_titles + 2) * 3)) # Adjust for captions row + canvas = FigureCanvas(fig) + + # Add a row for captions + for j, caption in enumerate(captions): + caption_img = text_to_image(caption, font_size=6) + axes[0, j].imshow(caption_img) + axes[0, j].axis('off') + if j == 0: + axes[0, j].set_title('Captions', fontsize=12) + + # Add a row for the input latents decoded directly + input_images = vae_decoder(input_tensors[frozenset(["image_latents"])]["image_latents"]).cpu().detach() + for j, img in enumerate(input_images): + axes[1, j].imshow(img.permute(1, 2, 0).clamp(0, 1)) # Convert from CHW to HWC and clamp values to [0, 1] + axes[1, j].axis('off') + if j == 0: + axes[1, j].set_title('Directly Decoded Input Latents', fontsize=12) + + for i, (tensor, title) in enumerate(zip(tensors, titles)): + imgs = vae_decoder(tensor).cpu().detach() + row_index = i + 2 # Adjust row index to account for the new top row + + for j, img in enumerate(imgs): + axes[row_index, j].imshow(img.permute(1, 2, 0).clamp(0, 1)) # Convert from CHW to HWC and clamp values to [0, 1] + axes[row_index, j].axis('off') + if j == 0: + axes[row_index, j].set_title(title, fontsize=12) + + # Save the final image + save_path = os.path.join(folder, f"{final_image_name}.png") + canvas.print_figure(save_path) + + +def run_inference_and_plot(global_workspace, random_latent_domains, captions): + # Initialize device + device = 'cuda' if torch.cuda.is_available() else 'cpu' + global_workspace.to(device) + + # Initialize the VAE model + vae_model = global_workspace.domain_mods["image_latents"] + + # Prepare the input domains for different cases + # 1. Caption embeddings only + caption_only_domains = {frozenset(["caption_embeddings"]): random_latent_domains[frozenset(["caption_embeddings"])]} + + # 2. Image latents only + image_only_domains = {frozenset(["image_latents"]): random_latent_domains[frozenset(["image_latents"])]} + + # 3. Both caption embeddings and image latents + both_domains = { + frozenset(["caption_embeddings"]): random_latent_domains[frozenset(["caption_embeddings"])], + frozenset(["image_latents"]): random_latent_domains[frozenset(["image_latents"])] + } + + # Perform inference + caption_only_output = global_workspace.forward(caption_only_domains) + image_only_output = global_workspace.forward(image_only_domains) + both_output = global_workspace.forward(both_domains) + + # Collect the results to be plotted, excluding those going to caption embeddings + results = [ + (image_only_output['broadcasts'][frozenset({'image_latents'})].get('image_latents', None), 'Broadcast - Image Latents from Image Only'), + (image_only_output['cycles'][frozenset({'image_latents'})].get('image_latents', None), 'Cycle - Image Latents from Image Only'), + + (caption_only_output['broadcasts'][frozenset({'caption_embeddings'})].get('image_latents', None), 'Broadcast - Image Latents from Caption Only'), + + (both_output['broadcasts'][frozenset({'image_latents'})].get('image_latents', None), 'Broadcast - Image Latents from Both'), + ] + + # Filter out None results and those going to caption embeddings + filtered_results = [(tensor, title) for tensor, title in results if tensor is not None] + + # Plot and save the results + folder = "inference_results" + final_image_name = "final_output_image" + input_tensors = random_latent_domains + tensors = [result[0] for result in filtered_results] + titles = [result[1] for result in filtered_results] + plot_and_save_images(input_tensors, tensors, titles, captions, vae_model.decode, folder, final_image_name) + # Define dropout layers def dropout_get_n_layers(n_layers: int, hidden_dim: int, dropout_rate: float = 0.5) -> list[nn.Module]: layers: list[nn.Module] = [] @@ -47,6 +158,7 @@ def __init__(self, in_dim: int, hidden_dim: int, out_dim: int, n_layers: int, dr class dropout_GWEncoder(dropout_GWDecoder): def __init__(self, in_dim: int, hidden_dim: int, out_dim: int, n_layers: int, dropout_rate: float = 0.5): super().__init__(in_dim, hidden_dim, out_dim, n_layers, dropout_rate) + def print_shapes(nested_dict, indent=0): """Recursively print the shapes of tensors in a nested dictionary.""" for key, value in nested_dict.items(): @@ -125,89 +237,75 @@ def inference_gw(): bge_model = SentenceTransformer("BAAI/bge-small-en-v1.5") # Load the model checkpoint - CHECKPOINT_PATH = "/home/rbertin/cleaned/git_synced/shimmer/examples/imgnet_example/wandb_output_bigger_vae_regularized/simple_shapes_fusion/w2syv1ph/checkpoints/epoch=122-step=76752.ckpt" + CHECKPOINT_PATH = "/home/rbertin/cleaned/git_synced/shimmer/examples/imgnet_example/wandb_output_weight_decay/simple_shapes_fusion/bqumes0b/checkpoints/epoch=299-step=187200.ckpt" checkpoint = torch.load(CHECKPOINT_PATH) - # Assuming the model has a method to load from checkpoint state_dict - global_workspace.load_state_dict(checkpoint['state_dict']) - print("got checkpoint :)") - csv_file = "captions_fullimgnet_val_noshuffle.csv" - df = pd.read_csv(csv_file) + # Set a fixed seed for reproducibility + np.random.seed(42) - # Function to get random samples - def get_random_samples(df, num_samples=5): - return df.sample(n=num_samples) - - # Get 5 random samples - train_samples = get_random_samples(df, num_samples=5) - - print("got samples :)") + # Assuming the model has a method to load from checkpoint state_dict + def load_state_dict_without_vae(global_workspace, checkpoint): + + # Create a new state dictionary excluding keys with "vae" + filtered_state_dict = {k: v for k, v in checkpoint['state_dict'].items() if 'vae' not in k} - # Path to the directory where the embeddings file is stored and the filename - root_dir = '' # Change this to your directory path - embeddings_file = '../../../../../pyt_scripts/BLIP_TEST/get_embed/bge_downsampled_captions.npy' - file_path = os.path.join(root_dir, embeddings_file) + # Print the keys that are being loaded + print("Loading the following keys:") + for key in filtered_state_dict.keys(): + print(key) - # Load embeddings - embeddings = np.load(file_path) + # Load the filtered state dictionary + global_workspace.load_state_dict(filtered_state_dict, strict=False) - # Print original mean and std - mean_tensor = np.mean(embeddings, axis=0) - std_tensor = np.std(embeddings, axis=0) + # Example usage + load_state_dict_without_vae(global_workspace, checkpoint) - def normalize_embedding(embedding, mean, std): - return (embedding - mean) / std + print("got checkpoint :)") + csv_file = "/home/rbertin/pyt_scripts/BLIP_TEST/gemma/captions_gemma_valimgnet.csv" + df = pd.read_csv(csv_file) # Set up the ImageNet dataset transformations transform = transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), + transforms.Resize(128), + transforms.CenterCrop(128), transforms.ToTensor(), ]) val_dir = os.environ.get('DATASET_DIR', '.') + '/imagenet/train' imagenet_val_dataset = ImageFolder(root=val_dir, transform=transform) - # Get the image indices from the sampled captions - image_indices = train_samples['Image Index'].tolist() - - # Create a subset of the dataset with the selected image indices - subset_dataset = Subset(imagenet_val_dataset, image_indices) - subset_loader = DataLoader(subset_dataset, batch_size=1, shuffle=False) - # Process the textual train samples through the model fig, axes = plt.subplots(2, 5, figsize=(30, 18)) - # Move mean and std tensors to device - mean_tensor = torch.tensor(mean_tensor).to('cuda') - std_tensor = torch.tensor(std_tensor).to('cuda') - - # Load image latents - IMAGE_LATENTS_PATH_VAL = ( - "/home/rbertin/pyt_scripts/full_imgnet/full_size/vae_bigdisc_goodbeta_50ep" - "/combined_standardized_embeddings.npy" - ) - # Load image latents + IMAGE_LATENTS_PATH_VAL = "/home/rbertin/pyt_scripts/full_imgnet/full_size/vae_full_withbigger_disc/384_val_combined_standardized_embeddings.npy" image_latents_val = np.load(IMAGE_LATENTS_PATH_VAL) image_latents_val = torch.tensor(image_latents_val).to('cuda') + CAPTION_EMBEDDINGS_PATH_VAL = "/home/rbertin/pyt_scripts/BLIP_TEST/gemma/gemma_norm_bge_captions_val.npy" + caption_embeddings_val = np.load(CAPTION_EMBEDDINGS_PATH_VAL) + caption_embeddings_val = torch.tensor(caption_embeddings_val).to('cuda') + + # Select five random indices + random_indices = np.random.choice(image_latents_val.shape[0], 5, replace=False) + + print("distribution : ",image_latents_val.mean(), image_latents_val.std()) + print("distribution : ",caption_embeddings_val.mean(), caption_embeddings_val.std()) + # Create random latents domain groups for testing random_latent_domains = { frozenset(["image_latents"]): { - "image_latents": torch.rand((1, 384)).to(device) # Example random tensor + "image_latents": image_latents_val[random_indices] }, frozenset(["caption_embeddings"]): { - "caption_embeddings": torch.rand((1, 384)).to(device) # Example random tensor + "caption_embeddings": caption_embeddings_val[random_indices] } } - - # Call the forward function on random samples - gw_predictions = global_workspace.forward(random_latent_domains) + captions = df.iloc[random_indices]['Caption'].tolist() # Print the shapes of the output tensors - print_shapes(gw_predictions) + run_inference_and_plot(global_workspace, random_latent_domains, captions) if __name__ == "__main__": - inference_gw() + inference_gw() \ No newline at end of file diff --git a/examples/imgnet_example/normalize_embeddings.py b/examples/imgnet_example/normalize_embeddings.py index d1631191..bec90a31 100644 --- a/examples/imgnet_example/normalize_embeddings.py +++ b/examples/imgnet_example/normalize_embeddings.py @@ -3,7 +3,7 @@ # Path to the directory where the embeddings file is stored and the filename root_dir = '' # Change this to your directory path -embeddings_file = '/home/rbertin/cleaned/git_synced/shimmer/examples/imgnet_example/bge_fullsized_captions_val.npy' +embeddings_file = '/home/rbertin/pyt_scripts/BLIP_TEST/gemma/gemma_bge_captions_val.npy' file_path = os.path.join(root_dir, embeddings_file) # Load embeddings @@ -25,6 +25,6 @@ print("Normalized Std:", normalized_std) # Save normalized embeddings -normalized_file_path = os.path.join(root_dir, 'fullsized_captions_norm_val.npy') +normalized_file_path = "/home/rbertin/pyt_scripts/BLIP_TEST/gemma/gemma_norm_bge_captions_val.npy" np.save(normalized_file_path, normalized_embeddings) print("Saved normalized embeddings to", normalized_file_path) diff --git a/examples/imgnet_example/temp_uniteembeddings.py b/examples/imgnet_example/temp_uniteembeddings.py new file mode 100644 index 00000000..50a90ceb --- /dev/null +++ b/examples/imgnet_example/temp_uniteembeddings.py @@ -0,0 +1,35 @@ +import numpy as np + +# Function to normalize embeddings to N(0,1) using a single mean and std for all dimensions +def standardize_embeddings_global(embeddings): + mean = np.mean(embeddings) + std = np.std(embeddings) + + print("standardize based on ", mean, std) + return (embeddings - mean) / std + +# Load the first file +file1 = '/home/rbertin/cleaned/git_synced/shimmer/examples/imgnet_example/sd_image_embeddings/image_embeddings_torch.Size([256, 4, 16, 16])_sd.npy' +embeddings1 = np.load(file1) + +# Load the second file +file2 = '/home/rbertin/cleaned/git_synced/shimmer/examples/imgnet_example/sd_image_embeddings/image_embeddings_torch.Size([143, 4, 16, 16])_sd.npy' +embeddings2 = np.load(file2) + + + +# Standardize the embeddings +standardized_embeddings1 = standardize_embeddings_global(embeddings1) +standardized_embeddings2 = standardize_embeddings_global(embeddings2) + +# Concatenate the standardized embeddings +combined_embeddings = np.concatenate((standardized_embeddings1, standardized_embeddings2), axis=0) +combined_embeddings = combined_embeddings.reshape(combined_embeddings.shape[0], -1) + +print("combined embeddings ", combined_embeddings.mean(),combined_embeddings.std(), combined_embeddings.shape) + +# Save the combined standardized embeddings to a new file +output_file = 'sd_image_embeddings/train_united.npy' +np.save(output_file, combined_embeddings) + +print(f'Combined standardized embeddings saved to {output_file}') \ No newline at end of file diff --git a/examples/imgnet_example/test_indices.py b/examples/imgnet_example/test_indices.py index 31fb295b..d8474c8d 100644 --- a/examples/imgnet_example/test_indices.py +++ b/examples/imgnet_example/test_indices.py @@ -1,122 +1,48 @@ -import torch -from imnet_logging import LogGWImagesCallback - -from dataset import make_datamodule -from domains import ImageDomain, TextDomain -from lightning.pytorch import Trainer, Callback -from lightning.pytorch.callbacks import ModelCheckpoint,LearningRateMonitor - -from shimmer import GlobalWorkspace, GWDecoder, GWEncoder, BroadcastLossCoefs -from shimmer.modules.global_workspace import GlobalWorkspaceFusion, SchedulerArgs - -from lightning.pytorch.loggers.wandb import WandbLogger - -#put in utils later -import torch.nn as nn -import torch -import numpy as np +import os import random +import pandas as pd +from torchvision import datasets, transforms +from PIL import Image +import matplotlib.pyplot as plt -# Define the load_data function -def load_data(path): - return np.load(path) - -# Paths to data -IMAGE_LATENTS_PATH_TRAIN = ( - "/home/rbertin/pyt_scripts/full_imgnet/full_size/vae_full_withbigger__disc/" - "val_image_embeddings.npy" -) -IMAGE_LATENTS_PATH_VAL = ( - "/home/rbertin/pyt_scripts/full_imgnet/full_size/vae_full_withbigger__disc/" - "val_image_embeddings_val.npy" -) -CAPTION_EMBEDDINGS_PATH_TRAIN = ( - "/home/rbertin/cleaned/git_synced/shimmer/examples/imgnet_example/" - "bge_fullsized_captions_norm_fixed.npy" -) -CAPTION_EMBEDDINGS_PATH_VAL = ( - "/home/rbertin/cleaned/git_synced/shimmer/examples/imgnet_example/" - "bge_fullsized_captions_norm_val.npy" -) - -# Load training data -image_latents_train = load_data(IMAGE_LATENTS_PATH_TRAIN)[:50000] -caption_embeddings_train = load_data(CAPTION_EMBEDDINGS_PATH_TRAIN)[:50000] - - -# Load validation data -image_latents_val = load_data(IMAGE_LATENTS_PATH_VAL) -caption_embeddings_val = load_data(CAPTION_EMBEDDINGS_PATH_VAL) - - - -print("got this far !") - -# Assuming `make_datamodule` and `data_module` are defined somewhere in your code -data = make_datamodule(0.8, batch_size=2056) - -print("got this far !") -train_loader = data.train_dataloader() -val_loader = data.val_dataloader() +# ImageNet DataLoader setup +imagenet_data_path = '/shared/datasets/imagenet/train' # Path to the ImageNet dataset -# Function to get a random pair of matched latents -def get_random_matched_pair(batch): - indices = list(range(len(batch['image_latents']))) - random_index = random.choice(indices) - return batch['image_latents'][random_index], batch['caption_embeddings'][random_index] +# Define transformations +transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), +]) -# Function to find the index of a latent in a dataset -def find_index(latent, dataset): - for idx, item in enumerate(dataset): - if np.allclose(latent, item, atol=1e-15): # Using a tolerance for floating point comparison - return idx - return -1 +# Load the dataset +imagenet_dataset = datasets.ImageFolder(root=imagenet_data_path, transform=transform) -print("got this far !") +# Load the no-duplicates csv +captions_df = pd.read_csv('../../../../../pyt_scripts/BLIP_TEST/gemma/no_duplicates_gemma_captions.csv') -# Perform the comparison 1000 times on the validation set -print("Validation set comparisons:") -for i in range(1000): - batch = next(iter(val_loader)) +# Check if the lengths match +assert len(captions_df) == len(imagenet_dataset), "The lengths of the dataset and captions CSV do not match." - image_latent, text_latent = get_random_matched_pair(batch) - - if image_latent is None or text_latent is None: - print("No matched pair found in the batch.") - continue +# Pick randomly five indices between 0 and the length of the captions csv +random_indices = random.sample(range(len(captions_df)), 5) - image_latent_np = image_latent.cpu().numpy() - text_latent_np = text_latent.cpu().numpy() - - image_index = find_index(image_latent_np, image_latents_val) - text_index = find_index(text_latent_np, caption_embeddings_val) - - if image_index == -1 or text_index == -1: - print("Index not found in the respective dataset.") - continue - if image_index != text_index or i%10==0: - print(f"Image index: {image_index}, Text index: {text_index}, Match: {image_index == text_index}") +# Plot the images and captions +fig, axes = plt.subplots(5, 1, figsize=(10, 20)) -# Perform the comparison 1000 times on the training set -print("Training set comparisons:") -for i in range(1000): - batch = next(iter(train_loader)) - image_latent, text_latent = get_random_matched_pair(batch) +for i, idx in enumerate(random_indices): + image_path, _ = imagenet_dataset.samples[idx] + image = Image.open(image_path) + caption = captions_df.loc[idx, 'Caption'] + axes[i].imshow(image) + axes[i].axis('off') + # Split the caption into lines with up to 5 words each + caption_lines = '\n'.join([' '.join(caption.split()[j:j+5]) for j in range(0, len(caption.split()), 5)]) + axes[i].set_title(caption_lines, fontsize=8) - - if image_latent is None or text_latent is None: - print("No matched pair found in the batch.") - continue +# Save the plot to a file +plt.tight_layout() +plt.savefig('random_images_with_captions.png') - image_latent_np = image_latent.cpu().numpy() - text_latent_np = text_latent.cpu().numpy() - - image_index = find_index(image_latent_np, image_latents_train) - text_index = find_index(text_latent_np, caption_embeddings_train) - - if image_index == -1 or text_index == -1: - print("Index not found in the respective dataset.") - continue - if image_index != text_index or i%10==0: - print(f"Image index: {image_index}, Text index: {text_index}, Match: {image_index == text_index}") +print("Images and captions plotted and saved to 'random_images_with_captions.png'.") diff --git a/examples/imgnet_example/train_gw.py b/examples/imgnet_example/train_gw.py index 466b8520..cd47056c 100644 --- a/examples/imgnet_example/train_gw.py +++ b/examples/imgnet_example/train_gw.py @@ -2,7 +2,7 @@ from imnet_logging import LogGWImagesCallback from dataset import make_datamodule -from domains import ImageDomain, TextDomain +from domains import ImageDomain, TextDomain, SDImageDomain from lightning.pytorch import Trainer, Callback from lightning.pytorch.callbacks import ModelCheckpoint,LearningRateMonitor @@ -112,14 +112,14 @@ def train_gw(): hidden_dim=1024, out_dim=workspace_dim, n_layers=4, - dropout_rate=0.01 # Example dropout rate + dropout_rate=0.0 # Example dropout rate ) gw_decoders[name] = dropout_GWDecoder( in_dim=workspace_dim, hidden_dim=1024, out_dim=mod.latent_dim, n_layers=4, - dropout_rate=0.01 # Example dropout rate + dropout_rate=0.0 # Example dropout rate ) # Loss coefficients setup @@ -131,7 +131,7 @@ def train_gw(): "fused": 1.0 } - n_epochs = 200 # Number of training epochs + n_epochs = 2000 # Number of training epochs global_workspace = GlobalWorkspace( domain_mods, @@ -144,11 +144,11 @@ def train_gw(): total_steps=n_epochs * len(iter(data.train_dataloader())) ), - optim_weight_decay=0.05 + optim_weight_decay=0. ) wandb_logger = None - run_name = "bigger_vae_regularized" + run_name = "ourvae_gemma_384_512" wandb_logger = WandbLogger( save_dir=f"wandb_output_{run_name}", project="simple_shapes_fusion", @@ -180,7 +180,7 @@ def train_gw(): # Trainer setup trainer = Trainer( - logger=wandb_logger, + #logger=wandb_logger, devices=1, # assuming training on 1 GPU max_epochs=n_epochs, log_every_n_steps=100,