Skip to content

Commit

Permalink
Finish shape wrapping
Browse files Browse the repository at this point in the history
  • Loading branch information
lewish committed Feb 2, 2024
1 parent 78db370 commit c966c1c
Show file tree
Hide file tree
Showing 8 changed files with 308 additions and 199 deletions.
390 changes: 195 additions & 195 deletions arrayfire/ArrayFire.java

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions arrayfire/R0.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package arrayfire;

import arrayfire.numbers.U;

/**
* A rank 0 shape (Scalar).
*/
public class R0 extends Shape<U, U, U, U> {

public R0() {
super(af.u(), af.u(), af.u(), af.u());
}
}
14 changes: 14 additions & 0 deletions arrayfire/R1.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package arrayfire;

import arrayfire.numbers.Num;
import arrayfire.numbers.U;

/**
* A rank 1 shape (Vector).
*/
public class R1<D0 extends Num<D0>> extends Shape<D0, U, U, U> {

public R1(D0 d0) {
super(d0, af.u(), af.u(), af.u());
}
}
14 changes: 14 additions & 0 deletions arrayfire/R2.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package arrayfire;

import arrayfire.numbers.Num;
import arrayfire.numbers.U;

/**
* A rank 2 shape (Matrix).
*/
public class R2<D0 extends Num<D0>, D1 extends Num<D1>> extends Shape<D0, D1, U, U> {

public R2(D0 d0, D1 d1) {
super(d0, d1, af.u(), af.u());
}
}
14 changes: 14 additions & 0 deletions arrayfire/R3.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package arrayfire;

import arrayfire.numbers.Num;
import arrayfire.numbers.U;

/**
* A rank 3 shape.
*/
public class R3<D0 extends Num<D0>, D1 extends Num<D1>, D2 extends Num<D2>> extends Shape<D0, D1, D2, U> {

public R3(D0 d0, D1 d1, D2 d2) {
super(d0, d1, d2, af.u());
}
}
47 changes: 46 additions & 1 deletion arrayfire/Shape.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,22 @@
import arrayfire.numbers.N;

import java.util.Arrays;
import java.util.Objects;
import java.util.function.Function;


public record Shape<D0 extends Num<D0>, D1 extends Num<D1>, D2 extends Num<D2>, D3 extends Num<D3>>(D0 d0, D1 d1, D2 d2, D3 d3) {
public class Shape<D0 extends Num<D0>, D1 extends Num<D1>, D2 extends Num<D2>, D3 extends Num<D3>> {
private final D0 d0;
private final D1 d1;
private final D2 d2;
private final D3 d3;

public Shape(D0 d0, D1 d1, D2 d2, D3 d3) {
this.d0 = d0;
this.d1 = d1;
this.d2 = d2;
this.d3 = d3;
}

public int capacity() {
return d0.size() * d1.size() * d2.size() * d3.size();
Expand All @@ -21,4 +33,37 @@ public long[] dims() {
public String toString() {
return Arrays.toString(dims());
}

public D0 d0() {
return d0;
}

public D1 d1() {
return d1;
}

public D2 d2() {
return d2;
}

public D3 d3() {
return d3;
}

@Override
public boolean equals(Object obj) {
if (obj == this)
return true;
if (obj == null || obj.getClass() != this.getClass())
return false;
var that = (Shape) obj;
return Objects.equals(this.d0, that.d0) && Objects.equals(this.d1, that.d1) && Objects.equals(this.d2, that.d2) &&
Objects.equals(this.d3, that.d3);
}

@Override
public int hashCode() {
return Objects.hash(d0, d1, d2, d3);
}

}
2 changes: 1 addition & 1 deletion arrayfire/SvdResult.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
import arrayfire.numbers.U;

public record SvdResult<T extends DataType<?, ?>, D0 extends Num<D0>, D1 extends Num<D1>>(
Tensor<T, Shape<D0, D0, U, U>> u, Tensor<T, Shape<D0, U, U, U>> s, Tensor<T, Shape<D1, D1, U, U>> vt) {
Tensor<T, R2<D0, D0>> u, Tensor<T, R1<D0>> s, Tensor<T, R2<D1, D1>> vt) {
}
13 changes: 11 additions & 2 deletions arrayfire/Tensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.util.Arrays;

public class Tensor<T extends DataType<?, ?>, S extends Shape<?, ?, ?, ?>> implements MemoryContainer {

Expand Down Expand Up @@ -95,8 +96,16 @@ public <OD0 extends Num<OD0>, OD1 extends Num<OD1>, OD2 extends Num<OD2>, OD3 ex
return af.reshape(this, af.shape(d0, d1, d2, d3));
}

public <OD0 extends Num<OD0>, OD1 extends Num<OD1>, OD2 extends Num<OD2>, OD3 extends Num<OD3>> Tensor<T, Shape<OD0, OD1, OD2, OD3>> reshape(
Shape<OD0, OD1, OD2, OD3> newShape) {
public <NS extends Shape<?, ? ,? ,?>> Tensor<T, NS> reshape(
NS newShape) {
return af.reshape(this, newShape);
}

public <NS extends Shape<?, ? ,? ,?>> Tensor<T, NS> castshape(
NS newShape) {
if (!Arrays.equals(shape.dims(), newShape.dims())) {
throw new IllegalArgumentException("Cannot cast shape " + shape + " to " + newShape);
}
return af.reshape(this, newShape);
}

Expand Down

0 comments on commit c966c1c

Please sign in to comment.