Skip to content

Commit

Permalink
Merge pull request #2 from gizatechxyz/refactor-serializer
Browse files Browse the repository at this point in the history
Refactor Serializer/Deserializer
  • Loading branch information
raphaelDkhn authored Jan 19, 2024
2 parents 6b44c34 + d4a241e commit c349665
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 173 deletions.
205 changes: 107 additions & 98 deletions osiris/cairo/serde/deserialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,21 @@ def deserializer(serialized: list, data_type: str, fp_impl='FP16x16'):
return deserialize_tensor_uint(serialized)
elif data_type == 'tensor_signed_int':
return deserialize_tensor_signed_int(serialized)
elif data_type == 'tensor_fixed_point':
return deserialize_tensor_fixed_point(serialized, fp_impl)
elif data_type == 'tuple_uint':
return deserialize_tuple_uint(serialized)
elif data_type == 'tuple_signed_int':
return deserialize_tuple_signed_int(serialized)
elif data_type == 'tuple_fixed_point':
return deserialize_tuple_fixed_point(serialized, fp_impl)
elif data_type == 'tuple_tensor_uint':
return deserialize_tuple_tensor_uint(serialized)
elif data_type == 'tuple_tensor_signed_int':
return deserialize_tuple_tensor_signed_int(serialized)
elif data_type == 'tuple_tensor_fixed_point':
return deserialize_tuple_tensor_fixed_point(serialized, fp_impl)
# TODO: Support Tuples
# elif data_type == 'tensor_fixed_point':
# return deserialize_tensor_fixed_point(serialized, fp_impl)
# elif data_type == 'tuple_uint':
# return deserialize_tuple_uint(serialized)
# elif data_type == 'tuple_signed_int':
# return deserialize_tuple_signed_int(serialized)
# elif data_type == 'tuple_fixed_point':
# return deserialize_tuple_fixed_point(serialized, fp_impl)
# elif data_type == 'tuple_tensor_uint':
# return deserialize_tuple_tensor_uint(serialized)
# elif data_type == 'tuple_tensor_signed_int':
# return deserialize_tuple_tensor_signed_int(serialized)
# elif data_type == 'tuple_tensor_fixed_point':
# return deserialize_tuple_tensor_fixed_point(serialized, fp_impl)
else:
raise ValueError(f"Unknown data type: {data_type}")

Expand Down Expand Up @@ -76,153 +77,161 @@ def deserialize_fixed_point(serialized: list, impl='FP16x16') -> np.float64:


def deserialize_arr_uint(serialized: list) -> np.array:
return np.array(serialized[1:], dtype=np.int64)
return np.array(serialized[0], dtype=np.int64)

# ================= ARRAY SIGNED INT =================


def deserialize_arr_signed_int(serialized: list) -> np.array:
num_ele = (len(serialized) - 1) // 2
def deserialize_arr_signed_int(serialized):

deserialized_array = np.empty(num_ele, dtype=np.int64)
serialized = serialized[0]

for i in range(num_ele):
deserialized_array[i] = deserialize_signed_int(
serialized[1 + i*2: 3 + i*2])
if len(serialized) % 2 != 0:
raise ValueError("Array length must be even")

return deserialized_array
deserialized = []
for i in range(0, len(serialized), 2):
mag = serialized[i]
sign = serialized[i + 1]

if sign == 1:
mag = -mag

deserialized.append(mag)

return np.array(deserialized)

# ================= ARRAY FIXED POINT =================


def deserialize_arr_fixed_point(serialized: list, impl='FP16x16') -> np.array:
num_ele = (len(serialized) - 1) // 2
def deserialize_arr_fixed_point(serialized: list, impl='FP16x16'):

deserialized_array = np.empty(num_ele, dtype=np.float64)
serialized = serialized[0]

for i in range(num_ele):
deserialized_array[i] = deserialize_fixed_point(
serialized[1 + i*2: 3 + i*2], impl)
if len(serialized) % 2 != 0:
raise ValueError("Array length must be even")

return deserialized_array
deserialized = []
for i in range(0, len(serialized), 2):
mag = serialized[i]
sign = serialized[i + 1]

deserialized.append(deserialize_fixed_point([mag, sign], impl))

return np.array(deserialized)


