-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* A simple end-to-end MNIST example * Remove comments
- Loading branch information
Showing
7 changed files
with
319 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
java_library( | ||
name = "mnist", | ||
srcs = glob(["*.java"]), | ||
data = [ | ||
"@mnist_test_images//file", | ||
"@mnist_test_labels//file", | ||
"@mnist_train_images//file", | ||
"@mnist_train_labels//file", | ||
], | ||
deps = [ | ||
"//arrayfire", | ||
], | ||
) | ||
|
||
java_binary( | ||
name = "SimpleNN", | ||
main_class = "examples.mnist.SimpleNN", | ||
runtime_deps = [":mnist"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
package examples.mnist; | ||
|
||
import arrayfire.HostArray; | ||
import arrayfire.Shape; | ||
import arrayfire.U8; | ||
import arrayfire.af; | ||
import arrayfire.numbers.I; | ||
import arrayfire.numbers.N; | ||
import arrayfire.numbers.U; | ||
|
||
import java.io.FileInputStream; | ||
import java.nio.file.Path; | ||
import java.util.List; | ||
import java.util.zip.GZIPInputStream; | ||
|
||
public record Dataset(HostArray<U8, Byte, Shape<I, N, U, U>> images, HostArray<U8, Byte, Shape<U, N, U, U>> labels) { | ||
|
||
public static int TOTAL_COUNT = 70000; | ||
public static int LABEL_COUNT = 10; | ||
public static int IMAGE_SIZE = 28 * 28; | ||
public static int IMAGE_WIDTH = 28; | ||
public static int IMAGE_HEIGHT = 28; | ||
|
||
public static Dataset load() { | ||
var runFiles = System.getenv().get("JAVA_RUNFILES"); | ||
var images = getImages( | ||
List.of(readGzipBytes(Path.of(runFiles, "mnist_train_images/file/downloaded").toString()), | ||
readGzipBytes(Path.of(runFiles, "mnist_test_images/file/downloaded").toString()))); | ||
var labels = getLabels( | ||
List.of(readGzipBytes(Path.of(runFiles, "mnist_train_labels/file/downloaded").toString()), | ||
readGzipBytes(Path.of(runFiles, "mnist_test_labels/file/downloaded").toString()))); | ||
return new Dataset(images, labels); | ||
} | ||
|
||
private static byte[] readGzipBytes(String path) { | ||
try { | ||
var fis = new FileInputStream(path); | ||
var gis = new GZIPInputStream(fis); | ||
return gis.readAllBytes(); | ||
} catch (Exception e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
|
||
private static HostArray<U8, Byte, Shape<I, N, U, U>> getImages(List<byte[]> datas) { | ||
var shape = af.shape(af.i(IMAGE_SIZE), af.n(TOTAL_COUNT)); | ||
var imageIndex = 0; | ||
var images = af.createHost(af.U8, shape); | ||
for (byte[] data : datas) { | ||
int byteIndex = 8; | ||
int rows = ((data[byteIndex + 3] & 0xFF)) | ((data[byteIndex + 2] & 0xFF) << 8) | | ||
((data[byteIndex + 1] & 0xFF) << 16) | ((data[byteIndex] & 0xFF) << 24); | ||
byteIndex += 4; | ||
int cols = ((data[byteIndex + 3] & 0xFF)) | ((data[byteIndex + 2] & 0xFF) << 8) | | ||
((data[byteIndex + 1] & 0xFF) << 16) | ((data[byteIndex] & 0xFF) << 24); | ||
byteIndex += 4; | ||
if (rows != 28 || cols != 28) { | ||
throw new IllegalStateException(String.format("Expected 28x28 but rows/cols where %s/%s", rows, cols)); | ||
} | ||
while (byteIndex < data.length) { | ||
for (int i = 0; i < rows; i++) { | ||
for (int j = 0; j < cols; j++) { | ||
images.set(shape.offset(i * rows + j, imageIndex), data[byteIndex]); | ||
byteIndex++; | ||
} | ||
} | ||
imageIndex++; | ||
} | ||
} | ||
return images; | ||
} | ||
|
||
private static HostArray<U8, Byte, Shape<U, N, U, U>> getLabels(List<byte[]> datas) { | ||
var shape = af.shape(af.U, af.n(TOTAL_COUNT)); | ||
var labels = af.createHost(af.U8, shape); | ||
var labelIndex = 0; | ||
for (byte[] data : datas) { | ||
for (int i = 8; i < data.length; i++) { | ||
labels.set(labelIndex, data[i]); | ||
labelIndex++; | ||
} | ||
} | ||
for (int i = 0; i < labels.length(); i++) { | ||
if (!(labels.get(i) <= 9 && labels.get(i) >= 0)) { | ||
throw new IllegalStateException( | ||
String.format("Label greater than 9 or less than 0: %s", labels.get(i))); | ||
} | ||
} | ||
return labels; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
package examples.mnist; | ||
|
||
import arrayfire.Array; | ||
import arrayfire.Shape; | ||
import arrayfire.U8; | ||
import arrayfire.af; | ||
import arrayfire.numbers.I; | ||
import arrayfire.numbers.N; | ||
import arrayfire.numbers.U; | ||
import arrayfire.optimizers.SGD; | ||
import arrayfire.utils.Functions; | ||
|
||
import java.util.stream.IntStream; | ||
|
||
/** | ||
* A simple 2 layer neural network for classifying MNIST digits. | ||
* $ bazel run examples/mnist:SimpleNN | ||
*/ | ||
public class SimpleNN { | ||
public static void main(String[] args) { | ||
af.tidy(() -> { | ||
af.setSeed(0); | ||
|
||
var optimizer = SGD.create().learningRate(0.1f); | ||
var hiddenDim = af.a(2000); | ||
var hiddenWeights = af.params( | ||
() -> af.normalize(af.randn(af.F32, af.shape(af.i(Dataset.IMAGE_SIZE), hiddenDim))), optimizer); | ||
var weights = af.params( | ||
() -> af.normalize(af.randn(af.F32, af.shape(hiddenDim, af.l(Dataset.LABEL_COUNT)))), optimizer); | ||
|
||
run((imageBatch, labelBatch, train) -> { | ||
var imagesF32 = imageBatch.cast(af.F32); | ||
var imageNorm = af.normalize(af.center(imagesF32)); | ||
var hidden = af.relu(af.matmul(af.transpose(hiddenWeights), imageNorm)); | ||
var predict = af.softmax(af.matmul(af.transpose(weights), hidden)); | ||
if (train) { | ||
var labelsOneHot = af.oneHot(labelBatch.cast(af.S32), af.l(Dataset.LABEL_COUNT)); | ||
var rmsLoss = af.pow(af.sub(labelsOneHot, predict), 2); | ||
af.optimize(rmsLoss); | ||
} | ||
return af.imax(predict).indices().cast(af.U8); | ||
}); | ||
}); | ||
} | ||
|
||
|
||
public static void run( | ||
Functions.Function3<Array<U8, Shape<I, N, U, U>>, Array<U8, Shape<U, N, U, U>>, Boolean, Array<U8, Shape<U, N, U, U>>> fn) { | ||
var mnist = Dataset.load(); | ||
// Sort images and labels. | ||
var permutation = af.permutation(af.n(Dataset.TOTAL_COUNT)); | ||
var images = af.index(af.create(mnist.images()), af.span(), permutation); | ||
var labels = af.index(af.create(mnist.labels()), af.span(), permutation); | ||
// Split into train and test sets. | ||
var trainImages = af.index(images, af.span(), af.seq(0, 60000 - 1)); | ||
var trainLabels = af.index(labels, af.span(), af.seq(0, 60000 - 1)); | ||
var testImages = af.index(images, af.span(), af.seq(60000, 70000 - 1)); | ||
var testLabels = af.index(labels, af.span(), af.seq(60000, 70000 - 1)); | ||
|
||
var epochs = 50; | ||
var batchSize = 256; | ||
|
||
IntStream.range(0, epochs).forEach(epoch -> { | ||
var trainImageBatches = af.batch(trainImages, af.D1, batchSize); | ||
var trainLabelBatches = af.batch(trainLabels, af.D1, batchSize); | ||
// Train. | ||
var trainCorrect = IntStream.range(0, trainImageBatches.size()).mapToLong(i -> af.tidy(() -> { | ||
var trainImagesBatch = trainImageBatches.get(i).get(); | ||
var trainLabelsBatch = trainLabelBatches.get(i).get(); | ||
var predicted = fn.apply(trainImagesBatch, trainLabelsBatch, true); | ||
var correct = af.sum(af.eq(predicted, trainLabelsBatch).flatten()); | ||
return af.data(correct).get(0); | ||
})).sum(); | ||
// Test. | ||
var testCorrect = af.tidy(() -> { | ||
var testImageBatches = af.batch(testImages, af.D1, batchSize); | ||
var testLabelBatches = af.batch(testLabels, af.D1, batchSize); | ||
return IntStream.range(0, testImageBatches.size()).mapToLong(i -> af.tidy(() -> { | ||
var testImagesBatch = testImageBatches.get(i).get(); | ||
var testLabelsBatch = testLabelBatches.get(i).get(); | ||
var predicted = fn.apply(testImagesBatch, af.zeros(testLabelsBatch.type(), testLabelsBatch.shape()), | ||
false); | ||
var correct = af.sum(af.eq(predicted, testLabelsBatch).flatten()); | ||
return af.data(correct).get(0); | ||
})).sum(); | ||
}); | ||
System.out.printf("Epoch %s: Train: %.5f, Test: %.5f%n", epoch, | ||
trainCorrect / (double) trainImages.shape().d1().size(), | ||
testCorrect / (double) testImages.shape().d1().size()); | ||
}); | ||
|
||
} | ||
} |