Skip to content

Commit

Permalink
Update to Static.jl v0.8 (#115)
Browse files Browse the repository at this point in the history
* Update CI configuration

* Update to Static.jl v0.8

* Increase package version to v0.14.6

* format

* Fix docstring for PowerMeasure

Co-authored-by: Chad Scherrer <[email protected]>

* Rename dslength and dssize

---------

Co-authored-by: Chad Scherrer <[email protected]>
  • Loading branch information
oschulz and cscherrer authored May 30, 2023
1 parent 99a603b commit 5bba40f
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 70 deletions.
60 changes: 22 additions & 38 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ on:
push:
branches:
- master
tags: '*'
pull_request:


concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
Expand All @@ -19,59 +19,43 @@ jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
continue-on-error: ${{ matrix.version == 'nightly' }}
strategy:
fail-fast: false
matrix:
version:
- '1.6'
- '1.7'
- '1.8'
- '1'
- 'nightly'
os:
- ubuntu-latest
arch:
- x64
include:
- version: 1
os: ubuntu-latest
arch: x86
- version: 1
os: macOS-latest
arch: x64
- version: 1
os: windows-latest
arch: x64
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: actions/cache@v1
env:
cache-name: cache-artifacts
with:
path: ~/.julia/artifacts
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
restore-keys: |
${{ runner.os }}-test-${{ env.cache-name }}-
${{ runner.os }}-test-
${{ runner.os }}-
- uses: julia-actions/cache@v1
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
with:
coverage: ${{ matrix.version == '1' && matrix.os == 'ubuntu-latest' && matrix.arch == 'x64' }}
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v1
if: matrix.version == '1' && matrix.os == 'ubuntu-latest' && matrix.arch == 'x64'
- uses: codecov/codecov-action@v3
if: matrix.version == '1' && matrix.os == 'ubuntu-latest' && matrix.arch == 'x64'
with:
file: lcov.info
# docs:
# name: Documentation
# runs-on: ubuntu-latest
# steps:
# - uses: actions/checkout@v2
# - uses: julia-actions/setup-julia@v1
# with:
# version: '1'
# - run: |
# julia --project=docs -e '
# using Pkg
# Pkg.develop(PackageSpec(path=pwd()))
# Pkg.instantiate()'
# - run: |
# julia --project=docs -e '
# using Documenter: DocMeta, doctest
# using MeasureBase
# DocMeta.setdocmeta!(MeasureBase, :DocTestSetup, :(using MeasureBase); recursive=true)
# doctest(MeasureBase)'
# - run: julia --project=docs docs/make.jl
# env:
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}

4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MeasureBase"
uuid = "fa1605e6-acd5-459c-a1e6-7e635759db14"
authors = ["Chad Scherrer <[email protected]> and contributors"]
version = "0.14.5"
version = "0.14.6"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down Expand Up @@ -48,6 +48,6 @@ NaNMath = "0.3, 1"
PrettyPrinting = "0.3, 0.4"
Reexport = "1"
SpecialFunctions = "2"
Static = "0.5, 0.6"
Static = "0.8"
Tricks = "0.1"
julia = "1.3"
3 changes: 2 additions & 1 deletion src/MeasureBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ using PrettyPrinting
const Pretty = PrettyPrinting

using ChainRulesCore
using FillArrays
import FillArrays
using Static
using FunctionChains

Expand Down Expand Up @@ -106,6 +106,7 @@ using Compat

using IrrationalConstants

include("static.jl")
include("smf.jl")
include("getdof.jl")
include("transport.jl")
Expand Down
53 changes: 30 additions & 23 deletions src/combinators/power.jl
Original file line number Diff line number Diff line change
@@ -1,58 +1,68 @@
import Base
using FillArrays: Fill
# """
# A power measure is a product of a measure with itself. The number of elements in
# the product determines the dimensionality of the resulting support.

# Note that power measures are only well-defined for integer powers.
export PowerMeasure

# The nth power of a measure μ can be written μ^x.
# """
# PowerMeasure{M,N,D} = ProductMeasure{Fill{M,N,D}}
"""
struct PowerMeasure{M,...} <: AbstractProductMeasure
export PowerMeasure
A power measure is a product of a measure with itself. The number of elements in
the product determines the dimensionality of the resulting support.
Note that power measures are only well-defined for integer powers.
The nth power of a measure μ can be written μ^n.
"""
struct PowerMeasure{M,A} <: AbstractProductMeasure
parent::M
axes::A
end

maybestatic_length::PowerMeasure) = prod(maybestatic_size(μ))
maybestatic_size::PowerMeasure) = map(maybestatic_length, μ.axes)

