Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1D NumPy array encoders for new Python release #490

Merged
merged 12 commits into from
Oct 12, 2019
Empty file.
60 changes: 60 additions & 0 deletions weld-python/tests/encoders/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@

import ctypes

def encdec_factory(encoder, decoder, eq=None):
""" Returns a function that encodes and decodes a value.

Parameters
----------
encoder : WeldEncoder
the encoder class to use.
decoder : WeldDecoder
the decoder class to use.
eq : function (T, T) => bool, optional (default=None)
the equality function to use. If this is `None`, the `==` operator is
used.

Returns
-------
function

"""
def encdec(value, ty, assert_equal=True, err=False):
""" Helper function that encodes a value and decodes it.

The function asserts that the original value and the decoded value are
equal.

Parameters
----------
value : any
The value to encode and decode
ty : WeldType
the WeldType of the value
assert_equal : bool (default True)
Checks whether the original value and decoded value are equal.
err : bool (default False)
If True, expects an error.

"""
enc = encoder()
dec = decoder()

try:
result = dec.decode(ctypes.pointer(enc.encode(value, ty)), ty)
except Exception as e:
if err:
return
else:
raise e

if err:
raise RuntimeError("Expected error during encode/decode")

if assert_equal:
if eq is not None:
assert eq(value, result)
else:
assert value == result

return encdec
100 changes: 100 additions & 0 deletions weld-python/tests/encoders/test_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
Tests NumPy encoders and decoders.
"""

import ctypes
import numpy as np

from .helpers import encdec_factory

from weld import WeldConf, WeldContext
from weld.encoders.numpy import weldbasearray, NumPyWeldEncoder, NumPyWeldDecoder
from weld.types import *

encdec = encdec_factory(NumPyWeldEncoder, NumPyWeldDecoder, eq=np.allclose)

def array(dtype, length=5):
"""Creates a 1D NumPy array with the given data type.

The array is filled with data [1...length).

>>> array('int8')
array([0, 1, 2, 3, 4], dtype=int8)
>>> array('float32')
array([0., 1., 2., 3., 4.], dtype=float32)

Parameters
----------
dtype: np.dtype
data type of array elements
length: int
elements in array

Returns
-------
np.ndarray

"""
return np.arange(start=0, stop=length, dtype=dtype)


# Tests for ensuring weldbasearrays propagate their contexts properly

def test_baseweldarray_basics():
x = np.array([1, 2, 3, 4, 5], dtype="int8")

ctx = WeldContext(WeldConf())

welded = weldbasearray(x, weld_context=ctx)
assert welded.dtype == "int8"
assert welded.weld_context is ctx

sliced = welded[1:]
assert np.allclose(sliced, np.array([2,3,4,5]))
assert sliced.base is welded
assert sliced.weld_context is ctx

copied = sliced.copy2numpy()
assert copied.base is None
try:
copied.ctx
assert False
except AttributeError as e:
pass

# Tests for encoding and decoding 1D arrays

def test_bool_vec():
# Booleans in NumPy, like in Weld, are represented as bytes.
encdec(np.array([True, True, False, False, True], dtype='bool'),
WeldVec(Bool()))

def test_i8_vec():
encdec(array('int8'), WeldVec(I8()))

def test_i16_vec():
encdec(array('int16'), WeldVec(I16()))

def test_i32_vec():
encdec(array('int32'), WeldVec(I32()))

def test_i64_vec():
encdec(array('int64'), WeldVec(I64()))

def test_u8_vec():
encdec(array('uint8'), WeldVec(U8()))

def test_u16_vec():
encdec(array('uint16'), WeldVec(U16()))

def test_u32_vec():
encdec(array('uint32'), WeldVec(U32()))

def test_u64_vec():
encdec(array('uint64'), WeldVec(U64()))

def test_float32_vec():
encdec(array('float32'), WeldVec(F32()))

def test_float64_vec():
encdec(array('float64'), WeldVec(F64()))
37 changes: 2 additions & 35 deletions weld-python/tests/encoders/test_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,13 @@
Tests primitive encoders and decoders.
"""

import copy
import ctypes

from .helpers import encdec_factory
from weld.encoders import PrimitiveWeldEncoder, PrimitiveWeldDecoder
from weld.types import *

def encdec(value, ty, assert_equal=True, err=False):
""" Helper function that encodes a value and decodes it.

The function asserts that the original value and the decoded value are
equal.

Parameters
----------
value : any
The value to encode and decode
ty : WeldType
the WeldType of the value
assert_equal : bool (default True)
Checks whether the original value and decoded value are equal.
err : bool (default False)
If True, expects an error.

"""
enc = PrimitiveWeldEncoder()
dec = PrimitiveWeldDecoder()

try:
result = dec.decode(ctypes.pointer(enc.encode(value, ty)), ty)
except Exception as e:
if err:
return
else:
raise e

if err:
raise RuntimeError("Expected error during encode/decode")

if assert_equal:
assert value == result
encdec = encdec_factory(PrimitiveWeldEncoder, PrimitiveWeldDecoder)

def test_i8_encode():
encdec(-1, I8())
Expand Down
4 changes: 2 additions & 2 deletions weld-python/weld/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ def func(*args, context=None):
data = ctypes.cast(result.data(), pointer_type)

if decoder is not None:
result = decoder.decode(data, restype)
result = decoder.decode(data, restype, context)
else:
result = primitive_decoder.decode(data, restype)
result = primitive_decoder.decode(data, restype, context)
return (result, context)

return func
8 changes: 6 additions & 2 deletions weld-python/weld/encoders/encoder_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class WeldDecoder(ABC):
"""

@abstractmethod
def decode(obj, restype):
def decode(self, obj, restype, context):
"""
Decodes the object, assuming object has the WeldType restype.

Expand All @@ -51,11 +51,15 @@ def decode(obj, restype):
An object encoded in the Weld ABI.
restype : WeldType
The WeldType of the object that is being decoded.
context : WeldContext or None
The context backing `obj` if this value was constructed in Weld.

Returns
-------
any
The decoder can return any Python value.
The decoder can return any Python value. If the data is not copied
and context is not `None`, the returned object should hold a
reference to the context to prevent use-after-free bugs.

"""
pass
Loading