Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scale, Product Array Combinator #349

Merged
merged 10 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/Finch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ export Scalar, SparseScalar, ShortCircuitScalar, SparseShortCircuitScalar
export walk, gallop, follow, extrude, laminate
export fiber, fiber!, Fiber!, pattern!, dropdefaults, dropdefaults!, redefault!
export diagmask, lotrimask, uptrimask, bandmask
export offset, permissive, protocolize, swizzle, toeplitz, window
export scale, product, offset, permissive, protocolize, swizzle, toeplitz, window

export choose, minby, maxby, overwrite, initwrite, d

export default, AsArray

export parallelAnalysis, ParallelAnalysisResults
export parallel, extent, dimless
export parallel, realextent, extent, dimless
export CPU, CPULocalVector, CPULocalMemory

export Limit
Expand Down Expand Up @@ -120,6 +120,8 @@ include("tensors/combinators/offset.jl")
include("tensors/combinators/toeplitz.jl")
include("tensors/combinators/windowed.jl")
include("tensors/combinators/swizzle.jl")
include("tensors/combinators/scale.jl")
include("tensors/combinators/product.jl")


include("traits.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/FinchNotation/FinchNotation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ module FinchNotation

export getval, getname

export overwrite, initwrite, Dimensionless, dimless, extent
export overwrite, initwrite, Dimensionless, dimless, extent, realextent

export d

Expand Down
1 change: 1 addition & 0 deletions src/FinchNotation/syntax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ overwrite(l, r) = r
struct Dimensionless end
const dimless = Dimensionless()
function extent end
function realextent end

struct FinchParserVisitor
nodes
Expand Down
52 changes: 45 additions & 7 deletions src/dimensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,27 @@
stop
end

@kwdef struct ContinuousExtent
start
stop
end

function virtual_call(::typeof(extent), ctx, start, stop)
if isconstant(start) && isconstant(stop)
Extent(start, stop)
end
end
function virtual_call(::typeof(realextent), ctx, start, stop)
if isconstant(start) && isconstant(stop)
ContinuousExtent(start, stop)
end
end

virtual_uncall(ext::Extent) = call(extent, ext.start, ext.stop)
virtual_uncall(ext::ContinuousExtent) = call(realextent, ext.start, ext.stop)

FinchNotation.finch_leaf(x::Extent) = virtual(x)
FinchNotation.finch_leaf(x::ContinuousExtent) = virtual(x)

Base.:(==)(a::Extent, b::Extent) =
a.start == b.start &&
Expand Down Expand Up @@ -162,6 +174,13 @@
stop = call(+, ext.stop, delta)
)
end
function shiftdim(ext::ContinuousExtent, delta)
ContinuousExtent(
start = call(+, ext.start, delta),
stop = call(+, ext.stop, delta)
)
end


shiftdim(ext::Dimensionless, delta) = dimless
shiftdim(ext::ParallelDimension, delta) = ParallelDimension(ext, shiftdim(ext.ext, delta), ext.device)
Expand All @@ -174,6 +193,32 @@
end
end


function scaledim(ext::Extent, scale)
Extent(

Check warning on line 198 in src/dimensions.jl

View check run for this annotation

Codecov / codecov/patch

src/dimensions.jl#L197-L198

Added lines #L197 - L198 were not covered by tests
start = call(*, ext.start, scale),
stop = call(*, ext.stop, scale)
)
end
function scaledim(ext::ContinuousExtent, scale)
ContinuousExtent(
start = call(*, ext.start, scale),
stop = call(*, ext.stop, scale)
)
end

scaledim(ext::Dimensionless, scale) = dimless
scaledim(ext::ParallelDimension, scale) = ParallelDimension(ext, scaledim(ext.ext, scale), ext.device)

Check warning on line 211 in src/dimensions.jl

View check run for this annotation

Codecov / codecov/patch

src/dimensions.jl#L211

