Skip to content

Commit

Permalink
run pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Ritvik Vasan committed Mar 6, 2024
1 parent dc30155 commit a4981a4
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 74 deletions.
4 changes: 1 addition & 3 deletions cyto_dl/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ def __call__(cls, *args, **kwargs):
# instantiate class with instantiated `init_args`
# hydra doesn't change the original dict, so we can use it after this
# with `save_hyperparameters`
obj = type.__call__(
cls, **instantiate(init_args, _recursive_=True, _convert_=True)
)
obj = type.__call__(cls, **instantiate(init_args, _recursive_=True, _convert_=True))

# make sure only primitives get stored in the ckpt
ignore = [arg for arg, v in init_args.items() if not _is_primitive(v)]
Expand Down
19 changes: 5 additions & 14 deletions cyto_dl/models/vae/base_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,7 @@ def calculate_rcl_dict(self, x, xhat, z):
rcl_per_input_dimension = {}
rcl_reduced = {}
for key in xhat.keys():
rcl_per_input_dimension[key] = self.reconstruction_loss[key](
xhat[key], x[key]
)
rcl_per_input_dimension[key] = self.reconstruction_loss[key](xhat[key], x[key])
if len(rcl_per_input_dimension[key].shape) > 0:
rcl = (
rcl_per_input_dimension[key]
Expand All @@ -191,12 +189,9 @@ def calculate_rcl_dict(self, x, xhat, z):
def calculate_elbo(self, x, xhat, z):
rcl_reduced = self.calculate_rcl_dict(x, xhat, z)
kld_per_part = {
part: prior(z[part], mode="kl", reduction="none")
for part, prior in self.prior.items()
}
kld_per_part_summed = {
part: kl.sum(dim=-1).mean() for part, kl in kld_per_part.items()
part: prior(z[part], mode="kl", reduction="none") for part, prior in self.prior.items()
}
kld_per_part_summed = {part: kl.sum(dim=-1).mean() for part, kl in kld_per_part.items()}

total_kld = sum(kld_per_part_summed.values())
total_recon = sum(rcl_reduced.values())
Expand All @@ -217,9 +212,7 @@ def sample_z(self, z_parts_params, inference=False):
z = {}
for part, part_params in z_parts_params.items():
if part in self.prior:
z[part] = self.prior[part](
part_params, mode="sample", inference=inference
)
z[part] = self.prior[part](part_params, mode="sample", inference=inference)
else:
# if prior for this part isn't in the dict, assume dirac prior
# i.e. just return the params, and it won't contribute to kl
Expand Down Expand Up @@ -249,9 +242,7 @@ def decode(self, z):
for part, decoder in self.decoder.items()
}

def forward(
self, batch, decode=False, inference=True, return_params=False, **kwargs
):
def forward(self, batch, decode=False, inference=True, return_params=False, **kwargs):
is_inference = inference or not self.training

