diff --git a/arrayfire/ArrayFire.java b/arrayfire/ArrayFire.java
index bbb9474..bae2b31 100644
--- a/arrayfire/ArrayFire.java
+++ b/arrayfire/ArrayFire.java
@@ -99,24 +99,24 @@ public static Scope scope() {
/**
* Sorts a tensor over D0.
*/
- public static
, D0 extends Num>, D1 extends Num>, D2 extends Num>, D3 extends Num>> Tensor sort(
- Tensor tensor) {
+ public static , S extends Shape, ?, ?, ?>> Tensor sort(Tensor tensor) {
return sort(tensor, D0);
}
/**
* Sorts a tensor over the given dimension.
*/
- public static , D0 extends Num>, D1 extends Num>, D2 extends Num>, D3 extends Num>> Tensor sort(
- Tensor tensor, Dim dim) {
+ public static , S extends Shape, ?, ?, ?>> Tensor sort(Tensor tensor,
+ Dim dim) {
return sort(tensor, dim, true);
}
/**
* Sorts a tensor over the given dimension in ascending or descending order.
*/
- public static , D0 extends Num>, D1 extends Num>, D2 extends Num>, D3 extends Num>> Tensor sort(
- Tensor tensor, Dim dim, boolean ascending) {
+ public static , S extends Shape, ?, ?, ?>> Tensor sort(Tensor tensor,
+ Dim dim,
+ boolean ascending) {
return operation("sort")
.inputs(tensor)
.outputs(tensor.prototype())
@@ -127,40 +127,39 @@ public static Scope scope() {
/**
* Returns a prototype tensor with the given type and shape.
*/
- public static , D0 extends Num>, D1 extends Num>, D2 extends Num>, D3 extends Num>> Prototype prototype(
- T type, Shape shape) {
+ public static , S extends Shape, ?, ?, ?>> Prototype prototype(T type, S shape) {
return new Prototype<>(type, shape);
}
/**
* Returns a prototype tensor with the same type and shape as the given tensor.
*/
- public static , D0 extends Num>, D1 extends Num>, D2 extends Num>, D3 extends Num>> Prototype prototype(
- Tensor tensor) {
+ public static , S extends Shape, ?, ?, ?>> Prototype prototype(
+ Tensor tensor) {
return new Prototype<>(tensor.type(), tensor.shape());
}
/**
* Sorts a tensor over D0 and returns the values and indices of original values.
*/
- public static , D0 extends Num>, D1 extends Num>, D2 extends Num>, D3 extends Num>> SortIndexResult sortIndex(
- Tensor tensor) {
+ public static , S extends Shape, ?, ?, ?>> SortIndexResult sortIndex(
+ Tensor tensor) {
return sortIndex(tensor, D0);
}
/**
* Sorts a tensor over the given dimension and returns the values and indices of original values.
*/
- public static , D0 extends Num>, D1 extends Num>, D2 extends Num>, D3 extends Num>> SortIndexResult sortIndex(
- Tensor tensor, Dim dim) {
+ public static , S extends Shape, ?, ?, ?>> SortIndexResult sortIndex(
+ Tensor tensor, Dim dim) {
return sortIndex(tensor, dim, true);
}
/**
* Sorts a tensor over the given dimension in ascending or descending order and returns the values and indices of original values.
*/
- public static , D0 extends Num>, D1 extends Num>, D2 extends Num>, D3 extends Num>> SortIndexResult sortIndex(
- Tensor tensor, Dim dim, boolean ascending) {
+ public static , S extends Shape, ?, ?, ?>> SortIndexResult sortIndex(
+ Tensor tensor, Dim dim, boolean ascending) {
var pair = operation("sort_index")
.inputs(tensor)
.outputs(prototype(tensor.type(), tensor.shape()), prototype(U32, tensor.shape()))
@@ -182,7 +181,7 @@ public static > Index permutation(D dim) {
/**
* Creates a device tensor from the given native array.
*/
- public static , AT extends NativeArray> Tensor create(
+ public static , AT extends NativeArray> Tensor> create(
AT array) {
return create(array, shape(n(array.length())));
}
@@ -190,8 +189,8 @@ public static > Index permutation(D dim) {
/**
* Creates a device tensor from the given native array and shape.
*/
- public static , AT extends NativeArray, D0 extends Num>, D1 extends Num>, D2 extends Num>, D3 extends Num>> Tensor create(
- AT array, Shape shape) {
+ public static , AT extends NativeArray, S extends Shape, ?, ?, ?>> Tensor create(
+ AT array, S shape) {
return operation("create")
.inputs()
.outputs(prototype(array.type(), shape))
@@ -206,7 +205,7 @@ public static > Index permutation(D dim) {
* This is not recommended in a production setting, as memory will be copied twice. Instead, use {@link #create(NativeArray)}.
*/
@SafeVarargs
- public static , DT extends DataType> Tensor create(
+ public static , DT extends DataType> Tensor> create(
DT type, JT... values) {
return tidy(() -> {
var array = type.create(values.length);
@@ -222,7 +221,7 @@ public static > Index permutation(D dim) {
* This is not recommended in a production setting, as memory will be copied twice. Instead, use {@link #create(NativeArray)}.
*/
@SuppressWarnings("unchecked")
- public static , DT extends DataType> Tensor create(
+ public static , DT extends DataType> Tensor> create(
DT type, JTA values) {
return tidy(() -> {
var length = Array.getLength(values);
@@ -237,50 +236,50 @@ public static > Index permutation(D dim) {
/**
* Creates a {@link F32} device tensor from the given float values.
*/
- public static Tensor create(float... values) {
+ public static Tensor> create(float... values) {
return create(F32, values);
}
/**
* Creates a {@link F64} device tensor from the given double values.
*/
- public static Tensor create(double... values) {
+ public static Tensor> create(double... values) {
return create(F64, values);
}
/**
* Creates a {@link S32} device tensor from the given byte values.
*/
- public static Tensor create(int... values) {
+ public static Tensor> create(int... values) {
return create(S32, values);
}
/**
* Creates a constant scalar {@link F32} device tensor from the given float value.
*/
- public static Tensor constant(float value) {
+ public static Tensor> constant(float value) {
return constant(F32, value);
}
/**
* Creates a constant scalar {@link F64} device tensor from the given float value.
*/
- public static Tensor constant(double value) {
+ public static Tensor> constant(double value) {
return constant(F64, value);
}
/**
* Creates a constant scalar device tensor from the given type and double value.
*/
- public static > Tensor constant(DT type, double value) {
+ public static > Tensor> constant(DT type, double value) {
return constant(type, shape(u()), value);
}
/**
* Creates a constant device tensor from the given type, shape, and double value.
*/
- public static , D0 extends Num>, D1 extends Num>, D2 extends Num>, D3 extends Num>> Tensor constant(
- DT type, Shape shape, double value) {
+ public static , S extends Shape, ?, ?, ?>> Tensor constant(DT type, S shape,
+ double value) {
return operation("constant")
.inputs()
.outputs(prototype(type, shape))
@@ -314,8 +313,8 @@ public static Index seq(int begin, int endInclusive) {
/**
* Returns a lookup index using the given tensor as lookup values (indices).
*/
- public static , D0 extends Num> Index seq(Tensor index) {
- return new Index<>(index, index.d0()::create);
+ public static , D0 extends Num> Index seq(Tensor> index) {
+ return new Index<>(index, index.shape().d0()::create);
}
/**
@@ -349,15 +348,15 @@ public static Shape shape(int d0) {
/**
* Returns a 1D shape of the given dimension.
*/
- public static > Shape shape(D0 d0) {
+ public static > Shape shape(D0 d0) {
return new Shape<>(d0, u(), u(), u());
}
- public static > Shape shape(D0 d0, int d1) {
+ public static > Shape shape(D0 d0, int d1) {
return new Shape<>(d0, n(d1), u(), u());
}
- public static > Shape shape(int d0, D1 d1) {
+ public static > Shape shape(int d0, D1 d1) {
return new Shape<>(n(d0), d1, u(), u());
}
@@ -365,7 +364,7 @@ public static Shape shape(int d0, int d1) {
return new Shape<>(n(d0), n(d1), u(), u());
}
- public static , D1 extends Num>> Shape shape(D0 d0, D1 d1) {
+ public static , D1 extends Num> Shape shape(D0 d0, D1 d1) {
return new Shape<>(d0, d1, u(), u());
}
@@ -377,50 +376,50 @@ public static Shape shape(int d0, int d1, int d2, int d3) {
return new Shape<>(n(d0), n(d1), n(d2), n(d3));
}
- public static , D1 extends Num>, D2 extends Num>> Shape shape(D0 d0, D1 d1,
- D2 d2) {
+ public static , D1 extends Num, D2 extends Num> Shape shape(D0 d0, D1 d1,
+ D2 d2) {
return new Shape<>(d0, d1, d2, u());
}
- public static , D1 extends Num>, D2 extends Num>, D3 extends Num>> Shape shape(
+ public static , D1 extends Num, D2 extends Num, D3 extends Num> Shape shape(
D0 d0, D1 d1, D2 d2, D3 d3) {
return new Shape<>(d0, d1, d2, d3);
}
- private static , IT extends DataType, ?>, D0 extends Num>, D1 extends Num>, D2 extends Num>, D3 extends Num>> Operation.Builder.Unary.Single reduce(
- String name, Tensor a,
+ private static , IT extends DataType, ?>, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Operation.Builder.Unary>>.Single>> reduce(
+ String name, Tensor> a,
Functions.Function3 method, arrayfire.D0 dim, T resultType) {
return operation(name)
.inputs(a)
- .outputs(prototype(resultType, shape(u(), a.d1(), a.d2(), a.d3())))
+ .outputs(prototype(resultType, shape(u(), a.shape().d1(), a.shape().d2(), a.shape().d3())))
.operation(ptr -> method.apply(ptr, a.dereference(), dim.index()));
}
- private static , IT extends DataType, ?>, D0 extends Num>, D1 extends Num>, D2 extends Num>, D3 extends Num>> Operation.Builder.Unary.Single reduce(
- String name, Tensor a,
+ private static , IT extends DataType, ?>, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Operation.Builder.Unary>>.Single>> reduce(
+ String name, Tensor> a,
Functions.Function3 method, arrayfire.D1 dim, T resultType) {
return operation(name)
.inputs(a)
- .outputs(prototype(resultType, shape(a.d0(), u(), a.d2(), a.d3())))
+ .outputs(prototype(resultType, shape(a.shape().d0(), u(), a.shape().d2(), a.shape().d3())))
.operation(ptr -> method.apply(ptr, a.dereference(), dim.index()));
}
- private static , IT extends DataType, ?>, D0 extends Num>, D1 extends Num>, D2 extends Num>, D3 extends Num>> Operation.Builder.Unary.Single reduce(
- String name, Tensor a,
+ private static , IT extends DataType, ?>, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Operation.Builder.Unary>>.Single>> reduce(
+ String name, Tensor> a,
Functions.Function3 method, arrayfire.D2 dim, T resultType) {
return operation(name)
.inputs(a)
- .outputs(prototype(resultType, shape(a.d0(), a.d1(), u(), a.d3())))
+ .outputs(prototype(resultType, shape(a.shape().d0(), a.shape().d1(), u(), a.shape().d3())))
.operation(ptr -> method.apply(ptr, a.dereference(), dim.index()));
}
- private static , IT extends DataType, ?>, D0 extends Num>, D1 extends Num>, D2 extends Num>, D3 extends Num>> Operation.Builder.Unary.Single reduce(
- String name, Tensor a,
+ private static , IT extends DataType, ?>, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Operation.Builder.Unary>>.Single>> reduce(
+ String name, Tensor> a,
Functions.Function3 method, arrayfire.D3 dim, T resultType) {
return operation(name)
.inputs(a)
- .outputs(prototype(resultType, shape(a.d0(), a.d1(), a.d2(), u())))
+ .outputs(prototype(resultType, shape(a.shape().d0(), a.shape().d1(), a.shape().d2(), u())))
.operation(ptr -> method.apply(ptr, a.dereference(), dim.index()));
}
@@ -428,48 +427,44 @@ public static , D1 extends Num>, D2 extends Num>, D3 exten
* Cast the given tensor to the given type.
*/
@SuppressWarnings("unchecked")
- public static , OT extends DataType, ?>, D0 extends Num>, D1 extends Num>, D2 extends Num>, D3 extends Num>> Tensor cast(
- Tensor input, OT type) {
+ public static , OT extends DataType, ?>, S extends Shape, ?, ?, ?>> Tensor cast(
+ Tensor input, OT type) {
if (input.type().equals(type)) {
- return (Tensor) input;
+ return (Tensor) input;
}
return operation("cast")
.inputs(input)
.outputs(prototype(type, input.shape()))
.operation(ptr -> arrayfire_h.af_cast(ptr, input.dereference(), type.code()))
- .grads((result, grads) -> grads.cast(input.type()))
+ .grads((result, grads) -> cast(grads, input.type()))
.build();
}
/**
* Returns a tensor of value 1 with the same type and shape as the given tensor.
*/
- public static , D0 extends Num>, D1 extends Num>, D2 extends Num>, D3 extends Num>> Tensor ones(
- Tensor model) {
+ public static , S extends Shape, ?, ?, ?>> Tensor ones(Tensor model) {
return ones(model.type(), model.shape());
}
/**
* Returns a tensor of value 1 with the given type and shape.
*/
- public static , D0 extends Num>, D1 extends Num>, D2 extends Num>, D3 extends Num>> Tensor ones(
- T type, Shape shape) {
+ public static , S extends Shape, ?, ?, ?>> Tensor ones(T type, S shape) {
return constant(type, 1).tileAs(shape);
}
/**
* Returns a tensor of value 0 with the given type and shape.
*/
- public static , D0 extends Num>, D1 extends Num>, D2 extends Num>, D3 extends Num>> Tensor zeros(
- T type, Shape shape) {
+ public static , S extends Shape, ?, ?, ?>> Tensor zeros(T type, S shape) {
return constant(type, 0).tileAs(shape);
}
/**
* Create a random tensor sampled from uniform distribution between [0, 1].
*/
- public static , D0 extends Num>, D1 extends Num>, D2 extends Num>, D3 extends Num>> Tensor randu(
- T type, Shape shape) {
+ public static , S extends Shape, ?, ?, ?>> Tensor randu(T type, S shape) {
return operation("randu")
.inputs()
.outputs(prototype(type, shape))
@@ -480,8 +475,7 @@ public static , D1 extends Num>, D2 extends Num>, D3 exten
/**
* Create a random tensor sampled from a normal distribution with mean 0.
*/
- public static , D0 extends Num>, D1 extends Num>, D2 extends Num>, D3 extends Num>> Tensor randn(
- T type, Shape shape) {
+ public static , S extends Shape, ?, ?, ?>> Tensor randn(T type, S shape) {
return operation("randn")
.inputs()
.outputs(prototype(type, shape))
@@ -492,14 +486,14 @@ public static , D1 extends Num>, D2 extends Num>, D3 exten
/**
* Create a tensor with values [0, n-1].
*/
- public static Tensor range(int n) {
+ public static Tensor> range(int n) {
return range(U32, n);
}
/**
* Create a tensor with values [0, n-1] of the given type.
*/
- public static > Tensor range(T type, int n) {
+ public static > Tensor> range(T type, int n) {
var shape = shape(n(n));
return operation("range")
.inputs()
@@ -525,13 +519,13 @@ public static void setRandomEngineType(RandomEngineType type) {
/**
* Pull data from the device to the host, returning a native array.
*/
- public static , T extends DataType> AT data(Tensor a) {
+ public static , T extends DataType> AT data(Tensor a) {
var result = a.type().create(a.capacity());
handleStatus(() -> arrayfire_h.af_get_data_ptr(result.segment(), a.dereference()));
return result;
}
- private static void checkDims(Tensor, ?, ?, ?, ?> tensor) {
+ private static void checkDims(Tensor, ?> tensor) {
try (Arena arena = Arena.ofConfined()) {
var dims = arena.allocateArray(ValueLayout.JAVA_LONG, 4);
handleStatus(
@@ -641,11 +635,12 @@ private static MemorySegment nativeDims(Shape, ?, ?, ?> shape) {
/**
* Transpose D0 and D1 dimensions of the given tensor.
*/
- public static , D0 extends Num>, D1 extends Num>, D2 extends Num>, D3 extends Num>> Tensor transpose(
- Tensor tensor) {
+ public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num> Tensor> transpose(
+ Tensor