Skip to content

Commit

Permalink
Merge pull request #125 from basf/develop
Browse files Browse the repository at this point in the history
include quantile regression
  • Loading branch information
AnFreTh authored Sep 19, 2024
2 parents 967f49f + ccfc75a commit 48d22da
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 4 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ MambularLSS allows you to model the full distribution of a response variable, no
- **negativebinom**: For over-dispersed count data.
- **inversegamma**: Often used as a prior in Bayesian inference.
- **categorical**: For data with more than two categories.
- **Quantile**: For quantile regression using the pinball loss.

These distribution classes make MambularLSS versatile in modeling various data types and distributions.

Expand Down
2 changes: 1 addition & 1 deletion mambular/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Version information."""

# The following line *must* be the last in the module, exactly as formatted:
__version__ = "0.2.2"
__version__ = "0.2.3"
6 changes: 3 additions & 3 deletions mambular/models/sklearn_base_lss.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
NormalDistribution,
PoissonDistribution,
StudentTDistribution,
Quantile,
)
from lightning.pytorch.callbacks import ModelSummary

Expand Down Expand Up @@ -210,11 +211,9 @@ def build_model(
X, y, X_val, y_val, val_size=val_size, random_state=random_state
)

num_classes = len(np.unique(y))

self.task_model = TaskModel(
model_class=self.base_model,
num_classes=num_classes,
num_classes=self.family.param_count,
config=self.config,
cat_feature_info=self.data_module.cat_feature_info,
num_feature_info=self.data_module.num_feature_info,
Expand Down Expand Up @@ -347,6 +346,7 @@ def fit(
"negativebinom": NegativeBinomialDistribution,
"inversegamma": InverseGammaDistribution,
"categorical": CategoricalDistribution,
"quantile": Quantile,
}

if distributional_kwargs is None:
Expand Down
49 changes: 49 additions & 0 deletions mambular/utils/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,3 +504,52 @@ def compute_loss(self, predictions, y_true):
# Compute the negative log-likelihood
nll = -cat_dist.log_prob(y_true).mean()
return nll


class Quantile(BaseDistribution):
"""
Quantile Regression Loss class.
This class computes the quantile loss (also known as pinball loss) for a set of quantiles.
It is used to handle quantile regression tasks where we aim to predict a given quantile of the target distribution.
Parameters
----------
name : str, optional
The name of the distribution, by default "Quantile".
quantiles : list of float, optional
A list of quantiles to be used for computing the loss, by default [0.25, 0.5, 0.75].
Attributes
----------
quantiles : list of float
List of quantiles for which the pinball loss is computed.
Methods
-------
compute_loss(predictions, y_true)
Computes the quantile regression loss between the predictions and true values.
"""

def __init__(self, name="Quantile", quantiles=[0.25, 0.5, 0.75]):
param_names = [
f"q_{q}" for q in quantiles
] # Use string representations of quantiles
super().__init__(name, param_names)
self.quantiles = quantiles

def compute_loss(self, predictions, y_true):

assert not y_true.requires_grad # Ensure y_true does not require gradients
assert predictions.size(0) == y_true.size(0) # Ensure batch size matches

losses = []
for i, q in enumerate(self.quantiles):
errors = y_true - predictions[:, i] # Calculate errors for each quantile
# Compute the pinball loss
quantile_loss = torch.max((q - 1) * errors, q * errors)
losses.append(quantile_loss)

# Sum losses across quantiles and compute mean
loss = torch.mean(torch.stack(losses, dim=1).sum(dim=1))
return loss

0 comments on commit 48d22da

Please sign in to comment.