Skip to content

Commit

Permalink
progress
Browse files Browse the repository at this point in the history
  • Loading branch information
RolandBERTINJOHANNET committed Jul 9, 2024
1 parent d011ea3 commit d604342
Show file tree
Hide file tree
Showing 9 changed files with 448 additions and 126 deletions.
10 changes: 5 additions & 5 deletions examples/imgnet_example/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

# Paths to data
IMAGE_LATENTS_PATH_TRAIN = (
"/home/rbertin/pyt_scripts/full_imgnet/full_size/vae_bigdisc_goodbeta_50ep"
"/combined_standardized_embeddings.npy"
)
"/home/rbertin/pyt_scripts/full_imgnet/full_size/"
"vae_full_withbigger_disc/384_combined_standardized_embeddings.npy"
)
IMAGE_LATENTS_PATH_VAL = (
"/home/rbertin/pyt_scripts/full_imgnet/full_size/vae_bigdisc_goodbeta_50ep"
"/val_combined_standardized_embeddings.npy"
"/home/rbertin/pyt_scripts/full_imgnet/full_size/"
"vae_full_withbigger_disc/384_val_combined_standardized_embeddings.npy"
)
CAPTION_EMBEDDINGS_PATH_TRAIN = (
"/home/rbertin/cleaned/git_synced/shimmer/examples/imgnet_example/"
Expand Down
24 changes: 2 additions & 22 deletions examples/imgnet_example/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,49 +13,29 @@ class ImageDomain(DomainModule):
def __init__(self, latent_dim: int):
super().__init__(latent_dim)
# load the model parameters
checkpoint_path = "vae_model.pth"
checkpoint_path = "/home/rbertin/pyt_scripts/full_imgnet/full_size/vae_full_withbigger__disc/nearest_lpips_latent_dim=384/vae_model.pth"
self.vae_model = VanillaVAE(
in_channels=3, latent_dim=512, upsampling="bilinear", loss_type="lpips"
in_channels=3, latent_dim=latent_dim, upsampling="nearest", loss_type="lpips"
)
self.vae_model.load_state_dict(torch.load(checkpoint_path))
self.vae_model.eval()
print(
"nb params in the vae model : ",
sum(p.numel() for p in self.vae_model.parameters())
)

# Load the non-normalized embeddings to get their stats
IMAGE_LATENTS_PATH_TRAIN = (
"/home/rbertin/pyt_scripts/full_imgnet/full_size/vae_bigdisc_goodbeta_50ep"
"/combined_embeddings.npy"
)
embeddings = np.load(IMAGE_LATENTS_PATH_TRAIN)
flattened_embeddings = embeddings.flatten()

self.mean = torch.tensor(np.mean(flattened_embeddings), dtype=torch.float32).to('cuda')
self.std = torch.tensor(np.std(flattened_embeddings), dtype=torch.float32).to('cuda')

print("extracted mean and std : ",self.mean, self.std)


def encode(self, x: torch.Tensor) -> torch.Tensor:
# Add random noise to x between 0 and 0.03
noise_level = torch.rand(1).item() * 0.03
noise = torch.randn_like(x) * noise_level
# Normalize the noise
noise = (noise - self.mean.to(noise.device)) / self.std.to(noise.device)
return x #+ noise

def decode(self, z: torch.Tensor) -> torch.Tensor:
self.eval()

# Denormalize the embeddings
z = z * self.std + self.mean

print("stats before decoding : ",z.mean(), z.std())

# Decode using VAE model
z = z.reshape(-1, 32, 4, 4)
val = self.vae_model.decode(z)
return val

Expand Down
7 changes: 1 addition & 6 deletions examples/imgnet_example/imnet_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,8 @@ def on_callback(
for domain, tensor in domains.items():
if domain == "image_latents":
pl_module.domain_mods["image_latents"].vae_model.eval()
mean, std = pl_module.domain_mods["image_latents"].mean, pl_module.domain_mods["image_latents"].std
latent_groups[domain_group][domain] = (pl_module.domain_mods["image_latents"].vae_model.encode(tensor)[0].flatten(start_dim=1) - mean) / std
latent_groups[domain_group][domain] = pl_module.domain_mods["image_latents"].vae_model.encode(tensor)[0].flatten(start_dim=1)

print("stats before predicting : ",latent_groups[domain_group][domain].mean(), latent_groups[domain_group][domain].std())

#mu, log_var = pl_module.domain_mods["image_latents"].vae_model.encode(tensor)
#print("stats before forwardbyhand : ",mu.mean(), mu.std())
self.log_samples(loggers[0] if loggers else None, pl_module, pl_module.domain_mods["image_latents"].vae_model(tensor)[0], domain, "forward_by_hand", trainer)
selection_mod = SingleDomainSelection()

Expand Down
75 changes: 24 additions & 51 deletions examples/imgnet_example/inference_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,23 @@ def __init__(self, in_dim: int, hidden_dim: int, out_dim: int, n_layers: int, dr
class dropout_GWEncoder(dropout_GWDecoder):
def __init__(self, in_dim: int, hidden_dim: int, out_dim: int, n_layers: int, dropout_rate: float = 0.5):
super().__init__(in_dim, hidden_dim, out_dim, n_layers, dropout_rate)
def print_shapes(nested_dict, indent=0):
"""Recursively print the shapes of tensors in a nested dictionary."""
for key, value in nested_dict.items():
print(' ' * indent + str(key) + ':')
if isinstance(value, dict):
print_shapes(value, indent + 2)
else:
print(' ' * (indent + 2) + str(value.shape))

# Define inference function
def inference_gw():
batch_size = 2056

# Prepare data modules from the dataset script
data = make_datamodule(batch_size)

# Initialize the domain modules with the specific latent dimensions and model parameters
image_domain = ImageDomain(latent_dim=512)
image_domain = ImageDomain(latent_dim=384)
text_domain = TextDomain(latent_dim=384)

domain_mods = {
Expand Down Expand Up @@ -118,7 +125,7 @@ def inference_gw():
bge_model = SentenceTransformer("BAAI/bge-small-en-v1.5")

# Load the model checkpoint
CHECKPOINT_PATH = "/home/rbertin/cleaned/git_synced/shimmer/examples/imgnet_example/wandb_output_bigger_vae/simple_shapes_fusion/7t80u3rm/checkpoints/epoch=499-step=312000.ckpt"
CHECKPOINT_PATH = "/home/rbertin/cleaned/git_synced/shimmer/examples/imgnet_example/wandb_output_bigger_vae_regularized/simple_shapes_fusion/w2syv1ph/checkpoints/epoch=122-step=76752.ckpt"
checkpoint = torch.load(CHECKPOINT_PATH)

# Assuming the model has a method to load from checkpoint state_dict
Expand Down Expand Up @@ -172,26 +179,11 @@ def normalize_embedding(embedding, mean, std):
# Process the textual train samples through the model
fig, axes = plt.subplots(2, 5, figsize=(30, 18))

# Move mean and std tensors to device
mean_tensor = torch.tensor(mean_tensor).to('cuda')
std_tensor = torch.tensor(std_tensor).to('cuda')
# Set the directory for the dataset
val_dir = os.environ.get('DATASET_DIR', '.') + '/imagenet/train'
imagenet_val_dataset = ImageFolder(root=val_dir, transform=transform)

# Randomly select five indices from the length of the dataset
random_indices = random.sample(range(len(imagenet_val_dataset)), 5)

# Initialize the DataLoader
data_loader = DataLoader(imagenet_val_dataset, batch_size=1, shuffle=False)

# Process the textual train samples through the model
fig, axes = plt.subplots(2, 5, figsize=(30, 18))

# Move mean and std tensors to device
mean_tensor = torch.tensor(mean_tensor).to('cuda')
std_tensor = torch.tensor(std_tensor).to('cuda')

# Load image latents
IMAGE_LATENTS_PATH_VAL = (
"/home/rbertin/pyt_scripts/full_imgnet/full_size/vae_bigdisc_goodbeta_50ep"
"/combined_standardized_embeddings.npy"
Expand All @@ -201,40 +193,21 @@ def normalize_embedding(embedding, mean, std):
image_latents_val = np.load(IMAGE_LATENTS_PATH_VAL)
image_latents_val = torch.tensor(image_latents_val).to('cuda')

for i, image_index in tqdm(enumerate(random_indices)):
# Get the image and target at the selected index
image, _ = imagenet_val_dataset[image_index]

# Get the image latents at the same index as the original image in the dataset
file_latents = image_latents_val[image_index].unsqueeze(0).to('cuda')
print("file_latents stats before modif : ", file_latents.mean(), file_latents.std())
file_latents = (file_latents * global_workspace.domain_mods["image_latents"].std.to(file_latents.device)) - global_workspace.domain_mods["image_latents"].mean.to(file_latents.device)

print("image.shape : ", image.shape)
encoded = image_domain.vae_model.encode(image.unsqueeze(0).to('cuda'))[0]

print("difference : ", (file_latents - encoded.flatten(start_dim=1)).mean())
print("encoded stats : ", encoded.mean(), encoded.std())
print("file_latents stats : ", file_latents.mean(), file_latents.std())

# Decode image latents to image space using the VAE model
image_output_1 = global_workspace.domain_mods["image_latents"].vae_model.decode(file_latents.reshape(-1, 32, 4, 4))
image_output_2 = global_workspace.domain_mods["image_latents"].vae_model.decode(encoded.reshape(-1, 32, 4, 4))

# Plot the demi-cycled image output tensor
axes[0, i].imshow(image_output_1.squeeze().permute(1, 2, 0).cpu().detach().numpy())
axes[0, i].set_title("Demi-cycled Image", fontsize=10, pad=10)
axes[0, i].axis('off')
# Create random latents domain groups for testing
random_latent_domains = {
frozenset(["image_latents"]): {
"image_latents": torch.rand((1, 384)).to(device) # Example random tensor
},
frozenset(["caption_embeddings"]): {
"caption_embeddings": torch.rand((1, 384)).to(device) # Example random tensor
}
}

# Plot the original image
image = image.to('cuda')
axes[1, i].imshow(image_output_2.squeeze().permute(1, 2, 0).cpu().detach().numpy())
axes[1, i].set_title("Original Image", fontsize=10, pad=10)
axes[1, i].axis('off')
# Call the forward function on random samples
gw_predictions = global_workspace.forward(random_latent_domains)

plt.subplots_adjust(top=0.85, wspace=0.4) # Adjust top margin and width spacing
plt.tight_layout(rect=[0, 0, 1, 0.95]) # Adjust rect parameter to make space for titles
plt.savefig("demi_cycle_plot_val.png")
# Print the shapes of the output tensors
print_shapes(gw_predictions)

if __name__ == "__main__":
inference_gw()
39 changes: 19 additions & 20 deletions examples/imgnet_example/noise_regularisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from torchvision.datasets import ImageFolder
from torchvision import transforms

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

print("device : ",device)

class ResidualBlock(nn.Module):
def __init__(self, channels):
Expand All @@ -31,19 +34,15 @@ def forward(self, x):
return F.relu(x + self.conv(x), inplace=True)


device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

print("device : ",device)

class VanillaVAE(nn.Module):
def __init__(self, in_channels: int, latent_dim: int, hidden_dims=None, beta=1.0, upsampling='bilinear', loss_type='lpips'):
super(VanillaVAE, self).__init__()
self.lpips_model = lpips.LPIPS(net='vgg', lpips=False) if loss_type == 'lpips' else None
self.latent_dim = latent_dim
hidden_dims = hidden_dims or [64, 128, 128, 256, 256, 512]
self.beta = beta
self.upsampling = upsampling
self.loss_type = loss_type
hidden_dims = hidden_dims or [128, 256, 512, 256, 128, 64]

# Encoder setup
modules = []
Expand All @@ -54,31 +53,31 @@ def __init__(self, in_channels: int, latent_dim: int, hidden_dims=None, beta=1.0
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
modules.append(ResidualBlock(h_dim))
modules.append(ResidualBlock(h_dim)) # Add a residual block after each conv layer
in_channels = h_dim

modules.append(torch.nn.Tanh())
self.encoder = nn.Sequential(*modules)

# Assuming a 4x4 spatial size at the bottleneck, this can be adjusted depending on your input size
self.fc_mu = nn.Conv2d(hidden_dims[-1], 32, kernel_size=3, padding=1) # bottleneck will be 32,4,4 featuremaps
self.fc_var = nn.Conv2d(hidden_dims[-1], 32, kernel_size=3, padding=1)
self.fc_mu = nn.Linear(hidden_dims[-1]*4*4, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1]*4*4, latent_dim)

# No need for a decoder input reshape since we're keeping spatial dimensions
self.decoder_input = nn.Conv2d(32, hidden_dims[-1], 3,1, 1)#get back to hidden_dims[-1 channels]
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4 * 4)

# Decoder setup
modules = []
hidden_dims.reverse()
hidden_dims.reverse() # Ensure hidden_dims is in the correct order for building up
for i in range(len(hidden_dims)-1):
final_layer = i == len(hidden_dims) - 2
modules.append(self._deconv_block(hidden_dims[i], hidden_dims[i + 1], final=final_layer))

modules.append(nn.Upsample(size=(224, 224), mode=upsampling))
final_layer = i == len(hidden_dims) - 2 # Check if it's the layer before the last
modules.append(self._deconv_block(hidden_dims[i], hidden_dims[i + 1],final=final_layer))

# Add an explicit Upsample to 224x224 as the last upsampling step
modules.append(nn.Upsample(size=(224, 224), mode=upsampling)) # Adjust mode as needed

# Final convolution to produce the output image
modules.append(nn.Sequential(
nn.Conv2d(hidden_dims[-2], 3, kernel_size=3, padding=1),
nn.Tanh()
))

self.decoder = nn.Sequential(*modules)

def _deconv_block(self, in_channels, out_channels, final=False):
Expand All @@ -100,11 +99,11 @@ def _deconv_block(self, in_channels, out_channels, final=False):


def encode(self, input):
result = self.encoder(input)
result = torch.flatten(self.encoder(input), start_dim=1)
return self.fc_mu(result), self.fc_var(result)

def decode(self, z):
result = self.decoder_input(z)
result = self.decoder_input(z).view(-1, 512, 4,4)
return self.decoder(result)

def reparameterize(self, mu, logvar):
Expand Down
Loading

0 comments on commit d604342

Please sign in to comment.