Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes for MeasureBase v0.15 #122

Draft
wants to merge 28 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
0418307
Require Julia v1.10
oschulz Sep 15, 2024
3c8c393
Move ChainRulesCore support to extension
oschulz Sep 15, 2024
8d17caa
Add ForwardDiff extension
oschulz Sep 15, 2024
3d45f64
Add Distributions and DistributionsForwardDiff extensions
oschulz Sep 15, 2024
aad3dae
Add DistributionsChainRulesCore extension
oschulz Sep 15, 2024
b70ce70
Add function asmeasure
oschulz Sep 15, 2024
0f89a57
Add AsMeasure
oschulz Nov 1, 2024
0fb2009
Add collection utils
oschulz Nov 1, 2024
04e24b0
Add mreshape
oschulz Nov 1, 2024
3135e1e
STASH Distributions ext impl
oschulz Nov 1, 2024
1d27188
Remove PowerWeightedMeasure
oschulz Nov 1, 2024
9844583
Remove kernelfactor
oschulz Nov 1, 2024
cd92b8d
Rename pullback to pullbck and export it
oschulz Nov 1, 2024
f500aea
Rename bind to mbind and deprecate rightarrowtail
oschulz Nov 1, 2024
fefe36d
Introduce mintegrate and mintegrate_exp
oschulz Nov 1, 2024
fb3c98c
Remove the rebase function
oschulz Nov 1, 2024
86db05e
Rename bind to mbind and remove fish operator
oschulz Nov 1, 2024
58f0746
Change field order of Bind and improve docs.
oschulz Nov 1, 2024
0cdca3d
Remove operator otimes
oschulz Nov 1, 2024
3c61180
Removes PointwiseProductMeasure
oschulz Nov 1, 2024
867adbe
Remove scrd operator
oschulz Nov 1, 2024
e2b88a2
Remove ll-operator
oschulz Nov 1, 2024
5404ff1
Add measure operators in submodule MeasureOperators
oschulz Nov 1, 2024
aa7416b
Improve docstring for mbind
oschulz Nov 1, 2024
470db23
Improve likelihood docs
oschulz Nov 1, 2024
28f25ef
Apply JuliaFormatter
oschulz Nov 1, 2024
bf35d40
Improve Likelihood ctor
oschulz Nov 1, 2024
7aae58e
Fix typo in _mintegrate_exp_impl exception
oschulz Nov 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1.10'
- '1'
- 'pre'
os:
Expand Down
19 changes: 16 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ authors = ["Chad Scherrer <[email protected]>", "Oliver Schulz <oschulz@mp
version = "0.14.11"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Expand All @@ -29,21 +28,35 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

[extensions]
MeasureBaseChainRulesCoreExt = "ChainRulesCore"
MeasureBaseDistributionsExt = "Distributions"
MeasureBaseDistributionsChainRulesCoreExt = ["Distributions", "ChainRulesCore"]
MeasureBaseDistributionsForwardDiffExt = ["Distributions", "ForwardDiff"]
MeasureBaseForwardDiffExt = "ForwardDiff"

[compat]
ChainRulesCore = "1"
ChangesOfVariables = "0.1.3"
Compat = "3.35, 4"
ConstructionBase = "1.3"
DensityInterface = "0.4"
Distributions = "0.25.1"
FillArrays = "0.12, 0.13, 1"
ForwardDiff = "0.8, 0.9, 0.10"
FunctionChains = "0.1"
IfElse = "0.1"
IntervalSets = "0.7"
InverseFunctions = "0.1.8"
IrrationalConstants = "0.1, 0.2"
LinearAlgebra = "1"
LogExpFunctions = "0.3"
LogarithmicNumbers = "1"
LogExpFunctions = "0.3"
MappedArrays = "0.4"
NaNMath = "0.3, 1"
PrettyPrinting = "0.3, 0.4"
Expand All @@ -54,4 +67,4 @@ Static = "0.8, 1"
Statistics = "1"
Test = "1"
Tricks = "0.1"
julia = "1.6"
julia = "1.10"
84 changes: 84 additions & 0 deletions ext/MeasureBaseChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT).

module MeasureBaseChainRulesCoreExt

using MeasureBase
using ChainRulesCore: NoTangent, ZeroTangent
import ChainRulesCore