function Pretty.tile::PowerMeasure)
sz = length.(μ.axes)
arg1 = Pretty.tile.parent)
arg2 = Pretty.tile(length(sz) == 1 ? only(sz) : sz)
return Pretty.pair_layout(arg1, arg2; sep = " ^ ")
end

# ToDo: Make rand return static arrays for statically-sized power measures.

function _cartidxs(axs::Tuple{Vararg{<:AbstractUnitRange,N}}) where {N}
CartesianIndices(map(_dynamic, axs))
end

function Base.rand(
rng::AbstractRNG,
::Type{T},
d::PowerMeasure{M},
) where {T,M<:AbstractMeasure}
map(CartesianIndices(d.axes)) do _
map(_cartidxs(d.axes)) do _
rand(rng, T, d.parent)
end
end

function Base.rand(rng::AbstractRNG, ::Type{T}, d::PowerMeasure) where {T}
map(CartesianIndices(d.axes)) do _
map(_cartidxs(d.axes)) do _
rand(rng, d.parent)
end
end

@inline _pm_axes(sz::Tuple{Vararg{<:IntegerLike,N}}) where {N} = map(one_to, sz)
@inline _pm_axes(axs::Tuple{Vararg{<:AbstractUnitRange,N}}) where {N} = axs

@inline function powermeasure(x::T, sz::Tuple{Vararg{<:Any,N}}) where {T,N}
a = axes(Fill{T,N}(x, sz))
A = typeof(a)
PowerMeasure{T,A}(x, a)
PowerMeasure(x, _pm_axes(sz))
end

marginals(d::PowerMeasure) = Fill(d.parent, d.axes)
marginals(d::PowerMeasure) = fill_with(d.parent, d.axes)

function Base.:^::AbstractMeasure, dims::Tuple{Vararg{<:AbstractArray,N}}) where {N}
powermeasure(μ, dims)
end

Base.:^::AbstractMeasure, dims::Tuple) = powermeasure(μ, Base.OneTo.(dims))
Base.:^::AbstractMeasure, dims::Tuple) = powermeasure(μ, one_to.(dims))
Base.:^::AbstractMeasure, n) = powermeasure(μ, (n,))

# Base.show(io::IO, d::PowerMeasure) = print(io, d.parent, " ^ ", size(d.xs))
Expand All @@ -75,18 +85,15 @@ end
end
end

@inline function logdensity_def(
d::PowerMeasure{M,Tuple{Base.OneTo{StaticInt{N}}}},
x,
) where {M,N}
@inline function logdensity_def(d::PowerMeasure{M,Tuple{Static.SOneTo{N}}}, x) where {M,N}
parent = d.parent
sum(1:N) do j
@inbounds logdensity_def(parent, x[j])
end
end

@inline function logdensity_def(
d::PowerMeasure{M,NTuple{N,Base.OneTo{StaticInt{0}}}},
d::PowerMeasure{M,NTuple{N,Static.SOneTo{0}}},
x,
) where {M,N}
static(0.0)
Expand All @@ -110,7 +117,7 @@ end

@inline getdof::PowerMeasure) = getdof.parent) * prod(map(length, μ.axes))

@inline function getdof(::PowerMeasure{<:Any,NTuple{N,Base.OneTo{StaticInt{0}}}}) where {N}
@inline function getdof(::PowerMeasure{<:Any,NTuple{N,Static.SOneTo{0}}}) where {N}
static(0)
end

Expand All @@ -135,7 +142,7 @@ logdensity_def(::PowerMeasure{P}, x) where {P<:PrimitiveMeasure} = static(0.0)

# To avoid ambiguities
function logdensity_def(
::PowerMeasure{P,Tuple{Vararg{Base.OneTo{Static.StaticInt{0}},N}}},
::PowerMeasure{P,Tuple{Vararg{Static.SOneTo{0},N}}},
x,
) where {P<:PrimitiveMeasure,N}
static(0.0)
Expand Down
2 changes: 1 addition & 1 deletion src/combinators/smart-constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ end
###############################################################################
# ProductMeasure

productmeasure(mar::Fill) = powermeasure(mar.value, mar.axes)
productmeasure(mar::FillArrays.Fill) = powermeasure(mar.value, mar.axes)

function productmeasure(mar::ReadonlyMappedArray{T,N,A,Returns{M}}) where {T,N,A,M}
return powermeasure(mar.f.value, axes(mar.data))
Expand Down
2 changes: 1 addition & 1 deletion src/domains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ struct Simplex <: CodimOne end

function zeroset(::Simplex)
f(x::AbstractArray{T}) where {T} = sum(x) - one(T)
∇f(x::AbstractArray{T}) where {T} = Fill(one(T), size(x))
∇f(x::AbstractArray{T}) where {T} = fill_with(one(T), size(x))
ZeroSet(f, ∇f)
end