Added line #L211 was not covered by tests

function scaledim(ext::FinchNode, body)
if ext.kind === virtual
scaledim(ext.val, body)
else
error("unimplemented")

Check warning on line 217 in src/dimensions.jl

View check run for this annotation

Codecov / codecov/patch

src/dimensions.jl#L217

Added line #L217 was not covered by tests
end
end


#virtual_intersect(ctx, a, b) = virtual_intersect(ctx, promote(a, b)...)
function virtual_intersect(ctx, a, b)
println(a, b)
Expand Down Expand Up @@ -204,13 +249,6 @@
)
end

@kwdef struct ContinuousExtent
start
stop
end

FinchNotation.finch_leaf(x::ContinuousExtent) = virtual(x)

make_extent(::Type, start, stop) = throw(ArgumentError("Unsupported type"))
make_extent(::Type{T}, start, stop) where T <: Integer = Extent(start, stop)
make_extent(::Type{T}, start, stop) where T <: Real = ContinuousExtent(start, stop)
Expand Down
152 changes: 152 additions & 0 deletions src/tensors/combinators/product.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
struct ProductArray{dim, Body} <: AbstractCombinator
body::Body
end

ProductArray(body, dim) = ProductArray{dim}(body)
ProductArray{dim}(body::Body) where {dim, Body} = ProductArray{dim, Body}(body)

Check warning on line 6 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L5-L6

Added lines #L5 - L6 were not covered by tests

Base.show(io::IO, ex::ProductArray) = Base.show(io, MIME"text/plain"(), ex)
function Base.show(io::IO, mime::MIME"text/plain", ex::ProductArray{dim}) where {dim}
print(io, "ProductArray{$dim}($(ex.body))")

Check warning on line 10 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L8-L10

Added lines #L8 - L10 were not covered by tests
end

#Base.getindex(arr::ProductArray, i...) = ...

struct VirtualProductArray <: AbstractVirtualCombinator
body
dim
end

function is_injective(lvl::VirtualProductArray, ctx)
sub = is_injective(lvl.body, ctx)
return [sub[1:lvl.dim]..., false, sub[lvl.dim + 1:end]...]

Check warning on line 22 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L20-L22

Added lines #L20 - L22 were not covered by tests
end
function is_concurrent(lvl::VirtualProductArray, ctx)
sub = is_concurrent(lvl.body, ctx)
return [sub[1:lvl.dim]..., false, sub[lvl.dim + 1:end]...]

Check warning on line 26 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L24-L26

Added lines #L24 - L26 were not covered by tests
end
is_atomic(lvl::VirtualProductArray, ctx) = is_atomic(lvl.body, ctx)

Check warning on line 28 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L28

Added line #L28 was not covered by tests

Base.show(io::IO, ex::VirtualProductArray) = Base.show(io, MIME"text/plain"(), ex)
function Base.show(io::IO, mime::MIME"text/plain", ex::VirtualProductArray)
print(io, "VirtualProductArray($(ex.body), $(ex.dim))")

Check warning on line 32 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L30-L32

Added lines #L30 - L32 were not covered by tests
end

Base.summary(io::IO, ex::VirtualProductArray) = print(io, "VProduct($(summary(ex.body)), $(ex.dim))")

Check warning on line 35 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L35

Added line #L35 was not covered by tests

FinchNotation.finch_leaf(x::VirtualProductArray) = virtual(x)

Check warning on line 37 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L37

Added line #L37 was not covered by tests

function virtualize(ex, ::Type{ProductArray{dim, Body}}, ctx) where {dim, Body}
VirtualProductArray(virtualize(:($ex.body), Body, ctx), dim)

Check warning on line 40 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L39-L40

Added lines #L39 - L40 were not covered by tests
end

products(body, dim) = ProductArray(body, dim)
function virtual_call(::typeof(products), ctx, body, dim)
@assert isliteral(dim)
VirtualProductArray(body, dim.val)

