Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convolve2 gradients #2

Merged
merged 3 commits into from
Feb 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions arrayfire/Array.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.util.Arrays;
import java.util.function.Function;

public class Array<T extends DataType<?>, S extends Shape<?, ?, ?, ?>> implements MemoryContainer {
Expand Down Expand Up @@ -37,9 +36,17 @@ MemorySegment segment() {
* @return the wrapped void* pointer of the C af_array.
*/
public MemorySegment dereference() {
return segment.get(LAYOUT, 0L);
var value = segment.get(LAYOUT, 0L);
if (MemorySegment.NULL.equals(value)) {
throw new IllegalStateException(
String.format("Cannot dereference an uninitialized segment (nullptr) %s", shape));
}
return value;
}

public boolean materialized() {
return !MemorySegment.NULL.equals(segment.get(LAYOUT, 0L));
}

public int capacity() {
return shape.capacity();
Expand Down
58 changes: 58 additions & 0 deletions arrayfire/ArrayFire.java
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,30 @@ public static <T extends DataType<?>, D0 extends Num<D0>, D1 extends Num<D1>, D2
.build();
}

public static <T extends DataType<?>, D0 extends Num<D0>, D1 extends Num<D1>, D2 extends Num<D2>, D3 extends Num<D3>, S extends Shape<D0, D1, D2, D3>> Array<T, Shape<D0, D2, D1, D3>> transpose(
Array<T, S> array, arrayfire.D1 d1, arrayfire.D2 d2) {
return operation("transpose_D1_D2")
.inputs(array)
.outputs(prototype(array.type(),
shape(array.shape().d0(), array.shape().d2(), array.shape().d1(), array.shape().d3())))
.operation(ptr -> arrayfire_h.af_reorder(ptr, array.dereference(), 0, 2, 1, 3))
.grads((result, grads) -> transpose(grads, d1, d2).reshape(array.shape()))
.build();
}


public static <T extends DataType<?>, D0 extends Num<D0>, D1 extends Num<D1>, D2 extends Num<D2>, D3 extends Num<D3>, S extends Shape<D0, D1, D2, D3>> Array<T, Shape<D0, D1, D3, D2>> transpose(
Array<T, S> array, arrayfire.D2 d2, arrayfire.D3 d3) {
return operation("transpose_D2_D3")
.inputs(array)
.outputs(prototype(array.type(),
shape(array.shape().d0(), array.shape().d1(), array.shape().d3(), array.shape().d2())))
.operation(ptr -> arrayfire_h.af_reorder(ptr, array.dereference(), 0, 1, 3, 2))
.grads((result, grads) -> transpose(grads, d2, d3).reshape(array.shape()))
.build();
}


/**
* Change the type of the array's D0 dimension to the given type variable provider.
*/
Expand Down Expand Up @@ -791,6 +815,9 @@ public static <T extends DataType<?>, OD0 extends Num<OD0>, OD1 extends Num<OD1>
* Release the memory of the given array on the device.
*/
public static void release(Array<?, ?> array) {
if (!array.materialized()) {
return;
}
handleStatus(() -> arrayfire_h.af_release_array(array.dereference()));
Scope.untrack(array);
}
Expand Down Expand Up @@ -1792,6 +1819,19 @@ public static <T extends DataType<?>, D0 extends Num<D0>, D1 extends Num<D1>, D2
public static <T extends DataType<?>, D0 extends Num<D0>, D1 extends Num<D1>, D2 extends Num<D2>, D3 extends Num<D3>, FD0 extends Num<FD0>, FD1 extends Num<FD1>, FD3 extends Num<FD3>, S extends Shape<D0, D1, D2, D3>, FS extends Shape<FD0, FD1, D2, FD3>> Array<T, Shape<N, N, FD3, D3>> convolve2(
Array<T, S> array, Array<T, FS> filters, Shape<?, ?, ?, ?> stride, Shape<?, ?, ?, ?> padding,
Shape<?, ?, ?, ?> dilation) {
if (array.shape().d2().size() != filters.shape().d2().size()) {
throw new IllegalArgumentException(
String.format("D2 for input %s and filters %s must match", array.shape(), filters.shape()));
}
if (stride.ndims() != 2) {
throw new IllegalArgumentException(String.format("Stride must be have 2 dims but was %s", stride));
}
if (padding.ndims() != 2) {
throw new IllegalArgumentException(String.format("Padding must be have 2 dims but was %s", padding));
}
if (dilation.ndims() != 2) {
throw new IllegalArgumentException(String.format("Dilation must be have 2 dims but was %s", dilation));
}
var computedShape = shape(n((array.shape().d0().size() + 2 * padding.d0().size() -
(filters.shape().d0().size() - 1) * dilation.d0().size() - 1) /
stride.d0().size() + 1),
Expand All @@ -1802,11 +1842,29 @@ public static <T extends DataType<?>, D0 extends Num<D0>, D1 extends Num<D1>, D2
.inputs(array, filters)
.outputs(prototype(array.type(), computedShape))
.operation(ptr -> {
// Potentially retry after GC due to https://github.com/arrayfire/arrayfire/issues/3402
retryWithGc(() -> handleStatus(
() -> arrayfire_h.af_convolve2_nn(ptr, array.dereference(), filters.dereference(), 2,
nativeDims(stride), 2, nativeDims(padding), 2, nativeDims(dilation))));
return Status.AF_SUCCESS.code();
})
.grads((result, grads) -> {
// We can get the filter gradients back by performing a convolution again, reducing over the
// image batch as "channels" in reverse.
var inputTranspose = transpose(array, D2, D3);
var gradsTranspose = transpose(grads, D2, D3);
var filterGradsTranspose = convolve2(inputTranspose, gradsTranspose, stride, padding, dilation);
var filterGrads = transpose(filterGradsTranspose, D2, D3);
if (!Arrays.equals(filterGrads.shape().dims(), filters.shape().dims())) {
// This shouldn't happen, but I haven't extensively tested convolution variations.
throw new IllegalStateException(
String.format("Internal: Filter grads shape %s does not match filters shape %s",
filterGradsTranspose.shape(), filters.shape()));
}
return new ArrayPair<>(new ErrorArray<>(array.type(), array.shape(),
"Gradients cannot currently be computed for the input to a convolution"),
filterGrads.reshape(filters.shape()));
})
.build();
}

Expand Down
25 changes: 21 additions & 4 deletions arrayfire/ArrayFireTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ public void validateScope() {
}

private static void checkDims(Array<?, ?> array) {
if (!array.materialized()) {
return;
}
try (Arena arena = Arena.ofConfined()) {
var dims = arena.allocateArray(ValueLayout.JAVA_LONG, 4);
handleStatus(
Expand Down Expand Up @@ -499,22 +502,36 @@ public void convolve2() {
var filters = af.create(new float[]{4, 3, 2, 1, 8, 6, 4, 2}).reshape(2, 2, 1, 2);
var convolved = af.convolve2(input, filters);
assertArrayEquals(new float[]{37, 47, 67, 77, 37 * 2, 47 * 2, 67 * 2, 77 * 2}, af.data(convolved));
var filterGrads = af.grads(convolved, filters);
assertArrayEquals(new float[]{12, 16, 24, 28, 12, 16, 24, 28}, af.data(filterGrads));
}

@Test
public void convolve2Padding() {
var input = af.create(new float[]{1, 2, 3, 4}).reshape(2, 2, 1);
var filters = af.create(new float[]{4, 3, 2, 1}).reshape(2, 2, 1, 1);
var input = af.create(new float[]{1, 2, 3, 4}).reshape(2, 2);
var filters = af.create(new float[]{4, 3, 2, 1}).reshape(2, 2);
var convolved = af.convolve2(input, filters, shape(1, 1), shape(1, 1));
assertArrayEquals(new float[]{4, 11, 6, 14, 30, 14, 6, 11, 4}, af.data(convolved));
var filterGrads = af.grads(convolved, filters);
assertArrayEquals(new float[]{10, 10, 10, 10}, af.data(filterGrads));
}

@Test
public void convolve2Stride() {
var input = af.create(new float[]{1, 2, 3, 4}).reshape(2, 2, 1);
var filters = af.create(new float[]{4, 3, 2, 1}).reshape(2, 2, 1, 1);
var input = af.create(new float[]{1, 2, 3, 4}).reshape(2, 2);
var filters = af.create(new float[]{4, 3, 2, 1}).reshape(2, 2);
var convolved = af.convolve2(input, filters, shape(2, 2), shape(1, 1));
assertArrayEquals(new float[]{4, 6, 6, 4}, af.data(convolved));
var filterGrads = af.grads(convolved, filters);
assertArrayEquals(new float[]{1, 2, 3, 4}, af.data(filterGrads));
}

@Test(expected = UnsupportedOperationException.class)
public void convolve2InputGrads() {
var input = af.create(new float[]{1, 2, 3, 4}).reshape(2, 2);
var filters = af.create(new float[]{4, 3, 2, 1}).reshape(2, 2);
var convolved = af.convolve2(input, filters, shape(2, 2), shape(1, 1));
af.grads(convolved, input);
}

@Test
Expand Down
22 changes: 22 additions & 0 deletions arrayfire/ErrorArray.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package arrayfire;

import java.lang.foreign.MemorySegment;

public class ErrorArray<T extends DataType<?>, S extends Shape<?, ?, ?, ?>> extends Array<T, S> {
private final String errorMessage;

public ErrorArray(T type, S shape, String errorMessage) {
super(type, shape);
this.errorMessage = errorMessage;
}

@Override
public MemorySegment dereference() {
throw new UnsupportedOperationException(errorMessage);
}

@Override
public MemorySegment segment() {
throw new UnsupportedOperationException(errorMessage);
}
}
Loading