Skip to content

Commit

Permalink
fix: round quantiles when target is integer (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
lsorber authored Mar 30, 2024
1 parent 91e0c4d commit 917002c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
20 changes: 10 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ When the input data is a pandas DataFrame, the output is also a pandas DataFrame

| house_id | 0.025 | 0.05 | 0.1 | 0.9 | 0.95 | 0.975 |
|-----------:|--------:|-------:|-------:|-------:|-------:|--------:|
| 1357 | 114783 | 120894 | 131618 | 175760 | 188051 | 205448 |
| 2367 | 67416 | 80073 | 86753 | 117854 | 127582 | 142321 |
| 2822 | 119422 | 132047 | 138724 | 178526 | 197246 | 214205 |
| 2126 | 94030 | 99849 | 110891 | 150249 | 164703 | 182528 |
| 1544 | 68996 | 81516 | 88231 | 121774 | 132425 | 147110 |
| 1357 | 114784 | 120894 | 131618 | 175761 | 188052 | 205449 |
| 2367 | 67417 | 80074 | 86754 | 117854 | 127583 | 142322 |
| 2822 | 119423 | 132048 | 138725 | 178526 | 197246 | 214206 |
| 2126 | 94031 | 99850 | 110891 | 150249 | 164703 | 182528 |
| 1544 | 68996 | 81516 | 88232 | 121774 | 132425 | 147110 |

Let's visualize the predicted quantiles on the test set:

Expand Down Expand Up @@ -116,11 +116,11 @@ When the input data is a pandas DataFrame, the output is also a pandas DataFrame

| house_id | 0.025 | 0.975 |
|-----------:|--------:|--------:|
| 1357 | 107202 | 206290 |
| 2367 | 66665 | 146004 |
| 2822 | 115591 | 220314 |
| 2126 | 85288 | 183037 |
| 1544 | 67889 | 150646 |
| 1357 | 107203 | 206290 |
| 2367 | 66665 | 146005 |
| 2822 | 115592 | 220315 |
| 2126 | 85288 | 183038 |
| 1544 | 67890 | 150646 |

## Contributing

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,10 @@ def predict_quantiles(
Δŷ_quantiles = Δŷ_quantiles[
np.arange(Δŷ_quantiles.shape[0]), :, np.argmin(dispersion, axis=-1)
]
ŷ_quantiles: FloatMatrix[F] = (ŷ[:, np.newaxis] + Δŷ_quantiles).astype(self.y_dtype_)
ŷ_quantiles: FloatMatrix[F] = ŷ[:, np.newaxis] + Δŷ_quantiles
if self.y_is_integer_:
ŷ_quantiles = np.round(ŷ_quantiles)
ŷ_quantiles = ŷ_quantiles.astype(self.y_dtype_)
# Convert ŷ_quantiles to a pandas DataFrame if X is a pandas DataFrame.
if hasattr(X, "dtypes") and hasattr(X, "index"):
try:
Expand Down

0 comments on commit 917002c

Please sign in to comment.