Check warning on line 46 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L43-L46

Added lines #L43 - L46 were not covered by tests
end

virtual_uncall(arr::VirtualProductArray) = call(products, arr.body, arr.dim)

Check warning on line 49 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L49

Added line #L49 was not covered by tests

lower(tns::VirtualProductArray, ctx::AbstractCompiler, ::DefaultStyle) = :(ProductArray($(ctx(tns.body)), $(tns.dim)))

Check warning on line 51 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L51

Added line #L51 was not covered by tests

#virtual_size(arr::Fill, ctx::AbstractCompiler) = (dimless,) # this is needed for multidimensional convolution..
#virtual_size(arr::Simplify, ctx::AbstractCompiler) = (dimless,)
#virtual_size(arr::Furlable, ctx::AbstractCompiler) = (dimless,)

function virtual_size(arr::VirtualProductArray, ctx::AbstractCompiler)
dims = virtual_size(arr.body, ctx)
return (dims[1:arr.dim - 1]..., dimless, dimless, dims[arr.dim + 1:end]...)

Check warning on line 59 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L57-L59

Added lines #L57 - L59 were not covered by tests
end
function virtual_resize!(arr::VirtualProductArray, ctx::AbstractCompiler, dims...)
virtual_resize!(arr.body, ctx, dims[1:arr.dim - 1]..., dimless, dims[arr.dim + 2:end]...)

Check warning on line 62 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L61-L62

Added lines #L61 - L62 were not covered by tests
end

function instantiate_reader(arr::VirtualProductArray, ctx, protos)
VirtualProductArray(instantiate_reader(arr.body, ctx, [protos[1:arr.dim]; protos[arr.dim + 2:end]]), arr.dim)

Check warning on line 66 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L65-L66

Added lines #L65 - L66 were not covered by tests
end
function instantiate_updater(arr::VirtualProductArray, ctx, protos)
VirtualProductArray(instantiate_updater(arr.body, ctx, [protos[1:arr.dim]; protos[arr.dim + 2:end]]), arr.dim)

Check warning on line 69 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L68-L69

Added lines #L68 - L69 were not covered by tests
end

(ctx::Stylize{<:AbstractCompiler})(node::VirtualProductArray) = ctx(node.body)
function stylize_access(node, ctx::Stylize{<:AbstractCompiler}, tns::VirtualProductArray)
stylize_access(node, ctx, tns.body)

Check warning on line 74 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L72-L74

Added lines #L72 - L74 were not covered by tests
end

function popdim(node::VirtualProductArray, ctx)
if length(virtual_size(node, ctx)) == node.dim
return node.body

Check warning on line 79 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L77-L79

Added lines #L77 - L79 were not covered by tests
else
return node

Check warning on line 81 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L81

Added line #L81 was not covered by tests
end
end

truncate(node::VirtualProductArray, ctx, ext, ext_2) = VirtualProductArray(truncate(node.body, ctx, ext, ext_2), node.dim)

Check warning on line 85 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L85

Added line #L85 was not covered by tests

function get_point_body(node::VirtualProductArray, ctx, ext, idx)
body_2 = get_point_body(node.body, ctx, ext, idx)
if body_2 === nothing
return nothing

Check warning on line 90 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L87-L90

Added lines #L87 - L90 were not covered by tests
else
return popdim(VirtualProductArray(body_2, node.dim), ctx)

Check warning on line 92 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L92

Added line #L92 was not covered by tests
end
end

(ctx::ThunkVisitor)(node::VirtualProductArray) = VirtualProductArray(ctx(node.body), node.dim)

Check warning on line 96 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L96

Added line #L96 was not covered by tests

function get_run_body(node::VirtualProductArray, ctx, ext)
body_2 = get_run_body(node.body, ctx, ext)
if body_2 === nothing
return nothing

Check warning on line 101 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L98-L101

