-
-
Notifications
You must be signed in to change notification settings - Fork 399
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create containers with map instead of for loops
- Loading branch information
Showing
16 changed files
with
439 additions
and
354 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
const ArrayIndices{N} = Base.Iterators.ProductIterator{NTuple{N, Base.OneTo{Int}}} | ||
container(f::Function, indices) = container(f, indices, default_container(indices)) | ||
default_container(::ArrayIndices) = Array | ||
function container(f::Function, indices::ArrayIndices, ::Type{Array}) | ||
return map(I -> f(I...), indices) | ||
end | ||
default_container(::Base.Iterators.ProductIterator) = DenseAxisArray | ||
function container(f::Function, indices::Base.Iterators.ProductIterator, | ||
::Type{DenseAxisArray}) | ||
return DenseAxisArray(map(I -> f(I...), indices), indices.iterators...) | ||
end | ||
default_container(::NestedIterator) = SparseAxisArray | ||
function container(f::Function, indices::NestedIterator, | ||
::Type{SparseAxisArray}) | ||
mappings = map(I -> I => f(I...), indices) | ||
data = Dict(mappings) | ||
if length(mappings) != length(data) | ||
# TODO compute idx | ||
error(string("Repeated index ", idx,". Index sets must have unique elements.")) | ||
end | ||
return SparseAxisArray(Dict(data)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
using Base.Meta | ||
|
||
""" | ||
_extract_kw_args(args) | ||
Process the arguments to a macro, separating out the keyword arguments. | ||
Return a tuple of (flat_arguments, keyword arguments, and requested_container), | ||
where `requested_container` is a symbol to be passed to `parse_container`. | ||
""" | ||
function _extract_kw_args(args) | ||
kw_args = filter(x -> isexpr(x, :(=)) && x.args[1] != :container , collect(args)) | ||
flat_args = filter(x->!isexpr(x, :(=)), collect(args)) | ||
requested_container = :Auto | ||
for kw in args | ||
if isexpr(kw, :(=)) && kw.args[1] == :container | ||
requested_container = kw.args[2] | ||
end | ||
end | ||
return flat_args, kw_args, requested_container | ||
end | ||
|
||
function _try_parse_idx_set(arg::Expr) | ||
# [i=1] and x[i=1] parse as Expr(:vect, Expr(:(=), :i, 1)) and | ||
# Expr(:ref, :x, Expr(:kw, :i, 1)) respectively. | ||
if arg.head === :kw || arg.head === :(=) | ||
@assert length(arg.args) == 2 | ||
return true, arg.args[1], arg.args[2] | ||
elseif isexpr(arg, :call) && arg.args[1] === :in | ||
return true, arg.args[2], arg.args[3] | ||
else | ||
return false, nothing, nothing | ||
end | ||
end | ||
function _explicit_oneto(index_set) | ||
s = Meta.isexpr(index_set,:escape) ? index_set.args[1] : index_set | ||
if Meta.isexpr(s,:call) && length(s.args) == 3 && s.args[1] == :(:) && s.args[2] == 1 | ||
return :(Base.OneTo($index_set)) | ||
else | ||
return index_set | ||
end | ||
end | ||
|
||
function _expr_is_splat(ex::Expr) | ||
if ex.head == :(...) | ||
return true | ||
elseif ex.head == :escape | ||
return _expr_is_splat(ex.args[1]) | ||
end | ||
return false | ||
end | ||
_expr_is_splat(::Any) = false | ||
|
||
""" | ||
_parse_ref_sets(expr::Expr) | ||
Helper function for macros to construct container objects. Takes an `Expr` that | ||
specifies the container, e.g. `:(x[i=1:3,[:red,:blue]],k=S; i+k <= 6)`, and | ||
returns: | ||
1. `idxvars`: Names for the index variables, e.g. `[:i, gensym(), :k]` | ||
2. `idxsets`: Sets used for indexing, e.g. `[1:3, [:red,:blue], S]` | ||
3. `condition`: Expr containing any conditional imposed on indexing, or `:()` if none is present | ||
""" | ||
function _parse_ref_sets(_error::Function, expr::Expr) | ||
c = copy(expr) | ||
idxvars = Any[] | ||
idxsets = Any[] | ||
# On 0.7, :(t[i;j]) is a :ref, while t[i,j;j] is a :typed_vcat. | ||
# In both cases :t is the first arg. | ||
if isexpr(c, :typed_vcat) || isexpr(c, :ref) | ||
popfirst!(c.args) | ||
end | ||
condition = :() | ||
if isexpr(c, :vcat) || isexpr(c, :typed_vcat) | ||
# Parameters appear as plain args at the end. | ||
if length(c.args) > 2 | ||
_error("Unsupported syntax $c.") | ||
elseif length(c.args) == 2 | ||
condition = pop!(c.args) | ||
end # else no condition. | ||
elseif isexpr(c, :ref) || isexpr(c, :vect) | ||
# Parameters appear at the front. | ||
if isexpr(c.args[1], :parameters) | ||
if length(c.args[1].args) != 1 | ||
_error("Invalid syntax: $c. Multiple semicolons are not " * | ||
"supported.") | ||
end | ||
condition = popfirst!(c.args).args[1] | ||
end | ||
end | ||
if isexpr(c, :vcat) || isexpr(c, :typed_vcat) || isexpr(c, :ref) | ||
if isexpr(c.args[1], :parameters) | ||
@assert length(c.args[1].args) == 1 | ||
condition = popfirst!(c.args).args[1] | ||
end # else no condition. | ||
end | ||
|
||
for s in c.args | ||
parse_done = false | ||
if isa(s, Expr) | ||
parse_done, idxvar, _idxset = _try_parse_idx_set(s::Expr) | ||
if parse_done | ||
idxset = esc(_idxset) | ||
end | ||
end | ||
if !parse_done # No index variable specified | ||
idxvar = gensym() | ||
idxset = esc(s) | ||
end | ||
push!(idxvars, idxvar) | ||
push!(idxsets, idxset) | ||
end | ||
return idxvars, idxsets, condition | ||
end | ||
|
||
""" | ||
_build_ref_sets(expr::Expr) | ||
Helper function for macros to construct container objects. Takes an `Expr` that | ||
specifies the container, e.g. `:(x[i=1:3,[:red,:blue]],k=S; i+k <= 6)`, and | ||
returns: | ||
1. `idxvars`: Names for the index variables, e.g. `[:i, gensym(), :k]` | ||
2. `idxsets`: Sets used for indexing, e.g. `[1:3, [:red,:blue], S]` | ||
3. `condition`: Expr containing any conditional imposed on indexing, or `:()` if none is present | ||
""" | ||
function _build_ref_sets(_error::Function, expr::Expr) | ||
idxvars, idxsets, condition = _parse_ref_sets(_error, expr) | ||
if any(_expr_is_splat.(idxsets)) | ||
_error("cannot use splatting operator `...` in the definition of an index set.") | ||
end | ||
has_dependent = has_dependent_sets(idxvars, idxsets) | ||
if has_dependent || condition != :() | ||
esc_idxvars = esc.(idxvars) | ||
idxfuns = [:(($(esc_idxvars[1:(i - 1)]...),) -> $(idxsets[i])) for i in 1:length(idxvars)] | ||
if condition == :() | ||
indices = :(Containers.NestedIterator(($(idxfuns...),))) | ||
else | ||
condition_fun = :(($(esc_idxvars...),) -> $condition) | ||
indices = :(Containers.NestedIterator(($(idxfuns...),), $condition_fun)) | ||
end | ||
else | ||
indices = :(Base.Iterators.product(($(_explicit_oneto.(idxsets)...)))) | ||
end | ||
return idxvars, indices | ||
end | ||
|
||
function container_code(idxvars, indices, code, requested_container) | ||
esc_idxvars = esc.(idxvars) | ||
func = :(($(esc_idxvars...),) -> $code) | ||
if requested_container == :Auto | ||
return :(Containers.container($func, $indices)) | ||
else | ||
return :(Containers.container($func, $indices, $requested_container)) | ||
end | ||
end | ||
function parse_container(_error, var, value, requested_container) | ||
idxvars, indices = _build_ref_sets(_error, var) | ||
return container_code(idxvars, indices, value, requested_container) | ||
end | ||
|
||
macro container(args...) | ||
args, kw_args, requested_container = _extract_kw_args(args) | ||
@assert length(args) == 2 | ||
@assert isempty(kw_args) | ||
var, value = args | ||
name = var.args[1] | ||
code = parse_container(error, var, esc(value), requested_container) | ||
return :($(esc(name)) = $code) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
struct NestedIterator{T} | ||
iterators::T # Tuple of functions | ||
condition::Function | ||
end | ||
NestedIterator(iterator) = NestedIterator(iterator, (args...) -> true) | ||
Base.IteratorSize(::Type{<:NestedIterator}) = Base.SizeUnknown() | ||
Base.IteratorEltype(::Type{<:NestedIterator}) = Base.EltypeUnknown() | ||
function next_iterate(it::NestedIterator, i, elems, states, iterator, elem_state) | ||
if elem_state === nothing | ||
return nothing | ||
end | ||
elem, state = elem_state | ||
elems_states = first_iterate( | ||
it, i + 1, (elems..., elem), | ||
(states..., (iterator, state, elem))) | ||
if elems_states !== nothing | ||
return elems_states | ||
end | ||
return next_iterate(it, i, elems, states, iterator, iterate(iterator, state)) | ||
end | ||
function first_iterate(it::NestedIterator, i, elems, states) | ||
if i > length(it.iterators) | ||
if it.condition(elems...) | ||
return elems, states | ||
else | ||
return nothing | ||
end | ||
end | ||
iterator = it.iterators[i](elems...) | ||
return next_iterate(it, i, elems, states, iterator, iterate(iterator)) | ||
end | ||
function tail_iterate(it::NestedIterator, i, elems, states) | ||
if i > length(it.iterators) | ||
return nothing | ||
end | ||
next = tail_iterate(it, i + 1, (elems..., states[i][3]), states) | ||
if next !== nothing | ||
return next | ||
end | ||
iterator = states[i][1] | ||
next_iterate(it, i, elems, states[1:(i - 1)], iterator, iterate(iterator, states[i][2])) | ||
end | ||
Base.iterate(it::NestedIterator) = first_iterate(it, 1, tuple(), tuple()) | ||
Base.iterate(it::NestedIterator, states) = tail_iterate(it, 1, tuple(), states) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.