diff --git a/arrayfire/Array.java b/arrayfire/Array.java index b1662e8..85e300e 100644 --- a/arrayfire/Array.java +++ b/arrayfire/Array.java @@ -8,7 +8,6 @@ import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; import java.lang.foreign.ValueLayout; -import java.util.Arrays; import java.util.function.Function; public class Array, S extends Shape> implements MemoryContainer { @@ -37,9 +36,17 @@ MemorySegment segment() { * @return the wrapped void* pointer of the C af_array. */ public MemorySegment dereference() { - return segment.get(LAYOUT, 0L); + var value = segment.get(LAYOUT, 0L); + if (MemorySegment.NULL.equals(value)) { + throw new IllegalStateException( + String.format("Cannot dereference an uninitialized segment (nullptr) %s", shape)); + } + return value; } + public boolean materialized() { + return !MemorySegment.NULL.equals(segment.get(LAYOUT, 0L)); + } public int capacity() { return shape.capacity(); diff --git a/arrayfire/ArrayFire.java b/arrayfire/ArrayFire.java index 5f5b9c8..4a5c17e 100644 --- a/arrayfire/ArrayFire.java +++ b/arrayfire/ArrayFire.java @@ -730,6 +730,30 @@ public static , D0 extends Num, D1 extends Num, D2 .build(); } + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Array> transpose( + Array array, arrayfire.D1 d1, arrayfire.D2 d2) { + return operation("transpose_D1_D2") + .inputs(array) + .outputs(prototype(array.type(), + shape(array.shape().d0(), array.shape().d2(), array.shape().d1(), array.shape().d3()))) + .operation(ptr -> arrayfire_h.af_reorder(ptr, array.dereference(), 0, 2, 1, 3)) + .grads((result, grads) -> transpose(grads, d1, d2).reshape(array.shape())) + .build(); + } + + + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Array> transpose( + Array array, arrayfire.D2 d2, arrayfire.D3 d3) { + return operation("transpose_D2_D3") + .inputs(array) + .outputs(prototype(array.type(), + shape(array.shape().d0(), array.shape().d1(), array.shape().d3(), array.shape().d2()))) + .operation(ptr -> arrayfire_h.af_reorder(ptr, array.dereference(), 0, 1, 3, 2)) + .grads((result, grads) -> transpose(grads, d2, d3).reshape(array.shape())) + .build(); + } + + /** * Change the type of the array's D0 dimension to the given type variable provider. */ @@ -791,6 +815,9 @@ public static , OD0 extends Num, OD1 extends Num * Release the memory of the given array on the device. */ public static void release(Array array) { + if (!array.materialized()) { + return; + } handleStatus(() -> arrayfire_h.af_release_array(array.dereference())); Scope.untrack(array); } @@ -1792,6 +1819,19 @@ public static , D0 extends Num, D1 extends Num, D2 public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, FD0 extends Num, FD1 extends Num, FD3 extends Num, S extends Shape, FS extends Shape> Array> convolve2( Array array, Array filters, Shape stride, Shape padding, Shape dilation) { + if (array.shape().d2().size() != filters.shape().d2().size()) { + throw new IllegalArgumentException( + String.format("D2 for input %s and filters %s must match", array.shape(), filters.shape())); + } + if (stride.ndims() != 2) { + throw new IllegalArgumentException(String.format("Stride must be have 2 dims but was %s", stride)); + } + if (padding.ndims() != 2) { + throw new IllegalArgumentException(String.format("Padding must be have 2 dims but was %s", padding)); + } + if (dilation.ndims() != 2) { + throw new IllegalArgumentException(String.format("Dilation must be have 2 dims but was %s", dilation)); + } var computedShape = shape(n((array.shape().d0().size() + 2 * padding.d0().size() - (filters.shape().d0().size() - 1) * dilation.d0().size() - 1) / stride.d0().size() + 1), @@ -1802,11 +1842,29 @@ public static , D0 extends Num, D1 extends Num, D2 .inputs(array, filters) .outputs(prototype(array.type(), computedShape)) .operation(ptr -> { + // Potentially retry after GC due to https://github.com/arrayfire/arrayfire/issues/3402 retryWithGc(() -> handleStatus( () -> arrayfire_h.af_convolve2_nn(ptr, array.dereference(), filters.dereference(), 2, nativeDims(stride), 2, nativeDims(padding), 2, nativeDims(dilation)))); return Status.AF_SUCCESS.code(); }) + .grads((result, grads) -> { + // We can get the filter gradients back by performing a convolution again, reducing over the + // image batch as "channels" in reverse. + var inputTranspose = transpose(array, D2, D3); + var gradsTranspose = transpose(grads, D2, D3); + var filterGradsTranspose = convolve2(inputTranspose, gradsTranspose, stride, padding, dilation); + var filterGrads = transpose(filterGradsTranspose, D2, D3); + if (!Arrays.equals(filterGrads.shape().dims(), filters.shape().dims())) { + // This shouldn't happen, but I haven't extensively tested convolution variations. + throw new IllegalStateException( + String.format("Internal: Filter grads shape %s does not match filters shape %s", + filterGradsTranspose.shape(), filters.shape())); + } + return new ArrayPair<>(new ErrorArray<>(array.type(), array.shape(), + "Gradients cannot currently be computed for the input to a convolution"), + filterGrads.reshape(filters.shape())); + }) .build(); } diff --git a/arrayfire/ArrayFireTest.java b/arrayfire/ArrayFireTest.java index 861e39d..fd9160b 100644 --- a/arrayfire/ArrayFireTest.java +++ b/arrayfire/ArrayFireTest.java @@ -54,6 +54,9 @@ public void validateScope() { } private static void checkDims(Array array) { + if (!array.materialized()) { + return; + } try (Arena arena = Arena.ofConfined()) { var dims = arena.allocateArray(ValueLayout.JAVA_LONG, 4); handleStatus( @@ -499,22 +502,36 @@ public void convolve2() { var filters = af.create(new float[]{4, 3, 2, 1, 8, 6, 4, 2}).reshape(2, 2, 1, 2); var convolved = af.convolve2(input, filters); assertArrayEquals(new float[]{37, 47, 67, 77, 37 * 2, 47 * 2, 67 * 2, 77 * 2}, af.data(convolved)); + var filterGrads = af.grads(convolved, filters); + assertArrayEquals(new float[]{12, 16, 24, 28, 12, 16, 24, 28}, af.data(filterGrads)); } @Test public void convolve2Padding() { - var input = af.create(new float[]{1, 2, 3, 4}).reshape(2, 2, 1); - var filters = af.create(new float[]{4, 3, 2, 1}).reshape(2, 2, 1, 1); + var input = af.create(new float[]{1, 2, 3, 4}).reshape(2, 2); + var filters = af.create(new float[]{4, 3, 2, 1}).reshape(2, 2); var convolved = af.convolve2(input, filters, shape(1, 1), shape(1, 1)); assertArrayEquals(new float[]{4, 11, 6, 14, 30, 14, 6, 11, 4}, af.data(convolved)); + var filterGrads = af.grads(convolved, filters); + assertArrayEquals(new float[]{10, 10, 10, 10}, af.data(filterGrads)); } @Test public void convolve2Stride() { - var input = af.create(new float[]{1, 2, 3, 4}).reshape(2, 2, 1); - var filters = af.create(new float[]{4, 3, 2, 1}).reshape(2, 2, 1, 1); + var input = af.create(new float[]{1, 2, 3, 4}).reshape(2, 2); + var filters = af.create(new float[]{4, 3, 2, 1}).reshape(2, 2); var convolved = af.convolve2(input, filters, shape(2, 2), shape(1, 1)); assertArrayEquals(new float[]{4, 6, 6, 4}, af.data(convolved)); + var filterGrads = af.grads(convolved, filters); + assertArrayEquals(new float[]{1, 2, 3, 4}, af.data(filterGrads)); + } + + @Test(expected = UnsupportedOperationException.class) + public void convolve2InputGrads() { + var input = af.create(new float[]{1, 2, 3, 4}).reshape(2, 2); + var filters = af.create(new float[]{4, 3, 2, 1}).reshape(2, 2); + var convolved = af.convolve2(input, filters, shape(2, 2), shape(1, 1)); + af.grads(convolved, input); } @Test diff --git a/arrayfire/ErrorArray.java b/arrayfire/ErrorArray.java new file mode 100644 index 0000000..8df642c --- /dev/null +++ b/arrayfire/ErrorArray.java @@ -0,0 +1,22 @@ +package arrayfire; + +import java.lang.foreign.MemorySegment; + +public class ErrorArray, S extends Shape> extends Array { + private final String errorMessage; + + public ErrorArray(T type, S shape, String errorMessage) { + super(type, shape); + this.errorMessage = errorMessage; + } + + @Override + public MemorySegment dereference() { + throw new UnsupportedOperationException(errorMessage); + } + + @Override + public MemorySegment segment() { + throw new UnsupportedOperationException(errorMessage); + } +}