diff --git a/cyto_dl/models/vae/base_vae.py b/cyto_dl/models/vae/base_vae.py index d2eda552c..b388ba152 100644 --- a/cyto_dl/models/vae/base_vae.py +++ b/cyto_dl/models/vae/base_vae.py @@ -139,7 +139,10 @@ def __init__( self.latent_dim = latent_dim if decoder_latent_parts is None: - self.decoder_latent_parts = {key: self.prior.keys() for key in self.decoder.keys()} + pass + # self.decoder_latent_parts = { + # key: self.prior.keys() for key in self.decoder.keys() + # } else: self.decoder_latent_parts = decoder_latent_parts for key in self.decoder.keys(): @@ -157,9 +160,11 @@ def __init__( def calculate_rcl(self, x, xhat, input_key, target_key=None): if not target_key: target_key = input_key + rcl_per_input_dimension = self.reconstruction_loss[input_key]( x[target_key], xhat[input_key] ) + return rcl_per_input_dimension def calculate_rcl_dict(self, x, xhat, z): @@ -190,6 +195,10 @@ def calculate_elbo(self, x, xhat, z): total_kld = sum(kld_per_part_summed.values()) total_recon = sum(rcl_reduced.values()) + if len(total_recon.shape) > 0: + total_recon = total_recon.mean() + for key in rcl_reduced.keys(): + rcl_reduced[key] = rcl_reduced[key].mean() return ( total_recon + self.beta * total_kld, diff --git a/cyto_dl/models/vae/image_encoder.py b/cyto_dl/models/vae/image_encoder.py index 1374c43c7..b8e75c2a9 100644 --- a/cyto_dl/models/vae/image_encoder.py +++ b/cyto_dl/models/vae/image_encoder.py @@ -123,6 +123,8 @@ def make_block( subunits=self.num_res_units, bias=bias, ) + if padding is None: + padding = same_padding(kernel_size) return Convolution( spatial_dims=self.spatial_dims, diff --git a/cyto_dl/models/vae/image_vae.py b/cyto_dl/models/vae/image_vae.py index 5d65d398f..6e8d95c25 100644 --- a/cyto_dl/models/vae/image_vae.py +++ b/cyto_dl/models/vae/image_vae.py @@ -34,6 +34,7 @@ def forward(self, x): class ImageVAE(BaseVAE): def __init__( self, + x_label: str, latent_dim: int, spatial_dims: int, in_shape: Sequence[int], @@ -61,14 +62,15 @@ def __init__( num_res_units: int = 2, up_kernel_size: int = 3, first_conv_padding_mode: str = "replicate", - encoder_padding: Optional[Union[int, Sequence[int]]] = None, eps: float = 1e-8, + encoder_padding: Optional[Union[int, Sequence[int]]] = None, + metric_keys: Optional[list] = None, **base_kwargs, ): in_channels, *in_shape = in_shape self.out_channels = out_channels if out_channels is not None else in_channels - + self.x_label = x_label self.spatial_dims = spatial_dims self.final_size = np.asarray(in_shape, dtype=int) self.up_kernel_size = up_kernel_size @@ -207,6 +209,8 @@ def __init__( decoder=decoder, latent_dim=latent_dim, prior=prior, + x_label=x_label, + metric_keys=metric_keys, **base_kwargs, ) diff --git a/cyto_dl/models/vae/point_cloud_vae.py b/cyto_dl/models/vae/point_cloud_vae.py index 909319e09..10a4e2945 100644 --- a/cyto_dl/models/vae/point_cloud_vae.py +++ b/cyto_dl/models/vae/point_cloud_vae.py @@ -8,7 +8,9 @@ from cyto_dl.models.vae.base_vae import BaseVAE from cyto_dl.models.vae.priors import IdentityPrior, IsotropicGaussianPrior -from cyto_dl.nn.losses import ChamferLoss + +# from topologylayer.nn import AlphaLayer, BarcodePolyFeature +from cyto_dl.nn.losses import ChamferLoss, TopoLoss from cyto_dl.nn.point_cloud import DGCNN, FoldingNet Array = Union[torch.Tensor, np.ndarray, Sequence[float]] @@ -65,11 +67,27 @@ def __init__( condition_encoder: Optional[dict] = None, condition_decoder: Optional[dict] = None, condition_keys: Optional[list] = None, + mask_keys: Optional[list] = None, + masking_ratio: Optional[float] = None, disable_metrics: Optional[bool] = False, + metric_keys: Optional[list] = None, + include_top_loss: Optional[bool] = False, + topo_lambda: Optional[float] = None, + topo_num_groups: Optional[int] = None, + farthest_point: Optional[bool] = True, + inference_mask_dict: Optional[dict] = None, + target_key: Optional[list] = None, + target_mask_keys: Optional[list] = None, + parse: Optional[bool] = False, + mean: Optional[bool] = True, + freeze_encoder: Optional[bool] = False, **base_kwargs, ): self.get_rotation = get_rotation self.symmetry_breaking_axis = symmetry_breaking_axis + self.target_key = target_key + self.target_mask_keys = target_mask_keys + self.metric_keys = metric_keys self.scalar_inds = scalar_inds self.decoder_type = decoder_type self.generate_grid_feats = generate_grid_feats @@ -83,6 +101,14 @@ def __init__( self.basal_head_loss = basal_head_loss self.basal_head_weight = basal_head_weight self.disable_metrics = disable_metrics + self.include_top_loss = include_top_loss + self.topo_lambda = topo_lambda + self.topo_num_groups = topo_num_groups + self.farthest_point = farthest_point + self.parse = parse + self.mask_keys = mask_keys + self.masking_ratio = masking_ratio + self.freeze_encoder = freeze_encoder if embedding_prior == "gaussian": self.encoder_out_size = 2 * latent_dim @@ -154,6 +180,7 @@ def __init__( optimizer=optimizer, prior=prior, disable_metrics=disable_metrics, + metric_keys=metric_keys, ) self.condition_encoder = nn.ModuleDict(condition_encoder) @@ -162,6 +189,56 @@ def __init__( self.embedding_head_loss = nn.ModuleDict(embedding_head_loss) self.basal_head = nn.ModuleDict(basal_head) self.basal_head_loss = nn.ModuleDict(basal_head_loss) + self.inference_mask_dict = inference_mask_dict + self.target_label = None + self.mean = mean + if self.include_top_loss: + # self.top_layer = nn.ModuleDict({x_label: AlphaLayer(maxdim=1)}) + # self.top_loss = nn.ModuleDict({x_label: BarcodePolyFeature(1,2,0)}) + # self.top_loss = nn.ModuleDict({x_label: TopoLoss(topo_lambda=0.1)}) # 0.1 works well for earthmovers + # self.top_loss = nn.ModuleDict({x_label: TopoLoss(topo_lambda=0.1, farthest_point=True, num_groups=256)}) + # self.top_loss = nn.ModuleDict({x_label: TopoLoss(topo_lambda=10, farthest_point=True, num_groups=256)}) + self.top_loss = nn.ModuleDict( + { + x_label: TopoLoss( + topo_lambda=self.topo_lambda, + farthest_point=self.farthest_point, + num_groups=self.topo_num_groups, + mean=self.mean, + ) + } + ) + # self.top_loss = nn.ModuleDict( + # { + # x_label: TopoLoss( + # topo_lambda=0.001, + # farthest_point=True, + # num_groups=256, + # mean=self.mean, + # ) + # } + # ) + + if freeze_encoder: + for part, encoder in self.encoder.items(): + for param in self.encoder[part].parameters(): + param.requires_grad = False + + def encode(self, batch, **kwargs): + ret_dict = {} + for part, encoder in self.encoder.items(): + this_batch_part = batch[part] + this_ret = encoder( + this_batch_part, + **{k: v for k, v in kwargs.items() if k in self.encoder_args[part]}, + ) + + if isinstance(this_ret, dict): # deal with multiple outputs for an encoder + for key in this_ret.keys(): + ret_dict[key] = this_ret[key] + else: + ret_dict[part] = this_ret + return ret_dict def decode(self, z_parts, return_canonical=False, batch=None): if hasattr(self.encoder[self.hparams.x_label], "generate_grid_feats"): @@ -187,7 +264,8 @@ def decode(self, z_parts, return_canonical=False, batch=None): return {self.hparams.x_label: xhat} - def encoder_compose_function(self, z_parts): + def encoder_compose_function(self, z_parts, batch): + batch_size = z_parts[self.hparams.x_label].shape[0] if self.basal_head: z_parts[self.hparams.x_label + "_basal"] = z_parts[self.hparams.x_label] for key in self.basal_head.keys(): @@ -200,10 +278,21 @@ def encoder_compose_function(self, z_parts): this_z_parts = torch.squeeze(z_parts[key], dim=(-1)) z_parts[key] = this_z_parts # this_z_parts = this_z_parts.argmax(dim=1) + this_z_parts = this_z_parts.view(batch_size, -1) if j == 0: cond_feats = this_z_parts else: + if self.mask_keys: + # if mask, then mask this batch part + if f"{key}" in self.mask_keys: + # mask is 1 for batch elements to mask, 0 otherwise + this_mask = ( + batch[f"{key}_mask"].byte().repeat(1, this_z_parts.shape[-1]) + ) + # multiply inverse mask with batch part, so every mask element of 1 is set to 0 + this_z_parts = this_z_parts * ~this_mask.bool() cond_feats = torch.cat((cond_feats, this_z_parts), dim=1) + # shared encoder z_parts[self.hparams.x_label] = self.condition_encoder[self.hparams.x_label]( cond_feats @@ -215,27 +304,57 @@ def encoder_compose_function(self, z_parts): return z_parts def decoder_compose_function(self, z_parts, batch): - # import ipdb - # ipdb.set_trace() + # if (self.condition_keys is not None) & (len(self.condition_decoder.keys()) != 0): if self.condition_keys: for j, key in enumerate(self.condition_keys): + this_batch_part = batch[key] + this_batch_part = this_batch_part.view(this_batch_part.shape[0], -1) + if self.mask_keys: + # if mask, then mask this batch part + if f"{key}" in self.mask_keys: + this_mask = ( + batch[f"{key}_mask"].byte().repeat(1, this_batch_part.shape[-1]) + ) + # multiply inverse mask with batch part, so every mask element of 1 is set to 0 + this_batch_part = this_batch_part * ~this_mask.bool() + if j == 0: - cond_inputs = batch[key] + cond_inputs = this_batch_part # cond_inputs = torch.squeeze(batch[key], dim=(-1)) else: - cond_inputs = torch.cat((cond_inputs, batch[key]), dim=1) - cond_feats = torch.cat((cond_inputs, z_parts[self.hparams.x_label]), dim=1) + cond_inputs = torch.cat((cond_inputs, this_batch_part), dim=1) + # cond_feats = torch.cat( + # (z_parts[self.hparams.x_label],cond_inputs), dim=1 + # ) + cond_feats = torch.cat((cond_inputs, z_parts[self.hparams.x_label]), dim=1) # shared decoder z_parts[self.hparams.x_label] = self.condition_decoder[self.hparams.x_label]( cond_feats ) return z_parts - def calculate_rcl_dict(self, x, xhat, z): + def calculate_rcl(self, batch, xhat, input_key, target_key=None): + if not target_key: + target_key = input_key + # import ipdb + # ipdb.set_trace() + rcl_per_input_dimension = self.reconstruction_loss[input_key]( + batch[target_key], xhat[input_key] + ) + + if (self.mask_keys is not None) and (self.target_mask_keys is not None): + this_mask = batch["target_mask"].type_as(rcl_per_input_dimension).byte() + rcl_per_input_dimension = rcl_per_input_dimension * ~this_mask.bool() + + return rcl_per_input_dimension + + def calculate_rcl_dict(self, batch, xhat, z): rcl_per_input_dimension = {} rcl_reduced = {} for key in xhat.keys(): - rcl_per_input_dimension[key] = self.calculate_rcl(x, xhat, key, self.occupancy_label) + rcl_per_input_dimension[key] = self.calculate_rcl( + batch, xhat, key, self.target_label # used to be self.occupancy label + ) if len(rcl_per_input_dimension[key].shape) > 0: rcl = ( rcl_per_input_dimension[key] @@ -244,43 +363,110 @@ def calculate_rcl_dict(self, x, xhat, z): # and sum across each batch element's dimensions .sum(dim=1) ) - - rcl_reduced[key] = rcl.mean() + if self.mean: + rcl_reduced[key] = rcl.mean() + else: + rcl_reduced[key] = rcl else: rcl_reduced[key] = rcl_per_input_dimension[key] if self.embedding_head_loss: for key in self.embedding_head_loss.keys(): rcl_reduced[key] = self.embedding_head_weight[key] * self.embedding_head_loss[key]( - z[key], x[key] + z[key], batch[key] ) if self.basal_head_loss: for key in self.basal_head_loss.keys(): rcl_reduced[key] = self.basal_head_weight[key] * self.basal_head_loss[key]( - z[key], x[key] + z[key], batch[key] ) + + # if (self.include_top_loss) & (self.current_epoch > 10): + if self.include_top_loss: + for key in self.top_loss.keys(): + # top_losses = [] + # for i in range(xhat[key].shape[0]): + # top_losses.append(self.top_loss[key](self.top_layer[ + # key + # ](xhat[key][i]))) + # rcl_reduced['top'] = torch.stack(top_losses).mean() + rcl_reduced["top"] = self.top_loss[key](batch[key], xhat[key]) + return rcl_reduced - def forward(self, batch, decode=False, inference=True, return_params=False): - is_inference = inference or not self.training + def parse_batch(self, batch): + if self.parse: + for key in batch.keys(): + if len(batch[key].shape) == 1: + batch[key] = batch[key].unsqueeze(dim=-1) + + if self.mask_keys is not None: + for key in self.mask_keys: + C = batch[key] + if self.inference_mask_dict: + this_mask = self.inference_mask_dict[key] + else: + this_mask = 0 + # get random mask and save to batch + C_mask = torch.zeros(C.shape[0]).bernoulli_(this_mask).byte() + batch[f"{key}_mask"] = C_mask.unsqueeze(dim=-1).float().type_as(C) + + if self.target_key is not None: + for j, key in enumerate(self.target_key): + this_mask = 0 + if key in self.target_mask_keys: + this_mask = 1 + + if j == 0: + batch["target"] = batch[f"{key}"] + batch["target_mask"] = ( + torch.zeros(batch[f"{key}"].shape).bernoulli_(this_mask).byte() + ) + else: + batch["target"] = torch.cat([batch["target"], batch[f"{key}"]], dim=1) + this_part_mask = ( + torch.zeros(batch[f"{key}"].shape).bernoulli_(this_mask).byte() + ) + batch["target_mask"] = torch.cat( + [batch["target_mask"], this_part_mask], dim=1 + ) + + self.target_label = "target" + else: + self.target_label = self.hparams.x_label + + return batch + def get_embeddings(self, batch, inference=True): + # torch.isnan(z_params['pcloud']).any() + batch = self.parse_batch(batch) z_params = self.encode(batch, get_rotation=self.get_rotation) - z_params = self.encoder_compose_function(z_params) + z_params = self.encoder_compose_function(z_params, batch) z = self.sample_z(z_params, inference=inference) - z = self.decoder_compose_function(z, batch) - - if not decode: - return z + return z, z_params + def decode_embeddings(self, z, batch, decode=True, return_canonical=False): + z = self.decoder_compose_function(z, batch) if hasattr(self.encoder[self.hparams.x_label], "generate_grid_feats"): if self.encoder[self.hparams.x_label].generate_grid_feats: - xhat = self.decode(z, batch=batch) + xhat = self.decode(z, return_canonical=return_canonical, batch=batch) else: - xhat = self.decode(z) + xhat = self.decode(z, return_canonical=return_canonical) else: - xhat = self.decode(z) + xhat = self.decode(z, return_canonical=return_canonical) + + return xhat + + def forward(self, batch, decode=False, inference=True, return_params=False): + is_inference = inference or not self.training + z, z_params = self.get_embeddings(batch, inference) + + if not decode: + return z + + xhat = self.decode_embeddings(z, batch) if return_params: return xhat, z, z_params diff --git a/cyto_dl/nn/losses/weighted_mse_loss.py b/cyto_dl/nn/losses/weighted_mse_loss.py index 268ea350b..eee6ec9b4 100644 --- a/cyto_dl/nn/losses/weighted_mse_loss.py +++ b/cyto_dl/nn/losses/weighted_mse_loss.py @@ -6,13 +6,26 @@ class WeightedMSELoss(Loss): - def __init__(self, reduction="none", weights=1): + def __init__(self, reduction="none"): super().__init__(None, None, reduction) self.reduction = reduction - self.weights = torch.tensor(weights).unsqueeze(0) + # self.weights = torch.tensor(weights).unsqueeze(0) + self.bins = torch.linspace(-2, 2, steps=21) + self.bins = [(i, j) for i, j, in zip(self.bins, self.bins[1:])] + self.weights = list(torch.linspace(0, 100, steps=11)) + list( + torch.linspace(100, 0, steps=11) + ) def forward(self, input: Tensor, target: Tensor) -> Tensor: - loss = F.mse_loss(input, target, reduction="none") * self.weights + weights = torch.ones(*input.shape) + for j, bin in enumerate(self.bins): + bin_1 = bin[0] + bin_2 = bin[1] + this_mask = (input > bin_1) & (input < bin_2) + weights[this_mask] = self.weights[j] + + loss = F.mse_loss(input, target, reduction="none") + loss = loss * weights.type_as(loss) if self.reduction == "mean": loss = loss.mean(axis=1)