Added lines #L98 - L101 were not covered by tests
else
return popdim(VirtualProductArray(body_2, node.dim), ctx)

Check warning on line 103 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L103

Added line #L103 was not covered by tests
end
end

function get_acceptrun_body(node::VirtualProductArray, ctx, ext)
body_2 = get_acceptrun_body(node.body, ctx, ext)
if body_2 === nothing
return nothing

Check warning on line 110 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L107-L110

Added lines #L107 - L110 were not covered by tests
else
return popdim(VirtualProductArray(body_2, node.dim), ctx)

Check warning on line 112 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L112

Added line #L112 was not covered by tests
end
end

function (ctx::SequenceVisitor)(node::VirtualProductArray)
map(ctx(node.body)) do (keys, body)
return keys => VirtualProductArray(body, node.dim)

Check warning on line 118 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L116-L118

Added lines #L116 - L118 were not covered by tests
end
end

phase_body(node::VirtualProductArray, ctx, ext, ext_2) = VirtualProductArray(phase_body(node.body, ctx, ext, ext_2), node.dim)
phase_range(node::VirtualProductArray, ctx, ext) = phase_range(node.body, ctx, ext)

Check warning on line 123 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L122-L123

Added lines #L122 - L123 were not covered by tests

get_spike_body(node::VirtualProductArray, ctx, ext, ext_2) = VirtualProductArray(get_spike_body(node.body, ctx, ext, ext_2), node.dim)
get_spike_tail(node::VirtualProductArray, ctx, ext, ext_2) = VirtualProductArray(get_spike_tail(node.body, ctx, ext, ext_2), node.dim)

Check warning on line 126 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L125-L126

Added lines #L125 - L126 were not covered by tests

visit_fill(node, tns::VirtualProductArray) = visit_fill(node, tns.body)
visit_simplify(node::VirtualProductArray) = VirtualProductArray(visit_simplify(node.body), node.dim)

Check warning on line 129 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L128-L129

Added lines #L128 - L129 were not covered by tests

(ctx::SwitchVisitor)(node::VirtualProductArray) = map(ctx(node.body)) do (guard, body)
guard => VirtualProductArray(body, node.dim)

Check warning on line 132 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L131-L132

Added lines #L131 - L132 were not covered by tests
end

jumper_body(node::VirtualProductArray, ctx, ext) = VirtualProductArray(jumper_body(node.body, ctx, ext), node.dim)
stepper_body(node::VirtualProductArray, ctx, ext) = VirtualProductArray(stepper_body(node.body, ctx, ext), node.dim)
stepper_seek(node::VirtualProductArray, ctx, ext) = stepper_seek(node.body, ctx, ext)
jumper_seek(node::VirtualProductArray, ctx, ext) = jumper_seek(node.body, ctx, ext)

Check warning on line 138 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L135-L138

Added lines #L135 - L138 were not covered by tests

getroot(tns::VirtualProductArray) = getroot(tns.body)

Check warning on line 140 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L140

Added line #L140 was not covered by tests

function unfurl(tns::VirtualProductArray, ctx, ext, mode, protos...)
if length(virtual_size(tns, ctx)) == tns.dim + 1
Unfurled(tns,

Check warning on line 144 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L142-L144

Added lines #L142 - L144 were not covered by tests
Lookup(
body = (ctx, idx) -> VirtualPermissiveArray(VirtualScaleArray(tns.body, ([literal(1) for _ in 1:tns.dim - 1]..., idx)), ([false for _ in 1:tns.dim - 1]..., true)),

Check warning on line 146 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L146

Added line #L146 was not covered by tests
)
)
else
VirtualProductArray(unfurl(tns.body, ctx, ext, mode, protos...), tns.dim)

Check warning on line 150 in src/tensors/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/combinators/product.jl#L150

Added line #L150 was not covered by tests
end
end
Loading
Loading