Skip to content

Commit

Permalink
Working SVD reconstruction, more blog post examples
Browse files Browse the repository at this point in the history
  • Loading branch information
lewish committed Dec 15, 2023
1 parent 27ca51b commit e1ab972
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 47 deletions.
56 changes: 50 additions & 6 deletions arrayfire/ArrayFire.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@
import java.lang.foreign.ValueLayout;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.IntStream;

public class ArrayFire {

public static final U8 U8 = new U8();
public static final U64 U64 = new U64();
public static final U32 U32 = new U32();
public static final F32 F32 = new F32();
Expand Down Expand Up @@ -388,6 +391,18 @@ public static void setRandomEngineType(RandomEngineType type) {
return result;
}

public static void checkDims(Tensor<?, ?, ?, ?, ?> tensor) {
var trueDims = getDims(tensor.segment());
var expectedDims = tensor.shape().dims();
for (int i = 0; i < trueDims.length; i++) {
if (trueDims[i] != expectedDims[i]) {
throw new IllegalStateException(
String.format("Expected dimensions %s but got %s", Arrays.toString(expectedDims),
Arrays.toString(trueDims)));
}
}
}

public static long[] getDims(MemorySegment a) {
try (Arena arena = Arena.ofConfined()) {
var dims = arena.allocateArray(ValueLayout.JAVA_LONG, 4);
Expand Down Expand Up @@ -804,6 +819,18 @@ private static void checkTileableIsSmaller(Tensor<?, ?, ?, ?, ?> left, Tileable<
ptr -> arrayfire_h.af_matmul(ptr, tensor.dereference(), rhs.dereference(), 0, 0));
}

public static <T extends DataType<?, ?>, AD0 extends IntNumber<?>, AD1 extends IntNumber<?>, BD1 extends IntNumber<?>, CD1 extends IntNumber<?>, D2 extends IntNumber<?>, D3 extends IntNumber<?>> Tensor<T, AD0, CD1, D2, D3> matmul(
Tensor<T, AD0, AD1, D2, D3> a, Tensor<T, AD1, BD1, D2, D3> b, Tensor<T, BD1, CD1, D2, D3> c) {
return tidy(() -> {
// Determine the optimal order of operations.
if (a.d0().size() * b.d1().size() < b.d0().size() * c.d1().size()) {
return matmul(matmul(a, b), c);
} else {
return matmul(a, matmul(b, c));
}
});
}

public static <T extends DataType<?, ?>, D0 extends IntNumber<?>, D1 extends IntNumber<?>, D2 extends IntNumber<?>, D3 extends IntNumber<?>> Tensor<T, D0, D1, D2, D3> clamp(
Tensor<T, D0, D1, D2, D3> tensor, Tensor<T, ?, ?, ?, ?> lo, Tensor<T, ?, ?, ?, ?> hi) {
// TODO: Batch parameter.
Expand Down Expand Up @@ -1096,18 +1123,23 @@ public static <T extends DataType<?, T>, D0 extends IntNumber<?>, D1 extends Int
// svd
public static <T extends DataType<?, ?>, D0 extends IntNumber<?>, D1 extends IntNumber<?>> SvdResult<T, D0, D1> svd(
Tensor<T, D0, D1, U, U> tensor) {
var u = new Tensor<>(tensor.type(), shape(tensor.shape().d1(), tensor.shape().d1()));
var s = new Tensor<>(tensor.type(), shape(tensor.shape().d1()));
var v = new Tensor<>(tensor.type(), shape(tensor.shape().d0(), tensor.shape().d0()));
var u = new Tensor<>(tensor.type(), shape(tensor.shape().d0(), tensor.shape().d0()));
var s = new Tensor<>(tensor.type(), shape(tensor.shape().d0()));
var v = new Tensor<>(tensor.type(), shape(tensor.shape().d1(), tensor.shape().d1()));
handleStatus(() -> arrayfire_h.af_svd(u.segment(), s.segment(), v.segment(), tensor.dereference()));
checkDims(u);
checkDims(s);
checkDims(v);
return new SvdResult<>(u, s, v);
}

public static <T extends DataType<?, ?>, D0 extends IntNumber<?>, D1 extends IntNumber<?>> Tensor<T, D0, D0, U, U> cov(
Tensor<T, D0, D1, U, U> tensor) {
var subMean = sub(tensor, mean(tensor, D1).tileAs(tensor));
var matrix = matmul(subMean, subMean.transpose());
return div(matrix, constant(matrix.type(), matrix.shape(), tensor.shape().d1().size() - 1.0f));
return tidy(() -> {
var subMean = sub(tensor, mean(tensor, D1).tileAs(tensor));
var matrix = matmul(subMean, subMean.transpose());
return div(matrix, constant(matrix.type(), matrix.shape(), tensor.shape().d1().size() - 1.0f));
});
}

public static <T extends DataType<?, ?>, D0 extends IntNumber<?>, D1 extends IntNumber<?>> Tensor<T, D0, D0, U, U> zcaMatrix(
Expand Down Expand Up @@ -1181,6 +1213,18 @@ public static K k(int value) {
return new K(value);
}

public static X x(int value) {
return new X(value);
}

public static Y y(int value) {
return new Y(value);
}

public static Z z(int value) {
return new Z(value);
}

public static final U U = new U(1);

public static U u() {
Expand Down
92 changes: 60 additions & 32 deletions arrayfire/ArrayFireTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,22 @@ public void matmulS32() {
});
}

@Test
public void svd() {
af.tidy(() -> {
var a = af.a(2);
var b = af.b(3);
var matrix = af.create(F32, new float[]{1, 2, 3, 4, 5, 6}).reshape(a, b);
var svd = af.svd(matrix);
var u = svd.u(); // Tensor<F32, A, A, U, U>
var s = svd.s(); // Tensor<F32, A, U, U, U>
var vt = svd.vt(); // Tensor<F32, B, B, U, U>
// Recreate the matrix from the SVD.
var recreated = af.matmul(u, af.diag(s), af.index(vt, af.seq(a))); // Tensor<F32, A, B, U, U>
assertArrayEquals(new float[]{1, 2, 3, 4, 5, 6}, data(recreated).java(), 1E-5f);
});
}

@Test
public void mul() {
af.tidy(() -> {
Expand Down Expand Up @@ -220,15 +236,6 @@ public void mulScalar() {
});
}

// @Test
// public void sum() {
// af.tidy(() -> {
// var data = af.create(new float[]{1, 2, 3, 4}).reshape(2, 2)));
// var result = data.flatten().sum();
// Assert.assertArrayEquals(new float[]{10}, af.data(result),.toHeap() 1E-5f);
// });
// }

@Test
public void min() {
af.tidy(() -> {
Expand Down Expand Up @@ -257,14 +264,41 @@ public void imaxMatrix() {
}

@Test
public void sumMatrix() {
public void sum() {
af.tidy(() -> {
var data = af.create(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}).reshape(2, 2, 2,
2);
assertArrayEquals(new float[]{3, 7, 11, 15, 19, 23, 27, 31}, af.data(af.sum(data)).java(), 1E-5f);
assertArrayEquals(new float[]{4, 6, 12, 14, 20, 22, 28, 30}, af.data(af.sum(data, af.D1)).java(), 1E-5f);
assertArrayEquals(new float[]{6, 8, 10, 12, 22, 24, 26, 28}, af.data(af.sum(data, af.D2)).java(), 1E-5f);
assertArrayEquals(new float[]{10, 12, 14, 16, 18, 20, 22, 24}, af.data(af.sum(data, af.D3)).java(), 1E-5f);
});
}

@Test
public void sumB8() {
af.tidy(() -> {
var data = af.create(U8, new byte[]{1, 2, 3, 4}).reshape(2, 2);
var sum = af.sum(data);
assertArrayEquals(new int[]{3, 7}, af.data(sum).java());
});
}

@Test
public void mean() {
af.tidy(() -> {
var data = af.create(new float[]{1, 2, 3, 4, 5, 6, 7, 8}).reshape(4, 2);
var result = af.sum(data);
assertArrayEquals(new float[]{10, 26}, af.data(result).java(), 1E-5f);
var data = af.create(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}).reshape(2, 2, 2,
2);

assertArrayEquals(new float[]{1.5f, 3.5f, 5.5f, 7.5f, 9.5f, 11.5f, 13.5f, 15.5f},
af.data(af.mean(data)).java(), 1E-5f);
assertArrayEquals(new float[]{2, 3, 6, 7, 10, 11, 14, 15}, af.data(af.mean(data, af.D1)).java(), 1E-5f);
assertArrayEquals(new float[]{3, 4, 5, 6, 11, 12, 13, 14}, af.data(af.mean(data, af.D2)).java(), 1E-5f);
assertArrayEquals(new float[]{5, 6, 7, 8, 9, 10, 11, 12}, af.data(af.mean(data, af.D3)).java(), 1E-5f);
});
}


@Test
public void slice() {
af.tidy(() -> {
Expand Down Expand Up @@ -302,6 +336,19 @@ public void index4D() {
});
}

@Test
public void index3D() {
af.tidy(() -> {
var a = af.a(2);
var data = af.create(new float[]{1, 2, 3, 4, 5, 6, 7, 8}).reshape(a, a, a);
var result = af.index(data,
af.span(),
af.seq(0, 0),
af.seq(af.create(U32, new int[]{1})));
assertArrayEquals(new float[]{5, 6}, af.data(result).java(), 1E-5f);
});
}

@Test
public void indexSpan() {
af.tidy(() -> {
Expand Down Expand Up @@ -450,25 +497,6 @@ public void scale() {
});
}

// @Test
// public void allBackends() {
// af.tidy(() -> {
// var originalBackend = af.backend();
// try {
// for (var backend : af.availableBackends()) {
// af.setBackend(backend);
// System.out.println(af.deviceInfo());
// convolve2();
// mulBroadcast();
// matmul();
// convolve2();
// }
// } finally {
// af.setBackend(originalBackend);
// }
// });
// }

@Test
public void useAcrossScopes() {
af.tidy(() -> {
Expand Down
6 changes: 3 additions & 3 deletions arrayfire/SvdResult.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import arrayfire.numbers.U;

public record SvdResult<T extends DataType<?, ?>, D0 extends IntNumber<?>, D1 extends IntNumber<?>>(
Tensor<T, D1, D1, U, U> u,
Tensor<T, D1, U, U, U> s,
Tensor<T, D0, D0, U, U> v) {
Tensor<T, D0, D0, U, U> u,
Tensor<T, D0, U, U, U> s,
Tensor<T, D1, D1, U, U> vt) {
}
13 changes: 7 additions & 6 deletions arrayfire/Tensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.util.List;
import java.util.function.Function;

public class Tensor<T extends DataType<?, ?>, D0 extends IntNumber<?>, D1 extends IntNumber<?>, D2 extends IntNumber<?>, D3 extends IntNumber<?>> implements TensorLike<T, D0, D1, D2, D3>, MemoryContainer {
Expand Down Expand Up @@ -82,8 +81,9 @@ public <OD0 extends IntNumber<?>> Tensor<T, OD0, D1, D2, D3> castshape(Function<
return af.castshape(this, d0);
}

public <OD0 extends IntNumber<?>, OD1 extends IntNumber<?>> Tensor<T, OD0, OD1, D2, D3> castshape(Function<Integer, OD0> d0,
Function<Integer, OD1> d1) {
public <OD0 extends IntNumber<?>, OD1 extends IntNumber<?>> Tensor<T, OD0, OD1, D2, D3> castshape(
Function<Integer, OD0> d0,
Function<Integer, OD1> d1) {
return af.castshape(this, d0, d1);
}

Expand Down Expand Up @@ -123,9 +123,10 @@ public Tensor<T, N, N, N, N> reshape(int d0, int d1, int d2, int d3) {
}


public <OD0 extends IntNumber<?>, OD1 extends IntNumber<?>, OD2 extends IntNumber<?>> Tensor<T, OD0, OD1, OD2, U> reshape(OD0 d0,
OD1 d1,
OD2 d2) {
public <OD0 extends IntNumber<?>, OD1 extends IntNumber<?>, OD2 extends IntNumber<?>> Tensor<T, OD0, OD1, OD2, U> reshape(
OD0 d0,
OD1 d1,
OD2 d2) {
return af.reshape(this, af.shape(d0, d1, d2));
}

Expand Down
38 changes: 38 additions & 0 deletions arrayfire/containers/U8Array.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package arrayfire.containers;

import arrayfire.datatypes.U8;

import java.lang.foreign.ValueLayout;

import static arrayfire.ArrayFire.U8;

public class U8Array extends NativeArray<U8, Byte, byte[]> {

public U8Array(int length) {
super(U8, length);
}

@Override
public ValueLayout.OfByte layout() {
return ValueLayout.JAVA_BYTE;
}

@Override
public Byte get(int index) {
return segment.getAtIndex(layout(), index);
}

@Override
public void set(int index, Byte value) {
segment.setAtIndex(layout(), index, value);
}

@Override
public byte[] java() {
var array = new byte[length];
for (int i = 0; i < array.length; i++) {
array[i] = get(i);
}
return array;
}
}
25 changes: 25 additions & 0 deletions arrayfire/datatypes/U8.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package arrayfire.datatypes;

import arrayfire.containers.B8Array;
import arrayfire.containers.U8Array;

import static arrayfire.af.U32;

public class U8 implements DataType<U8Array, U32> {

@Override
public int code() {
return DataTypeEnum.U8.code();
}

@Override
public U32 sumType() {
return U32;
}

@Override
public U8Array create(int length) {
return new U8Array(length);
}
}

8 changes: 8 additions & 0 deletions arrayfire/numbers/X.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package arrayfire.numbers;

public record X(int size) implements IntNumber<X> {
@Override
public X create(int size) {
return new X(size);
}
}
8 changes: 8 additions & 0 deletions arrayfire/numbers/Y.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package arrayfire.numbers;

public record Y(int size) implements IntNumber<Y> {
@Override
public Y create(int size) {
return new Y(size);
}
}
8 changes: 8 additions & 0 deletions arrayfire/numbers/Z.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package arrayfire.numbers;

public record Z(int size) implements IntNumber<Z> {
@Override
public Z create(int size) {
return new Z(size);
}
}

0 comments on commit e1ab972

Please sign in to comment.