Skip to content

Commit

Permalink
CPU and GPU ctz implementation (#979)
Browse files Browse the repository at this point in the history
  • Loading branch information
dfellis authored Nov 22, 2024
1 parent 5b9b06a commit 7d8ec1c
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 8 deletions.
8 changes: 8 additions & 0 deletions alan/src/compile/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,14 @@ test_gpgpu!(gpu_ones => r#"
stdout "[1, 1, 2, 32]\n";
);

test_gpgpu!(gpu_ctz => r#"
export fn main {
let b = GBuffer([0.i32, 1.i32, 2.i32, -2_147_483_648.i32]);
b.map(fn (val: gi32) = val.ctz).read{i32}.print;
}"#;
stdout "[32, 0, 1, 31]\n";
);

// TODO: Fix u64 numeric constants to get u64 bitwise tests in the new test suite
test!(u64_bitwise => r#"
prefix u64 as ~ precedence 10
Expand Down
29 changes: 29 additions & 0 deletions alan/src/std/root.ln
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,8 @@ fn{Rs} clz (a: u8) = {Method{"leading_zeros"} :: u8 -> u32}(a).u8;
fn{Js} clz Method{"clz"} :: u8 -> u8;
fn{Rs} ones (a: u8) = {Method{"count_ones"} :: u8 -> u32}(a).u8;
fn{Js} ones Method{"ones"} :: u8 -> u8;
fn{Rs} ctz (a: u8) = {Method{"trailing_zeros"} :: u8 -> u32}(a).u8;
fn{Js} ctz Method{"ctz"} :: u8 -> u8;

fn{Rs} add Method{"wrapping_add"} :: (u16, Deref{u16}) -> u16;
fn{Js} add Method{"wrappingAdd"} :: (u16, u16) -> u16;
Expand Down Expand Up @@ -789,6 +791,8 @@ fn{Rs} clz (a: u16) = {Method{"leading_zeros"} :: u16 -> u32}(a).u16;
fn{Js} clz Method{"clz"} :: u16 -> u16;
fn{Rs} ones (a: u16) = {Method{"count_ones"} :: u16 -> u32}(a).u16;
fn{Js} ones Method{"ones"} :: u16 -> u16;
fn{Rs} ctz (a: u16) = {Method{"trailing_zeros"} :: u16 -> u32}(a).u16;
fn{Js} ctz Method{"ctz"} :: u16 -> u16;

fn{Rs} add Method{"wrapping_add"} :: (u32, Deref{u32}) -> u32;
fn{Js} add Method{"wrappingAdd"} :: (u32, u32) -> u32;
Expand Down Expand Up @@ -840,6 +844,8 @@ fn{Rs} clz Method{"leading_zeros"} :: u32 -> u32;
fn{Js} clz Method{"clz"} :: u32 -> u32;
fn{Rs} ones (a: u32) = {Method{"count_ones"} :: u32 -> u32}(a).u32;
fn{Js} ones Method{"ones"} :: u32 -> u32;
fn{Rs} ctz (a: u32) = {Method{"trailing_zeros"} :: u32 -> u32}(a).u32;
fn{Js} ctz Method{"ctz"} :: u32 -> u32;

fn{Rs} add Method{"wrapping_add"} :: (u64, Deref{u64}) -> u64;
fn{Js} add Method{"wrappingAdd"} :: (u64, u64) -> u64;
Expand Down Expand Up @@ -891,6 +897,8 @@ fn{Rs} clz (a: u64) = {Method{"leading_zeros"} :: u64 -> u32}(a).u64;
fn{Js} clz Method{"clz"} :: u64 -> u64;
fn{Rs} ones (a: u64) = {Method{"count_ones"} :: u64 -> u32}(a).u64;
fn{Js} ones Method{"ones"} :: u64 -> u64;
fn{Rs} ctz (a: u64) = {Method{"trailing_zeros"} :: u64 -> u32}(a).u64;
fn{Js} ctz Method{"ctz"} :: u64 -> u64;

/// Signed Integer-related functions and function bindings
fn{Rs} add Method{"wrapping_add"} :: (i8, Deref{i8}) -> i8;
Expand Down Expand Up @@ -947,6 +955,8 @@ fn{Rs} clz (a: i8) = {Method{"leading_zeros"} :: i8 -> u32}(a).i8;
fn{Js} clz Method{"clz"} :: i8 -> i8;
fn{Rs} ones (a: i8) = {Method{"count_ones"} :: i8 -> u32}(a).i8;
fn{Js} ones Method{"ones"} :: i8 -> i8;
fn{Rs} ctz (a: i8) = {Method{"trailing_zeros"} :: i8 -> u32}(a).i8;
fn{Js} ctz Method{"ctz"} :: i8 -> i8;

fn{Rs} add Method{"wrapping_add"} :: (i16, Deref{i16}) -> i16;
fn{Js} add Method{"wrappingAdd"} :: (i16, i16) -> i16;
Expand Down Expand Up @@ -1002,6 +1012,8 @@ fn{Rs} clz (a: i16) = {Method{"leading_zeros"} :: i16 -> u32}(a).i16;
fn{Js} clz Method{"clz"} :: i16 -> i16;
fn{Rs} ones (a: i16) = {Method{"count_ones"} :: i16 -> u32}(a).i16;
fn{Js} ones Method{"ones"} :: i16 -> i16;
fn{Rs} ctz (a: i16) = {Method{"trailing_zeros"} :: i16 -> u32}(a).i16;
fn{Js} ctz Method{"ctz"} :: i16 -> i16;

fn{Rs} add Method{"wrapping_add"} :: (i32, Deref{i32}) -> i32;
fn{Js} add Method{"wrappingAdd"} :: (i32, i32) -> i32;
Expand Down Expand Up @@ -1057,6 +1069,8 @@ fn{Rs} clz (a: i32) = {Method{"leading_zeros"} :: i32 -> u32}(a).i32;
fn{Js} clz Method{"clz"} :: i32 -> i32;
fn{Rs} ones (a: i32) = {Method{"count_ones"} :: i32 -> u32}(a).i32;
fn{Js} ones Method{"ones"} :: i32 -> i32;
fn{Rs} ctz (a: i32) = {Method{"trailing_zeros"} :: i32 -> u32}(a).i32;
fn{Js} ctz Method{"ctz"} :: i32 -> i32;

fn{Rs} add Method{"wrapping_add"} :: (i64, Deref{i64}) -> i64;
fn{Js} add Method{"wrappingAdd"} :: (i64, i64) -> i64;
Expand Down Expand Up @@ -1112,6 +1126,8 @@ fn{Rs} clz (a: i64) = {Method{"leading_zeros"} :: i64 -> u32}(a).i64;
fn{Js} clz Method{"clz"} :: i64 -> i64;
fn{Rs} ones (a: i64) = {Method{"count_ones"} :: i64 -> u32}(a).i64;
fn{Js} ones Method{"ones"} :: i64 -> i64;
fn{Rs} ctz (a: i64) = {Method{"trailing_zeros"} :: i64 -> u32}(a).i64;
fn{Js} ctz Method{"ctz"} :: i64 -> i64;

/// String related bindings
fn{Rs} string "format!" :: ("{}", f32) -> string;
Expand Down Expand Up @@ -3705,6 +3721,19 @@ fn ones(v: gvec3u) = gones{gvec3u}(v);
fn ones(v: gvec4i) = gones{gvec4i}(v);
fn ones(v: gvec4u) = gones{gvec4u}(v);

fn gctz{I}(v: I) {
let varName = 'countTrailingZeros('.concat(v.varName).concat(')');
return {I}(varName, v.statements, v.buffers);
}
fn ctz(v: gi32) = gctz{gi32}(v);
fn ctz(v: gu32) = gctz{gu32}(v);
fn ctz(v: gvec2i) = gctz{gvec2i}(v);
fn ctz(v: gvec2u) = gctz{gvec2u}(v);
fn ctz(v: gvec3i) = gctz{gvec3i}(v);
fn ctz(v: gvec3u) = gctz{gvec3u}(v);
fn ctz(v: gvec4i) = gctz{gvec4i}(v);
fn ctz(v: gvec4u) = gctz{gvec4u}(v);

fn gadd{A, B}(a: A, b: B) {
let varName = '('.concat(a.varName).concat(' + ').concat(b.varName).concat(')');
let statements = a.statements.concat(b.statements);
Expand Down
56 changes: 48 additions & 8 deletions alan/test.ln
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ export fn{Test} main {
.assert(eq, 1.i8.ones, 1.i8)
.assert(eq, 2.i8.ones, 1.i8)
.assert(eq, 3.i8.ones, 2.i8)
.assert(eq, (-1).i8.ones, 8.i8));
.assert(eq, (-1).i8.ones, 8.i8))
.it("ctz", fn (test: Mut{Testing}) = test
.assert(eq, 0.i8.ctz, 8.i8)
.assert(eq, 1.i8.ctz, 0.i8)
.assert(eq, 2.i8.ctz, 1.i8)
.assert(eq, (-128).i8.ctz, 7.i8));

test.describe("Basic math tests i16")
.it("add")
Expand Down Expand Up @@ -98,7 +103,12 @@ export fn{Test} main {
.assert(eq, 1.i16.ones, 1.i16)
.assert(eq, 2.i16.ones, 1.i16)
.assert(eq, 3.i16.ones, 2.i16)
.assert(eq, (-1).i16.ones, 16.i16));
.assert(eq, (-1).i16.ones, 16.i16))
.it("ctz", fn (test: Mut{Testing}) = test
.assert(eq, 0.i16.ctz, 16.i16)
.assert(eq, 1.i16.ctz, 0.i16)
.assert(eq, 2.i16.ctz, 1.i16)
.assert(eq, (-32768).i16.ctz, 15.i16));

test.describe("Basic math tests i32")
.it("add")
Expand Down Expand Up @@ -130,7 +140,12 @@ export fn{Test} main {
.assert(eq, 1.i32.ones, 1.i32)
.assert(eq, 2.i32.ones, 1.i32)
.assert(eq, 3.i32.ones, 2.i32)
.assert(eq, (-1).i32.ones, 32.i32));
.assert(eq, (-1).i32.ones, 32.i32))
.it("ctz", fn (test: Mut{Testing}) = test
.assert(eq, 0.i32.ctz, 32.i32)
.assert(eq, 1.i32.ctz, 0.i32)
.assert(eq, 2.i32.ctz, 1.i32)
.assert(eq, (-2_147_483_648).i32.ctz, 31.i32));

