Skip to content

Commit

Permalink
First pass at simplified shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
lewish committed Jan 30, 2024
1 parent def9ae5 commit 78db370
Show file tree
Hide file tree
Showing 21 changed files with 485 additions and 613 deletions.
728 changes: 362 additions & 366 deletions arrayfire/ArrayFire.java

Large diffs are not rendered by default.

18 changes: 9 additions & 9 deletions arrayfire/ArrayFireTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ public void sortIndex() {
public void permutationIndex() {
af.tidy(() -> {
var arr = af.create(1, 2, 3, 4, 5, 6, 7, 8).reshape(2, 4);
var permutation = af.permutation(arr.d1());
var permutation = af.permutation(arr.shape().d1());
var shuffled = af.index(arr, af.span(), permutation);
var data = af.data(shuffled);
assertArrayEquals(new int[]{5, 6, 1, 2, 7, 8, 3, 4}, data.java());
Expand All @@ -143,7 +143,7 @@ public void permutationIndex() {
public void transpose() {
af.tidy(() -> {
var arr = af.create(new float[]{1, 2, 3, 4}).reshape(2, 2);
var transpose = arr.transpose();
var transpose = af.transpose(arr);
assertArrayEquals(new float[]{1, 3, 2, 4}, af.data(transpose).java(), 1E-5f);
});
}
Expand Down Expand Up @@ -172,7 +172,7 @@ public void matmul() {
af.tidy(() -> {
var left = af.create(new float[]{1, 2, 3, 4}).reshape(a(2), b(2));
var right = af.create(new float[]{1, 2, 3, 4, 5, 6}).reshape(a(2), c(3));
var result = af.matmul(left.transpose(), right);
var result = af.matmul(af.transpose(left), right);
assertArrayEquals(new float[]{5, 11, 11, 25, 17, 39}, data(result).java(), 1E-5f);
});
}
Expand All @@ -182,7 +182,7 @@ public void matmulS32() {
af.tidy(() -> {
var left = af.create(new float[]{1, 2, 3, 4}).reshape(a(2), b(2));
var right = af.create(new float[]{1, 2, 3, 4, 5, 6}).reshape(a(2), c(3));
var result = af.matmul(left.transpose(), right);
var result = af.matmul(af.transpose(left), right);
assertArrayEquals(new float[]{5, 11, 11, 25, 17, 39}, data(result).java(), 1E-5f);
});
}
Expand Down Expand Up @@ -292,7 +292,7 @@ public void mulScalar() {
public void min() {
af.tidy(() -> {
var data = af.create(new float[]{-5, 12, 0, 1});
var result = data.min();
var result = af.min(data);
assertArrayEquals(new float[]{-5}, af.data(result).java(), 1e-5f);
});
}
Expand Down Expand Up @@ -384,9 +384,9 @@ public void index4D() {
var data = af
.create(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16})
.reshape(2, 2, 2, 2);
Tensor<F32, A, B, C, D> result = af.index(data, af.seq(af.create(0).castshape(af::a)),
af.seq(af.create(1).castshape(af::b)), af.seq(af.create(0).castshape(af::c)),
af.seq(af.create(1).castshape(af::d)));
Tensor<F32, Shape<A, B, C, D>> result = af.index(data, af.seq(af.create(0).reshape(af.a(1))),
af.seq(af.create(1).reshape(af.b(1))), af.seq(af.create(0).reshape(af.c(1))),
af.seq(af.create(1).reshape(af.d(1))));
assertArrayEquals(new float[]{11}, af.data(result).java(), 1E-5f);
});
}
Expand Down Expand Up @@ -554,7 +554,7 @@ public void graph() {
af.tidy(() -> {
var left = af.create(new float[]{1, 2, 3, 4}).reshape(a(2), b(2));
var right = af.create(new float[]{1, 2, 3, 4, 5, 6}).reshape(a(2), c(3));
var leftT = left.transpose();
var leftT = af.transpose(left);
var matmul = af.matmul(leftT, right);
var softmax = af.softmax(matmul);
var sum = af.sum(matmul);
Expand Down
15 changes: 5 additions & 10 deletions arrayfire/GradFunction.java
Original file line number Diff line number Diff line change
@@ -1,24 +1,19 @@
package arrayfire;

import arrayfire.DataType;
import arrayfire.Tensor;
import arrayfire.TensorPair;
import arrayfire.numbers.Num;

import java.util.List;

@FunctionalInterface
interface GradFunction {

List<Tensor<?, ?, ?, ?, ?>> grads(Tensor<?, ?, ?, ?, ?> resultGrads);
List<Tensor<?, ?>> grads(Tensor<?, ?> resultGrads);

interface Unary<RT extends DataType<?, ?>, RD0 extends Num<?>, RD1 extends Num<?>, RD2 extends Num<?>, RD3 extends Num<?>, I0T extends DataType<?, ?>, I0D0 extends Num<?>, I0D1 extends Num<?>, I0D2 extends Num<?>, I0D3 extends Num<?>> {
Tensor<I0T, I0D0, I0D1, I0D2, I0D3> grads(Tensor<RT, RD0, RD1, RD2, RD3> result,
Tensor<RT, RD0, RD1, RD2, RD3> grads);
interface Unary<RT extends Tensor<?, ?>, IT extends Tensor<?, ?>> {
IT grads(RT result, RT grads);
}

interface Binary<RT extends DataType<?, ?>, RD0 extends Num<?>, RD1 extends Num<?>, RD2 extends Num<?>, RD3 extends Num<?>, I0T extends DataType<?, ?>, I0D0 extends Num<?>, I0D1 extends Num<?>, I0D2 extends Num<?>, I0D3 extends Num<?>, I1T extends DataType<?, ?>, I1D0 extends Num<?>, I1D1 extends Num<?>, I1D2 extends Num<?>, I1D3 extends Num<?>> {
TensorPair<I0T, I0D0, I0D1, I0D2, I0D3, I1T, I1D0, I1D1, I1D2, I1D3> grads(
Tensor<RT, RD0, RD1, RD2, RD3> result, Tensor<RT, RD0, RD1, RD2, RD3> grads);
interface Binary<RT extends Tensor<?, ?>, I0T extends Tensor<?, ?>, I1T extends Tensor<?, ?>> {
TensorPair<I0T, I1T> grads(RT result, RT grads);
}
}
4 changes: 2 additions & 2 deletions arrayfire/Graph.java
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public void optimize(Tensor loss) {
}
}

public <T extends Tensor<?, ?, ?, ?, ?>> T grads(Tensor loss, T tensor) {
public <T extends Tensor<?, ?>> T grads(Tensor loss, T tensor) {
var grads = grads(loss, new Tensor[]{tensor});
return grads.get(tensor);
}
Expand Down Expand Up @@ -153,7 +153,7 @@ void put(Tensor tensor, Tensor grads) {
}

@SuppressWarnings("unchecked")
public <T extends Tensor<?, ?, ?, ?, ?>> T get(T tensor) {
public <T extends Tensor<?, ?>> T get(T tensor) {
return (T) gradsByTensor.get(tensor);
}
}
Expand Down
6 changes: 2 additions & 4 deletions arrayfire/ImaxResult.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package arrayfire;

import arrayfire.numbers.Num;

public record ImaxResult<T extends DataType<?, ?>, D0 extends Num<?>, D1 extends Num<?>, D2 extends Num<?>, D3 extends Num<?>>(
Tensor<T, D0, D1, D2, D3> values, Tensor<U32, D0, D1, D2, D3> indices) {
public record ImaxResult<T extends DataType<?, ?>, S extends Shape<?, ?, ?, ?>>(Tensor<T, S> values,
Tensor<U32, S> indices) {
}
4 changes: 2 additions & 2 deletions arrayfire/Index.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ public class Index<D extends Num<D>> {
ValueLayout.JAVA_BOOLEAN.withName("isSeq"), ValueLayout.JAVA_BOOLEAN.withName("isBatch"),
MemoryLayout.paddingLayout(6));

private final Tensor<?, ?, ?, ?, ?> arr;
private final Tensor<?, ?> arr;
private final Seq seq;

private final Function<Integer, D> generator;

Index(Tensor<?, ?, ?, ?, ?> arr, Function<Integer, D> generator) {
Index(Tensor<?, ?> arr, Function<Integer, D> generator) {
this.arr = arr;
this.seq = null;
this.generator = generator;
Expand Down
Loading

0 comments on commit 78db370

Please sign in to comment.