z_params = self.encode(batch, **kwargs)
Expand Down
12 changes: 3 additions & 9 deletions cyto_dl/models/vae/image_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def __init__(
self.num_res_units = num_res_units

if group not in ("so2", "so3", None):
raise ValueError(
f"`gspace` should be one of ('so2', 'so3', None). Got {group!r}"
)
raise ValueError(f"`gspace` should be one of ('so2', 'so3', None). Got {group!r}")

if group == "so2":
self.gspace = (
Expand All @@ -60,9 +58,7 @@ def __init__(
raise ValueError("The SO3 group only works for spatial_dims=3")
self.gspace = gspaces.rot3dOnR3(maximum_frequency=maximum_frequency)
else:
self.gspace = (
gspaces.trivialOnR2() if spatial_dims == 2 else gspaces.trivialOnR3()
)
self.gspace = gspaces.trivialOnR2() if spatial_dims == 2 else gspaces.trivialOnR3()

self.in_type = nn.FieldType(self.gspace, [self.gspace.trivial_repr])

Expand Down Expand Up @@ -286,9 +282,7 @@ def __init__(

scalar_fields = nn.FieldType(gspace, out_channels * [gspace.trivial_repr])
if type(group).__name__ in ("SO2", "SO3"):
vector_fields = nn.FieldType(
gspace, out_vector_channels * [gspace.irrep(1)]
)
vector_fields = nn.FieldType(gspace, out_vector_channels * [gspace.irrep(1)])
out_type = scalar_fields + vector_fields
else:
vector_fields = []
Expand Down
12 changes: 3 additions & 9 deletions cyto_dl/models/vae/image_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,7 @@ def __init__(
assert len(_strides) + 1 == len(_channels)

decode_blocks = []
for i, (s, c_in, c_out) in enumerate(
zip(_strides, _channels[:-1], _channels[1:])
):
for i, (s, c_in, c_out) in enumerate(zip(_strides, _channels[:-1], _channels[1:])):
last_block = i + 1 == len(_strides)

size = None if not last_block else in_shape
Expand Down Expand Up @@ -164,9 +162,7 @@ def __init__(

decode_blocks.append(nn.Sequential(upsample, res))

init_shape = (
self.final_size if decoder_initial_shape is None else decoder_initial_shape
)
init_shape = self.final_size if decoder_initial_shape is None else decoder_initial_shape

first_upsample = nn.Sequential(
nn.Linear(latent_dim, _channels[0] * int(np.product(init_shape))),
Expand Down Expand Up @@ -204,9 +200,7 @@ def __init__(
)

if group is not None:
self.rotation_module = RotationModule(
group, spatial_dims, background_value, eps
)
self.rotation_module = RotationModule(group, spatial_dims, background_value, eps)
else:
self.rotation_module = None

Expand Down
61 changes: 22 additions & 39 deletions cyto_dl/models/vae/point_cloud_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@

from cyto_dl.models.vae.base_vae import BaseVAE
from cyto_dl.models.vae.priors import IdentityPrior, IsotropicGaussianPrior
from cyto_dl.nn.losses import ChamferLoss
from cyto_dl.nn.point_cloud import DGCNN, FoldingNet

# from topologylayer.nn import AlphaLayer, BarcodePolyFeature
from cyto_dl.nn.losses import TopoLoss
from cyto_dl.nn.losses import ChamferLoss, TopoLoss
from cyto_dl.nn.point_cloud import DGCNN, FoldingNet

Array = Union[torch.Tensor, np.ndarray, Sequence[float]]
logger = logging.getLogger("lightning")
Expand Down Expand Up @@ -248,13 +247,9 @@ def decode(self, z_parts, return_canonical=False, batch=None):
batch[self.point_label], z_parts["grid_feats"]
)
else:
base_xhat = self.decoder[self.hparams.x_label](
z_parts[self.hparams.x_label]
)
base_xhat = self.decoder[self.hparams.x_label](z_parts[self.hparams.x_label])
else:
base_xhat = self.decoder[self.hparams.x_label](
z_parts[self.hparams.x_label]
)
base_xhat = self.decoder[self.hparams.x_label](z_parts[self.hparams.x_label])

if self.get_rotation:
rotation = z_parts["rotation"]
Expand All @@ -274,9 +269,7 @@ def encoder_compose_function(self, z_parts, batch):
if self.basal_head:
z_parts[self.hparams.x_label + "_basal"] = z_parts[self.hparams.x_label]
for key in self.basal_head.keys():
z_parts[key] = self.basal_head[key](
z_parts[self.hparams.x_label + "_basal"]
)
z_parts[key] = self.basal_head[key](z_parts[self.hparams.x_label + "_basal"])

if self.condition_keys:
for j, key in enumerate([self.hparams.x_label] + self.condition_keys):
Expand All @@ -294,18 +287,16 @@ def encoder_compose_function(self, z_parts, batch):
if f"{key}" in self.mask_keys:
# mask is 1 for batch elements to mask, 0 otherwise
this_mask = (
batch[f"{key}_mask"]
.byte()
.repeat(1, this_z_parts.shape[-1])
batch[f"{key}_mask"].byte().repeat(1, this_z_parts.shape[-1])
)
# multiply inverse mask with batch part, so every mask element of 1 is set to 0
this_z_parts = this_z_parts * ~this_mask.bool()
cond_feats = torch.cat((cond_feats, this_z_parts), dim=1)

# shared encoder
z_parts[self.hparams.x_label] = self.condition_encoder[
self.hparams.x_label
](cond_feats)
z_parts[self.hparams.x_label] = self.condition_encoder[self.hparams.x_label](
cond_feats
)
if self.embedding_head:
for key in self.embedding_head.keys():
z_parts[key] = self.embedding_head[key](z_parts[self.hparams.x_label])
Expand All @@ -322,9 +313,7 @@ def decoder_compose_function(self, z_parts, batch):
# if mask, then mask this batch part
if f"{key}" in self.mask_keys:
this_mask = (
batch[f"{key}_mask"]
.byte()
.repeat(1, this_batch_part.shape[-1])
batch[f"{key}_mask"].byte().repeat(1, this_batch_part.shape[-1])
)
# multiply inverse mask with batch part, so every mask element of 1 is set to 0
this_batch_part = this_batch_part * ~this_mask.bool()
Expand All @@ -339,9 +328,9 @@ def decoder_compose_function(self, z_parts, batch):
# )
cond_feats = torch.cat((cond_inputs, z_parts[self.hparams.x_label]), dim=1)
# shared decoder
z_parts[self.hparams.x_label] = self.condition_decoder[
self.hparams.x_label
](cond_feats)
z_parts[self.hparams.x_label] = self.condition_decoder[self.hparams.x_label](
cond_feats
)
return z_parts

def calculate_rcl(self, batch, xhat, input_key, target_key=None):
Expand Down Expand Up @@ -383,15 +372,15 @@ def calculate_rcl_dict(self, batch, xhat, z):

if self.embedding_head_loss:
for key in self.embedding_head_loss.keys():
rcl_reduced[key] = self.embedding_head_weight[
key
] * self.embedding_head_loss[key](z[key], x[key])
rcl_reduced[key] = self.embedding_head_weight[key] * self.embedding_head_loss[key](
z[key], batch[key]
)

if self.basal_head_loss:
for key in self.basal_head_loss.keys():
rcl_reduced[key] = self.basal_head_weight[key] * self.basal_head_loss[
key
](z[key], x[key])
rcl_reduced[key] = self.basal_head_weight[key] * self.basal_head_loss[key](
z[key], batch[key]
)

# if (self.include_top_loss) & (self.current_epoch > 10):
if self.include_top_loss:
Expand Down Expand Up @@ -432,18 +421,12 @@ def parse_batch(self, batch):
if j == 0:
batch["target"] = batch[f"{key}"]
batch["target_mask"] = (
torch.zeros(batch[f"{key}"].shape)
.bernoulli_(this_mask)
.byte()
torch.zeros(batch[f"{key}"].shape).bernoulli_(this_mask).byte()
)
else:
batch["target"] = torch.cat(
[batch["target"], batch[f"{key}"]], dim=1
)
batch["target"] = torch.cat([batch["target"], batch[f"{key}"]], dim=1)
this_part_mask = (
torch.zeros(batch[f"{key}"].shape)
.bernoulli_(this_mask)
.byte()
torch.zeros(batch[f"{key}"].shape).bernoulli_(this_mask).byte()
)
batch["target_mask"] = torch.cat(
[batch["target_mask"], this_part_mask], dim=1
Expand Down

0 comments on commit a4981a4

Please sign in to comment.