test.describe("Basic math tests i64")
.it("add")
Expand Down Expand Up @@ -162,7 +177,12 @@ export fn{Test} main {
.assert(eq, 1.ones, 1)
.assert(eq, 2.ones, 1)
.assert(eq, 3.ones, 2)
.assert(eq, -1.ones, 64));
.assert(eq, -1.ones, 64))
.it("ctz", fn (test: Mut{Testing}) = test
.assert(eq, 0.ctz, 64)
.assert(eq, 1.ctz, 0)
.assert(eq, 2.ctz, 1)
.assert(eq, (-9_223_372_036_854_775_808).ctz, 63));

test.describe("Basic math tests u8")
.it("add")
Expand Down Expand Up @@ -190,7 +210,12 @@ export fn{Test} main {
.assert(eq, 1.u8.ones, 1.u8)
.assert(eq, 2.u8.ones, 1.u8)
.assert(eq, 3.u8.ones, 2.u8)
.assert(eq, 255.u8.ones, 8.u8));
.assert(eq, 255.u8.ones, 8.u8))
.it("ctz", fn (test: Mut{Testing}) = test
.assert(eq, 0.u8.ctz, 8.u8)
.assert(eq, 1.u8.ctz, 0.u8)
.assert(eq, 2.u8.ctz, 1.u8)
.assert(eq, 128.u8.ctz, 7.u8));

