From 1199da5c0c4d7c93495de77179c31bec9e1adf52 Mon Sep 17 00:00:00 2001 From: Lewis Hemens Date: Mon, 12 Feb 2024 11:19:09 +0000 Subject: [PATCH 1/2] A simple end-to-end MNIST example --- WORKSPACE.bazel | 26 +++++++++- arrayfire/ArrayFire.java | 83 +++++++++++++++++++++++++------- arrayfire/ArrayFireTest.java | 30 ++++++++++-- arrayfire/Shape.java | 12 +++++ examples/mnist/BUILD | 19 ++++++++ examples/mnist/Dataset.java | 91 +++++++++++++++++++++++++++++++++++ examples/mnist/SimpleNN.java | 93 ++++++++++++++++++++++++++++++++++++ 7 files changed, 334 insertions(+), 20 deletions(-) create mode 100644 examples/mnist/BUILD create mode 100644 examples/mnist/Dataset.java create mode 100644 examples/mnist/SimpleNN.java diff --git a/WORKSPACE.bazel b/WORKSPACE.bazel index 8069a16..af954cf 100644 --- a/WORKSPACE.bazel +++ b/WORKSPACE.bazel @@ -2,7 +2,7 @@ workspace( name = "arrayfire_java_fla", ) -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file") RULES_JVM_EXTERNAL_TAG = "5.3" @@ -39,3 +39,27 @@ maven_install( load("//:deps.bzl", "arrayfire_java_fla_deps") arrayfire_java_fla_deps() + +http_file( + name = "mnist_train_images", + sha256 = "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609", + urls = ["https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz"], +) + +http_file( + name = "mnist_train_labels", + sha256 = "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c", + urls = ["https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz"], +) + +http_file( + name = "mnist_test_images", + sha256 = "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6", + urls = ["https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz"], +) + +http_file( + name = "mnist_test_labels", + sha256 = "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6", + urls = ["https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz"], +) diff --git a/arrayfire/ArrayFire.java b/arrayfire/ArrayFire.java index 4a5c17e..52b4544 100644 --- a/arrayfire/ArrayFire.java +++ b/arrayfire/ArrayFire.java @@ -10,12 +10,12 @@ import java.lang.foreign.MemoryLayout; import java.lang.foreign.MemorySegment; import java.lang.foreign.ValueLayout; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Set; import java.util.function.Function; import java.util.function.Supplier; +import java.util.stream.IntStream; public class ArrayFire { @@ -1177,6 +1177,21 @@ public static , LD1 extends Num, RD1 extends Num .build(); } +// public static , D0 extends Num, D1 extends Num, D2 extends Num, +// D3 extends Num, SL extends Shape> Array> join( +// List> arrays, arrayfire.D1 ignored) { +// return operation("join") +// .inputs(lhs, rhs) +// .outputs(prototype(lhs.type(), +// shape(lhs.shape().d0(), n(lhs.shape().d1().size() + rhs.shape().d1().size()), lhs.shape().d2(), +// lhs.shape().d3()))) +// .operation(ptr -> arrayfire_h.af_join(ptr, 1, lhs.dereference(), rhs.dereference())) +// .grads((result, grads) -> new ArrayPair<>( +// index(grads, span(), seq(lhs.shape().d1())).reshape(lhs.shape()), +// index(grads, span(), seq(lhs.shape().d1().size(), rhs.shape().d1())).reshape(rhs.shape()))) +// .build(); +// } + public static , LD2 extends Num, RD2 extends Num, D0 extends Num, D1 extends Num, D3 extends Num, SL extends Shape, SR extends Shape> Array> join( Array lhs, Array rhs, arrayfire.D2 ignored) { if (!(lhs.shape().d0().size() == rhs.shape().d0().size() && @@ -1724,22 +1739,58 @@ public static , D0 extends Num, D1 extends Num, D2 } - public static , D0 extends Num, D1 extends Num, S extends Shape> List>> batch( - Array array, int batchSize) { - return batch(array, ArrayFire::n, batchSize); + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> List>>> batch( + Array array, arrayfire.D0 ignored, int batchSize) { + return IntStream + .range(0, (int) Math.ceil(array.shape().d0().size() / (double) batchSize)) + .>>>mapToObj(i -> { + var offset = i * batchSize; + var computedSize = Math.min(batchSize, array.shape().d0().size() - offset); + var dim = array.shape().d0().create(computedSize); + return () -> index(array, seq(offset, dim)); + }) + .toList(); + } + + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> List>>> batch( + Array array, arrayfire.D1 ignored, int batchSize) { + return IntStream + .range(0, (int) Math.ceil(array.shape().d1().size() / (double) batchSize)) + .>>>mapToObj(i -> { + var offset = i * batchSize; + var computedSize = Math.min(batchSize, array.shape().d1().size() - offset); + var dim = array.shape().d1().create(computedSize); + return () -> index(array, span(), seq(offset, dim)); + }) + .toList(); + } + + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> List>>> batch( + Array array, arrayfire.D2 ignored, int batchSize) { + return IntStream + .range(0, (int) Math.ceil(array.shape().d2().size() / (double) batchSize)) + .>>>mapToObj(i -> { + var offset = i * batchSize; + var computedSize = Math.min(batchSize, array.shape().d2().size() - offset); + var dim = array.shape().d2().create(computedSize); + return () -> index(array, span(), span(), seq(offset, dim)); + }) + .toList(); + } + + public static , D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> List>>> batch( + Array array, arrayfire.D3 ignored, int batchSize) { + return IntStream + .range(0, (int) Math.ceil(array.shape().d3().size() / (double) batchSize)) + .>>>mapToObj(i -> { + var offset = i * batchSize; + var computedSize = Math.min(batchSize, array.shape().d3().size() - offset); + var dim = array.shape().d3().create(computedSize); + return () -> index(array, span(), span(), span(), seq(offset, dim)); + }) + .toList(); } - public static , D0 extends Num, D1 extends Num, S extends Shape, BDT extends Num> List>> batch( - Array array, Function type, int batchSize) { - var results = new ArrayList>>(); - var d0Seq = seq(array.shape().d0()); - for (int i = 0; i < array.shape().d1().size(); i += batchSize) { - var computedD1Size = Math.min(batchSize, array.shape().d1().size() - i); - var slice = index(array, d0Seq, seq(i, i + computedD1Size - 1)); - results.add(slice.reshape(shape(array.shape().d0(), type.apply(computedD1Size)))); - } - return results; - } @SuppressWarnings({"unchecked", "rawtypes"}) public static , S extends Shape, NS extends Shape> Array tileAs( @@ -1873,7 +1924,7 @@ public static , D0 extends Num, D1 extends Num, D2 */ public static , T extends DataType>, D0 extends Num, D1 extends Num, D2 extends Num, D3 extends Num, S extends Shape> Array> norm( Array array) { - var mul = pow(array, array); + var mul = pow(array, 2); var sum = sum(mul); return sqrt(sum); } diff --git a/arrayfire/ArrayFireTest.java b/arrayfire/ArrayFireTest.java index fd9160b..7ffba78 100644 --- a/arrayfire/ArrayFireTest.java +++ b/arrayfire/ArrayFireTest.java @@ -491,9 +491,33 @@ public void join() { @Test public void batch() { var data = af.create(new float[]{1, 2, 3, 4, 5}).reshape(1, 5); - var batches = af.batch(data, 2); - assertArrayEquals(new float[]{1, 2}, af.data(batches.get(0))); - assertArrayEquals(new float[]{5}, af.data(batches.get(2))); + var batches = af.batch(data, D1, 2); + assertArrayEquals(new float[]{1, 2}, af.data(batches.get(0).get())); + assertArrayEquals(new float[]{5}, af.data(batches.get(2).get())); + } + + @Test + public void batchD1() { + var data = af.create(new float[]{1, 2, 3, 4, 5}).reshape(1, 5); + var batches = af.batch(data, D1, 2); + assertArrayEquals(new float[]{1, 2}, af.data(batches.get(0).get())); + assertArrayEquals(new float[]{5}, af.data(batches.get(2).get())); + } + + @Test + public void batchD2() { + var data = af.create(new float[]{1, 2, 3, 4, 5}).reshape(1, 1, 5); + var batches = af.batch(data, D2, 2); + assertArrayEquals(new float[]{1, 2}, af.data(batches.get(0).get())); + assertArrayEquals(new float[]{5}, af.data(batches.get(2).get())); + } + + @Test + public void batchD3() { + var data = af.create(new float[]{1, 2, 3, 4, 5}).reshape(1, 1, 1, 5); + var batches = af.batch(data, D3, 2); + assertArrayEquals(new float[]{1, 2}, af.data(batches.get(0).get())); + assertArrayEquals(new float[]{5}, af.data(batches.get(2).get())); } @Test diff --git a/arrayfire/Shape.java b/arrayfire/Shape.java index 1154d75..e541096 100644 --- a/arrayfire/Shape.java +++ b/arrayfire/Shape.java @@ -64,6 +64,18 @@ public D3 d3() { return d3; } + public int offset(int d0i, int d1i, int d2i, int d3i) { + return d3i * d2.size() * d1.size() * d0.size() + d2i * d1.size() * d0.size() + d1i * d0.size() + d0i; + } + + public int offset(int d0i, int d1i, int d2i) { + return d2i * d1.size() * d0.size() + d1i * d0.size() + d0i; + } + + public int offset(int d0i, int d1i) { + return d1i * d0.size() + d0i; + } + @Override public boolean equals(Object obj) { if (obj == this) diff --git a/examples/mnist/BUILD b/examples/mnist/BUILD new file mode 100644 index 0000000..bdfac7f --- /dev/null +++ b/examples/mnist/BUILD @@ -0,0 +1,19 @@ +java_library( + name = "mnist", + srcs = glob(["*.java"]), + data = [ + "@mnist_test_images//file", + "@mnist_test_labels//file", + "@mnist_train_images//file", + "@mnist_train_labels//file", + ], + deps = [ + "//arrayfire", + ], +) + +java_binary( + name = "SimpleNN", + main_class = "examples.mnist.SimpleNN", + runtime_deps = [":mnist"], +) diff --git a/examples/mnist/Dataset.java b/examples/mnist/Dataset.java new file mode 100644 index 0000000..f1ce31a --- /dev/null +++ b/examples/mnist/Dataset.java @@ -0,0 +1,91 @@ +package examples.mnist; + +import arrayfire.HostArray; +import arrayfire.Shape; +import arrayfire.U8; +import arrayfire.af; +import arrayfire.numbers.I; +import arrayfire.numbers.N; +import arrayfire.numbers.U; + +import java.io.FileInputStream; +import java.nio.file.Path; +import java.util.List; +import java.util.zip.GZIPInputStream; + +public record Dataset(HostArray> images, HostArray> labels) { + + public static int TOTAL_COUNT = 70000; + public static int LABEL_COUNT = 10; + public static int IMAGE_SIZE = 28 * 28; + public static int IMAGE_WIDTH = 28; + public static int IMAGE_HEIGHT = 28; + + public static Dataset load() { + var runFiles = System.getenv().get("JAVA_RUNFILES"); + var images = getImages( + List.of(readGzipBytes(Path.of(runFiles, "mnist_train_images/file/downloaded").toString()), + readGzipBytes(Path.of(runFiles, "mnist_test_images/file/downloaded").toString()))); + var labels = getLabels( + List.of(readGzipBytes(Path.of(runFiles, "mnist_train_labels/file/downloaded").toString()), + readGzipBytes(Path.of(runFiles, "mnist_test_labels/file/downloaded").toString()))); + return new Dataset(images, labels); + } + + private static byte[] readGzipBytes(String path) { + try { + var fis = new FileInputStream(path); + var gis = new GZIPInputStream(fis); + return gis.readAllBytes(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static HostArray> getImages(List datas) { + var shape = af.shape(af.i(IMAGE_SIZE), af.n(TOTAL_COUNT)); + var imageIndex = 0; + var images = af.createHost(af.U8, shape); + for (byte[] data : datas) { + int byteIndex = 8; + int rows = ((data[byteIndex + 3] & 0xFF)) | ((data[byteIndex + 2] & 0xFF) << 8) | + ((data[byteIndex + 1] & 0xFF) << 16) | ((data[byteIndex] & 0xFF) << 24); + byteIndex += 4; + int cols = ((data[byteIndex + 3] & 0xFF)) | ((data[byteIndex + 2] & 0xFF) << 8) | + ((data[byteIndex + 1] & 0xFF) << 16) | ((data[byteIndex] & 0xFF) << 24); + byteIndex += 4; + if (rows != 28 || cols != 28) { + throw new IllegalStateException(String.format("Expected 28x28 but rows/cols where %s/%s", rows, cols)); + } + while (byteIndex < data.length) { + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + images.set(shape.offset(i * rows + j, imageIndex), data[byteIndex]); + byteIndex++; + } + } + imageIndex++; + } + } + return images; + } + + private static HostArray> getLabels(List datas) { + var shape = af.shape(af.U, af.n(TOTAL_COUNT)); + var labels = af.createHost(af.U8, shape); + var labelIndex = 0; + for (byte[] data : datas) { + for (int i = 8; i < data.length; i++) { + labels.set(labelIndex, data[i]); + labelIndex++; + } + } + for (int i = 0; i < labels.length(); i++) { + if (!(labels.get(i) <= 9 && labels.get(i) >= 0)) { + throw new IllegalStateException( + String.format("Label greater than 9 or less than 0: %s", labels.get(i))); + } + } + return labels; + } +} \ No newline at end of file diff --git a/examples/mnist/SimpleNN.java b/examples/mnist/SimpleNN.java new file mode 100644 index 0000000..cba0dc2 --- /dev/null +++ b/examples/mnist/SimpleNN.java @@ -0,0 +1,93 @@ +package examples.mnist; + +import arrayfire.Array; +import arrayfire.Shape; +import arrayfire.U8; +import arrayfire.af; +import arrayfire.numbers.I; +import arrayfire.numbers.N; +import arrayfire.numbers.U; +import arrayfire.optimizers.SGD; +import arrayfire.utils.Functions; + +import java.util.stream.IntStream; + +/** + * A simple 2 layer neural network for classifying MNIST digits. + * $ bazel run examples/mnist:SimpleNN + */ +public class SimpleNN { + public static void main(String[] args) { + af.tidy(() -> { + af.setSeed(0); + + var optimizer = SGD.create().learningRate(0.1f); + var hiddenDim = af.a(2000); + var hiddenWeights = af.params( + () -> af.normalize(af.randn(af.F32, af.shape(af.i(Dataset.IMAGE_SIZE), hiddenDim))), optimizer); + var weights = af.params( + () -> af.normalize(af.randn(af.F32, af.shape(hiddenDim, af.l(Dataset.LABEL_COUNT)))), optimizer); + + run((imageBatch, labelBatch, train) -> { + var imagesF32 = imageBatch.cast(af.F32); + var imageNorm = af.normalize(af.center(imagesF32)); + var hidden = af.relu(af.matmul(af.transpose(hiddenWeights), imageNorm)); + var predict = af.softmax(af.matmul(af.transpose(weights), hidden)); + if (train) { + var labelsOneHot = af.oneHot(labelBatch.cast(af.S32), af.l(Dataset.LABEL_COUNT)); + var rmsLoss = af.pow(af.sub(labelsOneHot, predict), 2); + af.optimize(rmsLoss); + } + return af.imax(predict).indices().cast(af.U8); + }); + }); + } + + + public static void run( + Functions.Function3>, Array>, Boolean, Array>> fn) { + var mnist = Dataset.load(); + // Sort images and labels. + var permutation = af.permutation(af.n(Dataset.TOTAL_COUNT)); + var images = af.index(af.create(mnist.images()), af.span(), permutation); + var labels = af.index(af.create(mnist.labels()), af.span(), permutation); + // Split into train and test sets. + var trainImages = af.index(images, af.span(), af.seq(0, 60000 - 1)); + var trainLabels = af.index(labels, af.span(), af.seq(0, 60000 - 1)); + var testImages = af.index(images, af.span(), af.seq(60000, 70000 - 1)); + var testLabels = af.index(labels, af.span(), af.seq(60000, 70000 - 1)); + + var epochs = 50; + var batchSize = 256; + + IntStream.range(0, epochs).forEach(epoch -> { + var trainImageBatches = af.batch(trainImages, af.D1, batchSize); + var trainLabelBatches = af.batch(trainLabels, af.D1, batchSize); + // Train. + var trainCorrect = IntStream.range(0, trainImageBatches.size()).mapToLong(i -> af.tidy(() -> { + var trainImagesBatch = trainImageBatches.get(i).get(); + var trainLabelsBatch = trainLabelBatches.get(i).get(); + var predicted = fn.apply(trainImagesBatch, trainLabelsBatch, true); + var correct = af.sum(af.eq(predicted, trainLabelsBatch).flatten()); + return af.data(correct).get(0); + })).sum(); + // Test. + var testCorrect = af.tidy(() -> { + var testImageBatches = af.batch(testImages, af.D1, batchSize); + var testLabelBatches = af.batch(testLabels, af.D1, batchSize); + return IntStream.range(0, testImageBatches.size()).mapToLong(i -> af.tidy(() -> { + var testImagesBatch = testImageBatches.get(i).get(); + var testLabelsBatch = testLabelBatches.get(i).get(); + var predicted = fn.apply(testImagesBatch, af.zeros(testLabelsBatch.type(), testLabelsBatch.shape()), + false); + var correct = af.sum(af.eq(predicted, testLabelsBatch).flatten()); + return af.data(correct).get(0); + })).sum(); + }); + System.out.printf("Epoch %s: Train: %.5f, Test: %.5f%n", epoch, + trainCorrect / (double) trainImages.shape().d1().size(), + testCorrect / (double) testImages.shape().d1().size()); + }); + + } +} \ No newline at end of file From 2da5f7ad0e2707535d866a2c9ee5e8d5a10663c2 Mon Sep 17 00:00:00 2001 From: Lewis Hemens Date: Mon, 12 Feb 2024 11:20:37 +0000 Subject: [PATCH 2/2] Remove comments --- arrayfire/ArrayFire.java | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/arrayfire/ArrayFire.java b/arrayfire/ArrayFire.java index 52b4544..9d052aa 100644 --- a/arrayfire/ArrayFire.java +++ b/arrayfire/ArrayFire.java @@ -1177,21 +1177,6 @@ public static , LD1 extends Num, RD1 extends Num .build(); } -// public static , D0 extends Num, D1 extends Num, D2 extends Num, -// D3 extends Num, SL extends Shape> Array> join( -// List> arrays, arrayfire.D1 ignored) { -// return operation("join") -// .inputs(lhs, rhs) -// .outputs(prototype(lhs.type(), -// shape(lhs.shape().d0(), n(lhs.shape().d1().size() + rhs.shape().d1().size()), lhs.shape().d2(), -// lhs.shape().d3()))) -// .operation(ptr -> arrayfire_h.af_join(ptr, 1, lhs.dereference(), rhs.dereference())) -// .grads((result, grads) -> new ArrayPair<>( -// index(grads, span(), seq(lhs.shape().d1())).reshape(lhs.shape()), -// index(grads, span(), seq(lhs.shape().d1().size(), rhs.shape().d1())).reshape(rhs.shape()))) -// .build(); -// } - public static , LD2 extends Num, RD2 extends Num, D0 extends Num, D1 extends Num, D3 extends Num, SL extends Shape, SR extends Shape> Array> join( Array lhs, Array rhs, arrayfire.D2 ignored) { if (!(lhs.shape().d0().size() == rhs.shape().d0().size() &&