Skip to content

Commit

Permalink
Add a simple MNIST example (#3)
Browse files Browse the repository at this point in the history
* A simple end-to-end MNIST example

* Remove comments
  • Loading branch information
lewish authored Feb 12, 2024
1 parent f0e8ca1 commit 9b966d2
Show file tree
Hide file tree
Showing 7 changed files with 319 additions and 20 deletions.
26 changes: 25 additions & 1 deletion WORKSPACE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ workspace(
name = "arrayfire_java_fla",
)

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file")

RULES_JVM_EXTERNAL_TAG = "5.3"

Expand Down Expand Up @@ -39,3 +39,27 @@ maven_install(
load("//:deps.bzl", "arrayfire_java_fla_deps")

arrayfire_java_fla_deps()

http_file(
name = "mnist_train_images",
sha256 = "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609",
urls = ["https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz"],
)

http_file(
name = "mnist_train_labels",
sha256 = "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c",
urls = ["https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz"],
)

http_file(
name = "mnist_test_images",
sha256 = "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6",
urls = ["https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz"],
)

http_file(
name = "mnist_test_labels",
sha256 = "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6",
urls = ["https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz"],
)
68 changes: 52 additions & 16 deletions arrayfire/ArrayFire.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.IntStream;

public class ArrayFire {

Expand Down Expand Up @@ -1724,22 +1724,58 @@ public static <T extends DataType<?>, D0 extends Num<D0>, D1 extends Num<D1>, D2

}

public static <T extends DataType<?>, D0 extends Num<D0>, D1 extends Num<D1>, S extends Shape<D0, D1, U, U>> List<Array<T, Shape<D0, N, U, U>>> batch(
Array<T, S> array, int batchSize) {
return batch(array, ArrayFire::n, batchSize);
public static <T extends DataType<?>, D0 extends Num<D0>, D1 extends Num<D1>, D2 extends Num<D2>, D3 extends Num<D3>, S extends Shape<D0, D1, D2, D3>> List<Supplier<Array<T, Shape<D0, D1, D2, D3>>>> batch(
Array<T, S> array, arrayfire.D0 ignored, int batchSize) {
return IntStream
.range(0, (int) Math.ceil(array.shape().d0().size() / (double) batchSize))
.<Supplier<Array<T, Shape<D0, D1, D2, D3>>>>mapToObj(i -> {
var offset = i * batchSize;
var computedSize = Math.min(batchSize, array.shape().d0().size() - offset);
var dim = array.shape().d0().create(computedSize);
return () -> index(array, seq(offset, dim));
})
.toList();
}

public static <T extends DataType<?>, D0 extends Num<D0>, D1 extends Num<D1>, D2 extends Num<D2>, D3 extends Num<D3>, S extends Shape<D0, D1, D2, D3>> List<Supplier<Array<T, Shape<D0, D1, D2, D3>>>> batch(
Array<T, S> array, arrayfire.D1 ignored, int batchSize) {
return IntStream
.range(0, (int) Math.ceil(array.shape().d1().size() / (double) batchSize))
.<Supplier<Array<T, Shape<D0, D1, D2, D3>>>>mapToObj(i -> {
var offset = i * batchSize;
var computedSize = Math.min(batchSize, array.shape().d1().size() - offset);
var dim = array.shape().d1().create(computedSize);
return () -> index(array, span(), seq(offset, dim));
})
.toList();
}

public static <T extends DataType<?>, D0 extends Num<D0>, D1 extends Num<D1>, D2 extends Num<D2>, D3 extends Num<D3>, S extends Shape<D0, D1, D2, D3>> List<Supplier<Array<T, Shape<D0, D1, D2, D3>>>> batch(
Array<T, S> array, arrayfire.D2 ignored, int batchSize) {
return IntStream
.range(0, (int) Math.ceil(array.shape().d2().size() / (double) batchSize))
.<Supplier<Array<T, Shape<D0, D1, D2, D3>>>>mapToObj(i -> {
var offset = i * batchSize;
var computedSize = Math.min(batchSize, array.shape().d2().size() - offset);
var dim = array.shape().d2().create(computedSize);
return () -> index(array, span(), span(), seq(offset, dim));
})
.toList();
}

public static <T extends DataType<?>, D0 extends Num<D0>, D1 extends Num<D1>, D2 extends Num<D2>, D3 extends Num<D3>, S extends Shape<D0, D1, D2, D3>> List<Supplier<Array<T, Shape<D0, D1, D2, D3>>>> batch(
Array<T, S> array, arrayfire.D3 ignored, int batchSize) {
return IntStream
.range(0, (int) Math.ceil(array.shape().d3().size() / (double) batchSize))
.<Supplier<Array<T, Shape<D0, D1, D2, D3>>>>mapToObj(i -> {
var offset = i * batchSize;
var computedSize = Math.min(batchSize, array.shape().d3().size() - offset);
var dim = array.shape().d3().create(computedSize);
return () -> index(array, span(), span(), span(), seq(offset, dim));
})
.toList();
}

public static <T extends DataType<?>, D0 extends Num<D0>, D1 extends Num<D1>, S extends Shape<D0, D1, U, U>, BDT extends Num<BDT>> List<Array<T, Shape<D0, BDT, U, U>>> batch(
Array<T, S> array, Function<Integer, BDT> type, int batchSize) {
var results = new ArrayList<Array<T, Shape<D0, BDT, U, U>>>();
var d0Seq = seq(array.shape().d0());
for (int i = 0; i < array.shape().d1().size(); i += batchSize) {
var computedD1Size = Math.min(batchSize, array.shape().d1().size() - i);
var slice = index(array, d0Seq, seq(i, i + computedD1Size - 1));
results.add(slice.reshape(shape(array.shape().d0(), type.apply(computedD1Size))));
}
return results;
}

@SuppressWarnings({"unchecked", "rawtypes"})
public static <T extends DataType<?>, S extends Shape<?, ?, ?, ?>, NS extends Shape<?, ?, ?, ?>> Array<T, NS> tileAs(
Expand Down Expand Up @@ -1873,7 +1909,7 @@ public static <T extends DataType<?>, D0 extends Num<D0>, D1 extends Num<D1>, D2
*/
public static <ST extends DataType<?>, T extends DataType<? extends DataType.Meta<ST, ?, ?>>, D0 extends Num<D0>, D1 extends Num<D1>, D2 extends Num<D2>, D3 extends Num<D3>, S extends Shape<D0, D1, D2, D3>> Array<ST, Shape<U, D1, D2, D3>> norm(
Array<T, S> array) {
var mul = pow(array, array);
var mul = pow(array, 2);
var sum = sum(mul);
return sqrt(sum);
}
Expand Down
30 changes: 27 additions & 3 deletions arrayfire/ArrayFireTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,33 @@ public void join() {
@Test
public void batch() {
var data = af.create(new float[]{1, 2, 3, 4, 5}).reshape(1, 5);
var batches = af.batch(data, 2);
assertArrayEquals(new float[]{1, 2}, af.data(batches.get(0)));
assertArrayEquals(new float[]{5}, af.data(batches.get(2)));
var batches = af.batch(data, D1, 2);
assertArrayEquals(new float[]{1, 2}, af.data(batches.get(0).get()));
assertArrayEquals(new float[]{5}, af.data(batches.get(2).get()));
}

@Test
public void batchD1() {
var data = af.create(new float[]{1, 2, 3, 4, 5}).reshape(1, 5);
var batches = af.batch(data, D1, 2);
assertArrayEquals(new float[]{1, 2}, af.data(batches.get(0).get()));
assertArrayEquals(new float[]{5}, af.data(batches.get(2).get()));
}

@Test
public void batchD2() {
var data = af.create(new float[]{1, 2, 3, 4, 5}).reshape(1, 1, 5);
var batches = af.batch(data, D2, 2);
assertArrayEquals(new float[]{1, 2}, af.data(batches.get(0).get()));
assertArrayEquals(new float[]{5}, af.data(batches.get(2).get()));
}

@Test
public void batchD3() {
var data = af.create(new float[]{1, 2, 3, 4, 5}).reshape(1, 1, 1, 5);
var batches = af.batch(data, D3, 2);
assertArrayEquals(new float[]{1, 2}, af.data(batches.get(0).get()));
assertArrayEquals(new float[]{5}, af.data(batches.get(2).get()));
}

@Test
Expand Down
12 changes: 12 additions & 0 deletions arrayfire/Shape.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ public D3 d3() {
return d3;
}

public int offset(int d0i, int d1i, int d2i, int d3i) {
return d3i * d2.size() * d1.size() * d0.size() + d2i * d1.size() * d0.size() + d1i * d0.size() + d0i;
}

public int offset(int d0i, int d1i, int d2i) {
return d2i * d1.size() * d0.size() + d1i * d0.size() + d0i;
}

public int offset(int d0i, int d1i) {
return d1i * d0.size() + d0i;
}

@Override
public boolean equals(Object obj) {
if (obj == this)
Expand Down
19 changes: 19 additions & 0 deletions examples/mnist/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
java_library(
name = "mnist",
srcs = glob(["*.java"]),
data = [
"@mnist_test_images//file",
"@mnist_test_labels//file",
"@mnist_train_images//file",
"@mnist_train_labels//file",
],
deps = [
"//arrayfire",
],
)

java_binary(
name = "SimpleNN",
main_class = "examples.mnist.SimpleNN",
runtime_deps = [":mnist"],
)
91 changes: 91 additions & 0 deletions examples/mnist/Dataset.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package examples.mnist;

import arrayfire.HostArray;
import arrayfire.Shape;
import arrayfire.U8;
import arrayfire.af;
import arrayfire.numbers.I;
import arrayfire.numbers.N;
import arrayfire.numbers.U;

import java.io.FileInputStream;
import java.nio.file.Path;
import java.util.List;
import java.util.zip.GZIPInputStream;

public record Dataset(HostArray<U8, Byte, Shape<I, N, U, U>> images, HostArray<U8, Byte, Shape<U, N, U, U>> labels) {

public static int TOTAL_COUNT = 70000;
public static int LABEL_COUNT = 10;
public static int IMAGE_SIZE = 28 * 28;
public static int IMAGE_WIDTH = 28;
public static int IMAGE_HEIGHT = 28;

public static Dataset load() {
var runFiles = System.getenv().get("JAVA_RUNFILES");
var images = getImages(
List.of(readGzipBytes(Path.of(runFiles, "mnist_train_images/file/downloaded").toString()),
readGzipBytes(Path.of(runFiles, "mnist_test_images/file/downloaded").toString())));
var labels = getLabels(
List.of(readGzipBytes(Path.of(runFiles, "mnist_train_labels/file/downloaded").toString()),
readGzipBytes(Path.of(runFiles, "mnist_test_labels/file/downloaded").toString())));
return new Dataset(images, labels);
}

private static byte[] readGzipBytes(String path) {
try {
var fis = new FileInputStream(path);
var gis = new GZIPInputStream(fis);
return gis.readAllBytes();
} catch (Exception e) {
throw new RuntimeException(e);
}
}

private static HostArray<U8, Byte, Shape<I, N, U, U>> getImages(List<byte[]> datas) {
var shape = af.shape(af.i(IMAGE_SIZE), af.n(TOTAL_COUNT));
var imageIndex = 0;
var images = af.createHost(af.U8, shape);
for (byte[] data : datas) {
int byteIndex = 8;
int rows = ((data[byteIndex + 3] & 0xFF)) | ((data[byteIndex + 2] & 0xFF) << 8) |
((data[byteIndex + 1] & 0xFF) << 16) | ((data[byteIndex] & 0xFF) << 24);
byteIndex += 4;
int cols = ((data[byteIndex + 3] & 0xFF)) | ((data[byteIndex + 2] & 0xFF) << 8) |
((data[byteIndex + 1] & 0xFF) << 16) | ((data[byteIndex] & 0xFF) << 24);
byteIndex += 4;
if (rows != 28 || cols != 28) {
throw new IllegalStateException(String.format("Expected 28x28 but rows/cols where %s/%s", rows, cols));
}
while (byteIndex < data.length) {
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
images.set(shape.offset(i * rows + j, imageIndex), data[byteIndex]);
byteIndex++;
}
}
imageIndex++;
}
}
return images;
}

private static HostArray<U8, Byte, Shape<U, N, U, U>> getLabels(List<byte[]> datas) {
var shape = af.shape(af.U, af.n(TOTAL_COUNT));
var labels = af.createHost(af.U8, shape);
var labelIndex = 0;
for (byte[] data : datas) {
for (int i = 8; i < data.length; i++) {
labels.set(labelIndex, data[i]);
labelIndex++;
}
}
for (int i = 0; i < labels.length(); i++) {
if (!(labels.get(i) <= 9 && labels.get(i) >= 0)) {
throw new IllegalStateException(
String.format("Label greater than 9 or less than 0: %s", labels.get(i)));
}
}
return labels;
}
}
93 changes: 93 additions & 0 deletions examples/mnist/SimpleNN.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package examples.mnist;

import arrayfire.Array;
import arrayfire.Shape;
import arrayfire.U8;
import arrayfire.af;
import arrayfire.numbers.I;
import arrayfire.numbers.N;
import arrayfire.numbers.U;
import arrayfire.optimizers.SGD;
import arrayfire.utils.Functions;

import java.util.stream.IntStream;

/**
* A simple 2 layer neural network for classifying MNIST digits.
* $ bazel run examples/mnist:SimpleNN
*/
public class SimpleNN {
public static void main(String[] args) {
af.tidy(() -> {
af.setSeed(0);

var optimizer = SGD.create().learningRate(0.1f);
var hiddenDim = af.a(2000);
var hiddenWeights = af.params(
() -> af.normalize(af.randn(af.F32, af.shape(af.i(Dataset.IMAGE_SIZE), hiddenDim))), optimizer);
var weights = af.params(
() -> af.normalize(af.randn(af.F32, af.shape(hiddenDim, af.l(Dataset.LABEL_COUNT)))), optimizer);

run((imageBatch, labelBatch, train) -> {
var imagesF32 = imageBatch.cast(af.F32);
var imageNorm = af.normalize(af.center(imagesF32));
var hidden = af.relu(af.matmul(af.transpose(hiddenWeights), imageNorm));
var predict = af.softmax(af.matmul(af.transpose(weights), hidden));
if (train) {
var labelsOneHot = af.oneHot(labelBatch.cast(af.S32), af.l(Dataset.LABEL_COUNT));
var rmsLoss = af.pow(af.sub(labelsOneHot, predict), 2);
af.optimize(rmsLoss);
}
return af.imax(predict).indices().cast(af.U8);
});
});
}


public static void run(
Functions.Function3<Array<U8, Shape<I, N, U, U>>, Array<U8, Shape<U, N, U, U>>, Boolean, Array<U8, Shape<U, N, U, U>>> fn) {
var mnist = Dataset.load();
// Sort images and labels.
var permutation = af.permutation(af.n(Dataset.TOTAL_COUNT));
var images = af.index(af.create(mnist.images()), af.span(), permutation);
var labels = af.index(af.create(mnist.labels()), af.span(), permutation);
// Split into train and test sets.
var trainImages = af.index(images, af.span(), af.seq(0, 60000 - 1));
var trainLabels = af.index(labels, af.span(), af.seq(0, 60000 - 1));
var testImages = af.index(images, af.span(), af.seq(60000, 70000 - 1));
var testLabels = af.index(labels, af.span(), af.seq(60000, 70000 - 1));

var epochs = 50;
var batchSize = 256;

IntStream.range(0, epochs).forEach(epoch -> {
var trainImageBatches = af.batch(trainImages, af.D1, batchSize);
var trainLabelBatches = af.batch(trainLabels, af.D1, batchSize);
// Train.
var trainCorrect = IntStream.range(0, trainImageBatches.size()).mapToLong(i -> af.tidy(() -> {
var trainImagesBatch = trainImageBatches.get(i).get();
var trainLabelsBatch = trainLabelBatches.get(i).get();
var predicted = fn.apply(trainImagesBatch, trainLabelsBatch, true);
var correct = af.sum(af.eq(predicted, trainLabelsBatch).flatten());
return af.data(correct).get(0);
})).sum();
// Test.
var testCorrect = af.tidy(() -> {
var testImageBatches = af.batch(testImages, af.D1, batchSize);
var testLabelBatches = af.batch(testLabels, af.D1, batchSize);
return IntStream.range(0, testImageBatches.size()).mapToLong(i -> af.tidy(() -> {
var testImagesBatch = testImageBatches.get(i).get();
var testLabelsBatch = testLabelBatches.get(i).get();
var predicted = fn.apply(testImagesBatch, af.zeros(testLabelsBatch.type(), testLabelsBatch.shape()),
false);
var correct = af.sum(af.eq(predicted, testLabelsBatch).flatten());
return af.data(correct).get(0);
})).sum();
});
System.out.printf("Epoch %s: Train: %.5f, Test: %.5f%n", epoch,
trainCorrect / (double) trainImages.shape().d1().size(),
testCorrect / (double) testImages.shape().d1().size());
});

}
}

0 comments on commit 9b966d2

Please sign in to comment.