Expand Down
6 changes: 3 additions & 3 deletions src/standard/stdmeasure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ function transport_def(ν::StdMeasure, μ::PowerMeasure{<:StdMeasure}, x)
end

function transport_def::PowerMeasure{<:StdMeasure}, μ::StdMeasure, x)
return Fill(transport_def.parent, μ, only(x)), map(length, ν.axes)...)
return fill_with(transport_def.parent, μ, only(x)), map(length, ν.axes))
end

function transport_def(
Expand All @@ -35,7 +35,7 @@ end
# Implement transport_to(NU::Type{<:StdMeasure}, μ) and transport_to(ν, MU::Type{<:StdMeasure}):

_std_measure(::Type{M}, ::StaticInt{1}) where {M<:StdMeasure} = M()
_std_measure(::Type{M}, dof::Integer) where {M<:StdMeasure} = M()^dof
_std_measure(::Type{M}, dof::IntegerLike) where {M<:StdMeasure} = M()^dof
_std_measure_for(::Type{M}, μ::Any) where {M<:StdMeasure} = _std_measure(M, getdof(μ))

function transport_to(::Type{NU}, μ) where {NU<:StdMeasure}
Expand Down Expand Up @@ -90,7 +90,7 @@ end
@inline _offset_cumsum(s, x) = (s,)
@inline _offset_cumsum(s) = ()

function _stdvar_viewranges(μs::Tuple, startidx::Integer)
function _stdvar_viewranges(μs::Tuple, startidx::IntegerLike)
N = map(getdof, μs)
offs = _offset_cumsum(startidx, N...)
map((o, n) -> o:o+n-1, offs, N)
Expand Down
61 changes: 61 additions & 0 deletions src/static.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
MeasureBase.IntegerLike
Equivalent to `Union{Integer,Static.StaticInt}`.
"""
const IntegerLike = Union{Integer,Static.StaticInt}

"""
MeasureBase.one_to(n::IntegerLike)
Creates a range from one to n.
Returns an instance of `Base.OneTo` or `Static.SOneTo`, depending
on the type of `n`.
"""
@inline one_to(n::Integer) = Base.OneTo(n)
@inline one_to(::Static.StaticInt{N}) where {N} = Static.SOneTo{N}()

_dynamic(x::Number) = dynamic(x)
_dynamic(::Static.SOneTo{N}) where {N} = Base.OneTo(N)
_dynamic(r::AbstractUnitRange) = minimum(r):maximum(r)

"""
MeasureBase.fill_with(x, sz::NTuple{N,<:IntegerLike}) where N
Creates an array of size `sz` filled with `x`.
Returns an instance of `FillArrays.Fill`.
"""
function fill_with end

@inline function fill_with(x::T, sz::Tuple{Vararg{<:IntegerLike,N}}) where {T,N}
fill_with(x, map(one_to, sz))
end

@inline function fill_with(x::T, axs::Tuple{Vararg{<:AbstractUnitRange,N}}) where {T,N}
# While `FillArrays.Fill` (mostly?) works with axes that are static unit
# ranges, some operations that automatic differentiation requires do fail
# on such instances of `Fill` (e.g. `reshape` from dynamic to static size).
# So need to use standard ranges for the axes for now:
dyn_axs = map(_dynamic, axs)
FillArrays.Fill(x, dyn_axs)
end

"""
MeasureBase.maybestatic_length(x)::IntegerLike
Returns the length of `x` as a dynamic or static integer.
"""
maybestatic_length(x) = length(x)
maybestatic_length(x::AbstractUnitRange) = length(x)
function maybestatic_length(::Static.OptionallyStaticUnitRange{StaticInt{A},StaticInt{B}}) where {A,B}
StaticInt{B - A + 1}()
end

"""
MeasureBase.maybestatic_size(x)::Tuple{Vararg{IntegerLike}}
Returns the size of `x` as a tuple of dynamic or static integers.
"""
maybestatic_size(x) = size(x)
6 changes: 5 additions & 1 deletion test/transport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ using LogExpFunctions: logit
using ChainRulesTestUtils

@testset "transport_to" begin
test_rrule(MeasureBase._origin_depth, pushfwd(exp, StdUniform()))
test_rrule(
MeasureBase._origin_depth,
pushfwd(exp, StdUniform()),
output_tangent = static(0),
)

for (f, μ) in [
(logit, StdUniform())
Expand Down

2 comments on commit 5bba40f

@cscherrer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/84550

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.14.6 -m "<description of version>" 5bba40fd7e04e805160c66b9a645617ad0ed5041
git push origin v0.14.6

Please sign in to comment.