diff --git a/cyto_dl/models/base_model.py b/cyto_dl/models/base_model.py index 08c76e8eb..3dedcba35 100644 --- a/cyto_dl/models/base_model.py +++ b/cyto_dl/models/base_model.py @@ -61,9 +61,7 @@ def __call__(cls, *args, **kwargs): # instantiate class with instantiated `init_args` # hydra doesn't change the original dict, so we can use it after this # with `save_hyperparameters` - obj = type.__call__( - cls, **instantiate(init_args, _recursive_=True, _convert_=True) - ) + obj = type.__call__(cls, **instantiate(init_args, _recursive_=True, _convert_=True)) # make sure only primitives get stored in the ckpt ignore = [arg for arg, v in init_args.items() if not _is_primitive(v)] diff --git a/cyto_dl/models/vae/base_vae.py b/cyto_dl/models/vae/base_vae.py index 830cda14c..b388ba152 100644 --- a/cyto_dl/models/vae/base_vae.py +++ b/cyto_dl/models/vae/base_vae.py @@ -171,9 +171,7 @@ def calculate_rcl_dict(self, x, xhat, z): rcl_per_input_dimension = {} rcl_reduced = {} for key in xhat.keys(): - rcl_per_input_dimension[key] = self.reconstruction_loss[key]( - xhat[key], x[key] - ) + rcl_per_input_dimension[key] = self.reconstruction_loss[key](xhat[key], x[key]) if len(rcl_per_input_dimension[key].shape) > 0: rcl = ( rcl_per_input_dimension[key] @@ -191,12 +189,9 @@ def calculate_rcl_dict(self, x, xhat, z): def calculate_elbo(self, x, xhat, z): rcl_reduced = self.calculate_rcl_dict(x, xhat, z) kld_per_part = { - part: prior(z[part], mode="kl", reduction="none") - for part, prior in self.prior.items() - } - kld_per_part_summed = { - part: kl.sum(dim=-1).mean() for part, kl in kld_per_part.items() + part: prior(z[part], mode="kl", reduction="none") for part, prior in self.prior.items() } + kld_per_part_summed = {part: kl.sum(dim=-1).mean() for part, kl in kld_per_part.items()} total_kld = sum(kld_per_part_summed.values()) total_recon = sum(rcl_reduced.values()) @@ -217,9 +212,7 @@ def sample_z(self, z_parts_params, inference=False): z = {} for part, part_params in z_parts_params.items(): if part in self.prior: - z[part] = self.prior[part]( - part_params, mode="sample", inference=inference - ) + z[part] = self.prior[part](part_params, mode="sample", inference=inference) else: # if prior for this part isn't in the dict, assume dirac prior # i.e. just return the params, and it won't contribute to kl @@ -249,9 +242,7 @@ def decode(self, z): for part, decoder in self.decoder.items() } - def forward( - self, batch, decode=False, inference=True, return_params=False, **kwargs - ): + def forward(self, batch, decode=False, inference=True, return_params=False, **kwargs): is_inference = inference or not self.training z_params = self.encode(batch, **kwargs) diff --git a/cyto_dl/models/vae/image_encoder.py b/cyto_dl/models/vae/image_encoder.py index 7469abac2..b8e75c2a9 100644 --- a/cyto_dl/models/vae/image_encoder.py +++ b/cyto_dl/models/vae/image_encoder.py @@ -45,9 +45,7 @@ def __init__( self.num_res_units = num_res_units if group not in ("so2", "so3", None): - raise ValueError( - f"`gspace` should be one of ('so2', 'so3', None). Got {group!r}" - ) + raise ValueError(f"`gspace` should be one of ('so2', 'so3', None). Got {group!r}") if group == "so2": self.gspace = ( @@ -60,9 +58,7 @@ def __init__( raise ValueError("The SO3 group only works for spatial_dims=3") self.gspace = gspaces.rot3dOnR3(maximum_frequency=maximum_frequency) else: - self.gspace = ( - gspaces.trivialOnR2() if spatial_dims == 2 else gspaces.trivialOnR3() - ) + self.gspace = gspaces.trivialOnR2() if spatial_dims == 2 else gspaces.trivialOnR3() self.in_type = nn.FieldType(self.gspace, [self.gspace.trivial_repr]) @@ -286,9 +282,7 @@ def __init__( scalar_fields = nn.FieldType(gspace, out_channels * [gspace.trivial_repr]) if type(group).__name__ in ("SO2", "SO3"): - vector_fields = nn.FieldType( - gspace, out_vector_channels * [gspace.irrep(1)] - ) + vector_fields = nn.FieldType(gspace, out_vector_channels * [gspace.irrep(1)]) out_type = scalar_fields + vector_fields else: vector_fields = [] diff --git a/cyto_dl/models/vae/image_vae.py b/cyto_dl/models/vae/image_vae.py index e65653b9b..6e8d95c25 100644 --- a/cyto_dl/models/vae/image_vae.py +++ b/cyto_dl/models/vae/image_vae.py @@ -127,9 +127,7 @@ def __init__( assert len(_strides) + 1 == len(_channels) decode_blocks = [] - for i, (s, c_in, c_out) in enumerate( - zip(_strides, _channels[:-1], _channels[1:]) - ): + for i, (s, c_in, c_out) in enumerate(zip(_strides, _channels[:-1], _channels[1:])): last_block = i + 1 == len(_strides) size = None if not last_block else in_shape @@ -164,9 +162,7 @@ def __init__( decode_blocks.append(nn.Sequential(upsample, res)) - init_shape = ( - self.final_size if decoder_initial_shape is None else decoder_initial_shape - ) + init_shape = self.final_size if decoder_initial_shape is None else decoder_initial_shape first_upsample = nn.Sequential( nn.Linear(latent_dim, _channels[0] * int(np.product(init_shape))), @@ -204,9 +200,7 @@ def __init__( ) if group is not None: - self.rotation_module = RotationModule( - group, spatial_dims, background_value, eps - ) + self.rotation_module = RotationModule(group, spatial_dims, background_value, eps) else: self.rotation_module = None diff --git a/cyto_dl/models/vae/point_cloud_vae.py b/cyto_dl/models/vae/point_cloud_vae.py index dc04d8f0d..10a4e2945 100644 --- a/cyto_dl/models/vae/point_cloud_vae.py +++ b/cyto_dl/models/vae/point_cloud_vae.py @@ -8,11 +8,10 @@ from cyto_dl.models.vae.base_vae import BaseVAE from cyto_dl.models.vae.priors import IdentityPrior, IsotropicGaussianPrior -from cyto_dl.nn.losses import ChamferLoss -from cyto_dl.nn.point_cloud import DGCNN, FoldingNet # from topologylayer.nn import AlphaLayer, BarcodePolyFeature -from cyto_dl.nn.losses import TopoLoss +from cyto_dl.nn.losses import ChamferLoss, TopoLoss +from cyto_dl.nn.point_cloud import DGCNN, FoldingNet Array = Union[torch.Tensor, np.ndarray, Sequence[float]] logger = logging.getLogger("lightning") @@ -248,13 +247,9 @@ def decode(self, z_parts, return_canonical=False, batch=None): batch[self.point_label], z_parts["grid_feats"] ) else: - base_xhat = self.decoder[self.hparams.x_label]( - z_parts[self.hparams.x_label] - ) + base_xhat = self.decoder[self.hparams.x_label](z_parts[self.hparams.x_label]) else: - base_xhat = self.decoder[self.hparams.x_label]( - z_parts[self.hparams.x_label] - ) + base_xhat = self.decoder[self.hparams.x_label](z_parts[self.hparams.x_label]) if self.get_rotation: rotation = z_parts["rotation"] @@ -274,9 +269,7 @@ def encoder_compose_function(self, z_parts, batch): if self.basal_head: z_parts[self.hparams.x_label + "_basal"] = z_parts[self.hparams.x_label] for key in self.basal_head.keys(): - z_parts[key] = self.basal_head[key]( - z_parts[self.hparams.x_label + "_basal"] - ) + z_parts[key] = self.basal_head[key](z_parts[self.hparams.x_label + "_basal"]) if self.condition_keys: for j, key in enumerate([self.hparams.x_label] + self.condition_keys): @@ -294,18 +287,16 @@ def encoder_compose_function(self, z_parts, batch): if f"{key}" in self.mask_keys: # mask is 1 for batch elements to mask, 0 otherwise this_mask = ( - batch[f"{key}_mask"] - .byte() - .repeat(1, this_z_parts.shape[-1]) + batch[f"{key}_mask"].byte().repeat(1, this_z_parts.shape[-1]) ) # multiply inverse mask with batch part, so every mask element of 1 is set to 0 this_z_parts = this_z_parts * ~this_mask.bool() cond_feats = torch.cat((cond_feats, this_z_parts), dim=1) # shared encoder - z_parts[self.hparams.x_label] = self.condition_encoder[ - self.hparams.x_label - ](cond_feats) + z_parts[self.hparams.x_label] = self.condition_encoder[self.hparams.x_label]( + cond_feats + ) if self.embedding_head: for key in self.embedding_head.keys(): z_parts[key] = self.embedding_head[key](z_parts[self.hparams.x_label]) @@ -322,9 +313,7 @@ def decoder_compose_function(self, z_parts, batch): # if mask, then mask this batch part if f"{key}" in self.mask_keys: this_mask = ( - batch[f"{key}_mask"] - .byte() - .repeat(1, this_batch_part.shape[-1]) + batch[f"{key}_mask"].byte().repeat(1, this_batch_part.shape[-1]) ) # multiply inverse mask with batch part, so every mask element of 1 is set to 0 this_batch_part = this_batch_part * ~this_mask.bool() @@ -339,9 +328,9 @@ def decoder_compose_function(self, z_parts, batch): # ) cond_feats = torch.cat((cond_inputs, z_parts[self.hparams.x_label]), dim=1) # shared decoder - z_parts[self.hparams.x_label] = self.condition_decoder[ - self.hparams.x_label - ](cond_feats) + z_parts[self.hparams.x_label] = self.condition_decoder[self.hparams.x_label]( + cond_feats + ) return z_parts def calculate_rcl(self, batch, xhat, input_key, target_key=None): @@ -383,15 +372,15 @@ def calculate_rcl_dict(self, batch, xhat, z): if self.embedding_head_loss: for key in self.embedding_head_loss.keys(): - rcl_reduced[key] = self.embedding_head_weight[ - key - ] * self.embedding_head_loss[key](z[key], x[key]) + rcl_reduced[key] = self.embedding_head_weight[key] * self.embedding_head_loss[key]( + z[key], batch[key] + ) if self.basal_head_loss: for key in self.basal_head_loss.keys(): - rcl_reduced[key] = self.basal_head_weight[key] * self.basal_head_loss[ - key - ](z[key], x[key]) + rcl_reduced[key] = self.basal_head_weight[key] * self.basal_head_loss[key]( + z[key], batch[key] + ) # if (self.include_top_loss) & (self.current_epoch > 10): if self.include_top_loss: @@ -432,18 +421,12 @@ def parse_batch(self, batch): if j == 0: batch["target"] = batch[f"{key}"] batch["target_mask"] = ( - torch.zeros(batch[f"{key}"].shape) - .bernoulli_(this_mask) - .byte() + torch.zeros(batch[f"{key}"].shape).bernoulli_(this_mask).byte() ) else: - batch["target"] = torch.cat( - [batch["target"], batch[f"{key}"]], dim=1 - ) + batch["target"] = torch.cat([batch["target"], batch[f"{key}"]], dim=1) this_part_mask = ( - torch.zeros(batch[f"{key}"].shape) - .bernoulli_(this_mask) - .byte() + torch.zeros(batch[f"{key}"].shape).bernoulli_(this_mask).byte() ) batch["target_mask"] = torch.cat( [batch["target_mask"], this_part_mask], dim=1