Skip to content

Commit

Permalink
clippy fixes (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
psvri authored May 11, 2024
1 parent bf84502 commit e90aa0d
Show file tree
Hide file tree
Showing 21 changed files with 446 additions and 374 deletions.
566 changes: 258 additions & 308 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ edition = "2021"
[workspace.dependencies]
pyo3 = { version = "0.20.3", features = ["extension-module"] }
#arrow_gpu = { path="../arrow-gpu/crates/arrow"}
arrow_gpu = { git="https://github.com/psvri/arrow-gpu.git", rev = "537d9bc08df795e28b0f58e7fd5f07da841c6ba0"}
wgpu = "0.19.3"
arrow_gpu = { git="https://github.com/psvri/arrow-gpu.git", rev = "639cbdfc92c51e3d1fd4e4c8804675b9ebe3053e"}
wgpu = "0.20.0"
once_cell = "1.19.0"
bytemuck = "1.15.0"

Expand Down
2 changes: 1 addition & 1 deletion crates/wgpy_core/src/array_routines/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ pub fn broadcast_if_required(
pipeline: &mut ArrowComputePipeline,
) -> Option<NdArray> {
if arr.shape != broadcasted_shape {
Some(broadcast_to_op(arr, &broadcasted_shape, pipeline))
Some(broadcast_to_op(arr, broadcasted_shape, pipeline))
} else {
None
}
Expand Down
44 changes: 36 additions & 8 deletions crates/wgpy_core/src/array_routines/where_routine.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::borrow::Cow;

use arrow_gpu::{array::ArrowArrayGPU, kernels::merge_dyn};

use crate::{
broadcast::{broadcast_shape, broadcast_to},
utils::Holder,
NdArray,
};

Expand All @@ -10,14 +13,39 @@ pub fn where_(mask: &NdArray, x: &NdArray, y: &NdArray) -> NdArray {
if let ArrowArrayGPU::BooleanArrayGPU(bool_mask) = &mask.data {
let broadcast_shape =
broadcast_shape(&mask.shape, &broadcast_shape(&x.shape, &y.shape).unwrap()).unwrap();
let broadcasted_x = broadcast_to(x, &broadcast_shape);
let broadcasted_y = broadcast_to(y, &broadcast_shape);
let merged_array = merge_dyn(&broadcasted_x.data, &broadcasted_y.data, bool_mask);
NdArray {
shape: x.shape.clone(),
dims: x.dims,
data: merged_array,
dtype: x.dtype,

let broadcasted_x = if x.shape != broadcast_shape {
Holder::Owned(broadcast_to(x, &broadcast_shape))
} else {
Holder::Borrowed(x)
};

let broadcasted_y = if y.shape != broadcast_shape {
Holder::Owned(broadcast_to(y, &broadcast_shape))
} else {
Holder::Borrowed(y)
};

let broadcasted_mask = if mask.shape != broadcast_shape {
Holder::Owned(broadcast_to(mask, &broadcast_shape))
} else {
Holder::Borrowed(mask)
};

if let ArrowArrayGPU::BooleanArrayGPU(bool_mask) = &broadcasted_mask.as_ref().data {
let merged_array = merge_dyn(
&broadcasted_x.as_ref().data,
&broadcasted_y.as_ref().data,
bool_mask,
);
NdArray {
shape: x.shape.clone(),
dims: x.dims,
data: merged_array,
dtype: x.dtype,
}
} else {
unreachable!()
}
} else {
panic!("Mask is not of boolean type")
Expand Down
1 change: 1 addition & 0 deletions crates/wgpy_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub(crate) mod ndarray;
pub(crate) mod operand;
pub(crate) mod types;
pub(crate) mod ufunc;
pub(crate) mod utils;

pub use array_routines::*;
pub use arrow_gpu::utils::ScalarArray;
Expand Down
2 changes: 1 addition & 1 deletion crates/wgpy_core/src/ndarray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ impl NdArray {
}

pub fn take(&self, indices: &NdArray, axis: Option<u32>) -> Self {
if let Some(_) = axis {
if axis.is_some() {
todo!()
} else {
if let ArrowArrayGPU::UInt32ArrayGPU(indices_array) = &indices.data {
Expand Down
6 changes: 3 additions & 3 deletions crates/wgpy_core/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ impl<'a> From<&'a str> for Dtype {
}
}

impl Into<ArrowType> for Dtype {
fn into(self) -> ArrowType {
match self {
impl From<Dtype> for ArrowType {
fn from(val: Dtype) -> Self {
match val {
Dtype::Int8 => ArrowType::Int8Type,
Dtype::Int16 => ArrowType::Int16Type,
Dtype::Int32 => ArrowType::Int32Type,
Expand Down
13 changes: 13 additions & 0 deletions crates/wgpy_core/src/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
pub enum Holder<'a, B> {
Borrowed(&'a B),
Owned(B),
}

impl<'a, T> AsRef<T> for Holder<'a, T> {
fn as_ref(&self) -> &T {
match self {
Holder::Borrowed(x) => x,
Holder::Owned(x) => x,
}
}
}
16 changes: 16 additions & 0 deletions crates/wgpy_math/compute_shader/f32/cross.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
@group(0) @binding(0)
var<storage, read> input_1 : array<u32>;

@group(0) @binding(1)
var<storage, read> input_2 : array<u32>;

@group(0) @binding(2)
var<storage, read_write> input_2 : array<u32>;

@compute
@workgroup_size(256)
fn cross_(@builtin(global_invocation_id) global_id: vec3<u32>) {
if global_id.x * 3 < arrayLength(input_1) {

}
}
16 changes: 15 additions & 1 deletion crates/wgpy_math/src/misc.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use arrow_gpu::{gpu_utils::ArrowComputePipeline, kernels::*};
use webgpupy_core::{
broadcast_shape, broadcast_shapes, broadcast_to_op, ufunc_nin1_nout1, ufunc_nin1_nout1_body,
Dtype, NdArray,
ufunc_nin2_nout1, ufunc_nin2_nout1_body, Dtype, NdArray,
};

ufunc_nin1_nout1_body!(sqrt, sqrt_op_dyn);
ufunc_nin1_nout1_body!(cbrt, cbrt_op_dyn);
ufunc_nin1_nout1_body!(absolute, abs_op_dyn);
ufunc_nin2_nout1_body!(power, power_op_dyn);

pub fn clip(a: &NdArray, a_min: Option<&NdArray>, a_max: Option<&NdArray>) -> NdArray {
match (a_min, a_max) {
Expand Down Expand Up @@ -103,6 +105,18 @@ pub fn clip(a: &NdArray, a_min: Option<&NdArray>, a_max: Option<&NdArray>) -> Nd
}
}

pub fn cross(a: &NdArray, b: &NdArray) -> NdArray {
let shader = match (a.dtype, b.dtype) {
(Dtype::Float32, Dtype::Float32) => include_str!("../compute_shader/f32/cross.wgsl"),
_ => panic!(
"cross not supported for type {:?} and {:?}",
a.dtype, b.dtype
),
};

todo!()
}

#[cfg(test)]
mod test {
use super::*;
Expand Down
7 changes: 6 additions & 1 deletion crates/wgpy_pyo3/src/misc_math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ use webgpupy::clip;

impl_ufunc_nin1_nout1!(_sqrt, webgpupy::sqrt);
impl_ufunc_nin1_nout1!(_cbrt, webgpupy::cbrt);
impl_ufunc_nin1_nout1!(_absolute, webgpupy::absolute);
impl_ufunc_nin2_nout1!(_maximum, webgpupy::maximum);
impl_ufunc_nin2_nout1!(_minimum, webgpupy::minimum);
impl_ufunc_nin2_nout1!(_power, webgpupy::power);

// TODO add ufunc kwargs support
#[pyfunction(name = "clip")]
Expand Down Expand Up @@ -42,12 +44,15 @@ pub fn create_py_items(m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(_cbrt, m)?)?;
m.add_function(wrap_pyfunction!(_maximum, m)?)?;
m.add_function(wrap_pyfunction!(_minimum, m)?)?;
m.add_function(wrap_pyfunction!(_minimum, m)?)?;
m.add_function(wrap_pyfunction!(_absolute, m)?)?;
m.add_function(wrap_pyfunction!(_power, m)?)?;
m.add_function(wrap_pyfunction!(clip_, m)?)?;

add_ufunc_nin1_nout1!(m, "sqrt");
add_ufunc_nin1_nout1!(m, "cbrt");
add_ufunc_nin2_nout1!(m, "maximum");
add_ufunc_nin2_nout1!(m, "minimum");
add_ufunc_nin2_nout1!(m, "absolute");
add_ufunc_nin2_nout1!(m, "power");
Ok(())
}
9 changes: 9 additions & 0 deletions crates/wgpy_pyo3/src/ndarraypy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::{
convert_pyobj_into_array_u32, convert_pyobj_into_operand, convert_pyobj_into_scalar,
convert_pyobj_into_vec_ndarray,
logical::{_equal, _greater, _lesser},
misc_math::_absolute,
types::{into_dtypepy, into_optional_dtypepy, DtypePy},
};

Expand Down Expand Up @@ -75,6 +76,10 @@ impl NdArrayPy {
Ok(_divide(py, slf.as_ref(), other, None, None))
}

pub fn __rtruediv__(slf: &PyCell<Self>, py: Python<'_>, other: &PyAny) -> PyResult<Self> {
Ok(_divide(py, other, slf.as_ref(), None, None))
}

pub fn __add__(slf: &PyCell<Self>, py: Python<'_>, other: &PyAny) -> PyResult<Self> {
Ok(_add(py, slf.as_ref(), other, None, None))
}
Expand Down Expand Up @@ -113,6 +118,10 @@ impl NdArrayPy {
})
}

pub fn __abs__(slf: &PyCell<Self>, py: Python<'_>) -> PyResult<Self> {
Ok(_absolute(py, slf.as_ref(), None, None))
}

pub fn __eq__(slf: &PyCell<Self>, py: Python<'_>, other: &PyAny) -> PyResult<Self> {
Ok(_equal(py, slf.as_ref(), other, None, None))
}
Expand Down
16 changes: 13 additions & 3 deletions crates/wgpy_pyo3/wp_tests/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,45 @@
import numpy as np
from test_utils import almost_equals


@pytest.fixture
def input_array():
return [[1.0], [2.0], [3.0], [4.0]]


@pytest.fixture
def wp_array(input_array):
return wp.array(input_array)


@pytest.fixture
def np_array(input_array):
return np.array(input_array)


def test_mul_f32(wp_array, np_array):
almost_equals(wp_array * 10.0, np_array * 10)



def test_rmul_f32(wp_array, np_array):
almost_equals(10.0 * wp_array, 10 * np_array)


def test_divide_f32(wp_array, np_array):
almost_equals(wp_array / 10.0, np_array / 10)


def test_add_f32(wp_array, np_array):
almost_equals(wp_array + 10.0, np_array + 10)


def test_radd_f32(wp_array, np_array):
almost_equals(10.0 + wp_array, 10 + np_array)


def test_sub_f32(wp_array, np_array):
almost_equals(wp_array - 10.0, np_array - 10)



def test_rsub_f32(wp_array, np_array):
almost_equals(10.0 - wp_array, 10 - np_array)
almost_equals(10.0 - wp_array, 10 - np_array)
33 changes: 29 additions & 4 deletions crates/wgpy_pyo3/wp_tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,56 @@
import numpy as np
from test_utils import almost_equals


@pytest.fixture
def input_array():
return [[1.0], [2.0], [3.0], [4.0]]


@pytest.fixture
def wp_array(input_array):
return wp.array(input_array)


@pytest.fixture
def np_array(input_array):
return np.array(input_array)


def test_array(wp_array):
assert wp_array.shape == [4,1]
assert wp_array.shape == [4, 1]
assert wp_array.tolist() == [[1.0], [2.0], [3.0], [4.0]]


def test_array_astype(wp_array):
arr = wp_array.astype('uint8')
arr = wp_array.astype("uint8")
assert arr.tolist() == [[1], [2], [3], [4]]


def test_indexing(wp_array, np_array):
almost_equals(wp_array[:, :], np_array[:, :])
almost_equals(wp_array[:], np_array[:])
almost_equals(wp_array[1], np_array[1])
almost_equals(wp_array[:,0], np_array[:,0])
almost_equals(wp_array[:, 0], np_array[:, 0])


def test_neg(wp_array, np_array):
almost_equals(-wp_array, -np_array)
almost_equals(-wp_array, -np_array)


def test_where(wp_array, np_array):
bool_array = (np.random.default_rng(0).random([640, 360, 1]) < 0).tolist()
wp_array = wp.broadcast_to(wp_array.reshape([4]), [640, 360, 4])
np_array = np.broadcast_to(np_array.reshape([4]), [640, 360, 4])
wp_where = wp.where(wp.array(bool_array), wp_array, -wp_array)
np_where = np.where(np.array(bool_array), np_array, -np_array)
almost_equals(wp_where, np_where)


def test_braodcast_to():
shape = [1, 160, 1]
new_shape = [1, 160, 2]
bool_array = (np.random.default_rng(0).random(shape) < 0).tolist()
np_array = np.broadcast_to(bool_array, new_shape)
wp_array = wp.broadcast_to(wp.array(bool_array), new_shape)
almost_equals(wp_array, np_array)
20 changes: 11 additions & 9 deletions crates/wgpy_pyo3/wp_tests/test_binary.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
import webgpupy as wp
import pytest
import numpy as np
from test_utils import *
from test_utils import assert_values_nin2, assert_values_nin1


@pytest.fixture
def input_array_1():
return [[1], [2], [0], [4]]


@pytest.fixture
def input_array_2():
return [[10], [2], [3], [4]]


@pytest.fixture
def wp_array_1(input_array_1):
return wp.array(input_array_1)


@pytest.fixture
def wp_array_2(input_array_2):
return wp.array(input_array_2)
Expand All @@ -31,14 +35,12 @@ def np_array_2(input_array_2):


def test_bitwise_and_u32(wp_array_1, wp_array_2, np_array_1, np_array_2):
assert_values_nin2(
wp_array_1, wp_array_2, np_array_1, np_array_2, 'bitwise_and'
)
assert_values_nin2(wp_array_1, wp_array_2, np_array_1, np_array_2, "bitwise_and")


def test_bitwise_or_u32(wp_array_1, wp_array_2, np_array_1, np_array_2):
assert_values_nin2(
wp_array_1, wp_array_2, np_array_1, np_array_2, 'bitwise_or'
)
assert_values_nin2(wp_array_1, wp_array_2, np_array_1, np_array_2, "bitwise_or")


def test_bitwise_or_u32(wp_array_1, np_array_1):
assert_values_nin1(wp_array_1, np_array_1, 'invert')
def test_invert(wp_array_1, np_array_1):
assert_values_nin1(wp_array_1, np_array_1, "invert")
1 change: 1 addition & 0 deletions crates/wgpy_pyo3/wp_tests/test_logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test_eq_f32(wp_array_1, wp_array_2):
def test_not_eq_f32(wp_array_1, wp_array_2):
assert (wp_array_1 != wp_array_2).tolist() == [[True], [False], [True], [False]]


@pytest.mark.skip(reason="Ignoring temporarily")
def test_bitwise_or_u32(input_array_1, input_array_2, wp_array_1, wp_array_2):
np_where = np.where([True, True, False, False], input_array_1, input_array_2)
Expand Down
Loading

0 comments on commit e90aa0d

Please sign in to comment.