# ================= TENSOR UINT =================


def deserialize_tensor_uint(serialized: list) -> np.array:
num_shape_elements = serialized[0]
shape = serialized[1:1 + num_shape_elements]
data = serialized[1 + num_shape_elements + 1:]
shape = serialized[0]
data = serialized[1]

return np.array(data, dtype=np.int64).reshape(shape)

# ================= TENSOR SIGNED INT =================


def deserialize_tensor_signed_int(serialized: list) -> np.array:
num_shape_elements = serialized[0]
shape = serialized[1:1 + num_shape_elements]
data = deserialize_arr_signed_int(
serialized[1 + num_shape_elements:])
shape = serialized[0]
data = deserialize_arr_signed_int([serialized[1]])

return data.reshape(shape)
return np.array(data, dtype=np.int64).reshape(shape)


# ================= TENSOR FIXED POINT =================


def deserialize_tensor_fixed_point(serialized: list, impl='FP16x16') -> np.array:
num_shape_elements = serialized[0]
shape = serialized[1:1 + num_shape_elements]
data = deserialize_arr_fixed_point(
serialized[1 + num_shape_elements:], impl)
shape = serialized[0]
data = deserialize_arr_fixed_point([serialized[1]], impl)

return np.array(data, dtype=np.float64).reshape(shape)

return data.reshape(shape)

# ================= TUPLE UINT =================


def deserialize_tuple_uint(serialized: list):
return np.array(serialized, dtype=np.int64)
# def deserialize_tuple_uint(serialized: list):
# return np.array(serialized[0], dtype=np.int64)


# ================= TUPLE SIGNED INT =================
# # ================= TUPLE SIGNED INT =================


def deserialize_tuple_signed_int(serialized: list):
num_ele = (len(serialized)) // 2
# def deserialize_tuple_signed_int(serialized: list):
# num_ele = (len(serialized)) // 2

deserialized_array = np.empty(num_ele, dtype=np.int64)
# deserialized_array = np.empty(num_ele, dtype=np.int64)

for i in range(num_ele):
deserialized_array[i] = deserialize_signed_int(
serialized[i*2: 3 + i*2])
# for i in range(num_ele):
# deserialized_array[i] = deserialize_signed_int(
# serialized[i*2: 3 + i*2])

return deserialized_array
# return deserialized_array

# ================= TUPLE FIXED POINT =================
# # ================= TUPLE FIXED POINT =================


def deserialize_tuple_fixed_point(serialized: list, impl='FP16x16'):
num_ele = (len(serialized)) // 2
# def deserialize_tuple_fixed_point(serialized: list, impl='FP16x16'):
# num_ele = (len(serialized)) // 2

deserialized_array = np.empty(num_ele, dtype=np.float64)
# deserialized_array = np.empty(num_ele, dtype=np.float64)

for i in range(num_ele):
deserialized_array[i] = deserialize_fixed_point(
serialized[i*2: 3 + i*2], impl)
# for i in range(num_ele):
# deserialized_array[i] = deserialize_fixed_point(
# serialized[i*2: 3 + i*2], impl)

return deserialized_array
# return deserialized_array


# ================= TUPLE TENSOR UINT =================
# # ================= TUPLE TENSOR UINT =================

def deserialize_tuple_tensor_uint(serialized: list):
return deserialize_tuple_tensor(serialized, deserialize_arr_uint)
# def deserialize_tuple_tensor_uint(serialized: list):
# return deserialize_tuple_tensor(serialized, deserialize_arr_uint)

# ================= TUPLE TENSOR SIGNED INT =================
# # ================= TUPLE TENSOR SIGNED INT =================


def deserialize_tuple_tensor_signed_int(serialized: list):
return deserialize_tuple_tensor(serialized, deserialize_arr_signed_int)
# def deserialize_tuple_tensor_signed_int(serialized: list):
# return deserialize_tuple_tensor(serialized, deserialize_arr_signed_int)

# ================= TUPLE TENSOR FIXED POINT =================
# # ================= TUPLE TENSOR FIXED POINT =================


def deserialize_tuple_tensor_fixed_point(serialized: list, impl='FP16x16'):
return deserialize_tuple_tensor(serialized, deserialize_arr_fixed_point, impl)
# def deserialize_tuple_tensor_fixed_point(serialized: list, impl='FP16x16'):
# return deserialize_tuple_tensor(serialized, deserialize_arr_fixed_point, impl)


