From e75a3fc4520a1167eab605067cab135d9983824d Mon Sep 17 00:00:00 2001 From: Lewis Hemens Date: Mon, 5 Feb 2024 11:53:34 +0000 Subject: [PATCH] Add examples folder and fix README --- .ijwb/.bazelproject | 2 +- README.md | 11 ++++++----- examples/BUILD | 6 ++++++ examples/SimpleSGD.java | 20 ++++++++++++++++++++ 4 files changed, 33 insertions(+), 6 deletions(-) create mode 100644 examples/BUILD create mode 100644 examples/SimpleSGD.java diff --git a/.ijwb/.bazelproject b/.ijwb/.bazelproject index 8c44738..4fa00b1 100644 --- a/.ijwb/.bazelproject +++ b/.ijwb/.bazelproject @@ -1,7 +1,7 @@ directories: # Add the directories you want added as source here # By default, we've added your entire workspace ('.') - arrayfire + . # Automatically includes all relevant targets under the 'directories' above derive_targets_from_directories: true diff --git a/README.md b/README.md index d12807f..ef6344e 100644 --- a/README.md +++ b/README.md @@ -9,23 +9,24 @@ Currently, the API only supports a subset of functionality of ArrayFire, along w The following example demonstrates usage of the API, including autograd. ```java -import arrayfire.*; -import arrayfire.af.*; +import arrayfire.optimizers.SGD; + +import static arrayfire.af.*; void main() { tidy(() -> { - var a = params(randu(F32, shape(5)), SGD.create()); + var a = params(() -> randu(F32, shape(5)), SGD.create()); var b = randu(F32, shape(5)); var latestLoss = Float.POSITIVE_INFINITY; for (int i = 0; i < 50 || latestLoss >= 1E-10; i++) { latestLoss = tidy(() -> { var mul = mul(a, b); - var loss = pow(sub(sum(mul), constant(5)), 2); + var loss = pow(sub(sum(mul), 5), 2); optimize(loss); return data(loss).get(0); }); } - assertEquals(0, latestLoss, 1E-10); + System.out.println(latestLoss); }); } ``` diff --git a/examples/BUILD b/examples/BUILD new file mode 100644 index 0000000..8927dd2 --- /dev/null +++ b/examples/BUILD @@ -0,0 +1,6 @@ +java_binary( + name = "SimpleSGD", + srcs = ["SimpleSGD.java"], + main_class = "SimpleSGD", + deps = ["//arrayfire"], +) diff --git a/examples/SimpleSGD.java b/examples/SimpleSGD.java new file mode 100644 index 0000000..c9f14d2 --- /dev/null +++ b/examples/SimpleSGD.java @@ -0,0 +1,20 @@ +import arrayfire.optimizers.SGD; + +import static arrayfire.af.*; + +void main() { + tidy(() -> { + var a = params(() -> randu(F32, shape(5)), SGD.create()); + var b = randu(F32, shape(5)); + var latestLoss = Float.POSITIVE_INFINITY; + for (int i = 0; i < 50 || latestLoss >= 1E-10; i++) { + latestLoss = tidy(() -> { + var mul = mul(a, b); + var loss = pow(sub(sum(mul), 5), 2); + optimize(loss); + return data(loss).get(0); + }); + } + System.out.println(latestLoss); + }); +} \ No newline at end of file