# = collection utils =========================================================

using MeasureBase: _dropfront, _dropback, _rev_cumsum, _exp_cumsum_log

function ChainRulesCore.rrule(::typeof(_pushfront), v::AbstractVector, x)
result = _pushfront(v, x)
function _pushfront_pullback(thunked_ΔΩ)
ΔΩ = ChainRulesCore.unthunk(thunked_ΔΩ)
(NoTangent(), ΔΩ[firstindex(ΔΩ)+1:lastindex(ΔΩ)], ΔΩ[firstindex(ΔΩ)])
end
return result, _pushfront_pullback
end


function ChainRulesCore.rrule(::typeof(_pushback), v::AbstractVector, x)
result = _pushback(v, x)
function _pushback_pullback(thunked_ΔΩ)
ΔΩ = ChainRulesCore.unthunk(thunked_ΔΩ)
(NoTangent(), ΔΩ[firstindex(ΔΩ):lastindex(ΔΩ)-1], ΔΩ[lastindex(ΔΩ)])
end
return result, _pushback_pullback
end


function ChainRulesCore.rrule(::typeof(_rev_cumsum), xs::AbstractVector)
result = _rev_cumsum(xs)
function _rev_cumsum_pullback(ΔΩ)
∂xs = ChainRulesCore.@thunk cumsum(ChainRulesCore.unthunk(ΔΩ))
(NoTangent(), ∂xs)
end
return result, _rev_cumsum_pullback
end


function ChainRulesCore.rrule(::typeof(_exp_cumsum_log), xs::AbstractVector)
result = _exp_cumsum_log(xs)
function _exp_cumsum_log_pullback(ΔΩ)
∂xs = inv.(xs) .* _rev_cumsum(exp.(cumsum(log.(xs))) .* ChainRulesCore.unthunk(ΔΩ))
(NoTangent(), ∂xs)
end
return result, _exp_cumsum_log_pullback
end


# = measure functions ========================================================


@inline function ChainRulesCore.rrule(::typeof(_checksupport), cond, result)
y = _checksupport(cond, result)
function _checksupport_pullback(ȳ)
return NoTangent(), ZeroTangent(), one(ȳ)
end
y, _checksupport_pullback
end


_require_insupport_pullback(ΔΩ) = NoTangent(), ZeroTangent()
function ChainRulesCore.rrule(::typeof(require_insupport), μ, x)
return require_insupport(μ, x), _require_insupport_pullback
end


_origin_depth_pullback(ΔΩ) = NoTangent(), NoTangent()
ChainRulesCore.rrule(::typeof(_origin_depth), ν) = _origin_depth(ν), _origin_depth_pullback


_check_dof_pullback(ΔΩ) = NoTangent(), NoTangent(), NoTangent()
ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = check_dof(ν, μ), _check_dof_pullback


_checked_arg_pullback(ΔΩ) = NoTangent(), NoTangent(), ΔΩ
ChainRulesCore.rrule(::typeof(checked_arg), ν, x) = checked_arg(ν, x), _checked_arg_pullback


end # module MeasureBaseChainRulesCoreExt
9 changes: 9 additions & 0 deletions ext/MeasureBaseDistributionsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT).

module MeasureBaseDistributionsChainRulesCoreExt

using MeasureBase
import Distributions
import ChainRulesCore

end # module MeasureBaseDistributionsChainRulesCoreExt
7 changes: 7 additions & 0 deletions ext/MeasureBaseDistributionsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT).

module MeasureBaseDistributionsExt

include "distributions/distributions.jl"

end # module MeasureBaseDistributionsExt
9 changes: 9 additions & 0 deletions ext/MeasureBaseDistributionsForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT).

module MeasureBaseDistributionsForwardDiffExt

using MeasureBase
import Distributions
import ForwardDiff

end # module MeasureBaseDistributionsForwardDiffExt
14 changes: 14 additions & 0 deletions ext/MeasureBaseForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT).

module MeasureBaseForwardDiffExt

using MeasureBase
import ForwardDiff

function MeasureBase.containsnan(x::ForwardDiff.Dual)
a = containsnan(x.value)
b = containsnan(x.partials)
return a || b
end

end # module MeasureBaseForwardDiffExt
Loading