From c966c1ce0125138600e3012729571b421988a46f Mon Sep 17 00:00:00 2001 From: Lewis Hemens Date: Fri, 2 Feb 2024 08:56:32 +0000 Subject: [PATCH] Finish shape wrapping --- arrayfire/ArrayFire.java | 390 +++++++++++++++++++-------------------- arrayfire/R0.java | 13 ++ arrayfire/R1.java | 14 ++ arrayfire/R2.java | 14 ++ arrayfire/R3.java | 14 ++ arrayfire/Shape.java | 47 ++++- arrayfire/SvdResult.java | 2 +- arrayfire/Tensor.java | 13 +- 8 files changed, 308 insertions(+), 199 deletions(-) create mode 100644 arrayfire/R0.java create mode 100644 arrayfire/R1.java create mode 100644 arrayfire/R2.java create mode 100644 arrayfire/R3.java diff --git a/arrayfire/ArrayFire.java b/arrayfire/ArrayFire.java index bae2b31..d1f8ea3 100644 --- a/arrayfire/ArrayFire.java +++ b/arrayfire/ArrayFire.java @@ -181,7 +181,7 @@ public static > Index permutation(D dim) { /** * Creates a device tensor from the given native array. */ - public static
, AT extends NativeArray> Tensor> create( + public static
, AT extends NativeArray> Tensor> create( AT array) { return create(array, shape(n(array.length()))); } @@ -205,7 +205,7 @@ public static > Index permutation(D dim) { * This is not recommended in a production setting, as memory will be copied twice. Instead, use {@link #create(NativeArray)}. */ @SafeVarargs - public static , DT extends DataType> Tensor> create( + public static , DT extends DataType> Tensor> create( DT type, JT... values) { return tidy(() -> { var array = type.create(values.length); @@ -221,7 +221,7 @@ public static > Index permutation(D dim) { * This is not recommended in a production setting, as memory will be copied twice. Instead, use {@link #create(NativeArray)}. */ @SuppressWarnings("unchecked") - public static , DT extends DataType> Tensor> create( + public static , DT extends DataType> Tensor> create( DT type, JTA values) { return tidy(() -> { var length = Array.getLength(values); @@ -236,43 +236,47 @@ public static > Index permutation(D dim) { /** * Creates a {@link F32} device tensor from the given float values. */ - public static Tensor> create(float... values) { + public static Tensor> create(float... values) { return create(F32, values); } /** * Creates a {@link F64} device tensor from the given double values. */ - public static Tensor> create(double... values) { + public static Tensor> create(double... values) { return create(F64, values); } /** * Creates a {@link S32} device tensor from the given byte values. */ - public static Tensor> create(int... values) { + public static Tensor> create(int... values) { return create(S32, values); } /** * Creates a constant scalar {@link F32} device tensor from the given float value. */ - public static Tensor> constant(float value) { + public static Tensor constant(float value) { return constant(F32, value); } /** * Creates a constant scalar {@link F64} device tensor from the given float value. */ - public static Tensor> constant(double value) { + public static Tensor constant(double value) { return constant(F64, value); } /** * Creates a constant scalar device tensor from the given type and double value. */ - public static
> Tensor> constant(DT type, double value) { - return constant(type, shape(u()), value); + public static
> Tensor constant(DT type, double value) { + return constant(type, scalar(), value); + } + + public static R0 scalar() { + return new R0(); } /** @@ -313,7 +317,8 @@ public static Index seq(int begin, int endInclusive) { /** * Returns a lookup index using the given tensor as lookup values (indices). */ - public static
, D0 extends Num> Index seq(Tensor> index) { + public static
, D0 extends Num, S extends Shape> Index seq( + Tensor index) { return new Index<>(index, index.shape().d0()::create); } @@ -341,44 +346,44 @@ public static Span span() { /** * Returns a 1D shape of the given size and type N. */ - public static Shape shape(int d0) { - return new Shape<>(n(d0), u(), u(), u()); + public static R1 shape(int d0) { + return new R1<>(n(d0)); } /** * Returns a 1D shape of the given dimension. */ - public static > Shape shape(D0 d0) { - return new Shape<>(d0, u(), u(), u()); + public static > R1 shape(D0 d0) { + return new R1<>(d0); } - public static > Shape shape(D0 d0, int d1) { - return new Shape<>(d0, n(d1), u(), u()); + public static > R2 shape(D0 d0, int d1) { + return new R2<>(d0, n(d1)); } - public static > Shape shape(int d0, D1 d1) { - return new Shape<>(n(d0), d1, u(), u()); + public static > R2 shape(int d0, D1 d1) { + return new R2<>(n(d0), d1); } - public static Shape shape(int d0, int d1) { - return new Shape<>(n(d0), n(d1), u(), u()); + public static R2 shape(int d0, int d1) { + return new R2<>(n(d0), n(d1)); } - public static , D1 extends Num> Shape shape(D0 d0, D1 d1) { - return new Shape<>(d0, d1, u(), u()); + public static , D1 extends Num> R2 shape(D0 d0, D1 d1) { + return new R2<>(d0, d1); } - public static Shape shape(int d0, int d1, int d2) { - return new Shape<>(n(d0), n(d1), n(d2), u()); + public static R3 shape(int d0, int d1, int d2) { + return new R3<>(n(d0), n(d1), n(d2)); } public static Shape shape(int d0, int d1, int d2, int d3) { return new Shape<>(n(d0), n(d1), n(d2), n(d3)); } - public static , D1 extends Num, D2 extends Num> Shape shape(D0 d0, D1 d1, + public static , D1 extends Num, D2 extends Num> R3 shape(D0 d0, D1 d1, D2 d2) { - return new Shape<>(d0, d1, d2, u()); + return new R3<>(d0, d1, d2); } public static , D1 extends Num, D2 extends Num, D3 extends Num> Shape shape( @@ -386,18 +391,18 @@ public static , D1 extends Num, D2 extends Num, D3 ex return new Shape<>(d0, d1, d2, d3); } - private static , IT extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Operation.Builder.Unary>>.Single>> reduce( - String name, Tensor> a, - Functions.Function3 method, arrayfire.D0 dim, T resultType) { + private static , IT extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Operation.Builder.Unary>.Single>> reduce( + String name, Tensor a, Functions.Function3 method, + arrayfire.D0 dim, T resultType) { return operation(name) .inputs(a) .outputs(prototype(resultType, shape(u(), a.shape().d1(), a.shape().d2(), a.shape().d3()))) .operation(ptr -> method.apply(ptr, a.dereference(), dim.index())); } - private static , IT extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Operation.Builder.Unary>>.Single>> reduce( - String name, Tensor> a, - Functions.Function3 method, arrayfire.D1 dim, T resultType) { + private static , IT extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Operation.Builder.Unary>.Single>> reduce( + String name, Tensor a, Functions.Function3 method, + arrayfire.D1 dim, T resultType) { return operation(name) .inputs(a) .outputs(prototype(resultType, shape(a.shape().d0(), u(), a.shape().d2(), a.shape().d3()))) @@ -405,18 +410,18 @@ public static , D1 extends Num, D2 extends Num, D3 ex } - private static , IT extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Operation.Builder.Unary>>.Single>> reduce( - String name, Tensor> a, - Functions.Function3 method, arrayfire.D2 dim, T resultType) { + private static , IT extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Operation.Builder.Unary>.Single>> reduce( + String name, Tensor a, Functions.Function3 method, + arrayfire.D2 dim, T resultType) { return operation(name) .inputs(a) .outputs(prototype(resultType, shape(a.shape().d0(), a.shape().d1(), u(), a.shape().d3()))) .operation(ptr -> method.apply(ptr, a.dereference(), dim.index())); } - private static , IT extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Operation.Builder.Unary>>.Single>> reduce( - String name, Tensor> a, - Functions.Function3 method, arrayfire.D3 dim, T resultType) { + private static , IT extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Operation.Builder.Unary>.Single>> reduce( + String name, Tensor a, Functions.Function3 method, + arrayfire.D3 dim, T resultType) { return operation(name) .inputs(a) .outputs(prototype(resultType, shape(a.shape().d0(), a.shape().d1(), a.shape().d2(), u()))) @@ -486,14 +491,14 @@ public static , D1 extends Num, D2 extends Num, D3 ex /** * Create a tensor with values [0, n-1]. */ - public static Tensor> range(int n) { + public static Tensor> range(int n) { return range(U32, n); } /** * Create a tensor with values [0, n-1] of the given type. */ - public static > Tensor> range(T type, int n) { + public static > Tensor> range(T type, int n) { var shape = shape(n(n)); return operation("range") .inputs() @@ -635,14 +640,14 @@ private static MemorySegment nativeDims(Shape shape) { /** * Transpose D0 and D1 dimensions of the given tensor. */ - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> transpose( - Tensor> tensor) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> transpose( + Tensor tensor) { return operation("transpose") .inputs(tensor) .outputs(prototype(tensor.type(), shape(tensor.shape().d1(), tensor.shape().d0(), tensor.shape().d2(), tensor.shape().d3()))) .operation(ptr -> arrayfire_h.af_transpose(ptr, tensor.dereference(), true)) - .grads((result, grads) -> transpose(grads)) + .grads((result, grads) -> transpose(grads).reshape(tensor.shape())) .build(); } @@ -848,45 +853,46 @@ private static void checkTileableIsSmaller(Tensor left, Tileable rig } - public static , S extends Shape> Tensor add(Tensor left, - Tensor right) { + public static , S extends Shape, SL extends S, SR extends S> Tensor add( + Tensor left, Tensor right) { return operation("add") .inputs(left, right) .outputs(prototype(left.type(), left.shape())) .operation(ptr -> arrayfire_h.af_add(ptr, left.dereference(), right.dereference(), true)) - .grads((result, grads) -> new TensorPair<>(grads, grads)) + .grads((result, grads) -> new TensorPair<>(grads, grads.reshape(right.shape()))) .build(); } - public static , S extends Shape> Tensor sub(Tensor left, - Tensor right) { + public static , S extends Shape, SL extends S, SR extends S> Tensor sub( + Tensor left, Tensor right) { return operation("sub") .inputs(left, right) - .outputs(prototype(left.type(), left.shape())) + .outputs(prototype(left)) .operation(ptr -> arrayfire_h.af_sub(ptr, left.dereference(), right.dereference(), true)) - .grads((result, grads) -> new TensorPair<>(grads, grads.negate())) + .grads((result, grads) -> new TensorPair<>(grads, grads.negate().reshape(right.shape()))) .build(); } - public static , S extends Shape> Tensor ge(Tensor tensor, - Tensor rhs) { + public static , S extends Shape, SL extends S, SR extends S> Tensor ge( + Tensor left, Tensor right) { return operation("ge") - .inputs(tensor, rhs) - .outputs(prototype(B8, tensor.shape())) - .operation(ptr -> arrayfire_h.af_ge(ptr, tensor.dereference(), rhs.dereference(), true)) + .inputs(left, right) + .outputs(prototype(B8, left.shape())) + .operation(ptr -> arrayfire_h.af_ge(ptr, left.dereference(), right.dereference(), true)) .build(); } - public static , S extends Shape> Tensor le(Tensor tensor, - Tensor rhs) { + public static , S extends Shape, SL extends S, SR extends S> Tensor le( + Tensor left, Tensor right) { return operation("le") - .inputs(tensor, rhs) - .outputs(prototype(B8, tensor.shape())) - .operation(ptr -> arrayfire_h.af_le(ptr, tensor.dereference(), rhs.dereference(), true)) + .inputs(left, right) + .outputs(prototype(B8, left.shape())) + .operation(ptr -> arrayfire_h.af_le(ptr, left.dereference(), right.dereference(), true)) .build(); } - public static > Tensor and(Tensor left, Tensor right) { + public static , SL extends S, SR extends S> Tensor and(Tensor left, + Tensor right) { return operation("and") .inputs(left, right) .outputs(prototype(B8, left.shape())) @@ -894,8 +900,8 @@ private static void checkTileableIsSmaller(Tensor left, Tileable rig .build(); } - public static , S extends Shape> Tensor maxof(Tensor left, - Tensor right) { + public static , S extends Shape, SL extends S, SR extends S> Tensor maxof( + Tensor left, Tensor right) { return operation("maxof") .inputs(left, right) .outputs(prototype(left.type(), left.shape())) @@ -903,13 +909,13 @@ private static void checkTileableIsSmaller(Tensor left, Tileable rig .grads((result, grads) -> { var leftIsMax = eq(result, left).cast(left.type()); var rightIsMax = eq(result, right).cast(left.type()); - return new TensorPair<>(mul(leftIsMax, grads), mul(rightIsMax, grads)); + return new TensorPair<>(mul(leftIsMax, grads), mul(rightIsMax, grads).reshape(right.shape())); }) .build(); } - public static , S extends Shape> Tensor minof(Tensor left, - Tensor right) { + public static , S extends Shape, SL extends S, SR extends S> Tensor minof( + Tensor left, Tensor right) { return operation("minof") .inputs(left, right) .outputs(prototype(left.type(), left.shape())) @@ -917,26 +923,27 @@ private static void checkTileableIsSmaller(Tensor left, Tileable rig .grads((result, grads) -> { var leftIsMin = eq(result, left).cast(left.type()); var rightIsMin = eq(result, right).cast(left.type()); - return new TensorPair<>(mul(leftIsMin, grads), mul(rightIsMin, grads)); + return new TensorPair<>(mul(leftIsMin, grads), mul(rightIsMin, grads).castshape(right.shape())); }) .build(); } - public static , LD0 extends Num, RD0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> join( - Tensor> lhs, Tensor> rhs) { + public static , LD0 extends Num, RD0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, SL extends Shape, SR extends Shape> Tensor> join( + Tensor lhs, Tensor rhs) { return operation("join") .inputs(lhs, rhs) .outputs(prototype(lhs.type(), shape(n(lhs.shape().d0().size() + rhs.shape().d0().size()), lhs.shape().d1(), lhs.shape().d2(), lhs.shape().d3()))) .operation(ptr -> arrayfire_h.af_join(ptr, 0, lhs.dereference(), rhs.dereference())) - .grads((result, grads) -> new TensorPair<>(index(grads, seq(lhs.shape().d0())), - index(grads, seq(lhs.shape().d0().size(), rhs.shape().d0())))) + .grads( + (result, grads) -> new TensorPair<>(index(grads, seq(lhs.shape().d0())).castshape(lhs.shape()), + index(grads, seq(lhs.shape().d0().size(), rhs.shape().d0())).castshape(rhs.shape()))) .build(); } - public static , LD1 extends Num, RD1 extends Num, D0 extends Num, D2 extends Num, D3 extends Num> Tensor> join( - Tensor> lhs, Tensor> rhs, arrayfire.D1 ignored) { + public static , LD1 extends Num, RD1 extends Num, D0 extends Num, D2 extends Num, D3 extends Num, SL extends Shape, SR extends Shape> Tensor> join( + Tensor lhs, Tensor rhs, arrayfire.D1 ignored) { if (!(lhs.shape().d0().size() == rhs.shape().d0().size() && lhs.shape().d2().size() == rhs.shape().d2().size() && lhs.shape().d3().size() == rhs.shape().d3().size())) { @@ -949,13 +956,14 @@ private static void checkTileableIsSmaller(Tensor left, Tileable rig shape(lhs.shape().d0(), n(lhs.shape().d1().size() + rhs.shape().d1().size()), lhs.shape().d2(), lhs.shape().d3()))) .operation(ptr -> arrayfire_h.af_join(ptr, 1, lhs.dereference(), rhs.dereference())) - .grads((result, grads) -> new TensorPair<>(index(grads, span(), seq(lhs.shape().d1())), - index(grads, span(), seq(lhs.shape().d1().size(), rhs.shape().d1())))) + .grads((result, grads) -> new TensorPair<>( + index(grads, span(), seq(lhs.shape().d1())).castshape(lhs.shape()), + index(grads, span(), seq(lhs.shape().d1().size(), rhs.shape().d1())).castshape(rhs.shape()))) .build(); } - public static , LD2 extends Num, RD2 extends Num, D0 extends Num, D1 extends Num, D3 extends Num> Tensor> join( - Tensor> lhs, Tensor> rhs, arrayfire.D2 ignored) { + public static , LD2 extends Num, RD2 extends Num, D0 extends Num, D1 extends Num, D3 extends Num, SL extends Shape, SR extends Shape> Tensor> join( + Tensor lhs, Tensor rhs, arrayfire.D2 ignored) { if (!(lhs.shape().d0().size() == rhs.shape().d0().size() && lhs.shape().d1().size() == rhs.shape().d1().size() && lhs.shape().d3().size() == rhs.shape().d3().size())) { @@ -968,13 +976,15 @@ private static void checkTileableIsSmaller(Tensor left, Tileable rig shape(lhs.shape().d0(), lhs.shape().d1(), n(lhs.shape().d2().size() + rhs.shape().d2().size()), lhs.shape().d3()))) .operation(ptr -> arrayfire_h.af_join(ptr, 2, lhs.dereference(), rhs.dereference())) - .grads((result, grads) -> new TensorPair<>(index(grads, span(), span(), seq(lhs.shape().d2())), - index(grads, span(), span(), seq(lhs.shape().d2().size(), rhs.shape().d2())))) + .grads((result, grads) -> new TensorPair<>( + index(grads, span(), span(), seq(lhs.shape().d2())).castshape(lhs.shape()), + index(grads, span(), span(), seq(lhs.shape().d2().size(), rhs.shape().d2())).castshape( + rhs.shape()))) .build(); } - public static , LD3 extends Num, RD3 extends Num, D0 extends Num, D1 extends Num, D2 extends Num> Tensor> join( - Tensor> lhs, Tensor> rhs, arrayfire.D3 ignored) { + public static , LD3 extends Num, RD3 extends Num, D0 extends Num, D1 extends Num, D2 extends Num, SL extends Shape, SR extends Shape> Tensor> join( + Tensor lhs, Tensor rhs, arrayfire.D3 ignored) { if (!(lhs.shape().d0().size() == rhs.shape().d0().size() && lhs.shape().d1().size() == rhs.shape().d1().size() && lhs.shape().d2().size() == rhs.shape().d2().size())) { @@ -986,184 +996,185 @@ private static void checkTileableIsSmaller(Tensor left, Tileable rig .outputs(prototype(lhs.type(), shape(lhs.shape().d0(), lhs.shape().d1(), lhs.shape().d2(), n(lhs.shape().d3().size() + rhs.shape().d3().size())))) .operation(ptr -> arrayfire_h.af_join(ptr, 3, lhs.dereference(), rhs.dereference())) - .grads( - (result, grads) -> new TensorPair<>(index(grads, span(), span(), span(), seq(lhs.shape().d3())), - index(grads, span(), span(), span(), seq(lhs.shape().d3().size(), rhs.shape().d3())))) + .grads((result, grads) -> new TensorPair<>( + index(grads, span(), span(), span(), seq(lhs.shape().d3())).castshape(lhs.shape()), + index(grads, span(), span(), span(), seq(lhs.shape().d3().size(), rhs.shape().d3())).castshape( + rhs.shape()))) .build(); } - public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> sum( - Tensor> tensor) { + public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> sum( + Tensor tensor) { return sum(tensor, D0); } - public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> sum( - Tensor> tensor, arrayfire.D0 dim) { + public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> sum( + Tensor tensor, arrayfire.D0 dim) { return reduce("sum", tensor, arrayfire_h::af_sum, dim, tensor.type().sumType()) .grads((result, grads) -> grads.cast(tensor.type()).tileAs(tensor)) .build(); } - public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> sum( - Tensor> tensor, arrayfire.D1 dim) { + public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> sum( + Tensor tensor, arrayfire.D1 dim) { return reduce("sum", tensor, arrayfire_h::af_sum, dim, tensor.type().sumType()) .grads((result, grads) -> grads.cast(tensor.type()).tileAs(tensor)) .build(); } - public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> sum( - Tensor> tensor, arrayfire.D2 dim) { + public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> sum( + Tensor tensor, arrayfire.D2 dim) { return reduce("sum", tensor, arrayfire_h::af_sum, dim, tensor.type().sumType()) .grads((result, grads) -> grads.cast(tensor.type()).tileAs(tensor)) .build(); } - public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> sum( - Tensor> tensor, arrayfire.D3 dim) { + public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> sum( + Tensor tensor, arrayfire.D3 dim) { return reduce("sum", tensor, arrayfire_h::af_sum, dim, tensor.type().sumType()) .grads((result, grads) -> grads.cast(tensor.type()).tileAs(tensor)) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> mean( - Tensor> tensor) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> mean( + Tensor tensor) { return mean(tensor, D0); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> mean( - Tensor> tensor, arrayfire.D0 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> mean( + Tensor tensor, arrayfire.D0 dim) { return reduce("mean", tensor, arrayfire_h::af_mean, dim, tensor.type()) .grads((result, grads) -> af.div(grads.tileAs(tensor), af.constant(tensor.type(), tensor.shape().d0().size()).tileAs(tensor))) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> mean( - Tensor> tensor, arrayfire.D1 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> mean( + Tensor tensor, arrayfire.D1 dim) { return reduce("mean", tensor, arrayfire_h::af_mean, dim, tensor.type()) .grads((result, grads) -> af.div(grads.tileAs(tensor), af.constant(tensor.type(), tensor.shape().d1().size()).tileAs(tensor))) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> mean( - Tensor> tensor, arrayfire.D2 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> mean( + Tensor tensor, arrayfire.D2 dim) { return reduce("mean", tensor, arrayfire_h::af_mean, dim, tensor.type()) .grads((result, grads) -> af.div(grads.tileAs(tensor), af.constant(tensor.type(), tensor.shape().d2().size()).tileAs(tensor))) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> mean( - Tensor> tensor, arrayfire.D3 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> mean( + Tensor tensor, arrayfire.D3 dim) { return reduce("mean", tensor, arrayfire_h::af_mean, dim, tensor.type()) .grads((result, grads) -> af.div(grads.tileAs(tensor), af.constant(tensor.type(), tensor.shape().d3().size()).tileAs(tensor))) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> median( - Tensor> tensor) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> median( + Tensor tensor) { return median(tensor, D0); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> median( - Tensor> tensor, arrayfire.D0 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> median( + Tensor tensor, arrayfire.D0 dim) { return reduce("median", tensor, arrayfire_h::af_median, dim, tensor.type()).build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> median( - Tensor> tensor, arrayfire.D1 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> median( + Tensor tensor, arrayfire.D1 dim) { return reduce("median", tensor, arrayfire_h::af_median, dim, tensor.type()).build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> median( - Tensor> tensor, arrayfire.D2 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> median( + Tensor tensor, arrayfire.D2 dim) { return reduce("median", tensor, arrayfire_h::af_median, dim, tensor.type()).build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> median( - Tensor> tensor, arrayfire.D3 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> median( + Tensor tensor, arrayfire.D3 dim) { return reduce("median", tensor, arrayfire_h::af_median, dim, tensor.type()).build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> max( - Tensor> tensor) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> max( + Tensor tensor) { return max(tensor, D0); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> max( - Tensor> tensor, arrayfire.D0 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> max( + Tensor tensor, arrayfire.D0 dim) { return reduce("max", tensor, arrayfire_h::af_max, dim, tensor.type()) .grads((result, grads) -> mul(af.eq(result.tileAs(tensor), tensor).cast(grads.type()), grads.tileAs(tensor))) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> max( - Tensor> tensor, arrayfire.D1 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> max( + Tensor tensor, arrayfire.D1 dim) { return reduce("max", tensor, arrayfire_h::af_max, dim, tensor.type()) .grads((result, grads) -> mul(af.eq(result.tileAs(tensor), tensor).cast(grads.type()), grads.tileAs(tensor))) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> max( - Tensor> tensor, arrayfire.D2 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> max( + Tensor tensor, arrayfire.D2 dim) { return reduce("max", tensor, arrayfire_h::af_max, dim, tensor.type()) .grads((result, grads) -> mul(af.eq(result.tileAs(tensor), tensor).cast(grads.type()), grads.tileAs(tensor))) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> max( - Tensor> tensor, arrayfire.D3 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> max( + Tensor tensor, arrayfire.D3 dim) { return reduce("max", tensor, arrayfire_h::af_max, dim, tensor.type()) .grads((result, grads) -> mul(af.eq(result.tileAs(tensor), tensor).cast(grads.type()), grads.tileAs(tensor))) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> min( - Tensor> tensor) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> min( + Tensor tensor) { return min(tensor, D0); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> min( - Tensor> tensor, arrayfire.D0 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> min( + Tensor tensor, arrayfire.D0 dim) { return reduce("min", tensor, arrayfire_h::af_min, dim, tensor.type()) .grads((result, grads) -> mul(af.eq(result.tileAs(tensor), tensor).cast(grads.type()), grads.tileAs(tensor))) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> min( - Tensor> tensor, arrayfire.D1 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> min( + Tensor tensor, arrayfire.D1 dim) { return reduce("min", tensor, arrayfire_h::af_min, dim, tensor.type()) .grads((result, grads) -> mul(af.eq(result.tileAs(tensor), tensor).cast(grads.type()), grads.tileAs(tensor))) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> min( - Tensor> tensor, arrayfire.D2 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> min( + Tensor tensor, arrayfire.D2 dim) { return reduce("min", tensor, arrayfire_h::af_min, dim, tensor.type()) .grads((result, grads) -> mul(af.eq(result.tileAs(tensor), tensor).cast(grads.type()), grads.tileAs(tensor))) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> min( - Tensor> tensor, arrayfire.D3 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> min( + Tensor tensor, arrayfire.D3 dim) { return reduce("min", tensor, arrayfire_h::af_min, dim, tensor.type()) .grads((result, grads) -> mul(af.eq(result.tileAs(tensor), tensor).cast(grads.type()), grads.tileAs(tensor))) .build(); } - public static , D1 extends Num, D2 extends Num, D3 extends Num> ImaxResult> imax( - Tensor> tensor) { + public static , D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> ImaxResult> imax( + Tensor tensor) { var shape = shape(u(), tensor.shape().d1(), tensor.shape().d2(), tensor.shape().d3()); var pair = operation("imax") .inputs(tensor) @@ -1174,8 +1185,8 @@ private static void checkTileableIsSmaller(Tensor left, Tileable rig return new ImaxResult<>(pair.left(), pair.right()); } - public static , D1 extends Num, D2 extends Num, D3 extends Num, K extends Num> TopKResult> topk( - Tensor> tensor, K k) { + public static , D1 extends Num, D2 extends Num, D3 extends Num, K extends Num, S extends Shape> TopKResult> topk( + Tensor tensor, K k) { var shape = shape(k, tensor.shape().d1(), tensor.shape().d2(), tensor.shape().d3()); var pair = operation("topk") .inputs(tensor) @@ -1187,8 +1198,8 @@ private static void checkTileableIsSmaller(Tensor left, Tileable rig return new TopKResult<>(pair.left(), pair.right()); } - public static , D0 extends Num, D2 extends Num, D3 extends Num> Tensor> diag( - Tensor> tensor) { + public static , D0 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> diag( + Tensor tensor) { return operation("diag") .inputs(tensor) .outputs(prototype(tensor.type(), @@ -1199,8 +1210,8 @@ private static void checkTileableIsSmaller(Tensor left, Tileable rig } // https://arrayfire.org/docs/group__blas__func__matmul.htm - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, OD1 extends Num> Tensor> matmul( - Tensor> left, Tensor> right) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, OD1 extends Num, SL extends Shape, SR extends Shape> Tensor> matmul( + Tensor left, Tensor right) { if (left.shape().d1().size() != right.shape().d0().size()) { throw new IllegalArgumentException( String.format("Incompatible shapes for matmul, left: %s right: %s", left.shape(), right.shape())); @@ -1213,14 +1224,13 @@ private static void checkTileableIsSmaller(Tensor left, Tileable rig .grads((result, grads) -> { var leftGrads = matmul(grads, transpose(right)); var rightGrads = matmul(transpose(left), grads); - return new TensorPair<>(leftGrads, rightGrads); + return new TensorPair<>(leftGrads.castshape(left.shape()), rightGrads.castshape(right.shape())); }) .build(); } - public static , AD0 extends Num, AD1 extends Num, BD1 extends Num, CD1 extends Num, D2 extends Num, D3 extends Num> Tensor> matmul( - Tensor> a, Tensor> b, - Tensor> c) { + public static , AD0 extends Num, AD1 extends Num, BD1 extends Num, CD1 extends Num, D2 extends Num, D3 extends Num, SA extends Shape, SB extends Shape, SC extends Shape> Tensor> matmul( + Tensor a, Tensor b, Tensor c) { if (a.shape().d0().size() * b.shape().d1().size() < b.shape().d0().size() * c.shape().d1().size()) { var tmp = matmul(a, b); var result = matmul(tmp, c); @@ -1256,8 +1266,8 @@ private static void checkTileableIsSmaller(Tensor left, Tileable rig constant(tensor.type(), Double.POSITIVE_INFINITY).tileAs(tensor)); } - public static , S extends Shape> Tensor eq(Tensor left, - Tensor right) { + public static , S extends Shape, SL extends S, SR extends S> Tensor eq( + Tensor left, Tensor right) { return operation("eq") .inputs(left, right) .outputs(prototype(B8, left.shape())) @@ -1349,8 +1359,8 @@ public static Operation.Builder operation(String name) { .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> softmax( - Tensor> tensor) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor softmax( + Tensor tensor) { return softmax(tensor, 1f); } @@ -1363,8 +1373,8 @@ public static Function tidyOperation(Supplier, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> softmax( - Tensor> tensor, float temperature) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor softmax( + Tensor tensor, float temperature) { return operation("softmax").inputs(tensor).outputs(prototype(tensor)).operation(tidyOperation(() -> { var max = max(tensor); var normalized = sub(tensor, max.tileAs(tensor)); @@ -1465,13 +1475,13 @@ public static , D0 extends Num, D1 extends Num, } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> List>> batch( - Tensor> tensor, int batchSize) { + public static , D0 extends Num, D1 extends Num, S extends Shape> List>> batch( + Tensor tensor, int batchSize) { return batch(tensor, ArrayFire::n, batchSize); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, BDT extends Num> List>> batch( - Tensor> tensor, Function type, int batchSize) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape, BDT extends Num> List>> batch( + Tensor tensor, Function type, int batchSize) { var results = new ArrayList>>(); var d0Seq = seq(tensor.shape().d0()); for (int i = 0; i < tensor.shape().d1().size(); i += batchSize) { @@ -1542,26 +1552,24 @@ public static , D0 extends Num, D1 extends Num, .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, FD0 extends Num, FD1 extends Num, FD3 extends Num> Tensor> convolve2( - Tensor> tensor, Tensor> filters) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, FD0 extends Num, FD1 extends Num, FD3 extends Num, S extends Shape, FS extends Shape> Tensor> convolve2( + Tensor tensor, Tensor filters) { return convolve2(tensor, filters, shape(1, 1), shape(0, 0), shape(1, 1)); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, FD0 extends Num, FD1 extends Num, FD3 extends Num> Tensor> convolve2( - Tensor> tensor, Tensor> filters, - Shape stride) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, FD0 extends Num, FD1 extends Num, FD3 extends Num, S extends Shape, FS extends Shape> Tensor> convolve2( + Tensor tensor, Tensor filters, Shape stride) { return convolve2(tensor, filters, stride, shape(0, 0), shape(1, 1)); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, FD0 extends Num, FD1 extends Num, FD3 extends Num> Tensor> convolve2( - Tensor> tensor, Tensor> filters, Shape stride, - Shape padding) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, FD0 extends Num, FD1 extends Num, FD3 extends Num, S extends Shape, FS extends Shape> Tensor> convolve2( + Tensor tensor, Tensor filters, Shape stride, Shape padding) { return convolve2(tensor, filters, stride, padding, shape(1, 1)); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, FD0 extends Num, FD1 extends Num, FD3 extends Num> Tensor> convolve2( - Tensor> tensor, Tensor> filters, Shape stride, - Shape padding, Shape dilation) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, FD0 extends Num, FD1 extends Num, FD3 extends Num, S extends Shape, FS extends Shape> Tensor> convolve2( + Tensor tensor, Tensor filters, Shape stride, Shape padding, + Shape dilation) { // TODO: CoPilot wrote this, needs tests. var computedShape = shape(n((tensor.shape().d0().size() + 2 * padding.d0().size() - (filters.shape().d0().size() - 1) * dilation.d0().size() - 1) / @@ -1584,8 +1592,8 @@ public static , D0 extends Num, D1 extends Num, /** * L2 norm. */ - public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> norm( - Tensor> tensor) { + public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor> norm( + Tensor tensor) { var mul = mul(tensor, tensor); var sum = sum(mul); return sqrt(sum); @@ -1594,22 +1602,22 @@ public static , D0 extends Num, D1 extends Num, /** * Normalize by dividing by the L2 norm. */ - public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> normalize( - Tensor> tensor) { + public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor normalize( + Tensor tensor) { return div(cast(tensor, tensor.type().sumType()), norm(tensor).tileAs(tensor.shape())); } /** * Center by subtracting the average. */ - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> center( - Tensor> tensor) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Tensor center( + Tensor tensor) { return sub(tensor, mean(tensor).tileAs(tensor)); } // svd - public static , D0 extends Num, D1 extends Num> SvdResult svd( - Tensor> tensor) { + public static , D0 extends Num, D1 extends Num, S extends Shape> SvdResult svd( + Tensor tensor) { var trio = operation("svd") .inputs(tensor) .outputs(prototype(tensor.type(), shape(tensor.shape().d0(), tensor.shape().d0())), @@ -1623,8 +1631,8 @@ public static , D0 extends Num, D1 extends Num, /** * Computes the covariance matrix of the given matrix. */ - public static , D0 extends Num, D1 extends Num> Tensor> cov( - Tensor> tensor) { + public static , D0 extends Num, D1 extends Num, S extends Shape> Tensor> cov( + Tensor tensor) { return tidy(() -> { var subMean = sub(tensor, mean(tensor, D1).tileAs(tensor)); var matrix = matmul(subMean, transpose(subMean)); @@ -1635,8 +1643,8 @@ public static , D0 extends Num, D1 extends Num, /** * Computes the ZCA whitening matrix of the given matrix. */ - public static , D0 extends Num, D1 extends Num> Tensor> zca( - Tensor> tensor) { + public static , D0 extends Num, D1 extends Num, S extends Shape> Tensor> zca( + Tensor tensor) { return tidy(() -> { var cov = cov(tensor); var svd = svd(cov); @@ -1649,8 +1657,8 @@ public static , D0 extends Num, D1 extends Num, /** * Inverts the given matrix. */ - public static , D extends Num> Tensor> inverse( - Tensor> tensor) { + public static , D extends Num, S extends Shape> Tensor inverse( + Tensor tensor) { return operation("inverse") .inputs(tensor) .outputs(prototype(tensor)) @@ -1659,8 +1667,8 @@ public static , D0 extends Num, D1 extends Num, } // TODO: Add uncropped version. - public static , D0 extends Num, D1 extends Num> Tensor> rotate( - Tensor> tensor, float angle, InterpolationType interpolationType) { + public static , D0 extends Num, D1 extends Num, S extends Shape> Tensor rotate( + Tensor tensor, float angle, InterpolationType interpolationType) { return operation("rotate") .inputs(tensor) .outputs(prototype(tensor)) @@ -1670,15 +1678,15 @@ public static , D0 extends Num, D1 extends Num, .build(); } - public static , D0 extends Num, D1 extends Num, ND0 extends Num, ND1 extends Num> Tensor> scale( - Tensor> tensor, ND0 nd0, ND1 nd1, InterpolationType interpolationType) { + public static , D0 extends Num, D1 extends Num, ND0 extends Num, ND1 extends Num, S extends Shape> Tensor> scale( + Tensor tensor, ND0 nd0, ND1 nd1, InterpolationType interpolationType) { return operation("scale") .inputs(tensor) .outputs(prototype(tensor.type(), shape(nd0, nd1))) .operation(ptr -> arrayfire_h.af_scale(ptr, tensor.dereference(), (float) nd0.size() / tensor.shape().d0().size(), (float) nd1.size() / tensor.shape().d1().size(), nd0.size(), nd1.size(), interpolationType.code())) - .grads((result, grads) -> scale(grads, tensor.shape().d0(), tensor.shape().d1(), interpolationType)) + .grads((result, grads) -> scale(grads, tensor.shape().d0(), tensor.shape().d1(), interpolationType).castshape(tensor.shape())) .build(); } @@ -1829,14 +1837,6 @@ static void handleStatus(Supplier res) { var result = Status.fromCode((int) res.get()); if (!Status.AF_SUCCESS.equals(result)) { throw new ArrayFireException(result); - // String lastError; - // try { - // lastError = lastError(); - // - // } catch (Exception e) { - // throw new RuntimeException("ArrayFireError: " + result.name()); - // } - // throw new RuntimeException("ArrayFireError: " + result.name() + ": " + lastError); } } diff --git a/arrayfire/R0.java b/arrayfire/R0.java new file mode 100644 index 0000000..dbf17a3 --- /dev/null +++ b/arrayfire/R0.java @@ -0,0 +1,13 @@ +package arrayfire; + +import arrayfire.numbers.U; + +/** + * A rank 0 shape (Scalar). + */ +public class R0 extends Shape { + + public R0() { + super(af.u(), af.u(), af.u(), af.u()); + } +} diff --git a/arrayfire/R1.java b/arrayfire/R1.java new file mode 100644 index 0000000..14e4aa6 --- /dev/null +++ b/arrayfire/R1.java @@ -0,0 +1,14 @@ +package arrayfire; + +import arrayfire.numbers.Num; +import arrayfire.numbers.U; + +/** + * A rank 1 shape (Vector). + */ +public class R1> extends Shape { + + public R1(D0 d0) { + super(d0, af.u(), af.u(), af.u()); + } +} diff --git a/arrayfire/R2.java b/arrayfire/R2.java new file mode 100644 index 0000000..65c7334 --- /dev/null +++ b/arrayfire/R2.java @@ -0,0 +1,14 @@ +package arrayfire; + +import arrayfire.numbers.Num; +import arrayfire.numbers.U; + +/** + * A rank 2 shape (Matrix). + */ +public class R2, D1 extends Num> extends Shape { + + public R2(D0 d0, D1 d1) { + super(d0, d1, af.u(), af.u()); + } +} diff --git a/arrayfire/R3.java b/arrayfire/R3.java new file mode 100644 index 0000000..92c503a --- /dev/null +++ b/arrayfire/R3.java @@ -0,0 +1,14 @@ +package arrayfire; + +import arrayfire.numbers.Num; +import arrayfire.numbers.U; + +/** + * A rank 3 shape. + */ +public class R3, D1 extends Num, D2 extends Num> extends Shape { + + public R3(D0 d0, D1 d1, D2 d2) { + super(d0, d1, d2, af.u()); + } +} diff --git a/arrayfire/Shape.java b/arrayfire/Shape.java index 61f10c1..2a0c467 100644 --- a/arrayfire/Shape.java +++ b/arrayfire/Shape.java @@ -4,10 +4,22 @@ import arrayfire.numbers.N; import java.util.Arrays; +import java.util.Objects; import java.util.function.Function; -public record Shape, D1 extends Num, D2 extends Num, D3 extends Num>(D0 d0, D1 d1, D2 d2, D3 d3) { +public class Shape, D1 extends Num, D2 extends Num, D3 extends Num> { + private final D0 d0; + private final D1 d1; + private final D2 d2; + private final D3 d3; + + public Shape(D0 d0, D1 d1, D2 d2, D3 d3) { + this.d0 = d0; + this.d1 = d1; + this.d2 = d2; + this.d3 = d3; + } public int capacity() { return d0.size() * d1.size() * d2.size() * d3.size(); @@ -21,4 +33,37 @@ public long[] dims() { public String toString() { return Arrays.toString(dims()); } + + public D0 d0() { + return d0; + } + + public D1 d1() { + return d1; + } + + public D2 d2() { + return d2; + } + + public D3 d3() { + return d3; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) + return true; + if (obj == null || obj.getClass() != this.getClass()) + return false; + var that = (Shape) obj; + return Objects.equals(this.d0, that.d0) && Objects.equals(this.d1, that.d1) && Objects.equals(this.d2, that.d2) && + Objects.equals(this.d3, that.d3); + } + + @Override + public int hashCode() { + return Objects.hash(d0, d1, d2, d3); + } + } diff --git a/arrayfire/SvdResult.java b/arrayfire/SvdResult.java index 645332d..dd1ec9b 100644 --- a/arrayfire/SvdResult.java +++ b/arrayfire/SvdResult.java @@ -4,5 +4,5 @@ import arrayfire.numbers.U; public record SvdResult, D0 extends Num, D1 extends Num>( - Tensor> u, Tensor> s, Tensor> vt) { + Tensor> u, Tensor> s, Tensor> vt) { } diff --git a/arrayfire/Tensor.java b/arrayfire/Tensor.java index d400bbd..603d2b7 100644 --- a/arrayfire/Tensor.java +++ b/arrayfire/Tensor.java @@ -8,6 +8,7 @@ import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; import java.lang.foreign.ValueLayout; +import java.util.Arrays; public class Tensor, S extends Shape> implements MemoryContainer { @@ -95,8 +96,16 @@ public , OD1 extends Num, OD2 extends Num, OD3 ex return af.reshape(this, af.shape(d0, d1, d2, d3)); } - public , OD1 extends Num, OD2 extends Num, OD3 extends Num> Tensor> reshape( - Shape newShape) { + public > Tensor reshape( + NS newShape) { + return af.reshape(this, newShape); + } + + public > Tensor castshape( + NS newShape) { + if (!Arrays.equals(shape.dims(), newShape.dims())) { + throw new IllegalArgumentException("Cannot cast shape " + shape + " to " + newShape); + } return af.reshape(this, newShape); }