-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
91a30eb
commit 3a78132
Showing
14 changed files
with
1,485 additions
and
399 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
module UnrolledUtilitiesStaticArraysExt | ||
|
||
import UnrolledUtilities | ||
import StaticArrays: SVector, MVector | ||
|
||
@inline UnrolledUtilities.output_type_for_promotion(::SVector) = SVector | ||
@inline UnrolledUtilities.constructor_from_tuple(::Type{SVector}) = SVector | ||
|
||
@inline UnrolledUtilities.output_type_for_promotion(::MVector) = MVector | ||
@inline UnrolledUtilities.constructor_from_tuple(::Type{MVector}) = MVector | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
""" | ||
StaticBitVector{N, [U]}(f) | ||
StaticBitVector{N, [U]}([bit]) | ||
A statically-sized analogue of `BitVector` with `Unsigned` chunks of type `U`, | ||
which can be constructed using either a function `f(n)` or a constant `bit`. By | ||
default, `U` is set to `UInt8` and `bit` is set to `false`. | ||
This iterator can only store `Bool`s, so its `output_type_for_promotion` is a | ||
`ConditionalOutputType`. Efficient methods are provided for `unrolled_map`, | ||
`unrolled_accumulate`, `unrolled_take`, and `unrolled_drop`, though the methods | ||
for `unrolled_map` and `unrolled_accumulate` only apply when their output's | ||
first item is a `Bool`. No other unrolled functions can use `StaticBitVector`s | ||
as output types. | ||
""" | ||
struct StaticBitVector{N, U <: Unsigned, I <: NTuple{<:Any, U}} <: | ||
StaticSequence{N} | ||
ints::I | ||
end | ||
@inline StaticBitVector{N, U}(ints) where {N, U} = | ||
StaticBitVector{N, U, typeof(ints)}(ints) | ||
@inline StaticBitVector{N}(args...) where {N} = | ||
StaticBitVector{N, UInt8}(args...) | ||
|
||
@inline function StaticBitVector{N, U}(bit::Bool = false) where {N, U} | ||
n_bits_per_int = 8 * sizeof(U) | ||
n_ints = cld(N, n_bits_per_int) | ||
ints = ntuple(Returns(bit ? ~zero(U) : zero(U)), Val(n_ints)) | ||
return StaticBitVector{N, U}(ints) | ||
end | ||
|
||
@inline function StaticBitVector{N, U}(f::Function) where {N, U} | ||
n_bits_per_int = 8 * sizeof(U) | ||
n_ints = cld(N, n_bits_per_int) | ||
ints = ntuple(Val(n_ints)) do int_index | ||
@inline | ||
first_index = n_bits_per_int * (int_index - 1) + 1 | ||
unrolled_reduce( | ||
StaticOneTo(min(n_bits_per_int, N - first_index + 1)); | ||
init = zero(U), | ||
) do int, bit_index | ||
@inline | ||
bit_offset = bit_index - 1 | ||
int | U(f(first_index + bit_offset)::Bool) << bit_offset | ||
end | ||
end | ||
return StaticBitVector{N, U}(ints) | ||
end | ||
|
||
@inline function int_index_and_bit_offset(::Type{U}, n) where {U} | ||
int_offset, bit_offset = divrem(n - 1, 8 * sizeof(U)) | ||
return (int_offset + 1, bit_offset) | ||
end | ||
|
||
@inline function generic_getindex( | ||
itr::StaticBitVector{<:Any, U}, | ||
n::Integer, | ||
) where {U} | ||
int_index, bit_offset = int_index_and_bit_offset(U, n) | ||
int = itr.ints[int_index] | ||
return Bool(int >> bit_offset & one(int)) | ||
end | ||
|
||
@inline function Base.setindex( | ||
itr::StaticBitVector{N, U}, | ||
bit::Bool, | ||
n::Integer, | ||
) where {N, U} | ||
int_index, bit_offset = int_index_and_bit_offset(U, n) | ||
int = itr.ints[int_index] | ||
int′ = int & ~(one(int) << bit_offset) | U(bit) << bit_offset | ||
ints = Base.setindex(itr.ints, int′, int_index) | ||
return StaticBitVector{N, U}(ints) | ||
end | ||
|
||
@inline output_type_for_promotion(::StaticBitVector{<:Any, U}) where {U} = | ||
ConditionalOutputType(Bool, StaticBitVector{<:Any, U}) | ||
|
||
@inline function unrolled_map_into( | ||
::Type{StaticBitVector{<:Any, U}}, | ||
f, | ||
itrs..., | ||
) where {U} | ||
lazy_itr = Iterators.map(f, itrs...) | ||
N = length(lazy_itr) | ||
return StaticBitVector{N, U}(Base.Fix1(generic_getindex, lazy_itr)) | ||
end | ||
|
||
@inline function unrolled_accumulate_into( | ||
::Type{StaticBitVector{<:Any, U}}, | ||
op, | ||
itr, | ||
init, | ||
transform, | ||
) where {U} | ||
N = length(itr) | ||
n_bits_per_int = 8 * sizeof(U) | ||
n_ints = cld(N, n_bits_per_int) | ||
ints = unrolled_accumulate_into_tuple( | ||
StaticOneTo(n_ints); | ||
init = (nothing, init), | ||
transform = first, | ||
) do (_, init_value_for_new_int), int_index | ||
@inline | ||
first_index = n_bits_per_int * (int_index - 1) + 1 | ||
unrolled_reduce( | ||
StaticOneTo(min(n_bits_per_int, N - first_index + 1)); | ||
init = (zero(U), init_value_for_new_int), | ||
) do (int, prev_value), bit_index | ||
@inline | ||
bit_offset = bit_index - 1 | ||
item = generic_getindex(itr, first_index + bit_offset) | ||
new_value = | ||
first_index + bit_offset == 1 && prev_value isa NoInit ? | ||
item : op(prev_value, item) | ||
(int | U(transform(new_value)::Bool) << bit_offset, new_value) | ||
end | ||
end | ||
return StaticBitVector{N, U}(ints) | ||
end | ||
|
||
# TODO: Add unrolled_push and unrolled_append | ||
|
||
@inline function unrolled_take( | ||
itr::StaticBitVector{<:Any, U}, | ||
::Val{N}, | ||
) where {N, U} | ||
n_bits_per_int = 8 * sizeof(U) | ||
n_ints = cld(N, n_bits_per_int) | ||
ints = unrolled_take(itr.ints, Val(n_ints)) | ||
return StaticBitVector{N, U}(ints) | ||
end | ||
|
||
@inline function unrolled_drop( | ||
itr::StaticBitVector{N_old, U}, | ||
::Val{N}, | ||
) where {N_old, N, U} | ||
n_bits_per_int = 8 * sizeof(U) | ||
n_ints = cld(N_old - N, n_bits_per_int) | ||
n_dropped_ints = length(itr.ints) - n_ints | ||
bit_offset = N - n_bits_per_int * n_dropped_ints | ||
ints_without_offset = unrolled_drop(itr.ints, Val(n_dropped_ints)) | ||
ints = if bit_offset == 0 | ||
ints_without_offset | ||
else | ||
cur_ints = ints_without_offset | ||
next_ints = unrolled_push(unrolled_drop(cur_ints, Val(1)), nothing) | ||
unrolled_map_into_tuple(cur_ints, next_ints) do cur_int, next_int | ||
@inline | ||
isnothing(next_int) ? cur_int >> bit_offset : | ||
cur_int >> bit_offset | next_int << (n_bits_per_int - bit_offset) | ||
end | ||
end | ||
return StaticBitVector{N_old - N, U}(ints) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
""" | ||
StaticOneTo(N) | ||
A lazy and statically-sized analogue of `Base.OneTo(N)`. | ||
This iterator can only store the integers from 1 to `N`, so its | ||
`output_type_for_promotion` is `NoOutputType()`. An efficient method is provided | ||
for `unrolled_take`, but no other unrolled functions can use `StaticOneTo`s as | ||
output types. | ||
""" | ||
struct StaticOneTo{N} <: StaticSequence{N} end | ||
@inline StaticOneTo(N) = StaticOneTo{N}() | ||
|
||
@inline generic_getindex(::StaticOneTo, n) = n | ||
|
||
@inline output_type_for_promotion(::StaticOneTo) = NoOutputType() | ||
|
||
@inline unrolled_take(::StaticOneTo, ::Val{N}) where {N} = StaticOneTo(N) |
Oops, something went wrong.