test.describe("Basic math tests u16")
.it("add")
Expand Down Expand Up @@ -218,7 +243,12 @@ export fn{Test} main {
.assert(eq, 1.u16.ones, 1.u16)
.assert(eq, 2.u16.ones, 1.u16)
.assert(eq, 3.u16.ones, 2.u16)
.assert(eq, 65535.u16.ones, 16.u16));
.assert(eq, 65535.u16.ones, 16.u16))
.it("ctz", fn (test: Mut{Testing}) = test
.assert(eq, 0.u16.ctz, 16.u16)
.assert(eq, 1.u16.ctz, 0.u16)
.assert(eq, 2.u16.ctz, 1.u16)
.assert(eq, 32768.u16.ctz, 15.u16));

test.describe("Basic math tests u32")
.it("add")
Expand Down Expand Up @@ -246,7 +276,12 @@ export fn{Test} main {
.assert(eq, 1.u32.ones, 1.u32)
.assert(eq, 2.u32.ones, 1.u32)
.assert(eq, 3.u32.ones, 2.u32)
.assert(eq, 4_294_967_295.u32.ones, 32.u32));
.assert(eq, 4_294_967_295.u32.ones, 32.u32))
.it("ctz", fn (test: Mut{Testing}) = test
.assert(eq, 0.u32.ctz, 32.u32)
.assert(eq, 1.u32.ctz, 0.u32)
.assert(eq, 2.u32.ctz, 1.u32)
.assert(eq, 2_147_483_648.u32.ctz, 31.u32));

test.describe("Basic math tests u64")
.it("add")
Expand All @@ -273,8 +308,13 @@ export fn{Test} main {
.it("ones", fn (test: Mut{Testing}) = test
.assert(eq, 1.u64.ones, 1.u64)
.assert(eq, 2.u64.ones, 1.u64)
.assert(eq, 3.u64.ones, 2.u64));
.assert(eq, 3.u64.ones, 2.u64))
// .assert(eq, 9_223_372_036_854_775_808.ones, 64.u32)); TODO: Same u64 representation issue
.it("ctz", fn (test: Mut{Testing}) = test
.assert(eq, 0.u64.ctz, 64.u64)
.assert(eq, 1.u64.ctz, 0.u64)
.assert(eq, 2.u64.ctz, 1.u64));
// .assert(eq, 9_223_372_036_854_775_808.u64.ctz, 63.u64)); TODO: Same u64 representation

test.describe("Basic math tests f32")
.it("add")
Expand Down

0 comments on commit 7d8ec1c

Please sign in to comment.