diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 06880e5f0..0d39192bd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ ci: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v5.0.0 hooks: # list of supported hooks: https://pre-commit.com/hooks.html - id: trailing-whitespace @@ -26,35 +26,35 @@ repos: # python code formatting - repo: https://github.com/psf/black - rev: 22.6.0 + rev: 24.10.0 hooks: - id: black args: [--line-length, "99"] # python import sorting - repo: https://github.com/PyCQA/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort args: ["--profile", "black", "--filter-files"] # python upgrading syntax to newer version - repo: https://github.com/asottile/pyupgrade - rev: v2.32.1 + rev: v3.17.0 hooks: - id: pyupgrade args: [--py38-plus] # python docstring formatting - repo: https://github.com/myint/docformatter - rev: v1.4 + rev: v1.7.5 hooks: - id: docformatter args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99] # python check (PEP8), programming errors and code complexity - repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 + rev: 7.1.1 hooks: - id: flake8 args: @@ -67,14 +67,14 @@ repos: # python security linter - repo: https://github.com/PyCQA/bandit - rev: "1.7.1" + rev: "1.7.10" hooks: - id: bandit args: ["-s", "B101"] # yaml formatting - repo: https://github.com/pre-commit/mirrors-prettier - rev: v2.7.1 + rev: v4.0.0-alpha.8 hooks: - id: prettier types: [yaml] @@ -82,13 +82,13 @@ repos: # shell scripts linter - repo: https://github.com/shellcheck-py/shellcheck-py - rev: v0.8.0.4 + rev: v0.10.0.1 hooks: - id: shellcheck # md formatting - repo: https://github.com/executablebooks/mdformat - rev: 0.7.14 + rev: 0.7.17 hooks: - id: mdformat args: ["--number"] @@ -96,12 +96,10 @@ repos: - mdformat-gfm - mdformat-tables - mdformat_frontmatter - # - mdformat-toc - # - mdformat-black # word spelling linter - repo: https://github.com/codespell-project/codespell - rev: v2.1.0 + rev: v2.3.0 hooks: - id: codespell args: @@ -110,13 +108,13 @@ repos: # jupyter notebook cell output clearing - repo: https://github.com/kynan/nbstripout - rev: 0.5.0 + rev: 0.7.1 hooks: - id: nbstripout # jupyter notebook linting - repo: https://github.com/nbQA-dev/nbQA - rev: 1.4.0 + rev: 1.8.7 hooks: - id: nbqa-black args: ["--line-length=99"] @@ -130,17 +128,18 @@ repos: ] - repo: https://github.com/dosisod/refurb - rev: v1.3.0 + rev: v2.0.0 hooks: - id: refurb language_version: python3.10 args: - --ignore - FURB120 + - --ignore + - FURB123 - repo: https://github.com/asottile/blacken-docs - rev: v1.12.1 + rev: 1.19.0 hooks: - id: blacken-docs args: [--line-length=120] - additional_dependencies: [black==21.12b0] diff --git a/configs/model/im2im/ijepa.yaml b/configs/model/im2im/ijepa.yaml index 40ea24201..0ae285162 100644 --- a/configs/model/im2im/ijepa.yaml +++ b/configs/model/im2im/ijepa.yaml @@ -6,7 +6,7 @@ save_dir: ${paths.output_dir} encoder: _target_: cyto_dl.nn.vits.encoder.JEPAEncoder - patch_size: 2 # patch_size * num_patches should equl data._aux.patch_shape + patch_size: 2 # patch_size * num_patches should equal data._aux.patch_shape num_patches: ${model._aux.num_patches} emb_dim: 16 num_layer: 2 diff --git a/cyto_dl/api/cyto_dl_model/cyto_dl_base_model.py b/cyto_dl/api/cyto_dl_model/cyto_dl_base_model.py index a06c4a740..31cb8347c 100644 --- a/cyto_dl/api/cyto_dl_model/cyto_dl_base_model.py +++ b/cyto_dl/api/cyto_dl_model/cyto_dl_base_model.py @@ -36,9 +36,9 @@ def _get_experiment_type(cls) -> ExperimentType: def from_existing_config(cls, config_filepath: Path): """Returns a model from an existing config. - :param config_filepath: path to a .yaml config file that will be used as the basis - for this CytoDLBaseModel (must be generated by the CytoDLBaseModel subclass that wants - to use it). + :param config_filepath: path to a .yaml config file that will be used as the basis for this + CytoDLBaseModel (must be generated by the CytoDLBaseModel subclass that wants to use + it). """ return cls(OmegaConf.load(config_filepath)) diff --git a/cyto_dl/callbacks/outlier_detection.py b/cyto_dl/callbacks/outlier_detection.py index 9e66694f2..d80f05982 100644 --- a/cyto_dl/callbacks/outlier_detection.py +++ b/cyto_dl/callbacks/outlier_detection.py @@ -64,7 +64,7 @@ def fn(_, __, output): return fn def _update_covariance(self, output, layer_name): - """record spatial mean and cov of channel activations per image in batch.""" + """Record spatial mean and cov of channel activations per image in batch.""" output = self.flatten_activations(output) if self.mu[layer_name] is None: self.mu[layer_name] = np.zeros(output.shape[1]) @@ -76,7 +76,7 @@ def _update_covariance(self, output, layer_name): self.n += 1 def on_train_epoch_start(self, trainer, pl_module): - """set forward hook.""" + """Set forward hook.""" if trainer.current_epoch == trainer.max_epochs - self.n_epochs: named_modules = dict([*pl_module.backbone.named_modules()]) for layer_name in self.layer_names: @@ -101,7 +101,7 @@ def _calculate_mahalanobis(self, output, layer_name): self.activations[layer_name].append(out) def _inference_start(self, pl_module): - """add mahalanobis calculation hook and calculate inverse covariance matrix.""" + """Add mahalanobis calculation hook and calculate inverse covariance matrix.""" if self._run: named_modules = dict([*pl_module.backbone.named_modules()]) for layer_name in self.layer_names: diff --git a/cyto_dl/dataframe/readers.py b/cyto_dl/dataframe/readers.py index 03305178a..9f1b54ce5 100644 --- a/cyto_dl/dataframe/readers.py +++ b/cyto_dl/dataframe/readers.py @@ -149,7 +149,7 @@ def read_dataframe( include_columns = sorted(list(include_columns)) required_columns = sorted(list(required_columns)) - if len(include_columns) == 0: + if not include_columns: include_columns = None if isinstance(dataframe, str): diff --git a/cyto_dl/datamodules/dataframe/dataframe_datamodule.py b/cyto_dl/datamodules/dataframe/dataframe_datamodule.py index 384aba2db..9c2fcfe14 100644 --- a/cyto_dl/datamodules/dataframe/dataframe_datamodule.py +++ b/cyto_dl/datamodules/dataframe/dataframe_datamodule.py @@ -163,7 +163,7 @@ def get_dataset(self, split): return self.datasets[split][sample] def make_dataloader(self, split): - kwargs = dict(**self.dataloader_kwargs) + kwargs = {**self.dataloader_kwargs} kwargs["shuffle"] = kwargs.get("shuffle", True) and split == "train" kwargs["batch_size"] = self.batch_size diff --git a/cyto_dl/datamodules/dataframe/grouped_dataframe_datamodule.py b/cyto_dl/datamodules/dataframe/grouped_dataframe_datamodule.py index 288abb0c8..37a2aa0de 100644 --- a/cyto_dl/datamodules/dataframe/grouped_dataframe_datamodule.py +++ b/cyto_dl/datamodules/dataframe/grouped_dataframe_datamodule.py @@ -116,7 +116,7 @@ def __init__( self.target_columns = target_columns def make_dataloader(self, split): - kwargs = dict(**self.dataloader_kwargs) + kwargs = {**self.dataloader_kwargs} kwargs["shuffle"] = kwargs.get("shuffle", True) and split == "train" subset = self.get_dataset(split) diff --git a/cyto_dl/datamodules/smartcache.py b/cyto_dl/datamodules/smartcache.py index db9ddad80..f1787a0c1 100644 --- a/cyto_dl/datamodules/smartcache.py +++ b/cyto_dl/datamodules/smartcache.py @@ -1,3 +1,4 @@ +from itertools import chain from pathlib import Path from typing import Optional, Union @@ -118,14 +119,18 @@ def _get_file_args(self, row): for timepoint in timepoints: img_data.append( { - "dimension_order_out": "ZYX"[-self.spatial_dims :] - if not use_neighbors - else "T" + "ZYX"[-self.spatial_dims :], + "dimension_order_out": ( + "ZYX"[-self.spatial_dims :] + if not use_neighbors + else "T" + "ZYX"[-self.spatial_dims :] + ), "C": row[self.channel_column], "scene": scene, - "T": timepoint - if not use_neighbors - else [timepoint + i for i in range(self.num_neighbors + 1)], + "T": ( + timepoint + if not use_neighbors + else [timepoint + i for i in range(self.num_neighbors + 1)] + ), "original_path": row[self.img_path_column], } ) @@ -136,7 +141,7 @@ def get_per_file_args(self, df): timepoints/channels/scenes for each file in the dataframe.""" with ProgressBar(): img_data = dask.compute(*[self._get_file_args(row) for row in df.itertuples()]) - img_data = [item for sublist in img_data for item in sublist] + img_data = list(chain.from_iterable(img_data)) return img_data def prepare_data(self): diff --git a/cyto_dl/image/io/aicsimage_loader.py b/cyto_dl/image/io/aicsimage_loader.py index 0ca9b6ea8..4c7b5a220 100644 --- a/cyto_dl/image/io/aicsimage_loader.py +++ b/cyto_dl/image/io/aicsimage_loader.py @@ -9,9 +9,9 @@ class AICSImageLoaderd(Transform): """Enumerates scenes and timepoints for dictionary with format. - {path_key: path, channel_key: channel, scene_key: scene, timepoint_key: timepoint}. - Differs from monai_bio_reader in that reading kwargs are passed in the dictionary, instead of - fixed at initialization. + {path_key: path, channel_key: channel, scene_key: scene, timepoint_key: timepoint}. Differs + from monai_bio_reader in that reading kwargs are passed in the dictionary, instead of fixed at + initialization. """ def __init__( diff --git a/cyto_dl/image/io/monai_bio_reader.py b/cyto_dl/image/io/monai_bio_reader.py index 89e7e0726..88d11c318 100644 --- a/cyto_dl/image/io/monai_bio_reader.py +++ b/cyto_dl/image/io/monai_bio_reader.py @@ -30,10 +30,7 @@ def __init__(self, dask_load: bool = True, **reader_kwargs): def read(self, data: Union[Sequence[PathLike], PathLike]): filenames: Sequence[PathLike] = ensure_tuple(data) - img_ = [] - for name in filenames: - img_.append(BioImage(f"{name}")) - + img_ = [BioImage(name) for name in filenames] return img_ if len(filenames) > 1 else img_[0] def get_data(self, img) -> Tuple[np.ndarray, Dict]: diff --git a/cyto_dl/image/io/skimage_reader.py b/cyto_dl/image/io/skimage_reader.py index 5844a0e46..c2a6ee6a4 100644 --- a/cyto_dl/image/io/skimage_reader.py +++ b/cyto_dl/image/io/skimage_reader.py @@ -25,7 +25,7 @@ def read(self, data: Union[Sequence[PathLike], PathLike]): filenames: Sequence[PathLike] = ensure_tuple(data) img_ = [] for name in filenames: - this_im = imread(f"{name}") + this_im = imread(name) if self.channels: this_im = this_im[self.channels] diff --git a/cyto_dl/image/transforms/multiscale_cropper.py b/cyto_dl/image/transforms/multiscale_cropper.py index 5cd999fa6..1f05b38b9 100644 --- a/cyto_dl/image/transforms/multiscale_cropper.py +++ b/cyto_dl/image/transforms/multiscale_cropper.py @@ -92,7 +92,7 @@ def _apply_slice(data, slicee): @staticmethod def _generate_slice(start_coords: Sequence[int], roi_size: Sequence[int]) -> slice: """Creates slice starting at `start_coords` of size `roi_size`""" - return [slice(None, None)] + [ + return [slice(None, None)] + [ # noqa: FURB140 slice(start, end) for start, end in zip(start_coords, start_coords + roi_size) ] diff --git a/cyto_dl/loggers/mlflow.py b/cyto_dl/loggers/mlflow.py index e23f6a2a8..9b2fd2df3 100644 --- a/cyto_dl/loggers/mlflow.py +++ b/cyto_dl/loggers/mlflow.py @@ -55,7 +55,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace], mode="train" with tempfile.TemporaryDirectory() as tmp_dir: conf_path = Path(tmp_dir) / f"{mode}.yaml" - with open(conf_path, "w") as f: + with conf_path.open("w") as f: config = OmegaConf.create(params) OmegaConf.save(config=config, f=f) @@ -133,7 +133,7 @@ def _after_save_checkpoint(self, ckpt_callback: ModelCheckpoint) -> None: self.run_id, local_path=best_path, artifact_path=artifact_path ) - os.unlink(best_path) + best_path.unlink() else: filepath = ckpt_callback.best_model_path @@ -149,7 +149,7 @@ def _after_save_checkpoint(self, ckpt_callback: ModelCheckpoint) -> None: self.run_id, local_path=last_path, artifact_path=artifact_path ) - os.unlink(last_path) + last_path.unlink() else: self.experiment.log_artifact( self.run_id, local_path=filepath, artifact_path=artifact_path @@ -157,9 +157,12 @@ def _after_save_checkpoint(self, ckpt_callback: ModelCheckpoint) -> None: def _delete_local_artifact(repo, artifact_path): - artifact_path = local_file_uri_to_path( - os.path.join(repo._artifact_dir, artifact_path) if artifact_path else repo._artifact_dir + artifact_path = Path( + local_file_uri_to_path( + os.path.join(repo._artifact_dir, artifact_path) + if artifact_path + else repo._artifact_dir + ) ) - - if os.path.isfile(artifact_path): - os.remove(artifact_path) + if artifact_path.is_file(): + artifact_path.unlink() diff --git a/cyto_dl/models/base_model.py b/cyto_dl/models/base_model.py index 3dedcba35..8ef795252 100644 --- a/cyto_dl/models/base_model.py +++ b/cyto_dl/models/base_model.py @@ -23,7 +23,7 @@ def _is_primitive(value): - if isinstance(value, (type(None), bool, str, int, float)): + if value is None or isinstance(value, (bool, str, int, float)): return True if isinstance(value, (tuple, list)): diff --git a/cyto_dl/models/basic_model.py b/cyto_dl/models/basic_model.py index 14560b34c..52a3834eb 100644 --- a/cyto_dl/models/basic_model.py +++ b/cyto_dl/models/basic_model.py @@ -59,11 +59,11 @@ def __init__( super().__init__(metrics=metrics) - if network is None and pretrained_weights is None: + if network is pretrained_weights is None: raise ValueError("`network` and `pretrained_weights` can't both be None.") if pretrained_weights is not None: - pretrained_weights = torch.load(pretrained_weights) + pretrained_weights = torch.load(pretrained_weights) # nosec B614 if network is not None: self.network = network diff --git a/cyto_dl/models/classification/timepoint_classification.py b/cyto_dl/models/classification/timepoint_classification.py index dbd70c553..e20bc0e8d 100644 --- a/cyto_dl/models/classification/timepoint_classification.py +++ b/cyto_dl/models/classification/timepoint_classification.py @@ -50,7 +50,7 @@ def predict_step(self, batch, batch_idx): batch, "predict", logits, - name=f"{batch['track_id'].cpu().item()}", + name=str(batch["track_id"].cpu().item()), ) timepoints = np.array(batch["timepoints"][0][1:-1].split(",")).astype(int) diff --git a/cyto_dl/models/handlers/base_handler.py b/cyto_dl/models/handlers/base_handler.py index bdbff8226..90759927f 100644 --- a/cyto_dl/models/handlers/base_handler.py +++ b/cyto_dl/models/handlers/base_handler.py @@ -76,12 +76,12 @@ def postprocess(self, data): mode = self.config["return"].get("mode", "network") if mode == "path": - path = self.config["return"].get("path", "/tmp") # nosec: B108 + path = self.config["return"].get("path", "/tmp") # nosec B108 response_path = Path(path) / f"{uuid.uuid4()}.pt" - torch.save(data, response_path) + torch.save(data, response_path) # nosec B614 return [str(response_path)] buf = io.BytesIO() - torch.save(data, buf) + torch.save(data, buf) # nosec B614 buf.seek(0) return [buf.read()] diff --git a/cyto_dl/models/im2im/gan.py b/cyto_dl/models/im2im/gan.py index 30d8a704a..00fceeda2 100644 --- a/cyto_dl/models/im2im/gan.py +++ b/cyto_dl/models/im2im/gan.py @@ -95,7 +95,7 @@ def configure_optimizers(self): return (opts, scheds) def _train_forward(self, batch, stage, save_image, run_heads): - """during training we are only dealing with patches,so we can calculate per-patch loss, + """During training we are only dealing with patches,so we can calculate per-patch loss, metrics, postprocessing etc.""" z = self.backbone(batch[self.hparams.x_key]) return { @@ -106,7 +106,7 @@ def _train_forward(self, batch, stage, save_image, run_heads): } def _inference_forward(self, batch, stage, save_image, run_heads): - """during inference, we need to calculate per-fov loss/metrics/postprocessing. + """During inference, we need to calculate per-fov loss/metrics/postprocessing. To avoid storing and passing to each head the intermediate results of the backbone, we need to run backbone + taskheads patch by patch, then do saving/postprocessing/etc on the entire diff --git a/cyto_dl/models/im2im/multi_task.py b/cyto_dl/models/im2im/multi_task.py index 9ac05c05b..d8b8f0d13 100644 --- a/cyto_dl/models/im2im/multi_task.py +++ b/cyto_dl/models/im2im/multi_task.py @@ -111,7 +111,7 @@ def configure_optimizers(self): return (opts, scheds) def _train_forward(self, batch, stage, n_postprocess, run_heads): - """during training we are only dealing with patches,so we can calculate per-patch loss, + """During training we are only dealing with patches,so we can calculate per-patch loss, metrics, postprocessing etc.""" z = self.backbone(batch[self.hparams.x_key]) return { @@ -124,7 +124,7 @@ def forward(self, x, run_heads): return {task: self.task_heads[task](z) for task in run_heads} def _inference_forward(self, batch, stage, n_postprocess, run_heads): - """during inference, we need to calculate per-fov loss/metrics/postprocessing. + """During inference, we need to calculate per-fov loss/metrics/postprocessing. To avoid storing and passing to each head the intermediate results of the backbone, we need to run backbone + taskheads patch by patch, then do saving/postprocessing/etc on the entire @@ -173,7 +173,7 @@ def _sum_losses(self, losses): return losses def _get_unrun_heads(self, io_map): - """returns heads that don't have outputs yet.""" + """Returns heads that don't have outputs yet.""" updated_run_heads = [] # check that all output files exist for each head for head, head_io_map in io_map.items(): @@ -184,7 +184,7 @@ def _get_unrun_heads(self, io_map): return updated_run_heads def _combine_io_maps(self, io_maps): - """aggregate io_maps from per-head to per-input image.""" + """Aggregate io_maps from per-head to per-input image.""" io_map = {} # create input-> per head output mapping for head, head_io_map in io_maps.items(): @@ -218,7 +218,7 @@ def _get_run_heads(self, batch, stage, batch_idx): return run_heads, io_map def _to_tensor(self, batch): - """convert monai metatensors to tensors.""" + """Convert monai metatensors to tensors.""" for k, v in batch.items(): if isinstance(v, MetaTensor): batch[k] = v.as_tensor() diff --git a/cyto_dl/models/im2im/utils/instance_seg.py b/cyto_dl/models/im2im/utils/instance_seg.py index 9223321ff..78a9fd971 100644 --- a/cyto_dl/models/im2im/utils/instance_seg.py +++ b/cyto_dl/models/im2im/utils/instance_seg.py @@ -90,10 +90,7 @@ def skeleton_tall(self, img, max_label): return tall_skeleton def label_2d(self, img): - """ - dim = 2: return labeled image - dim = 3: label each z slice separately - """ + """Dim = 2: return labeled image dim = 3: label each z slice separately.""" if self.dim == 2: out, _ = label(img) return out @@ -185,9 +182,9 @@ def embed_from_skel(self, skel: np.ndarray, iseg: np.ndarray): if len(object_points) == 2: crop_embedding[:, object_points[0], object_points[1]] = point_embeddings elif len(object_points) == 3: - crop_embedding[ - :, object_points[0], object_points[1], object_points[2] - ] = point_embeddings + crop_embedding[:, object_points[0], object_points[1], object_points[2]] = ( + point_embeddings + ) crop_embedding = torch.from_numpy(self.smooth_embedding(crop_embedding)) @@ -432,7 +429,7 @@ def _get_point_embeddings(self, object_points, skeleton_points): return dist, tree.data[idx].T.astype(int) def kd_clustering(self, embeddings, skel): - """assign embedded points to closest skeleton.""" + """Assign embedded points to closest skeleton.""" skel = find_boundaries(skel, mode="inner") * skel # propagate labels to boundaries skel_points = np.stack(skel.nonzero()).T embed_points = np.stack(embeddings).T @@ -446,7 +443,7 @@ def kd_clustering(self, embeddings, skel): return embedding_labels def remove_small_skeletons(self, skel): - """remove small skeletons below self.min_size that are not touching the edge of the + """Remove small skeletons below self.min_size that are not touching the edge of the image.""" skel_removed = skel.copy() regions = find_objects(skel) diff --git a/cyto_dl/models/vae/base_vae.py b/cyto_dl/models/vae/base_vae.py index d2eda552c..b167464b9 100644 --- a/cyto_dl/models/vae/base_vae.py +++ b/cyto_dl/models/vae/base_vae.py @@ -30,6 +30,7 @@ def __init__( **base_kwargs, ): """Instantiate a basic VAE model. + Parameters ---------- encoder: nn.Module @@ -107,7 +108,7 @@ def __init__( super().__init__(metrics=metrics, **base_kwargs) for key in prior.keys(): - if isinstance(prior[key], (str, type(None))): + if prior[key] is None or isinstance(prior[key], str): if prior[key] == "gaussian": prior[key] = IsotropicGaussianPrior(dimensionality=latent_dim) else: diff --git a/cyto_dl/models/vae/image_canon_vae.py b/cyto_dl/models/vae/image_canon_vae.py index a67396788..2d1429354 100644 --- a/cyto_dl/models/vae/image_canon_vae.py +++ b/cyto_dl/models/vae/image_canon_vae.py @@ -124,7 +124,7 @@ def __init__( ) ) - if isinstance(prior, (str, type(None))): + if prior is None or isinstance(prior, str): if prior == "gaussian": encoder_out_size = 2 * latent_dim else: diff --git a/cyto_dl/models/vae/image_encoder.py b/cyto_dl/models/vae/image_encoder.py index 1374c43c7..c97fd85e4 100644 --- a/cyto_dl/models/vae/image_encoder.py +++ b/cyto_dl/models/vae/image_encoder.py @@ -143,8 +143,7 @@ def forward(self, x): y = self.net(x) pool_dims = (2, 3) if self.spatial_dims == 2 else (2, 3, 4) - y = y.tensor - y = y.mean(dim=pool_dims) + y = y.tensor.mean(dim=pool_dims) y_embedding = y[:, : self.out_dim] diff --git a/cyto_dl/models/vae/image_vae.py b/cyto_dl/models/vae/image_vae.py index 5d65d398f..a64e82e82 100644 --- a/cyto_dl/models/vae/image_vae.py +++ b/cyto_dl/models/vae/image_vae.py @@ -175,7 +175,7 @@ def __init__( _Scale(last_scale), ) - if isinstance(prior, (str, type(None))): + if prior is None or isinstance(prior, str): if prior == "gaussian": encoder_out_size = 2 * latent_dim else: diff --git a/cyto_dl/models/vae/priors/gaussian.py b/cyto_dl/models/vae/priors/gaussian.py index 01e1482c3..12b807736 100644 --- a/cyto_dl/models/vae/priors/gaussian.py +++ b/cyto_dl/models/vae/priors/gaussian.py @@ -51,8 +51,8 @@ def kl_divergence(cls, mean, logvar, tc_penalty_weight=None, reduction="sum"): @classmethod def sample(cls, mean, logvar): std = torch.exp(0.5 * logvar) - eps = torch.randn_like(std) - return eps.mul(std).add(mean) + eps = torch.randn_like(std).mul(std).add(mean) + return eps def forward(self, z, mode="kl", inference=False, **kwargs): mean_logvar = z @@ -130,8 +130,9 @@ def param_size(self): @classmethod def kl_divergence(cls, mu1, mu2, logvar1, logvar2, tc_penalty_weight=None, reduction="sum"): - """Computes the Kullback-Leibler divergence between two diagonal - gaussians (not necessarily isotropic). It also works batch-wise. + """Computes the Kullback-Leibler divergence between two diagonal gaussians (not necessarily + isotropic). It also works batch-wise. + Parameters ---------- mu1: torch.Tensor diff --git a/cyto_dl/nn/discriminators/n_layer_discriminator.py b/cyto_dl/nn/discriminators/n_layer_discriminator.py index 0cb021fac..bf8ed8ea2 100644 --- a/cyto_dl/nn/discriminators/n_layer_discriminator.py +++ b/cyto_dl/nn/discriminators/n_layer_discriminator.py @@ -33,7 +33,7 @@ def __init__( super().__init__() if dim not in (2, 3): raise ValueError(f"dim must be 2 or 3, got {dim}") - if type(norm_layer) == functools.partial: + if isinstance(norm_layer, functools.partial): use_bias = norm_layer.func != nn.BatchNorm3d else: use_bias = norm_layer != nn.BatchNorm3d diff --git a/cyto_dl/nn/head/base_head.py b/cyto_dl/nn/head/base_head.py index 8737f8fed..cc8621e6b 100644 --- a/cyto_dl/nn/head/base_head.py +++ b/cyto_dl/nn/head/base_head.py @@ -39,7 +39,7 @@ def _postprocess(self, img, img_type, n_postprocess=1): return [self.postprocess[img_type](img[i]) for i in range(n_postprocess)] def generate_io_map(self, input_filenames): - """generates map between input files and output files for a head. + """Generates map between input files and output files for a head. Only used for prediction """ @@ -80,14 +80,16 @@ def run_head( return { "loss": loss, "pred": self._postprocess(y_hat, img_type="prediction", n_postprocess=n_postprocess), - "target": self._postprocess( - batch[self.head_name], img_type="input", n_postprocess=n_postprocess - ) - if stage != "predict" - else None, - "input": self._postprocess( - batch[self.x_key], img_type="input", n_postprocess=n_postprocess - ) - if stage != "predict" - else None, + "target": ( + self._postprocess( + batch[self.head_name], img_type="input", n_postprocess=n_postprocess + ) + if stage != "predict" + else None + ), + "input": ( + self._postprocess(batch[self.x_key], img_type="input", n_postprocess=n_postprocess) + if stage != "predict" + else None + ), } diff --git a/cyto_dl/nn/head/gan_head.py b/cyto_dl/nn/head/gan_head.py index 9072bb9bb..164db32ed 100644 --- a/cyto_dl/nn/head/gan_head.py +++ b/cyto_dl/nn/head/gan_head.py @@ -81,14 +81,16 @@ def run_head( "loss_D": loss_D, "loss_G": loss_G, "pred": self._postprocess(y_hat, img_type="prediction", n_postprocess=n_postprocess), - "target": self._postprocess( - batch[self.head_name], img_type="input", n_postprocess=n_postprocess - ) - if stage != "predict" - else None, - "input": self._postprocess( - batch[self.x_key], img_type="input", n_postprocess=n_postprocess - ) - if stage != "predict" - else None, + "target": ( + self._postprocess( + batch[self.head_name], img_type="input", n_postprocess=n_postprocess + ) + if stage != "predict" + else None + ), + "input": ( + self._postprocess(batch[self.x_key], img_type="input", n_postprocess=n_postprocess) + if stage != "predict" + else None + ), } diff --git a/cyto_dl/nn/head/mae_head.py b/cyto_dl/nn/head/mae_head.py index a630855bf..693d10be5 100644 --- a/cyto_dl/nn/head/mae_head.py +++ b/cyto_dl/nn/head/mae_head.py @@ -27,14 +27,16 @@ def run_head( return { "loss": loss, "pred": self._postprocess(y_hat, img_type="prediction", n_postprocess=n_postprocess), - "target": self._postprocess( - batch[self.head_name], img_type="input", n_postprocess=n_postprocess - ) - if stage != "predict" - else None, - "input": self._postprocess( - batch[self.x_key], img_type="input", n_postprocess=n_postprocess - ) - if stage != "predict" - else None, + "target": ( + self._postprocess( + batch[self.head_name], img_type="input", n_postprocess=n_postprocess + ) + if stage != "predict" + else None + ), + "input": ( + self._postprocess(batch[self.x_key], img_type="input", n_postprocess=n_postprocess) + if stage != "predict" + else None + ), } diff --git a/cyto_dl/nn/head/mask_head.py b/cyto_dl/nn/head/mask_head.py index a255ec2d0..fa3de7117 100644 --- a/cyto_dl/nn/head/mask_head.py +++ b/cyto_dl/nn/head/mask_head.py @@ -56,14 +56,16 @@ def run_head( return { "loss": loss, "pred": self._postprocess(y_hat, img_type="prediction", n_postprocess=n_postprocess), - "target": self._postprocess( - batch[self.head_name], img_type="input", n_postprocess=n_postprocess - ) - if stage != "predict" - else None, - "input": self._postprocess( - batch[self.x_key], img_type="input", n_postprocess=n_postprocess - ) - if stage != "predict" - else None, + "target": ( + self._postprocess( + batch[self.head_name], img_type="input", n_postprocess=n_postprocess + ) + if stage != "predict" + else None + ), + "input": ( + self._postprocess(batch[self.x_key], img_type="input", n_postprocess=n_postprocess) + if stage != "predict" + else None + ), } diff --git a/cyto_dl/nn/losses/gan_loss.py b/cyto_dl/nn/losses/gan_loss.py index f21657418..bb703da90 100644 --- a/cyto_dl/nn/losses/gan_loss.py +++ b/cyto_dl/nn/losses/gan_loss.py @@ -55,7 +55,8 @@ def get_target_tensor(self, prediction: torch.Tensor, target_is_real: bool): A label tensor filled with ground truth label, and with the size of input """ target_tensor = self.real_label if target_is_real else self.fake_label - return target_tensor.expand_as(prediction) + target_tensor = target_tensor.expand_as(prediction) # noqa: FURB184 + return target_tensor def __call__(self, prediction: torch.Tensor, target_is_real: bool): """Calculate loss given Discriminator's output and grount truth labels. diff --git a/cyto_dl/nn/losses/gaussian_nll_loss.py b/cyto_dl/nn/losses/gaussian_nll_loss.py index a15ca62c8..7f78371b8 100644 --- a/cyto_dl/nn/losses/gaussian_nll_loss.py +++ b/cyto_dl/nn/losses/gaussian_nll_loss.py @@ -21,9 +21,13 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: log_sigma = (sigma + self.eps).log().detach() loss = ( - 0.5 * torch.pow((target - input) / log_sigma.exp(), 2) - + log_sigma - + 0.5 * np.log(2 * np.pi) - ).reshape(input.shape[0], -1) + ( + 0.5 * torch.pow((target - input) / log_sigma.exp(), 2) + + log_sigma + + 0.5 * np.log(2 * np.pi) + ) + .reshape(input.shape[0], -1) + .sum(dim=1, keepdim=True) + ) - return loss.sum(dim=1, keepdim=True) + return loss diff --git a/cyto_dl/nn/mlp.py b/cyto_dl/nn/mlp.py index 0cbc221fb..96703ebe5 100644 --- a/cyto_dl/nn/mlp.py +++ b/cyto_dl/nn/mlp.py @@ -30,7 +30,7 @@ def __init__( net = [_make_block(sum(self.input_dims), hidden_layers[0])] - net += [ + net += [ # noqa: FURB140 _make_block(input_dim, output_dim) for (input_dim, output_dim) in zip(hidden_layers[0:], hidden_layers[1:]) ] diff --git a/cyto_dl/nn/point_cloud/dgcnn.py b/cyto_dl/nn/point_cloud/dgcnn.py index ce801d8bc..e34ad17e1 100644 --- a/cyto_dl/nn/point_cloud/dgcnn.py +++ b/cyto_dl/nn/point_cloud/dgcnn.py @@ -254,8 +254,9 @@ def _generate_plane_features(self, points, cond, plane="xz"): # scatter plane features from points fea_plane = cond.new_zeros(*view_dims1) cond = cond.permute(*permute_dims1) # B x 512 x T - fea_plane = scatter_mean(cond, index, out=fea_plane) # B x 512 x reso^2 - fea_plane = fea_plane.reshape(*view_dims2) # sparce matrix (B x 512 x reso x reso) + fea_plane = scatter_mean(cond, index, out=fea_plane).reshape( + *view_dims2 + ) # sparse matrix (B x 512 x reso x reso) # process the plane features with UNet if self.unet is not None: @@ -269,10 +270,9 @@ def _generate_grid_features(self, p, c): # scatter grid features from points fea_grid = c.new_zeros(p.size(0), self.num_features, self.reso_grid**3) c = c.permute(0, 2, 1) - fea_grid = scatter_mean(c, index, out=fea_grid) # B x C x reso^3 - fea_grid = fea_grid.reshape( + fea_grid = scatter_mean(c, index, out=fea_grid).reshape( p.size(0), self.num_features, self.reso_grid, self.reso_grid, self.reso_grid - ) # sparce matrix (B x 512 x reso x reso) + ) # sparse matrix (B x 512 x reso x reso) if self.unet3d is not None: fea_grid = self.unet3d(fea_grid) diff --git a/cyto_dl/nn/point_cloud/folding_net.py b/cyto_dl/nn/point_cloud/folding_net.py index 574efa626..91012183c 100644 --- a/cyto_dl/nn/point_cloud/folding_net.py +++ b/cyto_dl/nn/point_cloud/folding_net.py @@ -71,10 +71,8 @@ def __init__( def forward(self, x): x = self.project(x) - grid = self.grid.unsqueeze(0).expand(x.shape[0], -1, -1) - grid = grid.type_as(x) - x = x.unsqueeze(1) - cw_exp = x.expand(-1, grid.shape[1], -1) + grid = self.grid.unsqueeze(0).expand(x.shape[0], -1, -1).type_as(x) + cw_exp = x.unsqueeze(1).expand(-1, grid.shape[1], -1) cat1 = torch.cat((cw_exp, grid), dim=2) folding_result1 = self.folding1(cat1) diff --git a/cyto_dl/nn/point_cloud/graph_functions.py b/cyto_dl/nn/point_cloud/graph_functions.py index 9b51fd717..c71048c17 100644 --- a/cyto_dl/nn/point_cloud/graph_functions.py +++ b/cyto_dl/nn/point_cloud/graph_functions.py @@ -24,9 +24,7 @@ def knn(x, k): if idx_base.device != idx.device: idx_base = idx_base.to(idx.device) - idx = idx + idx_base - idx = idx.view(-1) - + idx = (idx + idx_base).view(-1) return idx diff --git a/cyto_dl/nn/spatial_transformer.py b/cyto_dl/nn/spatial_transformer.py index 1fa9ab86c..556903513 100644 --- a/cyto_dl/nn/spatial_transformer.py +++ b/cyto_dl/nn/spatial_transformer.py @@ -56,8 +56,7 @@ def __init__(self, n_input_ch=2, patch_shape=(64, 256, 512), n_conv_filters=32): self.fc_loc[2].bias.data.copy_(torch.tensor([0, 0, 0], dtype=torch.float)) def forward(self, x): - xs = self.localization(x) - xs = xs.view(-1, self.output_shape) + xs = self.localization(x).view(-1, self.output_shape) offsets = self.fc_loc(xs).squeeze() # create identity transformation matrix with only shifts theta = torch.eye(3, 4).reshape(1, 3, 4).repeat(x.shape[0], 1, 1) diff --git a/cyto_dl/nn/vits/blocks/cross_attention.py b/cyto_dl/nn/vits/blocks/cross_attention.py index a9dc47d38..d583e6d75 100644 --- a/cyto_dl/nn/vits/blocks/cross_attention.py +++ b/cyto_dl/nn/vits/blocks/cross_attention.py @@ -49,7 +49,7 @@ def __init__( self.proj_drop = nn.Dropout(proj_drop) def forward(self, x, y): - """query from decoder (x), key and value from encoder (y)""" + """Query from decoder (x), key and value from encoder (y)""" B, N, C = x.shape Ny = y.shape[1] q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) @@ -60,15 +60,18 @@ def forward(self, x, y): ) k, v = kv[0], kv[1] - attn = F.scaled_dot_product_attention( - q, - k, - v, - dropout_p=self.attn_drop, + attn = ( + F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop, + ) + .transpose(1, 2) + .reshape(B, N, C) ) - x = attn.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) + x = self.proj(attn) x = self.proj_drop(x) return x diff --git a/cyto_dl/nn/vits/seg.py b/cyto_dl/nn/vits/seg.py index ca36e465d..3e10e91a1 100644 --- a/cyto_dl/nn/vits/seg.py +++ b/cyto_dl/nn/vits/seg.py @@ -11,10 +11,8 @@ class EncodedSkip(torch.nn.Module): def __init__(self, spatial_dims, num_patches, emb_dim, n_decoder_filters, layer): super().__init__() - """ - layer = 0 is the smallest resolution, n is the highest - as the layer increases, the image size increases and the number of filters decreases - """ + """Layer = 0 is the smallest resolution, n is the highest as the layer increases, the image + size increases and the number of filters decreases.""" upsample = 2**layer self.n_out_channels = n_decoder_filters // (upsample**spatial_dims) @@ -50,7 +48,7 @@ def forward(self, features): class SuperresDecoder(torch.nn.Module): - """create unet-like decoder where each decoder layer is a fed a skip connection consisting of a + """Create unet-like decoder where each decoder layer is a fed a skip connection consisting of a different weighted sum of intermediate layer features.""" def __init__( @@ -223,7 +221,7 @@ def __init__( **encoder_kwargs, ) if encoder_ckpt is not None: - model = torch.load(encoder_ckpt, map_location="cuda:0") + model = torch.load(encoder_ckpt, map_location="cuda:0") # nosec B614 enc_state_dict = { k.replace("backbone.encoder.", ""): v for k, v in model["state_dict"].items() diff --git a/cyto_dl/utils/checkpoint.py b/cyto_dl/utils/checkpoint.py index e6f98aaea..597251869 100644 --- a/cyto_dl/utils/checkpoint.py +++ b/cyto_dl/utils/checkpoint.py @@ -7,7 +7,9 @@ def load_checkpoint(model, load_params): "ckpt_path" ), "ckpt_path must be provided to with argument weights_only=True" # load model from state dict to get around trainer.max_epochs limit, useful for resuming model training from existing weights - state_dict = torch.load(load_params["ckpt_path"], map_location="cpu")["state_dict"] + state_dict = torch.load(load_params["ckpt_path"], map_location="cpu")[ + "state_dict" + ] # nosec B614 model.load_state_dict(state_dict, strict=load_params.get("strict", True)) # set ckpt_path to None to avoid loading checkpoint again with model.fit/model.test load_params["ckpt_path"] = None diff --git a/cyto_dl/utils/rich_utils.py b/cyto_dl/utils/rich_utils.py index 3fc5815d4..79e20ab92 100644 --- a/cyto_dl/utils/rich_utils.py +++ b/cyto_dl/utils/rich_utils.py @@ -45,8 +45,12 @@ def print_config_tree( # add fields from `print_order` to queue for field in print_order: - queue.append(field) if field in cfg else log.warning( - f"Field '{field}' not found in config. Skipping '{field}' config printing..." + ( + queue.append(field) + if field in cfg + else log.warning( + f"Field '{field}' not found in config. Skipping '{field}' config printing..." + ) ) # add all the other fields to queue (not specified in `print_order`) @@ -71,7 +75,7 @@ def print_config_tree( # save config tree to file if save_to_file: - with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + with Path(cfg.paths.output_dir, "config_tree.log").open("w") as file: rich.print(tree, file=file) @@ -93,7 +97,7 @@ def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: log.info(f"Tags: {cfg.tags}") if save_to_file: - with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + with Path(cfg.paths.output_dir, "tags.log").open("w") as file: rich.print(cfg.tags, file=file) diff --git a/cyto_dl/utils/spharm/rotation.py b/cyto_dl/utils/spharm/rotation.py index 06549ca73..991d415c7 100644 --- a/cyto_dl/utils/spharm/rotation.py +++ b/cyto_dl/utils/spharm/rotation.py @@ -96,7 +96,7 @@ def flip_spharm(input, paired_indices, flips=-1): def get_band_indices(columns, max_band, prefix="", flat=False): - """get the tensor indices for each band, based on the column order of the batch loader (given + """Get the tensor indices for each band, based on the column order of the batch loader (given by `columns`, assuming that it is in the same order). this is passed to `rotate_spharm` later, to rotate the spherical harmonics around the z axis diff --git a/scripts/publish_bumpver_handler.py b/scripts/publish_bumpver_handler.py index 0b2a51ada..7061eef0e 100644 --- a/scripts/publish_bumpver_handler.py +++ b/scripts/publish_bumpver_handler.py @@ -1,7 +1,6 @@ # this file is intended to be called by a github workflow (.github/workflows/publish_to_pypi.yaml) # it makes decisions based on the current version and the component specified for bumping, # which the workflow cannot do - """ TESTING: - add and commit any changes (keep track of this commit hash)