Skip to content

Commit

Permalink
Added joint version of the power-loss function
Browse files Browse the repository at this point in the history
  • Loading branch information
msainsburydale committed Nov 13, 2023
1 parent 4722d43 commit e614116
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions src/loss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ _check_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1
# ---- kpowerloss ----

"""
kpowerloss(θ̂, y, k; agg = mean, safeorigin = true, ϵ = 0.1)
kpowerloss(θ̂, y, k; agg = mean, joint = true, safeorigin = true, ϵ = 0.1)
For `k` ∈ (0, ∞), the `k`-th power absolute-distance loss,
Expand All @@ -23,25 +23,33 @@ L(θ̂, θ) = |θ̂ - θ|ᵏ,
```
contains the squared-error, absolute-error, and 0-1 loss functions as special
cases (the latter obtained in the limit as `k` → 0).
cases (the latter obtained in the limit as `k` → 0). It is Lipschitz continuous
iff `k` = 1, convex iff `k` ≥ 1, and strictly convex iff `k` > 1: it is
quasiconvex for all `k` > 0.
It is Lipschitz continuous iff `k` = 1, convex iff `k` ≥ 1, and strictly convex
iff `k` > 1. It is quasiconvex for all `k` > 0.
If `joint = true`, the L₁ norm is computed over each parameter vector, so that
the resulting Bayes estimator is the mode of the joint posterior distribution;
otherwise, the Bayes estimator is the vector containing the modes of the
marginal posterior distributions.
If `safeorigin = true`, the loss function is modified to avoid pathologies
around the origin, so that the resulting loss function behaves similarly to the
absolute-error loss in the `ϵ`-interval surrounding the origin.
"""
function kpowerloss(θ̂, θ, k; safeorigin::Bool = true, agg = mean, ϵ = ofeltype(θ̂, 0.1))
function kpowerloss(θ̂, θ, k; safeorigin::Bool = true, agg = mean, ϵ = ofeltype(θ̂, 0.1), joint::Bool = true)

_check_sizes(θ̂, θ)

d = abs.(θ̂ .- θ)
if joint
d = sum(d, dims = 1)
end

if safeorigin
d = abs.(θ̂ .- θ)
b = d .> ϵ
L = vcat(d[b] .^ k, _safefunction.(d[.!b], k, ϵ))
else
L = abs.(θ̂ .- θ).^k
L = d.^k
end

return agg(L)
Expand All @@ -52,7 +60,6 @@ function _safefunction(d, k, ϵ)
ϵ^(k - 1) * d
end


# ---- quantile loss ----

"""
Expand Down

0 comments on commit e614116

Please sign in to comment.