Skip to content

Commit

Permalink
Create containers with map instead of for loops
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Sep 23, 2019
1 parent 8ce7535 commit ecb7915
Show file tree
Hide file tree
Showing 16 changed files with 439 additions and 354 deletions.
2 changes: 2 additions & 0 deletions src/Containers/Containers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,7 @@ export DenseAxisArray, SparseAxisArray
include("DenseAxisArray.jl")
include("SparseAxisArray.jl")
include("generate_container.jl")
include("container.jl")
include("macro.jl")

end
2 changes: 2 additions & 0 deletions src/Containers/SparseAxisArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

include("nested_iterator.jl")

"""
struct SparseAxisArray{T,N,K<:NTuple{N, Any}} <: AbstractArray{T,N}
data::Dict{K,T}
Expand Down
22 changes: 22 additions & 0 deletions src/Containers/container.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
const ArrayIndices{N} = Base.Iterators.ProductIterator{NTuple{N, Base.OneTo{Int}}}
container(f::Function, indices) = container(f, indices, default_container(indices))
default_container(::ArrayIndices) = Array
function container(f::Function, indices::ArrayIndices, ::Type{Array})
return map(I -> f(I...), indices)
end
default_container(::Base.Iterators.ProductIterator) = DenseAxisArray
function container(f::Function, indices::Base.Iterators.ProductIterator,
::Type{DenseAxisArray})
return DenseAxisArray(map(I -> f(I...), indices), indices.iterators...)
end
default_container(::NestedIterator) = SparseAxisArray
function container(f::Function, indices::NestedIterator,
::Type{SparseAxisArray})
mappings = map(I -> I => f(I...), indices)
data = Dict(mappings)
if length(mappings) != length(data)
# TODO compute idx
error(string("Repeated index ", idx,". Index sets must have unique elements."))
end
return SparseAxisArray(Dict(data))
end
170 changes: 170 additions & 0 deletions src/Containers/macro.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
using Base.Meta

"""
_extract_kw_args(args)
Process the arguments to a macro, separating out the keyword arguments.
Return a tuple of (flat_arguments, keyword arguments, and requested_container),
where `requested_container` is a symbol to be passed to `parse_container`.
"""
function _extract_kw_args(args)
kw_args = filter(x -> isexpr(x, :(=)) && x.args[1] != :container , collect(args))
flat_args = filter(x->!isexpr(x, :(=)), collect(args))
requested_container = :Auto
for kw in args
if isexpr(kw, :(=)) && kw.args[1] == :container
requested_container = kw.args[2]
end
end
return flat_args, kw_args, requested_container
end

function _try_parse_idx_set(arg::Expr)
# [i=1] and x[i=1] parse as Expr(:vect, Expr(:(=), :i, 1)) and
# Expr(:ref, :x, Expr(:kw, :i, 1)) respectively.
if arg.head === :kw || arg.head === :(=)
@assert length(arg.args) == 2
return true, arg.args[1], arg.args[2]
elseif isexpr(arg, :call) && arg.args[1] === :in
return true, arg.args[2], arg.args[3]
else
return false, nothing, nothing
end
end
function _explicit_oneto(index_set)
s = Meta.isexpr(index_set,:escape) ? index_set.args[1] : index_set
if Meta.isexpr(s,:call) && length(s.args) == 3 && s.args[1] == :(:) && s.args[2] == 1
return :(Base.OneTo($index_set))
else
return index_set
end
end

function _expr_is_splat(ex::Expr)
if ex.head == :(...)
return true
elseif ex.head == :escape
return _expr_is_splat(ex.args[1])
end
return false
end
_expr_is_splat(::Any) = false

"""
_parse_ref_sets(expr::Expr)
Helper function for macros to construct container objects. Takes an `Expr` that
specifies the container, e.g. `:(x[i=1:3,[:red,:blue]],k=S; i+k <= 6)`, and
returns:
1. `idxvars`: Names for the index variables, e.g. `[:i, gensym(), :k]`
2. `idxsets`: Sets used for indexing, e.g. `[1:3, [:red,:blue], S]`
3. `condition`: Expr containing any conditional imposed on indexing, or `:()` if none is present
"""
function _parse_ref_sets(_error::Function, expr::Expr)
c = copy(expr)
idxvars = Any[]
idxsets = Any[]
# On 0.7, :(t[i;j]) is a :ref, while t[i,j;j] is a :typed_vcat.
# In both cases :t is the first arg.
if isexpr(c, :typed_vcat) || isexpr(c, :ref)
popfirst!(c.args)
end
condition = :()
if isexpr(c, :vcat) || isexpr(c, :typed_vcat)
# Parameters appear as plain args at the end.
if length(c.args) > 2
_error("Unsupported syntax $c.")
elseif length(c.args) == 2
condition = pop!(c.args)
end # else no condition.
elseif isexpr(c, :ref) || isexpr(c, :vect)
# Parameters appear at the front.
if isexpr(c.args[1], :parameters)
if length(c.args[1].args) != 1
_error("Invalid syntax: $c. Multiple semicolons are not " *
"supported.")
end
condition = popfirst!(c.args).args[1]
end
end
if isexpr(c, :vcat) || isexpr(c, :typed_vcat) || isexpr(c, :ref)
if isexpr(c.args[1], :parameters)
@assert length(c.args[1].args) == 1
condition = popfirst!(c.args).args[1]
end # else no condition.
end

for s in c.args
parse_done = false
if isa(s, Expr)
parse_done, idxvar, _idxset = _try_parse_idx_set(s::Expr)
if parse_done
idxset = esc(_idxset)
end
end
if !parse_done # No index variable specified
idxvar = gensym()
idxset = esc(s)
end
push!(idxvars, idxvar)
push!(idxsets, idxset)
end
return idxvars, idxsets, condition
end

"""
_build_ref_sets(expr::Expr)
Helper function for macros to construct container objects. Takes an `Expr` that
specifies the container, e.g. `:(x[i=1:3,[:red,:blue]],k=S; i+k <= 6)`, and
returns:
1. `idxvars`: Names for the index variables, e.g. `[:i, gensym(), :k]`
2. `idxsets`: Sets used for indexing, e.g. `[1:3, [:red,:blue], S]`
3. `condition`: Expr containing any conditional imposed on indexing, or `:()` if none is present
"""
function _build_ref_sets(_error::Function, expr::Expr)
idxvars, idxsets, condition = _parse_ref_sets(_error, expr)
if any(_expr_is_splat.(idxsets))
_error("cannot use splatting operator `...` in the definition of an index set.")
end
has_dependent = has_dependent_sets(idxvars, idxsets)
if has_dependent || condition != :()
esc_idxvars = esc.(idxvars)
idxfuns = [:(($(esc_idxvars[1:(i - 1)]...),) -> $(idxsets[i])) for i in 1:length(idxvars)]
if condition == :()
indices = :(Containers.NestedIterator(($(idxfuns...),)))
else
condition_fun = :(($(esc_idxvars...),) -> $condition)
indices = :(Containers.NestedIterator(($(idxfuns...),), $condition_fun))
end
else
indices = :(Base.Iterators.product(($(_explicit_oneto.(idxsets)...))))
end
return idxvars, indices
end

function container_code(idxvars, indices, code, requested_container)
esc_idxvars = esc.(idxvars)
func = :(($(esc_idxvars...),) -> $code)
if requested_container == :Auto
return :(Containers.container($func, $indices))
else
return :(Containers.container($func, $indices, $requested_container))
end
end
function parse_container(_error, var, value, requested_container)
idxvars, indices = _build_ref_sets(_error, var)
return container_code(idxvars, indices, value, requested_container)
end

macro container(args...)
args, kw_args, requested_container = _extract_kw_args(args)
@assert length(args) == 2
@assert isempty(kw_args)
var, value = args
name = var.args[1]
code = parse_container(error, var, esc(value), requested_container)
return :($(esc(name)) = $code)
end
44 changes: 44 additions & 0 deletions src/Containers/nested_iterator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
struct NestedIterator{T}
iterators::T # Tuple of functions
condition::Function
end
NestedIterator(iterator) = NestedIterator(iterator, (args...) -> true)
Base.IteratorSize(::Type{<:NestedIterator}) = Base.SizeUnknown()
Base.IteratorEltype(::Type{<:NestedIterator}) = Base.EltypeUnknown()
function next_iterate(it::NestedIterator, i, elems, states, iterator, elem_state)
if elem_state === nothing
return nothing
end
elem, state = elem_state
elems_states = first_iterate(
it, i + 1, (elems..., elem),
(states..., (iterator, state, elem)))
if elems_states !== nothing
return elems_states
end
return next_iterate(it, i, elems, states, iterator, iterate(iterator, state))
end
function first_iterate(it::NestedIterator, i, elems, states)
if i > length(it.iterators)
if it.condition(elems...)
return elems, states
else
return nothing
end
end
iterator = it.iterators[i](elems...)
return next_iterate(it, i, elems, states, iterator, iterate(iterator))
end
function tail_iterate(it::NestedIterator, i, elems, states)
if i > length(it.iterators)
return nothing
end
next = tail_iterate(it, i + 1, (elems..., states[i][3]), states)
if next !== nothing
return next
end
iterator = states[i][1]
next_iterate(it, i, elems, states[1:(i - 1)], iterator, iterate(iterator, states[i][2]))
end
Base.iterate(it::NestedIterator) = first_iterate(it, 1, tuple(), tuple())
Base.iterate(it::NestedIterator, states) = tail_iterate(it, 1, tuple(), states)
5 changes: 3 additions & 2 deletions src/JuMP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import ForwardDiff
include("_Derivatives/_Derivatives.jl")
using ._Derivatives

include("Containers/Containers.jl")

# Exports are at the end of the file.

# Deprecations for JuMP v0.18 -> JuMP v0.19 transition
Expand Down Expand Up @@ -460,7 +462,7 @@ end
"""
set_time_limit_sec(model::Model, limit)
Sets the time limit (in seconds) of the solver.
Sets the time limit (in seconds) of the solver.
Can be unset using `unset_time_limit_sec` or with `limit` set to `nothing`.
"""
function set_time_limit_sec(model::Model, limit)
Expand Down Expand Up @@ -768,7 +770,6 @@ struct NonlinearParameter <: AbstractJuMPScalar
end

include("copy.jl")
include("Containers/Containers.jl")
include("operators.jl")
include("macros.jl")
include("optimizer_interface.jl")
Expand Down
Loading

0 comments on commit ecb7915

Please sign in to comment.