Skip to content

Commit

Permalink
harmonized format for docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
Maren Hackenberg committed Aug 18, 2023
1 parent 1c5cc7e commit 020426d
Show file tree
Hide file tree
Showing 16 changed files with 419 additions and 248 deletions.
68 changes: 32 additions & 36 deletions src/CountDistributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,14 @@ LogGammaTerms(x, theta) = @. loggamma(x + theta) - loggamma(theta) - loggamma(on
Log likelihood (scalar) of a minibatch according to a zinb model.
Parameters
----------
x: Data
mu: mean of the negative binomial (has to be positive support) (shape: minibatch x vars)
theta: inverse dispersion parameter (has to be positive support) (shape: minibatch x vars)
pi: logit of the dropout parameter (real support) (shape: minibatch x vars)
eps: numerical stability constant
Notes
-----
# Arguments
- `x`: data
- `mu`: mean of the negative binomial (has to be positive support) (shape: minibatch x vars)
- `theta`: inverse dispersion parameter (has to be positive support) (shape: minibatch x vars)
- `zi`: logit of the dropout parameter (real support) (shape: minibatch x vars)
- `eps`: numerical stability constant
# Notes
We parametrize the bernoulli using the logits, hence the softplus functions appearing.
"""
function log_zinb_positive(x::AbstractMatrix{S}, mu::AbstractMatrix{S}, theta::AbstractVecOrMat{S}, zi::AbstractMatrix{S}, eps::S=S(1e-8)) where S <: Real
Expand All @@ -45,12 +43,11 @@ end
Log likelihood (scalar) of a minibatch according to a nb model.
Parameters
----------
x: Data
mu: mean of the negative binomial (has to be positive support) (shape: minibatch x vars)
theta: inverse dispersion parameter (has to be positive support) (shape: minibatch x vars)
eps: numerical stability constant
# Arguments
- `x`: data
- `mu`: mean of the negative binomial (has to be positive support) (shape: minibatch x vars)
- `theta`: inverse dispersion parameter (has to be positive support) (shape: minibatch x vars)
- `eps`: numerical stability constant
"""
function log_nb_positive(x::AbstractMatrix{S}, mu::AbstractMatrix{S}, theta::AbstractVecOrMat{S}, eps::S=S(1e-8)) where S <: Real
if length(size(theta)) == 1
Expand All @@ -66,11 +63,10 @@ end
Log likelihood (scalar) of a minibatch according to a Poisson model.
Parameters
----------
x: Data
mu: mean=variance of the Poisson distribution (has to be positive support) (shape: minibatch x vars)
eps: numerical stability constant
# Arguments
- `x`: data
- `mu`: mean=variance of the Poisson distribution (has to be positive support) (shape: minibatch x vars)
- `eps`: numerical stability constant
"""
function log_poisson(x::AbstractMatrix{S}, mu::AbstractMatrix{S}, eps::S=S(1e-8)) where S <: Real
return logpdf.(Poisson.(mu), x)
Expand All @@ -80,14 +76,14 @@ end
_convert_mean_disp_to_counts_logits(mu, theta, eps=1e-6)
NB parameterizations conversion.
Parameters
----------
mu: mean of the NB distribution.
theta: inverse overdispersion.
eps: constant used for numerical log stability. (Default value = 1e-6)
Returns
-------
the number of failures until the experiment is stopped and the success probability.
# Arguments
- `mu`: mean of the NB distribution.
- `theta`: inverse overdispersion.
- `eps`: constant used for numerical log stability. (Default value = 1e-6)
# Returns
- the number of failures until the experiment is stopped and the success probability.
"""
function _convert_mean_disp_to_counts_logits(mu, theta, eps=1e-6)
logits = log.(mu .+ eps) .- log.(theta .+ eps)
Expand All @@ -99,13 +95,13 @@ end
_convert_counts_logits_to_mean_disp(total_count, logits)
NB parameterizations conversion.
Parameters
----------
total_count: Number of failures until the experiment is stopped.
logits: success logits.
Returns
-------
the mean and inverse overdispersion of the NB distribution.
# Arguments
- `total_count`: Number of failures until the experiment is stopped.
- `logits`: success logits.
# Returns
- the mean and inverse overdispersion of the NB distribution.
"""
function _convert_counts_logits_to_mean_disp(total_count, logits)
theta = total_count
Expand Down
66 changes: 33 additions & 33 deletions src/EncoderDecoder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,17 @@ Julia implementation of the encoder of a single-cell VAE model corresponding to
Collects all information on the encoder parameters and stores the basic encoder and mean and variance encoders.
Can be constructed using keywords.
**Keyword arguments**
-------------------------
- `encoder`: `Flux.Chain` of fully connected layers realising the first part of the encoder (before the split in mean and variance). For details, see the source code of `FC_layers` in `src/Utils`.
- `mean_encoder`: `Flux.Dense` fully connected layer realising the latent mean encoder
- `n_input`: input dimension = number of genes/features
- `n_hidden`: number of hidden units to use in each hidden layer
- `n_output`: output dimension of the encoder = dimension of latent space
- `n_layers`: number of hidden layers in encoder and decoder
- `var_activation`: whether or not to use an activation function for the variance layer in the encoder
- `var_encoder`: `Flux.Dense` fully connected layer realising the latent variance encoder
- `var_eps`: numerical stability constant to add to the variance in the reparameterisation of the latent representation
- `z_transformation`: whether to apply a `softmax` transformation the latent z if assuming a lognormal instead of a normal distribution
# Fields for constructions
- `encoder`: `Flux.Chain` of fully connected layers realising the first part of the encoder (before the split in mean and variance). For details, see the source code of `FC_layers` in `src/Utils`.
- `mean_encoder`: `Flux.Dense` fully connected layer realising the latent mean encoder
- `n_input`: input dimension = number of genes/features
- `n_hidden`: number of hidden units to use in each hidden layer
- `n_output`: output dimension of the encoder = dimension of latent space
- `n_layers`: number of hidden layers in encoder and decoder
- `var_activation`: whether or not to use an activation function for the variance layer in the encoder
- `var_encoder`: `Flux.Dense` fully connected layer realising the latent variance encoder
- `var_eps`: numerical stability constant to add to the variance in the reparameterisation of the latent representation
- `z_transformation`: whether to apply a `softmax` transformation the latent z if assuming a lognormal instead of a normal distribution
"""
Base.@kwdef mutable struct scEncoder
encoder
Expand Down Expand Up @@ -61,13 +60,11 @@ Flux.@functor scEncoder
Constructor for an `scVAE` encoder. Initialises an `scEncoder` object according to the input parameters.
Julia implementation of the [`scvi-tools` encoder](https://github.com/scverse/scvi-tools/blob/b33b42a04403842591c04e414c8bb4099eaf7006/scvi/nn/_base_components.py#L202).
**Arguments:**
---------------------------
# Arguments
- `n_input`: input dimension = number of genes/features
- `n_output`: output dimension of the encoder = latent space dimension
**Keyword arguments:**
---------------------------
# Keyword arguments
- `activation_fn`: function to use as activation in all encoder neural network layers
- `bias`: whether or not to use bias parameters in the encoder neural network layers
- `n_hidden`: number of hidden units to use in each hidden layer (if an `Int` is passed, this number is used in all hidden layers,
Expand All @@ -80,6 +77,9 @@ Julia implementation of the [`scvi-tools` encoder](https://github.com/scverse/sc
- `use_layer_norm`: whether or not to apply layer normalization in the encoder layers
- `var_activation`: whether or not to use an activation function for the variance layer in the encoder
- `var_eps`: numerical stability constant to add to the variance in the reparameterisation of the latent representation
# Returns
- `scEncoder` object
"""
function scEncoder(
n_input::Int,
Expand Down Expand Up @@ -251,19 +251,18 @@ Julia implementation of the decoder for a single-cell VAE model corresponding to
Collects all information on the decoder parameters and stores the decoder parts.
Can be constructed using keywords.
**Keyword arguments**
-------------------------
- `n_input`: input dimension = dimension of latent space
- `n_hidden`: number of hidden units to use in each hidden layer (if an `Int` is passed, this number is used in all hidden layers,
alternatively an array of `Int`s can be passed, in which case the kth element corresponds to the number of units in the kth layer.
- `n_output`: output dimension of the decoder = number of genes/features
- `n_layers`: number of hidden layers in decoder
- `px_decoder`: `Flux.Chain` of fully connected layers realising the first part of the decoder (before the split in mean, dispersion and dropout decoder). For details, see the source code of `FC_layers` in `src/Utils`.
- `px_dropout_decoder`: if the generative distribution is zero-inflated negative binomial (`gene_likelihood = :zinb` in the `scVAE` model construction): `Flux.Dense` layer, else `nothing`.
- `px_r_decoder`: decoder for the dispersion parameter. If generative distribution is not some (zero-inflated) negative binomial, it is `nothing`. Else, it is a parameter vector or a `Flux.Dense`, depending on whether the dispersion is estimated per gene (`dispersion = :gene`), or per gene and cell (`dispersion = :gene_cell`)
- `px_scale_decoder`: decoder for the mean of the reconstruction, `Flux.Chain` of a `Dense` layer followed by `softmax` activation
- `use_batch_norm`: whether or not to apply batch normalization in the decoder layers
- `use_layer_norm`: whether or not to apply layer normalization in the decoder layers
# Fields for construction
- `n_input`: input dimension = dimension of latent space
- `n_hidden`: number of hidden units to use in each hidden layer (if an `Int` is passed, this number is used in all hidden layers,
alternatively an array of `Int`s can be passed, in which case the kth element corresponds to the number of units in the kth layer.
- `n_output`: output dimension of the decoder = number of genes/features
- `n_layers`: number of hidden layers in decoder
- `px_decoder`: `Flux.Chain` of fully connected layers realising the first part of the decoder (before the split in mean, dispersion and dropout decoder). For details, see the source code of `FC_layers` in `src/Utils`.
- `px_dropout_decoder`: if the generative distribution is zero-inflated negative binomial (`gene_likelihood = :zinb` in the `scVAE` model construction): `Flux.Dense` layer, else `nothing`.
- `px_r_decoder`: decoder for the dispersion parameter. If generative distribution is not some (zero-inflated) negative binomial, it is `nothing`. Else, it is a parameter vector or a `Flux.Dense`, depending on whether the dispersion is estimated per gene (`dispersion = :gene`), or per gene and cell (`dispersion = :gene_cell`)
- `px_scale_decoder`: decoder for the mean of the reconstruction, `Flux.Chain` of a `Dense` layer followed by `softmax` activation
- `use_batch_norm`: whether or not to apply batch normalization in the decoder layers
- `use_layer_norm`: whether or not to apply layer normalization in the decoder layers
"""
Base.@kwdef mutable struct scDecoder <: AbstractDecoder
n_input::Int
Expand Down Expand Up @@ -297,13 +296,11 @@ Flux.@functor scDecoder
Constructor for an `scVAE` decoder. Initialises an `scDecoder` object according to the input parameters.
Julia implementation of the [`scvi-tools` decoder](https://github.com/scverse/scvi-tools/blob/b33b42a04403842591c04e414c8bb4099eaf7006/scvi/nn/_base_components.py#L308).
**Arguments:**
---------------------------
# Arguments
- `n_input`: input dimension of the decoder = latent space dimension
- `n_output`: output dimension = number of genes/features in the data
**Keyword arguments:**
---------------------------
# Keyword arguments
- `activation_fn`: function to use as activation in all decoder neural network layers
- `bias`: whether or not to use bias parameters in the decoder neural network layers
- `dispersion`: whether to estimate the dispersion parameter for the (zero-inflated) negative binomial generative distribution per gene (`:gene`) or per gene and cell (`:gene_cell`)
Expand All @@ -314,6 +311,9 @@ Julia implementation of the [`scvi-tools` decoder](https://github.com/scverse/sc
- `use_activation`: whether or not to use an activation function in the decoder neural network layers; if `false`, overrides choice in `actication_fn`
- `use_batch_norm`: whether or not to apply batch normalization in the decoder layers
- `use_layer_norm`: whether or not to apply layer normalization in the decoder layers
# Returns
- `scDecoder` object
"""
function scDecoder(n_input, n_output;
activation_fn::Function=relu,
Expand Down
76 changes: 51 additions & 25 deletions src/Evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,17 @@
Calculates the latent representation obtained from encoding the `countmatrix` of the `AnnData` object
with a trained `scVAE` model by applying the function `get_latent_representation(m, adata.X)`.
Stored the latent representation in the `obsm` field of the input `AnnData` object as `name_latent`
Stored the latent representation in the `obsm` field of the input `AnnData` object as `name_latent`.
Returns the modified `AnnData` object.
# Arguments
- `adata::AnnData`: `AnnData` object to which to add the latent representation
- `m::scVAE`: trained `scVAE` model to use for encoding the data
# Keyword arguments
- `name_latent::String="scVI_latent"`: name of the field in `adata.obsm` where the latent representation is stored
# Returns
- the modified `AnnData` object.
"""
function register_latent_representation!(adata::AnnData, m::scVAE; name_latent::String="scVI_latent")
!m.is_trained && @warn("model has not been trained yet!")
Expand All @@ -25,7 +33,16 @@ the UMAP, if not, a latent representation is calculated and registered by callin
The UMAP is calculated using the Julia package [UMAP.jl](https://github.com/dillondaudert/UMAP.jl) with default parameters.
It is then stored in the `name_umap` field of the input `AnnData` object.
Returns the modified `AnnData` object.
# Arguments
- `adata::AnnData`: `AnnData` object to which to add the UMAP representation
- `m::scVAE`: trained `scVAE` model to use for encoding the data
# Keyword arguments
- `name_latent::String="scVI_latent"`: name of the field in `adata.obsm` where the latent representation is stored
- `name_umap::String="scVI_latent_umap"`: name of the field in `adata.obsm` where the UMAP representation is stored
# Returns
- the modified `AnnData` object.
"""
function register_umap_on_latent!(adata::AnnData, m::scVAE; name_latent::String="scVI_latent", name_umap::String="scVI_latent_umap")
register_latent_representation!(adata, m, name_latent = name_latent)
Expand All @@ -51,18 +68,19 @@ By default, the cells are color-coded according to the `celltypes` field of the
For plotting, the [VegaLite.jl](https://www.queryverse.org/VegaLite.jl/stable/) package is used.
**Arguments:**
---------------
# Arguments
- `m::scVAE`: trained `scVAE` model to use for embedding the data with the model encoder
- `adata:AnnData`: data to embed with the model; `adata.X` is encoded with `m`
**Keyword arguments:**
-------------------
# Keyword arguments
- `name_latent::String="scVI_latent"`: name of the field in `adata.obsm` where the latent representation is stored
- `name_latent_umap::String="scVI_latent_umap"`: name of the field in `adata.obsm` where the UMAP representation is stored
- `save_plot::Bool=true`: whether or not to save the plot
- `filename::String="UMAP_on_latent.pdf`: filename under which to save the plot. Has no effect if `save_plot==false`.
- `seed::Int=987`: which random seed to use for calculating UMAP (to ensure reproducibility)
# Returns
- the plot
"""
function plot_umap_on_latent(
m::scVAE, adata::AnnData;
Expand Down Expand Up @@ -108,18 +126,19 @@ By default, the cells are color-coded according to the `celltypes` column in `ad
For plotting, the [VegaLite.jl](https://www.queryverse.org/VegaLite.jl/stable/) package is used.
**Arguments:**
---------------
# Arguments
- `m::scVAE`: trained `scVAE` model to use for embedding the data with the model encoder
- `adata:AnnData`: data to embed with the model; `adata.X` is encoded with `m`
**Keyword arguments:**
-------------------
# Keyword arguments
- `name_latent::String="scVI_latent"`: name of the field in `adata.obsm` where the latent representation is stored
- `plot_title::String="scVI latent representation"`: title of the plot
- `save_plot::Bool=true`: whether or not to save the plot
- `filename::String="UMAP_on_latent.pdf`: filename under which to save the plot. Has no effect if `save_plot==false`.
- `seed::Int=987`: which random seed to use for calculating UMAP (to ensure reproducibility)
# Returns
- the plot
"""
function plot_latent_representation(
m::scVAE, adata::AnnData;
Expand Down Expand Up @@ -164,15 +183,16 @@ By default, the cells are color-coded according to the `celltypes` field of the
For plotting, the [VegaLite.jl](https://www.queryverse.org/VegaLite.jl/stable/) package is used.
**Arguments:**
---------------
# Arguments
- `m::scVAE`: trained `scVAE` model to use for embedding the data with the model encoder
- `adata:AnnData`: data to embed with the model; `adata.X` is encoded with `m`
**Keyword arguments:**
-------------------
# Keyword arguments
- `save_plot::Bool=true`: whether or not to save the plot
- `filename::String="UMAP_on_latent.pdf`: filename under which to save the plot. Has no effect if `save_plot==false`.
# Returns
- the plot
"""
function plot_pca_on_latent(
m::scVAE, adata::AnnData;
Expand Down Expand Up @@ -212,11 +232,13 @@ Depending on whether `z` is sampled from the prior or posterior, the function ca
The distribution ((zero-inflated) negative binomial or Poisson) is parametrised by `mu`, `theta` and `zi` (logits of dropout parameter).
The implementation is adapted from the corresponding [`scvi tools` function](https://github.com/YosefLab/scvi-tools/blob/f0a3ba6e11053069fd1857d2381083e5492fa8b8/scvi/distributions/_negative_binomial.py#L420)
**Arguments:**
-----------------
- `m::scVAE`: `scVAE` model from which the decoder is used for sampling
- `z::AbstractMatrix`: values of the latent representation to use as input for the decoder
- `library::AbstractMatrix`: library size values that are used for scaling in the decoder (either corresponding to the observed or the model-encoded library size)
# Arguments
- `m::scVAE`: `scVAE` model from which the decoder is used for sampling
- `z::AbstractMatrix`: values of the latent representation to use as input for the decoder
- `library::AbstractMatrix`: library size values that are used for scaling in the decoder (either corresponding to the observed or the model-encoded library size)
# Returns
- matrix of samples from the generative distribution defined by the decoder of the `scVAE` model
"""
function decodersample(m::scVAE, z::AbstractMatrix{S}, library::AbstractMatrix{S}) where S <: Real
px_scale, theta, mu, zi_logits = generative(m, z, library)
Expand Down Expand Up @@ -244,10 +266,12 @@ Subsequently samples from the generative distribution defined by the decoder bas
Returns the samples from the model.
**Arguments:**
--------------
# Arguments
- `m::scVAE`: trained `scVAE` model from which to sample
- `adata::AnnData`: `AnnData` object based on which to calculate the latent posterior
# Returns
- matrix of posterior samples from the model
"""
function sample_from_posterior(m::scVAE, adata::AnnData)
sample_from_posterior(m, adata.X')
Expand All @@ -268,14 +292,16 @@ Subsequently draws `n_samples` from the generative distribution defined by the d
Returns the samples from the model.
**Arguments:**
--------------
# Arguments
- `m::scVAE`: trained `scVAE` model from which to sample
- `adata::AnnData`: `AnnData` object based on which to calculate the library size
- `n_samples::Int`: number of samples to draw
**Keyword arguments:**
# Keyword arguments
- `sample_library_size::Bool=false`: whether or not to sample from the library size. If `false`, the mean of the observed library size is used.
# Returns
- matrix of prior samples from the model
"""
function sample_from_prior(m::scVAE, adata::AnnData, n_samples::Int; sample_library_size::Bool=false)
sample_from_prior(m, adata.X', n_samples, sample_library_size=sample_library_size)
Expand Down
Loading

0 comments on commit 020426d

Please sign in to comment.