# ================= HELPERS =================
# # ================= HELPERS =================


def extract_shape(serialized, start_index):
""" Extracts the shape part of a tensor from a serialized list. """
num_shape_elements = serialized[start_index]
shape = serialized[start_index + 1: start_index + 1 + num_shape_elements]
return shape, start_index + 1 + num_shape_elements
# def extract_shape(serialized, start_index):
# """ Extracts the shape part of a tensor from a serialized list. """
# num_shape_elements = serialized[start_index]
# shape = serialized[start_index + 1: start_index + 1 + num_shape_elements]
# return shape, start_index + 1 + num_shape_elements


def extract_data(serialized, start_index, deserialization_func, impl=None):
""" Extracts and deserializes the data part of a tensor from a serialized list. """
num_data_elements = serialized[start_index]
end_index = start_index + 1 + num_data_elements
data_serialized = serialized[start_index: end_index]
if impl:
data = deserialization_func(data_serialized, impl)
else:
data = deserialization_func(data_serialized)
return data, end_index


def deserialize_tuple_tensor(serialized, deserialization_func, impl=None):
""" Generic deserialization function for a tuple of tensors. """
deserialized_tensors = []
i = 0
while i < len(serialized):
shape, i = extract_shape(serialized, i)
data, i = extract_data(serialized, i, deserialization_func, impl)
tensor = data.reshape(shape)
deserialized_tensors.append(tensor)
return tuple(deserialized_tensors)
# def extract_data(serialized, start_index, deserialization_func, impl=None):
# """ Extracts and deserializes the data part of a tensor from a serialized list. """
# num_data_elements = serialized[start_index]
# end_index = start_index + 1 + num_data_elements
# data_serialized = serialized[start_index: end_index]
# if impl:
# data = deserialization_func(data_serialized, impl)
# else:
# data = deserialization_func(data_serialized)
# return data, end_index


# def deserialize_tuple_tensor(serialized, deserialization_func, impl=None):
# """ Generic deserialization function for a tuple of tensors. """
# deserialized_tensors = []
# i = 0
# while i < len(serialized):
# shape, i = extract_shape(serialized, i)
# data, i = extract_data(serialized, i, deserialization_func, impl)
# tensor = data.reshape(shape)
# deserialized_tensors.append(tensor)
# return tuple(deserialized_tensors)
30 changes: 12 additions & 18 deletions osiris/cairo/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,24 @@
)


def serializer(data) -> list[str]:
def serializer(data):
if isinstance(data, bool):
return ["1"] if data else ["0"]
return "1" if data else "0"
elif isinstance(data, int):
return [str(data)]
if data >= 0:
return f"{data}"
else:
raise ValueError("Native signed integers are not supported yet")
# TODO: Support native singned-int
elif isinstance(data, (list, tuple)):
serialized_list = [str(len(data))]
for item in data:
serialized_list.extend(serializer(item))
return serialized_list
elif isinstance(data, dict):
serialized_dict = [str(len(data))]
for key, value in data.items():
serialized_dict.extend(serializer(key))
serialized_dict.extend(serializer(value))
return serialized_dict
joined_elements = ' '.join(serializer(e) for e in data)
return f"[{joined_elements}]"
elif isinstance(data, Tensor):
serialized_tensor = serializer(data.shape)
serialized_tensor.extend(serializer(data.data))
return serialized_tensor
return f"{serializer(data.shape)} {serializer(data.data)}"
elif isinstance(data, (SignedInt, FixedPoint)):
return [str(data.mag), str(data.sign)]
return f"{serializer(data.mag)} {serializer(data.sign)}"
elif isinstance(data, UnsignedInt):
return [str(data.mag)]
return f"{data.mag}"

else:
raise ValueError("Unsupported data type for serialization")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "giza-osiris"
version = "0.1.6"
version = "0.1.7"
description = "Osiris is a Python library designed for efficient data conversion and management, primarily transforming data into Cairo programs"
authors = ["Fran Algaba <[email protected]>"]
readme = "README.md"
Expand Down
Loading

0 comments on commit c349665

Please sign in to comment.