diff --git a/arrayfire/ArrayFire.java b/arrayfire/ArrayFire.java index bbb9474..bae2b31 100644 --- a/arrayfire/ArrayFire.java +++ b/arrayfire/ArrayFire.java @@ -99,24 +99,24 @@ public static Scope scope() { /** * Sorts a tensor over D0. */ - public static
, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor sort( - Tensor tensor) { + public static , S extends Shape> Tensor sort(Tensor tensor) { return sort(tensor, D0); } /** * Sorts a tensor over the given dimension. */ - public static
, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor sort( - Tensor tensor, Dim dim) { + public static , S extends Shape> Tensor sort(Tensor tensor, + Dim dim) { return sort(tensor, dim, true); } /** * Sorts a tensor over the given dimension in ascending or descending order. */ - public static
, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor sort( - Tensor tensor, Dim dim, boolean ascending) { + public static , S extends Shape> Tensor sort(Tensor tensor, + Dim dim, + boolean ascending) { return operation("sort") .inputs(tensor) .outputs(tensor.prototype()) @@ -127,40 +127,39 @@ public static Scope scope() { /** * Returns a prototype tensor with the given type and shape. */ - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Prototype prototype( - T type, Shape shape) { + public static , S extends Shape> Prototype prototype(T type, S shape) { return new Prototype<>(type, shape); } /** * Returns a prototype tensor with the same type and shape as the given tensor. */ - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Prototype prototype( - Tensor tensor) { + public static , S extends Shape> Prototype prototype( + Tensor tensor) { return new Prototype<>(tensor.type(), tensor.shape()); } /** * Sorts a tensor over D0 and returns the values and indices of original values. */ - public static
, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> SortIndexResult sortIndex( - Tensor tensor) { + public static , S extends Shape> SortIndexResult sortIndex( + Tensor tensor) { return sortIndex(tensor, D0); } /** * Sorts a tensor over the given dimension and returns the values and indices of original values. */ - public static
, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> SortIndexResult sortIndex( - Tensor tensor, Dim dim) { + public static , S extends Shape> SortIndexResult sortIndex( + Tensor tensor, Dim dim) { return sortIndex(tensor, dim, true); } /** * Sorts a tensor over the given dimension in ascending or descending order and returns the values and indices of original values. */ - public static
, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> SortIndexResult sortIndex( - Tensor tensor, Dim dim, boolean ascending) { + public static , S extends Shape> SortIndexResult sortIndex( + Tensor tensor, Dim dim, boolean ascending) { var pair = operation("sort_index") .inputs(tensor) .outputs(prototype(tensor.type(), tensor.shape()), prototype(U32, tensor.shape())) @@ -182,7 +181,7 @@ public static > Index permutation(D dim) { /** * Creates a device tensor from the given native array. */ - public static
, AT extends NativeArray> Tensor create( + public static
, AT extends NativeArray> Tensor> create( AT array) { return create(array, shape(n(array.length()))); } @@ -190,8 +189,8 @@ public static > Index permutation(D dim) { /** * Creates a device tensor from the given native array and shape. */ - public static
, AT extends NativeArray, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor create( - AT array, Shape shape) { + public static
, AT extends NativeArray, S extends Shape> Tensor create( + AT array, S shape) { return operation("create") .inputs() .outputs(prototype(array.type(), shape)) @@ -206,7 +205,7 @@ public static > Index permutation(D dim) { * This is not recommended in a production setting, as memory will be copied twice. Instead, use {@link #create(NativeArray)}. */ @SafeVarargs - public static , DT extends DataType> Tensor create( + public static , DT extends DataType> Tensor> create( DT type, JT... values) { return tidy(() -> { var array = type.create(values.length); @@ -222,7 +221,7 @@ public static > Index permutation(D dim) { * This is not recommended in a production setting, as memory will be copied twice. Instead, use {@link #create(NativeArray)}. */ @SuppressWarnings("unchecked") - public static , DT extends DataType> Tensor create( + public static , DT extends DataType> Tensor> create( DT type, JTA values) { return tidy(() -> { var length = Array.getLength(values); @@ -237,50 +236,50 @@ public static > Index permutation(D dim) { /** * Creates a {@link F32} device tensor from the given float values. */ - public static Tensor create(float... values) { + public static Tensor> create(float... values) { return create(F32, values); } /** * Creates a {@link F64} device tensor from the given double values. */ - public static Tensor create(double... values) { + public static Tensor> create(double... values) { return create(F64, values); } /** * Creates a {@link S32} device tensor from the given byte values. */ - public static Tensor create(int... values) { + public static Tensor> create(int... values) { return create(S32, values); } /** * Creates a constant scalar {@link F32} device tensor from the given float value. */ - public static Tensor constant(float value) { + public static Tensor> constant(float value) { return constant(F32, value); } /** * Creates a constant scalar {@link F64} device tensor from the given float value. */ - public static Tensor constant(double value) { + public static Tensor> constant(double value) { return constant(F64, value); } /** * Creates a constant scalar device tensor from the given type and double value. */ - public static
> Tensor constant(DT type, double value) { + public static
> Tensor> constant(DT type, double value) { return constant(type, shape(u()), value); } /** * Creates a constant device tensor from the given type, shape, and double value. */ - public static
, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor constant( - DT type, Shape shape, double value) { + public static
, S extends Shape> Tensor constant(DT type, S shape, + double value) { return operation("constant") .inputs() .outputs(prototype(type, shape)) @@ -314,8 +313,8 @@ public static Index seq(int begin, int endInclusive) { /** * Returns a lookup index using the given tensor as lookup values (indices). */ - public static
, D0 extends Num> Index seq(Tensor index) { - return new Index<>(index, index.d0()::create); + public static
, D0 extends Num> Index seq(Tensor> index) { + return new Index<>(index, index.shape().d0()::create); } /** @@ -349,15 +348,15 @@ public static Shape shape(int d0) { /** * Returns a 1D shape of the given dimension. */ - public static > Shape shape(D0 d0) { + public static > Shape shape(D0 d0) { return new Shape<>(d0, u(), u(), u()); } - public static > Shape shape(D0 d0, int d1) { + public static > Shape shape(D0 d0, int d1) { return new Shape<>(d0, n(d1), u(), u()); } - public static > Shape shape(int d0, D1 d1) { + public static > Shape shape(int d0, D1 d1) { return new Shape<>(n(d0), d1, u(), u()); } @@ -365,7 +364,7 @@ public static Shape shape(int d0, int d1) { return new Shape<>(n(d0), n(d1), u(), u()); } - public static , D1 extends Num> Shape shape(D0 d0, D1 d1) { + public static , D1 extends Num> Shape shape(D0 d0, D1 d1) { return new Shape<>(d0, d1, u(), u()); } @@ -377,50 +376,50 @@ public static Shape shape(int d0, int d1, int d2, int d3) { return new Shape<>(n(d0), n(d1), n(d2), n(d3)); } - public static , D1 extends Num, D2 extends Num> Shape shape(D0 d0, D1 d1, - D2 d2) { + public static , D1 extends Num, D2 extends Num> Shape shape(D0 d0, D1 d1, + D2 d2) { return new Shape<>(d0, d1, d2, u()); } - public static , D1 extends Num, D2 extends Num, D3 extends Num> Shape shape( + public static , D1 extends Num, D2 extends Num, D3 extends Num> Shape shape( D0 d0, D1 d1, D2 d2, D3 d3) { return new Shape<>(d0, d1, d2, d3); } - private static , IT extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Operation.Builder.Unary.Single reduce( - String name, Tensor a, + private static , IT extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Operation.Builder.Unary>>.Single>> reduce( + String name, Tensor> a, Functions.Function3 method, arrayfire.D0 dim, T resultType) { return operation(name) .inputs(a) - .outputs(prototype(resultType, shape(u(), a.d1(), a.d2(), a.d3()))) + .outputs(prototype(resultType, shape(u(), a.shape().d1(), a.shape().d2(), a.shape().d3()))) .operation(ptr -> method.apply(ptr, a.dereference(), dim.index())); } - private static , IT extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Operation.Builder.Unary.Single reduce( - String name, Tensor a, + private static , IT extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Operation.Builder.Unary>>.Single>> reduce( + String name, Tensor> a, Functions.Function3 method, arrayfire.D1 dim, T resultType) { return operation(name) .inputs(a) - .outputs(prototype(resultType, shape(a.d0(), u(), a.d2(), a.d3()))) + .outputs(prototype(resultType, shape(a.shape().d0(), u(), a.shape().d2(), a.shape().d3()))) .operation(ptr -> method.apply(ptr, a.dereference(), dim.index())); } - private static , IT extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Operation.Builder.Unary.Single reduce( - String name, Tensor a, + private static , IT extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Operation.Builder.Unary>>.Single>> reduce( + String name, Tensor> a, Functions.Function3 method, arrayfire.D2 dim, T resultType) { return operation(name) .inputs(a) - .outputs(prototype(resultType, shape(a.d0(), a.d1(), u(), a.d3()))) + .outputs(prototype(resultType, shape(a.shape().d0(), a.shape().d1(), u(), a.shape().d3()))) .operation(ptr -> method.apply(ptr, a.dereference(), dim.index())); } - private static , IT extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Operation.Builder.Unary.Single reduce( - String name, Tensor a, + private static , IT extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Operation.Builder.Unary>>.Single>> reduce( + String name, Tensor> a, Functions.Function3 method, arrayfire.D3 dim, T resultType) { return operation(name) .inputs(a) - .outputs(prototype(resultType, shape(a.d0(), a.d1(), a.d2(), u()))) + .outputs(prototype(resultType, shape(a.shape().d0(), a.shape().d1(), a.shape().d2(), u()))) .operation(ptr -> method.apply(ptr, a.dereference(), dim.index())); } @@ -428,48 +427,44 @@ public static , D1 extends Num, D2 extends Num, D3 exten * Cast the given tensor to the given type. */ @SuppressWarnings("unchecked") - public static , OT extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor cast( - Tensor input, OT type) { + public static , OT extends DataType, S extends Shape> Tensor cast( + Tensor input, OT type) { if (input.type().equals(type)) { - return (Tensor) input; + return (Tensor) input; } return operation("cast") .inputs(input) .outputs(prototype(type, input.shape())) .operation(ptr -> arrayfire_h.af_cast(ptr, input.dereference(), type.code())) - .grads((result, grads) -> grads.cast(input.type())) + .grads((result, grads) -> cast(grads, input.type())) .build(); } /** * Returns a tensor of value 1 with the same type and shape as the given tensor. */ - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor ones( - Tensor model) { + public static , S extends Shape> Tensor ones(Tensor model) { return ones(model.type(), model.shape()); } /** * Returns a tensor of value 1 with the given type and shape. */ - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor ones( - T type, Shape shape) { + public static , S extends Shape> Tensor ones(T type, S shape) { return constant(type, 1).tileAs(shape); } /** * Returns a tensor of value 0 with the given type and shape. */ - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor zeros( - T type, Shape shape) { + public static , S extends Shape> Tensor zeros(T type, S shape) { return constant(type, 0).tileAs(shape); } /** * Create a random tensor sampled from uniform distribution between [0, 1]. */ - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor randu( - T type, Shape shape) { + public static , S extends Shape> Tensor randu(T type, S shape) { return operation("randu") .inputs() .outputs(prototype(type, shape)) @@ -480,8 +475,7 @@ public static , D1 extends Num, D2 extends Num, D3 exten /** * Create a random tensor sampled from a normal distribution with mean 0. */ - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor randn( - T type, Shape shape) { + public static , S extends Shape> Tensor randn(T type, S shape) { return operation("randn") .inputs() .outputs(prototype(type, shape)) @@ -492,14 +486,14 @@ public static , D1 extends Num, D2 extends Num, D3 exten /** * Create a tensor with values [0, n-1]. */ - public static Tensor range(int n) { + public static Tensor> range(int n) { return range(U32, n); } /** * Create a tensor with values [0, n-1] of the given type. */ - public static > Tensor range(T type, int n) { + public static > Tensor> range(T type, int n) { var shape = shape(n(n)); return operation("range") .inputs() @@ -525,13 +519,13 @@ public static void setRandomEngineType(RandomEngineType type) { /** * Pull data from the device to the host, returning a native array. */ - public static , T extends DataType> AT data(Tensor a) { + public static , T extends DataType> AT data(Tensor a) { var result = a.type().create(a.capacity()); handleStatus(() -> arrayfire_h.af_get_data_ptr(result.segment(), a.dereference())); return result; } - private static void checkDims(Tensor tensor) { + private static void checkDims(Tensor tensor) { try (Arena arena = Arena.ofConfined()) { var dims = arena.allocateArray(ValueLayout.JAVA_LONG, 4); handleStatus( @@ -641,11 +635,12 @@ private static MemorySegment nativeDims(Shape shape) { /** * Transpose D0 and D1 dimensions of the given tensor. */ - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor transpose( - Tensor tensor) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> transpose( + Tensor> tensor) { return operation("transpose") .inputs(tensor) - .outputs(prototype(tensor.type(), shape(tensor.d1(), tensor.d0(), tensor.d2(), tensor.d3()))) + .outputs(prototype(tensor.type(), + shape(tensor.shape().d1(), tensor.shape().d0(), tensor.shape().d2(), tensor.shape().d3()))) .operation(ptr -> arrayfire_h.af_transpose(ptr, tensor.dereference(), true)) .grads((result, grads) -> transpose(grads)) .build(); @@ -654,47 +649,47 @@ private static MemorySegment nativeDims(Shape shape) { /** * Change the type of the tensor's D0 dimension to the given type variable provider. */ - public static , OD0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor castshape( - Tensor tensor, Function d0) { - return reshape(tensor, shape(d0.apply(tensor.d0().size()), tensor.d1(), tensor.d2(), tensor.d3())); + public static , OD0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> castshape( + Tensor> tensor, Function d0) { + return reshape(tensor, + shape(d0.apply(tensor.shape().d0().size()), tensor.shape().d1(), tensor.shape().d2(), tensor.shape().d3())); } /** * Change the type of the tensor's D0, D1 dimensions to the given type variable providers. */ - public static , OD0 extends Num, OD1 extends Num, D2 extends Num, D3 extends Num> Tensor castshape( - Tensor tensor, Function d0, Function d1) { + public static , OD0 extends Num, OD1 extends Num, D2 extends Num, D3 extends Num> Tensor> castshape( + Tensor> tensor, Function d0, Function d1) { return reshape(tensor, - shape(d0.apply(tensor.d0().size()), d1.apply(tensor.d1().size()), tensor.d2(), tensor.d3())); + shape(d0.apply(tensor.shape().d0().size()), d1.apply(tensor.shape().d1().size()), tensor.shape().d2(), + tensor.shape().d3())); } /** * Change the type of the tensor's D0, D1, D2 dimensions to the given type variable providers. */ - public static , OD0 extends Num, OD1 extends Num, OD2 extends Num, D3 extends Num> Tensor castshape( - Tensor tensor, Function d0, Function d1, + public static , OD0 extends Num, OD1 extends Num, OD2 extends Num, D3 extends Num> Tensor> castshape( + Tensor> tensor, Function d0, Function d1, Function d2) { - return reshape(tensor, - shape(d0.apply(tensor.d0().size()), d1.apply(tensor.d1().size()), d2.apply(tensor.d2().size()), - tensor.d3())); + return reshape(tensor, shape(d0.apply(tensor.shape().d0().size()), d1.apply(tensor.shape().d1().size()), + d2.apply(tensor.shape().d2().size()), tensor.shape().d3())); } /** * Change the type of the tensor's dimensions to the given type variable providers. */ - public static , OD0 extends Num, OD1 extends Num, OD2 extends Num, OD3 extends Num> Tensor castshape( - Tensor tensor, Function d0, Function d1, Function d2, - Function d3) { - return reshape(tensor, - shape(d0.apply(tensor.d0().size()), d1.apply(tensor.d1().size()), d2.apply(tensor.d2().size()), - d3.apply(tensor.d3().size()))); + public static , OD0 extends Num, OD1 extends Num, OD2 extends Num, OD3 extends Num> Tensor> castshape( + Tensor> tensor, Function d0, Function d1, + Function d2, Function d3) { + return reshape(tensor, shape(d0.apply(tensor.shape().d0().size()), d1.apply(tensor.shape().d1().size()), + d2.apply(tensor.shape().d2().size()), d3.apply(tensor.shape().d3().size()))); } /** * Reshape the tensor to the given shape. */ - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, OD0 extends Num, OD1 extends Num, OD2 extends Num, OD3 extends Num> Tensor reshape( - Tensor tensor, Shape newShape) { + public static , S extends Shape, NS extends Shape> Tensor reshape( + Tensor tensor, NS newShape) { if (tensor.shape().capacity() != newShape.capacity()) { throw new IllegalArgumentException( String.format("New shape %s doesn't have same capacity as original shape %s", newShape, @@ -712,7 +707,7 @@ private static MemorySegment nativeDims(Shape shape) { /** * Release the memory of the given tensor on the device. */ - public static void release(Tensor tensor) { + public static void release(Tensor tensor) { handleStatus(() -> arrayfire_h.af_release_array(tensor.dereference())); Scope.untrack(tensor); } @@ -720,8 +715,7 @@ public static void release(Tensor tensor) { /** * Retain the given tensor, increasing its ref count by 1 and return a new container for it. */ - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor retain( - Tensor tensor) { + public static
, S extends Shape> Tensor retain(Tensor tensor) { return operation("retain") .inputs(tensor) .outputs(prototype(tensor.type(), tensor.shape())) @@ -733,8 +727,8 @@ public static void release(Tensor tensor) { /** * Set the values of the given variable to the values of the given tensor. */ - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Operation set( - Variable variable, Tensor tensor) { + public static , S extends Shape> Operation set(Variable variable, + Tensor tensor) { return operation("set").inputs(tensor).outputs().operation(() -> { handleStatus(() -> arrayfire_h.af_release_array(variable.dereference())); handleStatus(() -> arrayfire_h.af_retain_array(variable.segment(), tensor.dereference())); @@ -744,7 +738,7 @@ public static void release(Tensor tensor) { /** * Return the ref count of the given tensor. */ - public static int refCount(Tensor tensor) { + public static int refCount(Tensor tensor) { try (Arena arena = Arena.ofConfined()) { var result = arena.allocate(ValueLayout.JAVA_INT); handleStatus(() -> arrayfire_h.af_get_data_ref_count(result, tensor.dereference())); @@ -755,8 +749,8 @@ public static int refCount(Tensor tensor) { /** * Create a variable with the given initializer. */ - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Variable variable( - Supplier> initializer) { + public static , S extends Shape> Variable variable( + Supplier> initializer) { var tensor = af.tidy(initializer); var variable = new Variable<>(tensor.type(), tensor.shape()); variable.segment().copyFrom(tensor.segment()); @@ -767,8 +761,8 @@ public static int refCount(Tensor tensor) { /** * Create params with the given initializer and optimizer. */ - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Params params( - Supplier> initializer, OptimizerProvider optimizerProvider) { + public static , S extends Shape> Params params( + Supplier> initializer, OptimizerProvider optimizerProvider) { var tensor = af.tidy(initializer); var params = new Params<>(tensor.type(), tensor.shape(), optimizerProvider); params.segment().copyFrom(tensor.segment()); @@ -779,8 +773,7 @@ public static int refCount(Tensor tensor) { /** * Evaluate the tensor, telling the ArrayFire JIT compiler that you want the literal values of the tensor. */ - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor eval( - Tensor tensor) { + public static , S extends Shape> Tensor eval(Tensor tensor) { handleStatus(() -> arrayfire_h.af_eval(tensor.dereference())); return tensor; } @@ -788,7 +781,7 @@ public static int refCount(Tensor tensor) { /** * Evaluate the tensors, telling the ArrayFire JIT compiler that you want the literal values of the tensors. */ - public static void eval(Tensor... tensors) { + public static void eval(Tensor... tensors) { try (Arena arena = Arena.ofConfined()) { var array = arena.allocateArray(ValueLayout.ADDRESS, tensors.length); for (int i = 0; i < tensors.length; i++) { @@ -801,15 +794,17 @@ public static void eval(Tensor... tensors) { /** * Multiply two tensors together element wise, broadcasting the smaller tensor to the larger tensor's shape. */ - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor mul( - Tensor tensor, Tileable tileable) { + public static , S extends Shape> Tensor mul(Tensor tensor, + Tileable tileable) { checkTileableIsSmaller(tensor, tileable); return mul(tensor, tileable.tensor().tileAs(tensor)); } - private static void checkTileableIsSmaller(Tensor left, Tileable right) { - if (left.d0().size() < right.tensor().d0().size() || left.d1().size() < right.tensor().d1().size() || - left.d2().size() < right.tensor().d2().size() || left.d3().size() < right.tensor().d3().size()) { + private static void checkTileableIsSmaller(Tensor left, Tileable right) { + if (left.shape().d0().size() < right.tensor().shape().d0().size() || + left.shape().d1().size() < right.tensor().shape().d1().size() || + left.shape().d2().size() < right.tensor().shape().d2().size() || + left.shape().d3().size() < right.tensor().shape().d3().size()) { throw new IllegalArgumentException( String.format("Tileable shape %s is larger than tensor shape %s", right.tensor().shape(), left.shape())); @@ -819,16 +814,16 @@ private static void checkTileableIsSmaller(Tensor left, Tileable< /** * Multiply the tensor by a scalar value. */ - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor mul( - Tensor left, double right) { + public static , S extends Shape> Tensor mul(Tensor left, + double right) { return mul(left, af.constant(left.type(), left.shape(), right)); } /** * Multiply two tensors together, element wise. */ - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor mul( - Tensor left, Tensor right) { + public static , S extends Shape> Tensor mul(Tensor left, + Tensor right) { return operation("mul") .inputs(left, right) .outputs(prototype(left.type(), left.shape())) @@ -837,24 +832,24 @@ private static void checkTileableIsSmaller(Tensor left, Tileable< .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor div( - Tensor left, Tensor right) { + public static , S extends Shape> Tensor div(Tensor left, + Tensor right) { return operation("div") .inputs(left, right) .outputs(prototype(left.type(), left.shape())) .operation(ptr -> arrayfire_h.af_div(ptr, left.dereference(), right.dereference(), true)) .grads((result, grads) -> { - var rightReciprocal = af.div(af.constant(1f).cast(left.type()).tileAs(right), right); + var rightReciprocal = div(constant(1f).cast(left.type()).tileAs(right), right); var leftGrads = mul(rightReciprocal, grads); - var rightGrads = af.mul(af.mul(leftGrads, left.negate()), rightReciprocal); + var rightGrads = mul(mul(leftGrads, left.negate()), rightReciprocal); return new TensorPair<>(leftGrads, rightGrads); }) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor add( - Tensor left, Tensor right) { + public static , S extends Shape> Tensor add(Tensor left, + Tensor right) { return operation("add") .inputs(left, right) .outputs(prototype(left.type(), left.shape())) @@ -863,8 +858,8 @@ private static void checkTileableIsSmaller(Tensor left, Tileable< .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor sub( - Tensor left, Tensor right) { + public static , S extends Shape> Tensor sub(Tensor left, + Tensor right) { return operation("sub") .inputs(left, right) .outputs(prototype(left.type(), left.shape())) @@ -873,8 +868,8 @@ private static void checkTileableIsSmaller(Tensor left, Tileable< .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor ge( - Tensor tensor, Tensor rhs) { + public static , S extends Shape> Tensor ge(Tensor tensor, + Tensor rhs) { return operation("ge") .inputs(tensor, rhs) .outputs(prototype(B8, tensor.shape())) @@ -882,8 +877,8 @@ private static void checkTileableIsSmaller(Tensor left, Tileable< .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor le( - Tensor tensor, Tensor rhs) { + public static , S extends Shape> Tensor le(Tensor tensor, + Tensor rhs) { return operation("le") .inputs(tensor, rhs) .outputs(prototype(B8, tensor.shape())) @@ -891,8 +886,7 @@ private static void checkTileableIsSmaller(Tensor left, Tileable< .build(); } - public static , D1 extends Num, D2 extends Num, D3 extends Num> Tensor and( - Tensor left, Tensor right) { + public static > Tensor and(Tensor left, Tensor right) { return operation("and") .inputs(left, right) .outputs(prototype(B8, left.shape())) @@ -900,271 +894,277 @@ public static , D1 extends Num, D2 extends Num, D3 exten .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor maxof( - Tensor left, Tensor right) { + public static , S extends Shape> Tensor maxof(Tensor left, + Tensor right) { return operation("maxof") .inputs(left, right) .outputs(prototype(left.type(), left.shape())) .operation(ptr -> arrayfire_h.af_maxof(ptr, left.dereference(), right.dereference(), true)) .grads((result, grads) -> { - var leftIsMax = af.eq(result, left).cast(left.type()); - var rightIsMax = af.eq(result, right).cast(left.type()); + var leftIsMax = eq(result, left).cast(left.type()); + var rightIsMax = eq(result, right).cast(left.type()); return new TensorPair<>(mul(leftIsMax, grads), mul(rightIsMax, grads)); }) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor minof( - Tensor left, Tensor right) { + public static , S extends Shape> Tensor minof(Tensor left, + Tensor right) { return operation("minof") .inputs(left, right) .outputs(prototype(left.type(), left.shape())) .operation(ptr -> arrayfire_h.af_minof(ptr, left.dereference(), right.dereference(), true)) .grads((result, grads) -> { - var leftIsMin = af.eq(result, left).cast(left.type()); - var rightIsMin = af.eq(result, right).cast(left.type()); + var leftIsMin = eq(result, left).cast(left.type()); + var rightIsMin = eq(result, right).cast(left.type()); return new TensorPair<>(mul(leftIsMin, grads), mul(rightIsMin, grads)); }) .build(); } - public static , LD0 extends Num, RD0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor join( - Tensor lhs, Tensor rhs) { + public static , LD0 extends Num, RD0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> join( + Tensor> lhs, Tensor> rhs) { return operation("join") .inputs(lhs, rhs) - .outputs( - prototype(lhs.type(), shape(n(lhs.d0().size() + rhs.d0().size()), lhs.d1(), lhs.d2(), lhs.d3()))) + .outputs(prototype(lhs.type(), + shape(n(lhs.shape().d0().size() + rhs.shape().d0().size()), lhs.shape().d1(), lhs.shape().d2(), + lhs.shape().d3()))) .operation(ptr -> arrayfire_h.af_join(ptr, 0, lhs.dereference(), rhs.dereference())) - .grads((result, grads) -> new TensorPair<>(index(grads, seq(lhs.d0())), - index(grads, seq(lhs.d0().size(), rhs.d0())))) + .grads((result, grads) -> new TensorPair<>(index(grads, seq(lhs.shape().d0())), + index(grads, seq(lhs.shape().d0().size(), rhs.shape().d0())))) .build(); } - public static , LD1 extends Num, RD1 extends Num, D0 extends Num, D2 extends Num, D3 extends Num> Tensor join( - Tensor lhs, Tensor rhs, arrayfire.D1 ignored) { - if (!(lhs.d0().size() == rhs.d0().size() && lhs.d2().size() == rhs.d2().size() && - lhs.d3().size() == rhs.d3().size())) { + public static , LD1 extends Num, RD1 extends Num, D0 extends Num, D2 extends Num, D3 extends Num> Tensor> join( + Tensor> lhs, Tensor> rhs, arrayfire.D1 ignored) { + if (!(lhs.shape().d0().size() == rhs.shape().d0().size() && + lhs.shape().d2().size() == rhs.shape().d2().size() && + lhs.shape().d3().size() == rhs.shape().d3().size())) { throw new IllegalArgumentException( String.format("Incompatible shapes to join along d1: %s, %s", lhs.shape(), rhs.shape())); } return operation("join") .inputs(lhs, rhs) - .outputs( - prototype(lhs.type(), shape(lhs.d0(), n(lhs.d1().size() + rhs.d1().size()), lhs.d2(), lhs.d3()))) + .outputs(prototype(lhs.type(), + shape(lhs.shape().d0(), n(lhs.shape().d1().size() + rhs.shape().d1().size()), lhs.shape().d2(), + lhs.shape().d3()))) .operation(ptr -> arrayfire_h.af_join(ptr, 1, lhs.dereference(), rhs.dereference())) - .grads((result, grads) -> new TensorPair<>(index(grads, span(), seq(lhs.d1())), - index(grads, span(), seq(lhs.d1().size(), rhs.d1())))) + .grads((result, grads) -> new TensorPair<>(index(grads, span(), seq(lhs.shape().d1())), + index(grads, span(), seq(lhs.shape().d1().size(), rhs.shape().d1())))) .build(); } - public static , LD2 extends Num, RD2 extends Num, D0 extends Num, D1 extends Num, D3 extends Num> Tensor join( - Tensor lhs, Tensor rhs, arrayfire.D2 ignored) { - if (!(lhs.d0().size() == rhs.d0().size() && lhs.d1().size() == rhs.d1().size() && - lhs.d3().size() == rhs.d3().size())) { + public static , LD2 extends Num, RD2 extends Num, D0 extends Num, D1 extends Num, D3 extends Num> Tensor> join( + Tensor> lhs, Tensor> rhs, arrayfire.D2 ignored) { + if (!(lhs.shape().d0().size() == rhs.shape().d0().size() && + lhs.shape().d1().size() == rhs.shape().d1().size() && + lhs.shape().d3().size() == rhs.shape().d3().size())) { throw new IllegalArgumentException( String.format("Incompatible shapes to join along d2: %s, %s", lhs.shape(), rhs.shape())); } return operation("join") .inputs(lhs, rhs) - .outputs( - prototype(lhs.type(), shape(lhs.d0(), lhs.d1(), n(lhs.d2().size() + rhs.d2().size()), lhs.d3()))) + .outputs(prototype(lhs.type(), + shape(lhs.shape().d0(), lhs.shape().d1(), n(lhs.shape().d2().size() + rhs.shape().d2().size()), + lhs.shape().d3()))) .operation(ptr -> arrayfire_h.af_join(ptr, 2, lhs.dereference(), rhs.dereference())) - .grads((result, grads) -> new TensorPair<>(index(grads, span(), span(), seq(lhs.d2())), - index(grads, span(), span(), seq(lhs.d2().size(), rhs.d2())))) + .grads((result, grads) -> new TensorPair<>(index(grads, span(), span(), seq(lhs.shape().d2())), + index(grads, span(), span(), seq(lhs.shape().d2().size(), rhs.shape().d2())))) .build(); } - public static , LD3 extends Num, RD3 extends Num, D0 extends Num, D1 extends Num, D2 extends Num> Tensor join( - Tensor lhs, Tensor rhs, arrayfire.D3 ignored) { - if (!(lhs.d0().size() == rhs.d0().size() && lhs.d1().size() == rhs.d1().size() && - lhs.d2().size() == rhs.d2().size())) { + public static , LD3 extends Num, RD3 extends Num, D0 extends Num, D1 extends Num, D2 extends Num> Tensor> join( + Tensor> lhs, Tensor> rhs, arrayfire.D3 ignored) { + if (!(lhs.shape().d0().size() == rhs.shape().d0().size() && + lhs.shape().d1().size() == rhs.shape().d1().size() && + lhs.shape().d2().size() == rhs.shape().d2().size())) { throw new IllegalArgumentException( String.format("Incompatible shapes to join along d3: %s, %s", lhs.shape(), rhs.shape())); } return operation("join") .inputs(lhs, rhs) - .outputs( - prototype(lhs.type(), shape(lhs.d0(), lhs.d1(), lhs.d2(), n(lhs.d3().size() + rhs.d3().size())))) + .outputs(prototype(lhs.type(), shape(lhs.shape().d0(), lhs.shape().d1(), lhs.shape().d2(), + n(lhs.shape().d3().size() + rhs.shape().d3().size())))) .operation(ptr -> arrayfire_h.af_join(ptr, 3, lhs.dereference(), rhs.dereference())) - .grads((result, grads) -> new TensorPair<>( - index(grads, span(), span(), span(), seq(lhs.d3())), - index(grads, span(), span(), span(), seq(lhs.d3().size(), rhs.d3())))) + .grads( + (result, grads) -> new TensorPair<>(index(grads, span(), span(), span(), seq(lhs.shape().d3())), + index(grads, span(), span(), span(), seq(lhs.shape().d3().size(), rhs.shape().d3())))) .build(); } - public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor sum( - Tensor tensor) { + public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> sum( + Tensor> tensor) { return sum(tensor, D0); } - public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor sum( - Tensor tensor, arrayfire.D0 dim) { + public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> sum( + Tensor> tensor, arrayfire.D0 dim) { return reduce("sum", tensor, arrayfire_h::af_sum, dim, tensor.type().sumType()) .grads((result, grads) -> grads.cast(tensor.type()).tileAs(tensor)) .build(); } - public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor sum( - Tensor tensor, arrayfire.D1 dim) { + public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> sum( + Tensor> tensor, arrayfire.D1 dim) { return reduce("sum", tensor, arrayfire_h::af_sum, dim, tensor.type().sumType()) .grads((result, grads) -> grads.cast(tensor.type()).tileAs(tensor)) .build(); } - public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor sum( - Tensor tensor, arrayfire.D2 dim) { + public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> sum( + Tensor> tensor, arrayfire.D2 dim) { return reduce("sum", tensor, arrayfire_h::af_sum, dim, tensor.type().sumType()) .grads((result, grads) -> grads.cast(tensor.type()).tileAs(tensor)) .build(); } - public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor sum( - Tensor tensor, arrayfire.D3 dim) { + public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> sum( + Tensor> tensor, arrayfire.D3 dim) { return reduce("sum", tensor, arrayfire_h::af_sum, dim, tensor.type().sumType()) .grads((result, grads) -> grads.cast(tensor.type()).tileAs(tensor)) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor mean( - Tensor tensor) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> mean( + Tensor> tensor) { return mean(tensor, D0); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor mean( - Tensor tensor, arrayfire.D0 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> mean( + Tensor> tensor, arrayfire.D0 dim) { return reduce("mean", tensor, arrayfire_h::af_mean, dim, tensor.type()) .grads((result, grads) -> af.div(grads.tileAs(tensor), - af.constant(tensor.type(), tensor.d0().size()).tileAs(tensor))) + af.constant(tensor.type(), tensor.shape().d0().size()).tileAs(tensor))) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor mean( - Tensor tensor, arrayfire.D1 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> mean( + Tensor> tensor, arrayfire.D1 dim) { return reduce("mean", tensor, arrayfire_h::af_mean, dim, tensor.type()) .grads((result, grads) -> af.div(grads.tileAs(tensor), - af.constant(tensor.type(), tensor.d1().size()).tileAs(tensor))) + af.constant(tensor.type(), tensor.shape().d1().size()).tileAs(tensor))) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor mean( - Tensor tensor, arrayfire.D2 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> mean( + Tensor> tensor, arrayfire.D2 dim) { return reduce("mean", tensor, arrayfire_h::af_mean, dim, tensor.type()) .grads((result, grads) -> af.div(grads.tileAs(tensor), - af.constant(tensor.type(), tensor.d2().size()).tileAs(tensor))) + af.constant(tensor.type(), tensor.shape().d2().size()).tileAs(tensor))) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor mean( - Tensor tensor, arrayfire.D3 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> mean( + Tensor> tensor, arrayfire.D3 dim) { return reduce("mean", tensor, arrayfire_h::af_mean, dim, tensor.type()) .grads((result, grads) -> af.div(grads.tileAs(tensor), - af.constant(tensor.type(), tensor.d3().size()).tileAs(tensor))) + af.constant(tensor.type(), tensor.shape().d3().size()).tileAs(tensor))) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor median( - Tensor tensor) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> median( + Tensor> tensor) { return median(tensor, D0); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor median( - Tensor tensor, arrayfire.D0 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> median( + Tensor> tensor, arrayfire.D0 dim) { return reduce("median", tensor, arrayfire_h::af_median, dim, tensor.type()).build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor median( - Tensor tensor, arrayfire.D1 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> median( + Tensor> tensor, arrayfire.D1 dim) { return reduce("median", tensor, arrayfire_h::af_median, dim, tensor.type()).build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor median( - Tensor tensor, arrayfire.D2 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> median( + Tensor> tensor, arrayfire.D2 dim) { return reduce("median", tensor, arrayfire_h::af_median, dim, tensor.type()).build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor median( - Tensor tensor, arrayfire.D3 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> median( + Tensor> tensor, arrayfire.D3 dim) { return reduce("median", tensor, arrayfire_h::af_median, dim, tensor.type()).build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor max( - Tensor tensor) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> max( + Tensor> tensor) { return max(tensor, D0); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor max( - Tensor tensor, arrayfire.D0 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> max( + Tensor> tensor, arrayfire.D0 dim) { return reduce("max", tensor, arrayfire_h::af_max, dim, tensor.type()) .grads((result, grads) -> mul(af.eq(result.tileAs(tensor), tensor).cast(grads.type()), grads.tileAs(tensor))) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor max( - Tensor tensor, arrayfire.D1 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> max( + Tensor> tensor, arrayfire.D1 dim) { return reduce("max", tensor, arrayfire_h::af_max, dim, tensor.type()) .grads((result, grads) -> mul(af.eq(result.tileAs(tensor), tensor).cast(grads.type()), grads.tileAs(tensor))) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor max( - Tensor tensor, arrayfire.D2 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> max( + Tensor> tensor, arrayfire.D2 dim) { return reduce("max", tensor, arrayfire_h::af_max, dim, tensor.type()) .grads((result, grads) -> mul(af.eq(result.tileAs(tensor), tensor).cast(grads.type()), grads.tileAs(tensor))) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor max( - Tensor tensor, arrayfire.D3 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> max( + Tensor> tensor, arrayfire.D3 dim) { return reduce("max", tensor, arrayfire_h::af_max, dim, tensor.type()) .grads((result, grads) -> mul(af.eq(result.tileAs(tensor), tensor).cast(grads.type()), grads.tileAs(tensor))) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor min( - Tensor tensor) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> min( + Tensor> tensor) { return min(tensor, D0); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor min( - Tensor tensor, arrayfire.D0 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> min( + Tensor> tensor, arrayfire.D0 dim) { return reduce("min", tensor, arrayfire_h::af_min, dim, tensor.type()) .grads((result, grads) -> mul(af.eq(result.tileAs(tensor), tensor).cast(grads.type()), grads.tileAs(tensor))) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor min( - Tensor tensor, arrayfire.D1 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> min( + Tensor> tensor, arrayfire.D1 dim) { return reduce("min", tensor, arrayfire_h::af_min, dim, tensor.type()) .grads((result, grads) -> mul(af.eq(result.tileAs(tensor), tensor).cast(grads.type()), grads.tileAs(tensor))) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor min( - Tensor tensor, arrayfire.D2 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> min( + Tensor> tensor, arrayfire.D2 dim) { return reduce("min", tensor, arrayfire_h::af_min, dim, tensor.type()) .grads((result, grads) -> mul(af.eq(result.tileAs(tensor), tensor).cast(grads.type()), grads.tileAs(tensor))) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor min( - Tensor tensor, arrayfire.D3 dim) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> min( + Tensor> tensor, arrayfire.D3 dim) { return reduce("min", tensor, arrayfire_h::af_min, dim, tensor.type()) .grads((result, grads) -> mul(af.eq(result.tileAs(tensor), tensor).cast(grads.type()), grads.tileAs(tensor))) .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> ImaxResult imax( - Tensor tensor) { - var shape = shape(u(), tensor.d1(), tensor.d2(), tensor.d3()); + public static , D1 extends Num, D2 extends Num, D3 extends Num> ImaxResult> imax( + Tensor> tensor) { + var shape = shape(u(), tensor.shape().d1(), tensor.shape().d2(), tensor.shape().d3()); var pair = operation("imax") .inputs(tensor) .outputs(prototype(tensor.type(), shape), prototype(U32, shape)) @@ -1174,9 +1174,9 @@ public static , D1 extends Num, D2 extends Num, D3 exten return new ImaxResult<>(pair.left(), pair.right()); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, K extends Num> TopKResult topk( - Tensor tensor, K k) { - var shape = shape(k, tensor.d1(), tensor.d2(), tensor.d3()); + public static , D1 extends Num, D2 extends Num, D3 extends Num, K extends Num> TopKResult> topk( + Tensor> tensor, K k) { + var shape = shape(k, tensor.shape().d1(), tensor.shape().d2(), tensor.shape().d3()); var pair = operation("topk") .inputs(tensor) .outputs(prototype(tensor.type(), shape), prototype(U32, shape)) @@ -1187,38 +1187,41 @@ public static , D1 extends Num, D2 extends Num, D3 exten return new TopKResult<>(pair.left(), pair.right()); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor diag( - Tensor tensor) { + public static , D0 extends Num, D2 extends Num, D3 extends Num> Tensor> diag( + Tensor> tensor) { return operation("diag") .inputs(tensor) - .outputs(prototype(tensor.type(), shape(tensor.d0(), tensor.d0(), tensor.d2(), tensor.d3()))) + .outputs(prototype(tensor.type(), + shape(tensor.shape().d0(), tensor.shape().d0(), tensor.shape().d2(), tensor.shape().d3()))) .operation(ptr -> arrayfire_h.af_diag_create(ptr, tensor.dereference(), 0)) // TODO: Implement grad function. .build(); } // https://arrayfire.org/docs/group__blas__func__matmul.htm - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, OD1 extends Num> Tensor matmul( - Tensor left, Tensor right) { - if (left.d1().size() != right.d0().size()) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, OD1 extends Num> Tensor> matmul( + Tensor> left, Tensor> right) { + if (left.shape().d1().size() != right.shape().d0().size()) { throw new IllegalArgumentException( String.format("Incompatible shapes for matmul, left: %s right: %s", left.shape(), right.shape())); } return operation("matmul") .inputs(left, right) - .outputs(prototype(left.type(), shape(left.d0(), right.d1(), left.d2(), left.d3()))) + .outputs(prototype(left.type(), + shape(left.shape().d0(), right.shape().d1(), left.shape().d2(), left.shape().d3()))) .operation(ptr -> arrayfire_h.af_matmul(ptr, left.dereference(), right.dereference(), 0, 0)) .grads((result, grads) -> { - var leftGrads = matmul(grads, right.transpose()); - var rightGrads = matmul(left.transpose(), grads); + var leftGrads = matmul(grads, transpose(right)); + var rightGrads = matmul(transpose(left), grads); return new TensorPair<>(leftGrads, rightGrads); }) .build(); } - public static , AD0 extends Num, AD1 extends Num, BD1 extends Num, CD1 extends Num, D2 extends Num, D3 extends Num> Tensor matmul( - Tensor a, Tensor b, Tensor c) { - if (a.d0().size() * b.d1().size() < b.d0().size() * c.d1().size()) { + public static , AD0 extends Num, AD1 extends Num, BD1 extends Num, CD1 extends Num, D2 extends Num, D3 extends Num> Tensor> matmul( + Tensor> a, Tensor> b, + Tensor> c) { + if (a.shape().d0().size() * b.shape().d1().size() < b.shape().d0().size() * c.shape().d1().size()) { var tmp = matmul(a, b); var result = matmul(tmp, c); tmp.release(); @@ -1231,8 +1234,9 @@ public static , D1 extends Num, D2 extends Num, D3 exten } } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor clamp( - Tensor tensor, Tensor lo, Tensor hi) { + public static , S extends Shape> Tensor clamp(Tensor tensor, + Tensor lo, + Tensor hi) { return operation("clamp") .inputs(tensor) .outputs(prototype(tensor)) @@ -1247,14 +1251,13 @@ public static , D1 extends Num, D2 extends Num, D3 exten } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor relu( - Tensor tensor) { + public static , S extends Shape> Tensor relu(Tensor tensor) { return clamp(tensor, constant(tensor.type(), 0f).tileAs(tensor), constant(tensor.type(), Double.POSITIVE_INFINITY).tileAs(tensor)); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor eq( - Tensor left, Tensor right) { + public static , S extends Shape> Tensor eq(Tensor left, + Tensor right) { return operation("eq") .inputs(left, right) .outputs(prototype(B8, left.shape())) @@ -1262,14 +1265,12 @@ public static , D1 extends Num, D2 extends Num, D3 exten .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor negate( - Tensor tensor) { + public static , S extends Shape> Tensor negate(Tensor tensor) { var minusOne = constant(tensor.type(), tensor.shape(), -1); return mul(tensor, minusOne); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor exp( - Tensor tensor) { + public static , S extends Shape> Tensor exp(Tensor tensor) { return operation("exp") .inputs(tensor) .outputs(prototype(tensor)) @@ -1278,13 +1279,13 @@ public static , D1 extends Num, D2 extends Num, D3 exten .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor pow( - Tensor tensor, double pow) { + public static , S extends Shape> Tensor pow(Tensor tensor, + double pow) { return pow(tensor, constant(tensor.type(), tensor.shape(), pow)); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor pow( - Tensor tensor, Tensor pow) { + public static , S extends Shape> Tensor pow(Tensor tensor, + Tensor pow) { return operation("pow") .inputs(tensor) .outputs(prototype(tensor)) @@ -1297,8 +1298,7 @@ public static , D1 extends Num, D2 extends Num, D3 exten /** * Returns 1 for negative numbers and 0 for positive numbers. */ - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor signbit( - Tensor tensor) { + public static , S extends Shape> Tensor signbit(Tensor tensor) { return operation("signbit") .inputs(tensor) .outputs(tensor.prototype()) @@ -1309,8 +1309,7 @@ public static , D1 extends Num, D2 extends Num, D3 exten /** * Returns -1 for negative numbers and 1 for positive numbers. */ - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor signum( - Tensor tensor) { + public static , S extends Shape> Tensor signum(Tensor tensor) { return operation("signum") .inputs(tensor) .outputs(tensor.prototype()) @@ -1323,8 +1322,7 @@ public static Operation.Builder operation(String name) { return new Operation.Builder().name(name); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor log( - Tensor tensor) { + public static , S extends Shape> Tensor log(Tensor tensor) { return operation("log") .inputs(tensor) .outputs(prototype(tensor)) @@ -1333,8 +1331,7 @@ public static Operation.Builder operation(String name) { .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor abs( - Tensor input) { + public static , S extends Shape> Tensor abs(Tensor input) { return operation("abs") .inputs(input) .outputs(prototype(input)) @@ -1343,8 +1340,7 @@ public static Operation.Builder operation(String name) { .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor sqrt( - Tensor tensor) { + public static , S extends Shape> Tensor sqrt(Tensor tensor) { return operation("sqrt") .inputs(tensor) .outputs(prototype(tensor)) @@ -1353,12 +1349,12 @@ public static Operation.Builder operation(String name) { .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor softmax( - Tensor tensor) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> softmax( + Tensor> tensor) { return softmax(tensor, 1f); } - public static Function tidyOperation(Supplier> fn) { + public static Function tidyOperation(Supplier> fn) { return ptr -> { var result = tidy(fn); ptr.copyFrom(result.segment()); @@ -1367,8 +1363,8 @@ public static Function tidyOperation(Supplier, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor softmax( - Tensor tensor, float temperature) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> softmax( + Tensor> tensor, float temperature) { return operation("softmax").inputs(tensor).outputs(prototype(tensor)).operation(tidyOperation(() -> { var max = max(tensor); var normalized = sub(tensor, max.tileAs(tensor)); @@ -1378,24 +1374,23 @@ public static , D0 extends Num, D1 extends Num, D // Compact all dimensions except the first into a batch dimension, so we have a spare dimension for the jacobian. var shape = result.shape(); var workingShape = af.shape(shape.d0(), af.u(), - af.b(result.d1().size() * result.d2().size() * result.d3().size())); + af.b(result.shape().d1().size() * result.shape().d2().size() * result.shape().d3().size())); var resultTensor = result.reshape(workingShape); var gradsTensor = grads.reshape(workingShape); var positives = af.mul(resultTensor, gradsTensor); - var negatives = af.matmul(resultTensor, resultTensor.transpose(), gradsTensor); + var negatives = af.matmul(resultTensor, transpose(resultTensor), gradsTensor); var inputGrads = af.sub(positives, negatives); return inputGrads.reshape(tensor.shape()); }).build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor sigmoid( - Tensor tensor) { + public static , S extends Shape> Tensor sigmoid(Tensor tensor) { var one = ones(tensor); return div(one, add(one, exp(negate(tensor)))); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor sparse( - Tensor tensor, Storage storage) { + public static , S extends Shape> Tensor sparse(Tensor tensor, + Storage storage) { return operation("sparse") .inputs(tensor) .outputs(prototype(tensor.type(), tensor.shape())) @@ -1405,48 +1400,48 @@ public static , D0 extends Num, D1 extends Num, D .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor index( - Tensor tensor, Index i0) { - return index(tensor, i0, seq(tensor.d1()), seq(tensor.d2()), seq(tensor.d3())); + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> index( + Tensor> tensor, Index i0) { + return index(tensor, i0, seq(tensor.shape().d1()), seq(tensor.shape().d2()), seq(tensor.shape().d3())); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor index( - Tensor tensor, Index i0, Index i1) { - return index(tensor, i0, i1, seq(tensor.d2()), seq(tensor.d3())); + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> index( + Tensor> tensor, Index i0, Index i1) { + return index(tensor, i0, i1, seq(tensor.shape().d2()), seq(tensor.shape().d3())); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor index( - Tensor tensor, Span ignored0, Index i1) { - return index(tensor, seq(tensor.d0()), i1, seq(tensor.d2()), seq(tensor.d3())); + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> index( + Tensor> tensor, Span ignored0, Index i1) { + return index(tensor, seq(tensor.shape().d0()), i1, seq(tensor.shape().d2()), seq(tensor.shape().d3())); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor index( - Tensor tensor, Span ignored0, Span ignored1, Index i2) { - return index(tensor, seq(tensor.d0()), seq(tensor.d1()), i2, seq(tensor.d3())); + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> index( + Tensor> tensor, Span ignored0, Span ignored1, Index i2) { + return index(tensor, seq(tensor.shape().d0()), seq(tensor.shape().d1()), i2, seq(tensor.shape().d3())); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor index( - Tensor tensor, Span ignored0, Index i1, Index i2) { - return index(tensor, seq(tensor.d0()), i1, i2, seq(tensor.d3())); + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> index( + Tensor> tensor, Span ignored0, Index i1, Index i2) { + return index(tensor, seq(tensor.shape().d0()), i1, i2, seq(tensor.shape().d3())); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor index( - Tensor tensor, Index i0, Span ignored1, Index i2) { - return index(tensor, i0, seq(tensor.d1()), i2, seq(tensor.d3())); + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> index( + Tensor> tensor, Index i0, Span ignored1, Index i2) { + return index(tensor, i0, seq(tensor.shape().d1()), i2, seq(tensor.shape().d3())); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor index( - Tensor tensor, Index i0, Index i1, Index i2) { - return index(tensor, i0, i1, i2, seq(tensor.d3())); + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> index( + Tensor> tensor, Index i0, Index i1, Index i2) { + return index(tensor, i0, i1, i2, seq(tensor.shape().d3())); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor index( - Tensor tensor, Span ignored0, Span ignored1, Span ignored2, Index i3) { - return index(tensor, seq(tensor.d0()), seq(tensor.d1()), seq(tensor.d2()), i3); + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> index( + Tensor> tensor, Span ignored0, Span ignored1, Span ignored2, Index i3) { + return index(tensor, seq(tensor.shape().d0()), seq(tensor.shape().d1()), seq(tensor.shape().d2()), i3); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor index( - Tensor tensor, Index i0, Index i1, Index i2, Index i3) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> index( + Tensor tensor, Index i0, Index i1, Index i2, Index i3) { return operation("index") .inputs(tensor) .outputs(prototype(tensor.type(), @@ -1470,34 +1465,34 @@ public static , D0 extends Num, D1 extends Num, D } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> List> batch( - Tensor tensor, int batchSize) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> List>> batch( + Tensor> tensor, int batchSize) { return batch(tensor, ArrayFire::n, batchSize); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, BDT extends Num> List> batch( - Tensor tensor, Function type, int batchSize) { - var results = new ArrayList>(); - var d0Seq = seq(tensor.d0()); - for (int i = 0; i < tensor.d1().size(); i += batchSize) { - var computedD1Size = Math.min(batchSize, tensor.d1().size() - i); + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, BDT extends Num> List>> batch( + Tensor> tensor, Function type, int batchSize) { + var results = new ArrayList>>(); + var d0Seq = seq(tensor.shape().d0()); + for (int i = 0; i < tensor.shape().d1().size(); i += batchSize) { + var computedD1Size = Math.min(batchSize, tensor.shape().d1().size() - i); var slice = index(tensor, d0Seq, seq(i, i + computedD1Size - 1)); - results.add(slice.reshape(shape(tensor.d0(), type.apply(computedD1Size)))); + results.add(slice.reshape(shape(tensor.shape().d0(), type.apply(computedD1Size)))); } return results; } @SuppressWarnings({"unchecked", "rawtypes"}) - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, OD0 extends Num, OD1 extends Num, OD2 extends Num, OD3 extends Num> Tensor tileAs( - Tensor tensor, Shape newShape) { + public static , S extends Shape, NS extends Shape> Tensor tileAs( + Tensor tensor, NS newShape) { if (newShape.capacity() % tensor.shape().capacity() != 0) { throw new IllegalArgumentException( String.format("Can't tile perfectly from %s to %s", tensor.shape(), newShape)); } - int d0ratio = newShape.d0().size() / tensor.d0().size(); - int d1ratio = newShape.d1().size() / tensor.d1().size(); - int d2ratio = newShape.d2().size() / tensor.d2().size(); - int d3ratio = newShape.d3().size() / tensor.d3().size(); + int d0ratio = newShape.d0().size() / tensor.shape().d0().size(); + int d1ratio = newShape.d1().size() / tensor.shape().d1().size(); + int d2ratio = newShape.d2().size() / tensor.shape().d2().size(); + int d3ratio = newShape.d3().size() / tensor.shape().d3().size(); return operation("tile") .inputs(tensor) .outputs(prototype(tensor.type(), newShape)) @@ -1507,39 +1502,38 @@ public static , D0 extends Num, D1 extends Num, D } @SuppressWarnings({"unchecked", "rawtypes"}) - public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, OD0 extends Num, OD1 extends Num, OD2 extends Num, OD3 extends Num> Tensor sumAs( - Tensor input, Shape newShape) { + public static , T extends DataType, S extends Shape, NS extends Shape> Tensor sumAs( + Tensor input, NS newShape) { // I think there is a nicer way to do this in at most two operations. Tensor result = input; - if (newShape.d0() != input.d0()) { + if (newShape.d0() != input.shape().d0()) { if (newShape.d0().size() != 1) throw new IllegalArgumentException("Can't sum over D0 from " + input.shape() + " to " + newShape); result = sum(result); } - if (newShape.d1() != input.d1()) { + if (newShape.d1() != input.shape().d1()) { if (newShape.d1().size() != 1) throw new IllegalArgumentException("Can't sum over D1 from " + input.shape() + " to " + newShape); result = sum(result); } - if (newShape.d2() != input.d2()) { + if (newShape.d2() != input.shape().d2()) { if (newShape.d2().size() != 1) throw new IllegalArgumentException("Can't sum over D2 from " + input.shape() + " to " + newShape); result = sum(result); } - if (newShape.d3() != input.d3()) { + if (newShape.d3() != input.shape().d3()) { if (newShape.d3().size() != 1) throw new IllegalArgumentException("Can't sum over D3 from " + input.shape() + " to " + newShape); result = sum(result); } - return ((Tensor) result).reshape(newShape); + return reshape(((Tensor) result), newShape); } - public static > Tensor flatten(Tensor tensor) { + public static > Tensor> flatten(Tensor tensor) { return reshape(tensor, shape(tensor.shape().capacity())); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor flip( - Tensor tensor) { + public static , S extends Shape> Tensor flip(Tensor tensor) { return operation("flip") .inputs(tensor) .outputs(prototype(tensor)) @@ -1548,31 +1542,33 @@ public static , D0 extends Num, D1 extends Num, D .build(); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, FD3 extends Num> Tensor convolve2( - Tensor tensor, Tensor filters) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, FD0 extends Num, FD1 extends Num, FD3 extends Num> Tensor> convolve2( + Tensor> tensor, Tensor> filters) { return convolve2(tensor, filters, shape(1, 1), shape(0, 0), shape(1, 1)); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, FD3 extends Num> Tensor convolve2( - Tensor tensor, Tensor filters, Shape stride) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, FD0 extends Num, FD1 extends Num, FD3 extends Num> Tensor> convolve2( + Tensor> tensor, Tensor> filters, + Shape stride) { return convolve2(tensor, filters, stride, shape(0, 0), shape(1, 1)); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, FD3 extends Num> Tensor convolve2( - Tensor tensor, Tensor filters, Shape stride, + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, FD0 extends Num, FD1 extends Num, FD3 extends Num> Tensor> convolve2( + Tensor> tensor, Tensor> filters, Shape stride, Shape padding) { return convolve2(tensor, filters, stride, padding, shape(1, 1)); } - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, FD3 extends Num> Tensor convolve2( - Tensor tensor, Tensor filters, Shape stride, + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, FD0 extends Num, FD1 extends Num, FD3 extends Num> Tensor> convolve2( + Tensor> tensor, Tensor> filters, Shape stride, Shape padding, Shape dilation) { // TODO: CoPilot wrote this, needs tests. - var computedShape = shape( - n((tensor.d0().size() + 2 * padding.d0().size() - (filters.d0().size() - 1) * dilation.d0().size() - 1) / - stride.d0().size() + 1), - n((tensor.d1().size() + 2 * padding.d1().size() - (filters.d1().size() - 1) * dilation.d1().size() - 1) / - stride.d1().size() + 1), filters.d3(), tensor.d3()); + var computedShape = shape(n((tensor.shape().d0().size() + 2 * padding.d0().size() - + (filters.shape().d0().size() - 1) * dilation.d0().size() - 1) / + stride.d0().size() + 1), + n((tensor.shape().d1().size() + 2 * padding.d1().size() - + (filters.shape().d1().size() - 1) * dilation.d1().size() - 1) / stride.d1().size() + 1), + filters.shape().d3(), tensor.shape().d3()); return operation("convolve2") .inputs(tensor, filters) .outputs(prototype(tensor.type(), computedShape)) @@ -1588,8 +1584,8 @@ public static , D0 extends Num, D1 extends Num, D /** * L2 norm. */ - public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor norm( - Tensor tensor) { + public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> norm( + Tensor> tensor) { var mul = mul(tensor, tensor); var sum = sum(mul); return sqrt(sum); @@ -1598,27 +1594,27 @@ public static , D0 extends Num, D1 extends Num, D /** * Normalize by dividing by the L2 norm. */ - public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor normalize( - Tensor tensor) { + public static , T extends DataType, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> normalize( + Tensor> tensor) { return div(cast(tensor, tensor.type().sumType()), norm(tensor).tileAs(tensor.shape())); } /** * Center by subtracting the average. */ - public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor center( - Tensor tensor) { + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> center( + Tensor> tensor) { return sub(tensor, mean(tensor).tileAs(tensor)); } // svd - public static , D0 extends Num, D1 extends Num> SvdResult svd( - Tensor tensor) { + public static , D0 extends Num, D1 extends Num> SvdResult svd( + Tensor> tensor) { var trio = operation("svd") .inputs(tensor) - .outputs(prototype(tensor.type(), shape(tensor.d0(), tensor.d0())), - prototype(tensor.type(), shape(tensor.d0())), - prototype(tensor.type(), shape(tensor.d1(), tensor.d1()))) + .outputs(prototype(tensor.type(), shape(tensor.shape().d0(), tensor.shape().d0())), + prototype(tensor.type(), shape(tensor.shape().d0())), + prototype(tensor.type(), shape(tensor.shape().d1(), tensor.shape().d1()))) .operation((u, s, v) -> arrayfire_h.af_svd(u, s, v, tensor.dereference())) .build(); return new SvdResult<>(trio.left(), trio.middle(), trio.right()); @@ -1627,34 +1623,34 @@ public static , D0 extends Num, D1 extends Num, D /** * Computes the covariance matrix of the given matrix. */ - public static , D0 extends Num, D1 extends Num> Tensor cov( - Tensor tensor) { + public static , D0 extends Num, D1 extends Num> Tensor> cov( + Tensor> tensor) { return tidy(() -> { var subMean = sub(tensor, mean(tensor, D1).tileAs(tensor)); - var matrix = matmul(subMean, subMean.transpose()); - return div(matrix, constant(matrix.type(), matrix.shape(), tensor.d1().size() - 1.0f)); + var matrix = matmul(subMean, transpose(subMean)); + return div(matrix, constant(matrix.type(), matrix.shape(), tensor.shape().d1().size() - 1.0f)); }); } /** * Computes the ZCA whitening matrix of the given matrix. */ - public static , D0 extends Num, D1 extends Num> Tensor zca( - Tensor tensor) { + public static , D0 extends Num, D1 extends Num> Tensor> zca( + Tensor> tensor) { return tidy(() -> { var cov = cov(tensor); var svd = svd(cov); var invSqrtS = diag(div(constant(svd.s().type(), svd.s().shape(), 1.0f), sqrt(add(svd.s(), constant(svd.s().type(), svd.s().shape(), 1e-5f))))); - return matmul(svd.u(), matmul(invSqrtS, svd.u().transpose())); + return matmul(svd.u(), matmul(invSqrtS, transpose(svd.u()))); }); } /** * Inverts the given matrix. */ - public static , D extends Num> Tensor inverse( - Tensor tensor) { + public static , D extends Num> Tensor> inverse( + Tensor> tensor) { return operation("inverse") .inputs(tensor) .outputs(prototype(tensor)) @@ -1663,8 +1659,8 @@ public static , D0 extends Num, D1 extends Num, D } // TODO: Add uncropped version. - public static , D0 extends Num, D1 extends Num> Tensor rotate( - Tensor tensor, float angle, InterpolationType interpolationType) { + public static , D0 extends Num, D1 extends Num> Tensor> rotate( + Tensor> tensor, float angle, InterpolationType interpolationType) { return operation("rotate") .inputs(tensor) .outputs(prototype(tensor)) @@ -1674,15 +1670,15 @@ public static , D0 extends Num, D1 extends Num, D .build(); } - public static , D0 extends Num, D1 extends Num, ND0 extends Num, ND1 extends Num> Tensor scale( - Tensor tensor, ND0 nd0, ND1 nd1, InterpolationType interpolationType) { + public static , D0 extends Num, D1 extends Num, ND0 extends Num, ND1 extends Num> Tensor> scale( + Tensor> tensor, ND0 nd0, ND1 nd1, InterpolationType interpolationType) { return operation("scale") .inputs(tensor) .outputs(prototype(tensor.type(), shape(nd0, nd1))) - .operation( - ptr -> arrayfire_h.af_scale(ptr, tensor.dereference(), (float) nd0.size() / tensor.d0().size(), - (float) nd1.size() / tensor.d1().size(), nd0.size(), nd1.size(), interpolationType.code())) - .grads((result, grads) -> scale(grads, tensor.d0(), tensor.d1(), interpolationType)) + .operation(ptr -> arrayfire_h.af_scale(ptr, tensor.dereference(), + (float) nd0.size() / tensor.shape().d0().size(), (float) nd1.size() / tensor.shape().d1().size(), + nd0.size(), nd1.size(), interpolationType.code())) + .grads((result, grads) -> scale(grads, tensor.shape().d0(), tensor.shape().d1(), interpolationType)) .build(); } @@ -1805,12 +1801,12 @@ public static U u(int value) { return U; } - public static > T grads(Tensor loss, T tensor) { + public static > T grads(Tensor loss, T tensor) { var graph = new Graph(scope().operations()); return graph.grads(loss, tensor); } - public static void optimize(Tensor loss) { + public static void optimize(Tensor loss) { var graph = new Graph(scope().operations()); graph.optimize(loss); } diff --git a/arrayfire/ArrayFireTest.java b/arrayfire/ArrayFireTest.java index 2202974..73db212 100644 --- a/arrayfire/ArrayFireTest.java +++ b/arrayfire/ArrayFireTest.java @@ -132,7 +132,7 @@ public void sortIndex() { public void permutationIndex() { af.tidy(() -> { var arr = af.create(1, 2, 3, 4, 5, 6, 7, 8).reshape(2, 4); - var permutation = af.permutation(arr.d1()); + var permutation = af.permutation(arr.shape().d1()); var shuffled = af.index(arr, af.span(), permutation); var data = af.data(shuffled); assertArrayEquals(new int[]{5, 6, 1, 2, 7, 8, 3, 4}, data.java()); @@ -143,7 +143,7 @@ public void permutationIndex() { public void transpose() { af.tidy(() -> { var arr = af.create(new float[]{1, 2, 3, 4}).reshape(2, 2); - var transpose = arr.transpose(); + var transpose = af.transpose(arr); assertArrayEquals(new float[]{1, 3, 2, 4}, af.data(transpose).java(), 1E-5f); }); } @@ -172,7 +172,7 @@ public void matmul() { af.tidy(() -> { var left = af.create(new float[]{1, 2, 3, 4}).reshape(a(2), b(2)); var right = af.create(new float[]{1, 2, 3, 4, 5, 6}).reshape(a(2), c(3)); - var result = af.matmul(left.transpose(), right); + var result = af.matmul(af.transpose(left), right); assertArrayEquals(new float[]{5, 11, 11, 25, 17, 39}, data(result).java(), 1E-5f); }); } @@ -182,7 +182,7 @@ public void matmulS32() { af.tidy(() -> { var left = af.create(new float[]{1, 2, 3, 4}).reshape(a(2), b(2)); var right = af.create(new float[]{1, 2, 3, 4, 5, 6}).reshape(a(2), c(3)); - var result = af.matmul(left.transpose(), right); + var result = af.matmul(af.transpose(left), right); assertArrayEquals(new float[]{5, 11, 11, 25, 17, 39}, data(result).java(), 1E-5f); }); } @@ -292,7 +292,7 @@ public void mulScalar() { public void min() { af.tidy(() -> { var data = af.create(new float[]{-5, 12, 0, 1}); - var result = data.min(); + var result = af.min(data); assertArrayEquals(new float[]{-5}, af.data(result).java(), 1e-5f); }); } @@ -384,9 +384,9 @@ public void index4D() { var data = af .create(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}) .reshape(2, 2, 2, 2); - Tensor result = af.index(data, af.seq(af.create(0).castshape(af::a)), - af.seq(af.create(1).castshape(af::b)), af.seq(af.create(0).castshape(af::c)), - af.seq(af.create(1).castshape(af::d))); + Tensor> result = af.index(data, af.seq(af.create(0).reshape(af.a(1))), + af.seq(af.create(1).reshape(af.b(1))), af.seq(af.create(0).reshape(af.c(1))), + af.seq(af.create(1).reshape(af.d(1)))); assertArrayEquals(new float[]{11}, af.data(result).java(), 1E-5f); }); } @@ -554,7 +554,7 @@ public void graph() { af.tidy(() -> { var left = af.create(new float[]{1, 2, 3, 4}).reshape(a(2), b(2)); var right = af.create(new float[]{1, 2, 3, 4, 5, 6}).reshape(a(2), c(3)); - var leftT = left.transpose(); + var leftT = af.transpose(left); var matmul = af.matmul(leftT, right); var softmax = af.softmax(matmul); var sum = af.sum(matmul); diff --git a/arrayfire/GradFunction.java b/arrayfire/GradFunction.java index b507754..b496a1e 100644 --- a/arrayfire/GradFunction.java +++ b/arrayfire/GradFunction.java @@ -1,8 +1,5 @@ package arrayfire; -import arrayfire.DataType; -import arrayfire.Tensor; -import arrayfire.TensorPair; import arrayfire.numbers.Num; import java.util.List; @@ -10,15 +7,13 @@ @FunctionalInterface interface GradFunction { - List> grads(Tensor resultGrads); + List> grads(Tensor resultGrads); - interface Unary, RD0 extends Num, RD1 extends Num, RD2 extends Num, RD3 extends Num, I0T extends DataType, I0D0 extends Num, I0D1 extends Num, I0D2 extends Num, I0D3 extends Num> { - Tensor grads(Tensor result, - Tensor grads); + interface Unary, IT extends Tensor> { + IT grads(RT result, RT grads); } - interface Binary, RD0 extends Num, RD1 extends Num, RD2 extends Num, RD3 extends Num, I0T extends DataType, I0D0 extends Num, I0D1 extends Num, I0D2 extends Num, I0D3 extends Num, I1T extends DataType, I1D0 extends Num, I1D1 extends Num, I1D2 extends Num, I1D3 extends Num> { - TensorPair grads( - Tensor result, Tensor grads); + interface Binary, I0T extends Tensor, I1T extends Tensor> { + TensorPair grads(RT result, RT grads); } } diff --git a/arrayfire/Graph.java b/arrayfire/Graph.java index 29a0a77..1ab3783 100644 --- a/arrayfire/Graph.java +++ b/arrayfire/Graph.java @@ -74,7 +74,7 @@ public void optimize(Tensor loss) { } } - public > T grads(Tensor loss, T tensor) { + public > T grads(Tensor loss, T tensor) { var grads = grads(loss, new Tensor[]{tensor}); return grads.get(tensor); } @@ -153,7 +153,7 @@ void put(Tensor tensor, Tensor grads) { } @SuppressWarnings("unchecked") - public > T get(T tensor) { + public > T get(T tensor) { return (T) gradsByTensor.get(tensor); } } diff --git a/arrayfire/ImaxResult.java b/arrayfire/ImaxResult.java index 7508c87..70a54cf 100644 --- a/arrayfire/ImaxResult.java +++ b/arrayfire/ImaxResult.java @@ -1,7 +1,5 @@ package arrayfire; -import arrayfire.numbers.Num; - -public record ImaxResult, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num>( - Tensor values, Tensor indices) { +public record ImaxResult, S extends Shape>(Tensor values, + Tensor indices) { } \ No newline at end of file diff --git a/arrayfire/Index.java b/arrayfire/Index.java index 7ab7ec9..24502ea 100644 --- a/arrayfire/Index.java +++ b/arrayfire/Index.java @@ -17,12 +17,12 @@ public class Index> { ValueLayout.JAVA_BOOLEAN.withName("isSeq"), ValueLayout.JAVA_BOOLEAN.withName("isBatch"), MemoryLayout.paddingLayout(6)); - private final Tensor arr; + private final Tensor arr; private final Seq seq; private final Function generator; - Index(Tensor arr, Function generator) { + Index(Tensor arr, Function generator) { this.arr = arr; this.seq = null; this.generator = generator; diff --git a/arrayfire/Operation.java b/arrayfire/Operation.java index 7f42c4d..06ec72c 100644 --- a/arrayfire/Operation.java +++ b/arrayfire/Operation.java @@ -1,6 +1,5 @@ package arrayfire; -import arrayfire.numbers.Num; import arrayfire.utils.Functions; import java.lang.foreign.MemorySegment; @@ -39,7 +38,7 @@ public void apply() { } } - public GradFunction grads() { + GradFunction grads() { return grads; } @@ -56,14 +55,13 @@ public Nullary inputs() { return new Nullary(); } - public , I0D0 extends Num, I0D1 extends Num, I0D2 extends Num, I0D3 extends Num> Unary inputs( - Tensor input) { + public > Unary inputs(IT input) { operation.inputs.add(input); return new Unary<>(); } - public , I0D0 extends Num, I0D1 extends Num, I0D2 extends Num, I0D3 extends Num, I1T extends DataType, I1D0 extends Num, I1D1 extends Num, I1D2 extends Num, I1D3 extends Num> Binary inputs( - Tensor left, Tensor right) { + public , I0S extends Shape, I1T extends DataType, I1S extends Shape> Binary, Tensor> inputs( + Tensor left, Tensor right) { operation.inputs.add(left); operation.inputs.add(right); return new Binary<>(); @@ -71,49 +69,48 @@ public Nullary inputs() { public class Nullary { - public , OD0 extends Num, OD1 extends Num, OD2 extends Num, OD3 extends Num> Single outputs( - Prototype prototype) { + public , OS extends Shape> Single> outputs( + Prototype prototype) { operation.outputs.add(new Tensor<>(prototype)); return new Single<>(); } - public class Single, OD0 extends Num, OD1 extends Num, OD2 extends Num, OD3 extends Num> { + public class Single> { - public Single operation(Function function) { + public Single operation(Function function) { operation.apply = (outputs) -> af.handleStatus(() -> function.apply(outputs.getFirst().segment())); return this; } @SuppressWarnings("unchecked") - public Tensor build() { + public OT build() { af.scope().register(operation); - return (Tensor) operation.outputs.getFirst(); + return (OT) operation.outputs.getFirst(); } } } - public class Unary, I0D0 extends Num, I0D1 extends Num, I0D2 extends Num, I0D3 extends Num> { + public class Unary> { - public , OD0 extends Num, OD1 extends Num, OD2 extends Num, OD3 extends Num> None outputs() { + public None outputs() { return new None(); } - public , OD0 extends Num, OD1 extends Num, OD2 extends Num, OD3 extends Num> Single outputs( - Prototype prototype) { + public , OS extends Shape> Single> outputs( + Prototype prototype) { operation.outputs.add(new Tensor<>(prototype)); return new Single<>(); } - public , O0D0 extends Num, O0D1 extends Num, O0D2 extends Num, O0D3 extends Num, O1T extends DataType, O1D0 extends Num, O1D1 extends Num, O1D2 extends Num, O1D3 extends Num> Pair outputs( - Prototype left, Prototype right) { + public , O0S extends Shape, O1T extends DataType, O1S extends Shape> Pair, Tensor> outputs( + Prototype left, Prototype right) { operation.outputs.add(new Tensor<>(left)); operation.outputs.add(new Tensor<>(right)); return new Pair<>(); } - public , O0D0 extends Num, O0D1 extends Num, O0D2 extends Num, O0D3 extends Num, O1T extends DataType, O1D0 extends Num, O1D1 extends Num, O1D2 extends Num, O1D3 extends Num, O2T extends DataType, O2D0 extends Num, O2D1 extends Num, O2D2 extends Num, O2D3 extends Num> Trio outputs( - Prototype left, Prototype middle, - Prototype right) { + public , O0S extends Shape, O1T extends DataType, O1S extends Shape, O2T extends DataType, O2S extends Shape> Trio, Tensor, Tensor> outputs( + Prototype left, Prototype middle, Prototype right) { operation.outputs.add(new Tensor<>(left)); operation.outputs.add(new Tensor<>(middle)); operation.outputs.add(new Tensor<>(right)); @@ -133,51 +130,47 @@ public Operation build() { } } - public class Single, OD0 extends Num, OD1 extends Num, OD2 extends Num, OD3 extends Num> { + public class Single> { - public Single operation(Function function) { + public Single operation(Function function) { operation.apply = (outputs) -> af.handleStatus(() -> function.apply(outputs.getFirst().segment())); return this; } @SuppressWarnings("unchecked") - public Single grads( - GradFunction.Unary unaryGradFunction) { + public Single grads(GradFunction.Unary unaryGradFunction) { operation.grads = (grads) -> { - var inputGrad = unaryGradFunction.grads(operation.outputs.getFirst(), (Tensor) grads); + var inputGrad = unaryGradFunction.grads((OT) operation.outputs.getFirst(), (OT) grads); return List.of(inputGrad); }; return this; } @SuppressWarnings("unchecked") - public Tensor build() { + public OT build() { af.scope().register(operation); - return (Tensor) operation.outputs.getFirst(); + return (OT) operation.outputs.getFirst(); } } - public class Pair, O0D0 extends Num, O0D1 extends Num, O0D2 extends Num, O0D3 extends Num, O1T extends DataType, O1D0 extends Num, O1D1 extends Num, O1D2 extends Num, O1D3 extends Num> { + public class Pair, O1T extends Tensor> { - public Pair operation( - Functions.Function2 function) { + public Pair operation(Functions.Function2 function) { operation.apply = (outputs) -> af.handleStatus( () -> function.apply(outputs.getFirst().segment(), outputs.get(1).segment())); return this; } @SuppressWarnings("unchecked") - public TensorPair build() { + public TensorPair build() { af.scope().register(operation); - return new TensorPair<>( - (Tensor) operation.outputs.getFirst(), - (Tensor) operation.outputs.get(1)); + return new TensorPair<>((O0T) operation.outputs.getFirst(), (O1T) operation.outputs.get(1)); } } - public class Trio, O0D0 extends Num, O0D1 extends Num, O0D2 extends Num, O0D3 extends Num, O1T extends DataType, O1D0 extends Num, O1D1 extends Num, O1D2 extends Num, O1D3 extends Num, O2T extends DataType, O2D0 extends Num, O2D1 extends Num, O2D2 extends Num, O2D3 extends Num> { + public class Trio, O1T extends Tensor, O2T extends Tensor> { - public Trio operation( + public Trio operation( Functions.Function3 function) { operation.apply = (outputs) -> af.handleStatus( () -> function.apply(outputs.getFirst().segment(), outputs.get(1).segment(), @@ -186,45 +179,42 @@ public Trio build() { + public TensorTrio build() { af.scope().register(operation); - return new TensorTrio<>( - (Tensor) operation.outputs.getFirst(), - (Tensor) operation.outputs.get(1), - (Tensor) operation.outputs.get(2)); + return new TensorTrio<>((O0T) operation.outputs.getFirst(), (O1T) operation.outputs.get(1), + (O2T) operation.outputs.get(2)); } } } - public class Binary, I0D0 extends Num, I0D1 extends Num, I0D2 extends Num, I0D3 extends Num, I1T extends DataType, I1D0 extends Num, I1D1 extends Num, I1D2 extends Num, I1D3 extends Num> { + public class Binary, I1T extends Tensor> { - public , OD0 extends Num, OD1 extends Num, OD2 extends Num, OD3 extends Num> Single outputs( - Prototype prototype) { + public , OS extends Shape> Single> outputs( + Prototype prototype) { operation.outputs.add(new Tensor<>(prototype)); return new Single<>(); } - public class Single, OD0 extends Num, OD1 extends Num, OD2 extends Num, OD3 extends Num> { + public class Single> { - public Single operation(Function function) { + public Single operation(Function function) { operation.apply = (outputs) -> af.handleStatus(() -> function.apply(outputs.getFirst().segment())); return this; } @SuppressWarnings("unchecked") - public Single grads( - GradFunction.Binary binaryGradFunction) { + public Single grads(GradFunction.Binary binaryGradFunction) { operation.grads = (grads) -> { - var inputGrad = binaryGradFunction.grads(operation.outputs.getFirst(), (Tensor) grads); + var inputGrad = binaryGradFunction.grads((OT) operation.outputs.getFirst(), (OT) grads); return List.of(inputGrad.left(), inputGrad.right()); }; return this; } @SuppressWarnings("unchecked") - public Tensor build() { + public OT build() { af.scope().register(operation); - return (Tensor) operation.outputs.getFirst(); + return (OT) operation.outputs.getFirst(); } } } diff --git a/arrayfire/Optimizer.java b/arrayfire/Optimizer.java index 6d03c82..2ed7970 100644 --- a/arrayfire/Optimizer.java +++ b/arrayfire/Optimizer.java @@ -1,8 +1,6 @@ package arrayfire; -import arrayfire.numbers.Num; +public interface Optimizer, S extends Shape> { -public interface Optimizer, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> { - - public void optimize(Params params, Tensor gradients); + public void optimize(Params params, Tensor gradients); } diff --git a/arrayfire/Params.java b/arrayfire/Params.java index d51e12a..090b17d 100644 --- a/arrayfire/Params.java +++ b/arrayfire/Params.java @@ -6,16 +6,16 @@ /** * A variable with an optimizer. */ -public class Params, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> extends Variable { +public class Params, S extends Shape> extends Variable { - private final Optimizer optimizer; + private final Optimizer optimizer; - public Params(T type, Shape shape, OptimizerProvider optimizerProvider) { + public Params(T type, S shape, OptimizerProvider optimizerProvider) { super(type, shape); this.optimizer = optimizerProvider.get(); } - public void optimize(Tensor gradients) { + public void optimize(Tensor gradients) { if (optimizer == null) { throw new IllegalStateException("Attempting to optimize params but no optimizer is provided."); } diff --git a/arrayfire/Prototype.java b/arrayfire/Prototype.java index 59fb6a1..84b2a75 100644 --- a/arrayfire/Prototype.java +++ b/arrayfire/Prototype.java @@ -1,7 +1,4 @@ package arrayfire; -import arrayfire.numbers.Num; - -public record Prototype, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num>( - T type, Shape shape) { +public record Prototype, S extends Shape>(T type, S shape) { } diff --git a/arrayfire/Shape.java b/arrayfire/Shape.java index 0365f7f..61f10c1 100644 --- a/arrayfire/Shape.java +++ b/arrayfire/Shape.java @@ -7,7 +7,7 @@ import java.util.function.Function; -public record Shape, D1 extends Num, D2 extends Num, D3 extends Num>(D0 d0, D1 d1, D2 d2, D3 d3) { +public record Shape, D1 extends Num, D2 extends Num, D3 extends Num>(D0 d0, D1 d1, D2 d2, D3 d3) { public int capacity() { return d0.size() * d1.size() * d2.size() * d3.size(); diff --git a/arrayfire/SortIndexResult.java b/arrayfire/SortIndexResult.java index eef02cd..4710a50 100644 --- a/arrayfire/SortIndexResult.java +++ b/arrayfire/SortIndexResult.java @@ -1,7 +1,5 @@ package arrayfire; -import arrayfire.numbers.Num; - -public record SortIndexResult, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num>( - Tensor values, Tensor indices) { +public record SortIndexResult, S extends Shape>(Tensor values, + Tensor indices) { } diff --git a/arrayfire/SvdResult.java b/arrayfire/SvdResult.java index cb6affe..645332d 100644 --- a/arrayfire/SvdResult.java +++ b/arrayfire/SvdResult.java @@ -3,8 +3,6 @@ import arrayfire.numbers.Num; import arrayfire.numbers.U; -public record SvdResult, D0 extends Num, D1 extends Num>( - Tensor u, - Tensor s, - Tensor vt) { +public record SvdResult, D0 extends Num, D1 extends Num>( + Tensor> u, Tensor> s, Tensor> vt) { } diff --git a/arrayfire/Tensor.java b/arrayfire/Tensor.java index d698aba..d400bbd 100644 --- a/arrayfire/Tensor.java +++ b/arrayfire/Tensor.java @@ -8,21 +8,20 @@ import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; import java.lang.foreign.ValueLayout; -import java.util.function.Function; -public class Tensor, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> implements MemoryContainer { +public class Tensor, S extends Shape> implements MemoryContainer { // Contains a single device pointer. public static final AddressLayout LAYOUT = ValueLayout.ADDRESS; private final T type; - private final Shape shape; + private final S shape; private final MemorySegment segment; - public Tensor(Prototype prototype) { + public Tensor(Prototype prototype) { this(prototype.type(), prototype.shape()); } - Tensor(T type, Shape shape) { + Tensor(T type, S shape) { this.type = type; this.shape = shape; this.segment = Arena.ofAuto().allocate(LAYOUT); @@ -39,31 +38,16 @@ public MemorySegment dereference() { return segment.get(LAYOUT, 0L); } - public D0 d0() { - return shape.d0(); - } - - public D1 d1() { - return shape.d1(); - } - - public D2 d2() { - return shape.d2(); - } - - public D3 d3() { - return shape.d3(); - } public int capacity() { return shape.capacity(); } - public Shape shape() { + public S shape() { return shape; } - public Prototype prototype() { + public Prototype prototype() { return new Prototype<>(type, shape); } @@ -76,65 +60,42 @@ public String toString() { return "AfTensor{" + "type=" + type + ", shape=" + shape + '}'; } - public Tensor transpose() { - return af.transpose(this); - } - - public > Tensor castshape(Function d0) { - return af.castshape(this, d0); - } - - public , OD1 extends Num> Tensor castshape(Function d0, - Function d1) { - return af.castshape(this, d0, d1); - } - - public , OD1 extends Num, OD2 extends Num> Tensor castshape( - Function d0, Function d1, Function d2) { - return af.castshape(this, d0, d1, d2); - } - public , OD1 extends Num, OD2 extends Num, OD3 extends Num> Tensor castshape( - Function d0, Function d1, Function d2, Function d3) { - return af.castshape(this, d0, d1, d2, d3); - } - - public Tensor reshape(int d0) { + public Tensor> reshape(int d0) { return af.reshape(this, af.shape(d0)); } - public Tensor reshape(int d0, int d1) { + public Tensor> reshape(int d0, int d1) { return af.reshape(this, af.shape(d0, d1)); } - public Tensor reshape(int d0, int d1, int d2) { + public Tensor> reshape(int d0, int d1, int d2) { return af.reshape(this, af.shape(d0, d1, d2)); } - public Tensor reshape(int d0, int d1, int d2, int d3) { + public Tensor> reshape(int d0, int d1, int d2, int d3) { return af.reshape(this, af.shape(d0, d1, d2, d3)); } - public > Tensor reshape(OD0 d0) { + public > Tensor> reshape(OD0 d0) { return af.reshape(this, af.shape(d0)); } - public , OD1 extends Num> Tensor reshape(OD0 d0, OD1 d1) { + public , OD1 extends Num> Tensor> reshape(OD0 d0, OD1 d1) { return af.reshape(this, af.shape(d0, d1)); } - public , OD1 extends Num, OD2 extends Num> Tensor reshape(OD0 d0, - OD1 d1, - OD2 d2) { + public , OD1 extends Num, OD2 extends Num> Tensor> reshape( + OD0 d0, OD1 d1, OD2 d2) { return af.reshape(this, af.shape(d0, d1, d2)); } - public , OD1 extends Num, OD2 extends Num, OD3 extends Num> Tensor reshape( + public , OD1 extends Num, OD2 extends Num, OD3 extends Num> Tensor> reshape( OD0 d0, OD1 d1, OD2 d2, OD3 d3) { return af.reshape(this, af.shape(d0, d1, d2, d3)); } - public , OD1 extends Num, OD2 extends Num, OD3 extends Num> Tensor reshape( + public , OD1 extends Num, OD2 extends Num, OD3 extends Num> Tensor> reshape( Shape newShape) { return af.reshape(this, newShape); } @@ -143,115 +104,77 @@ public void release() { af.release(this); } - Tensor retain() { + Tensor retain() { return af.retain(this); } - public Tensor eval() { + public Tensor eval() { return af.eval(this); } - public Tensor mean() { - return af.mean(this); - } - - public Tensor mean(arrayfire.D0 dim) { - return af.mean(this, dim); - } - - public Tensor mean(arrayfire.D1 dim) { - return af.mean(this, dim); - } - - public Tensor median() { - return af.median(this); - } - - public Tensor max() { - return af.max(this); - } - - public Tensor max(arrayfire.D1 dim) { - return af.max(this, dim); - } - - public Tensor min() { - return af.min(this); - } - - public Tensor clamp(Tensor lo, Tensor hi) { + public Tensor clamp(Tensor lo, Tensor hi) { return af.clamp(this, lo, hi); } - public Tensor relu() { + public Tensor relu() { return af.relu(this); } - public Tensor negate() { + public Tensor negate() { return af.negate(this); } - public Tensor exp() { + public Tensor exp() { return af.exp(this); } - public Tensor abs() { + public Tensor abs() { return af.abs(this); } - public Tensor sqrt() { + public Tensor sqrt() { return af.sqrt(this); } - public Tensor sigmoid() { + public Tensor sigmoid() { return af.sigmoid(this); } - public Tensor sparse(Storage storage) { + public Tensor sparse(Storage storage) { return af.sparse(this, storage); } - public Tileable tile() { + public Tileable tile() { return new Tileable<>(this); } - public , OD1 extends Num, OD2 extends Num, OD3 extends Num> Tensor tileAs( - Tensor newShapeTensor) { + public > Tensor tileAs(Tensor newShapeTensor) { return af.tileAs(this, newShapeTensor.shape()); } - public , OD1 extends Num, OD2 extends Num, OD3 extends Num> Tensor tileAs( - Shape newShape) { + public > Tensor tileAs(NS newShape) { return af.tileAs(this, newShape); } - public Tensor flatten() { + public Tensor> flatten() { return af.flatten(this); } - public Tensor flip() { + public Tensor flip() { return af.flip(this); } - public Tensor move(Scope scope) { + public Tensor move(Scope scope) { Scope.move(this, scope); return this; } - public > Tensor cast(TN t) { + public > Tensor cast(TN t) { return af.cast(this, t); } - /** - * Normalize by dividing by the L2 norm. - */ - - public Tensor center() { - return af.center(this); - } - @Override public void dispose() { release(); diff --git a/arrayfire/TensorPair.java b/arrayfire/TensorPair.java index 2135819..2c2d81b 100644 --- a/arrayfire/TensorPair.java +++ b/arrayfire/TensorPair.java @@ -1,10 +1,5 @@ package arrayfire; -import arrayfire.DataType; -import arrayfire.Tensor; -import arrayfire.numbers.Num; - -public record TensorPair, I0D0 extends Num, I0D1 extends Num, I0D2 extends Num, I0D3 extends Num, I1T extends DataType, I1D0 extends Num, I1D1 extends Num, I1D2 extends Num, I1D3 extends Num>( - Tensor left, Tensor right) { +public record TensorPair, R extends Tensor>(L left, R right) { } diff --git a/arrayfire/TensorTrio.java b/arrayfire/TensorTrio.java index 80fab9b..6104de0 100644 --- a/arrayfire/TensorTrio.java +++ b/arrayfire/TensorTrio.java @@ -1,11 +1,6 @@ package arrayfire; -import arrayfire.DataType; -import arrayfire.Tensor; -import arrayfire.numbers.Num; - -public record TensorTrio, I0D0 extends Num, I0D1 extends Num, I0D2 extends Num, I0D3 extends Num, I1T extends DataType, I1D0 extends Num, I1D1 extends Num, I1D2 extends Num, I1D3 extends Num, I2T extends DataType, I2D0 extends Num, I2D1 extends Num, I2D2 extends Num, I2D3 extends Num>( - Tensor left, Tensor middle, - Tensor right) { +public record TensorTrio, T2 extends Tensor, T3 extends Tensor>(T1 left, T2 middle, + T3 right) { } diff --git a/arrayfire/Tileable.java b/arrayfire/Tileable.java index b350512..6e489fa 100644 --- a/arrayfire/Tileable.java +++ b/arrayfire/Tileable.java @@ -1,7 +1,4 @@ package arrayfire; -import arrayfire.numbers.Num; - -public record Tileable, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num>( - Tensor tensor) { +public record Tileable, S extends Shape>(Tensor tensor) { } diff --git a/arrayfire/TopKResult.java b/arrayfire/TopKResult.java index 9926139..aa4b009 100644 --- a/arrayfire/TopKResult.java +++ b/arrayfire/TopKResult.java @@ -1,7 +1,5 @@ package arrayfire; -import arrayfire.numbers.Num; - -public record TopKResult, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num>( - Tensor values, Tensor indices) { +public record TopKResult, S extends Shape>(Tensor values, + Tensor indices) { } diff --git a/arrayfire/Variable.java b/arrayfire/Variable.java index b4924be..92f85ae 100644 --- a/arrayfire/Variable.java +++ b/arrayfire/Variable.java @@ -1,17 +1,15 @@ package arrayfire; -import arrayfire.numbers.Num; - /** * A variable with an optimizer. */ -public class Variable, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> extends Tensor { +public class Variable, S extends Shape> extends Tensor { - public Variable(T type, Shape shape) { + public Variable(T type, S shape) { super(type, shape); } - public void set(Tensor tensor) { + public void set(Tensor tensor) { af.set(this, tensor); } } diff --git a/arrayfire/optimizers/OptimizerProvider.java b/arrayfire/optimizers/OptimizerProvider.java index 993d91b..7e6e00f 100644 --- a/arrayfire/optimizers/OptimizerProvider.java +++ b/arrayfire/optimizers/OptimizerProvider.java @@ -1,9 +1,9 @@ package arrayfire.optimizers; -import arrayfire.Optimizer; import arrayfire.DataType; -import arrayfire.numbers.Num; +import arrayfire.Optimizer; +import arrayfire.Shape; public interface OptimizerProvider { - , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Optimizer get(); + , S extends Shape> Optimizer get(); } diff --git a/arrayfire/optimizers/SGD.java b/arrayfire/optimizers/SGD.java index 5c34635..128d7cd 100644 --- a/arrayfire/optimizers/SGD.java +++ b/arrayfire/optimizers/SGD.java @@ -1,10 +1,6 @@ package arrayfire.optimizers; -import arrayfire.Optimizer; -import arrayfire.Params; -import arrayfire.Tensor; -import arrayfire.af; -import arrayfire.DataType; +import arrayfire.*; import arrayfire.numbers.Num; public class SGD implements OptimizerProvider { @@ -20,14 +16,14 @@ public SGD learningRate(double learningRate) { return this; } - public , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Optimizer get() { + public , S extends Shape> Optimizer get() { return new SGDOptimizer<>(); } - public class SGDOptimizer, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> implements Optimizer { + public class SGDOptimizer, S extends Shape> implements Optimizer { @Override - public void optimize(Params params, Tensor gradients) { + public void optimize(Params params, Tensor gradients) { params.set(af.sub(params, af.mul(gradients, learningRate))); } }