Skip to content

Commit

Permalink
modify deserializer
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Jan 19, 2024
1 parent 0063cc0 commit 987d083
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 148 deletions.
205 changes: 107 additions & 98 deletions osiris/cairo/serde/deserialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,21 @@ def deserializer(serialized: list, data_type: str, fp_impl='FP16x16'):
return deserialize_tensor_uint(serialized)
elif data_type == 'tensor_signed_int':
return deserialize_tensor_signed_int(serialized)
elif data_type == 'tensor_fixed_point':
return deserialize_tensor_fixed_point(serialized, fp_impl)
elif data_type == 'tuple_uint':
return deserialize_tuple_uint(serialized)
elif data_type == 'tuple_signed_int':
return deserialize_tuple_signed_int(serialized)
elif data_type == 'tuple_fixed_point':
return deserialize_tuple_fixed_point(serialized, fp_impl)
elif data_type == 'tuple_tensor_uint':
return deserialize_tuple_tensor_uint(serialized)
elif data_type == 'tuple_tensor_signed_int':
return deserialize_tuple_tensor_signed_int(serialized)
elif data_type == 'tuple_tensor_fixed_point':
return deserialize_tuple_tensor_fixed_point(serialized, fp_impl)
# TODO: Support Tuples
# elif data_type == 'tensor_fixed_point':
# return deserialize_tensor_fixed_point(serialized, fp_impl)
# elif data_type == 'tuple_uint':
# return deserialize_tuple_uint(serialized)
# elif data_type == 'tuple_signed_int':
# return deserialize_tuple_signed_int(serialized)
# elif data_type == 'tuple_fixed_point':
# return deserialize_tuple_fixed_point(serialized, fp_impl)
# elif data_type == 'tuple_tensor_uint':
# return deserialize_tuple_tensor_uint(serialized)
# elif data_type == 'tuple_tensor_signed_int':
# return deserialize_tuple_tensor_signed_int(serialized)
# elif data_type == 'tuple_tensor_fixed_point':
# return deserialize_tuple_tensor_fixed_point(serialized, fp_impl)
else:
raise ValueError(f"Unknown data type: {data_type}")

Expand Down Expand Up @@ -76,153 +77,161 @@ def deserialize_fixed_point(serialized: list, impl='FP16x16') -> np.float64:


def deserialize_arr_uint(serialized: list) -> np.array:
return np.array(serialized[1:], dtype=np.int64)
return np.array(serialized[0], dtype=np.int64)

# ================= ARRAY SIGNED INT =================


def deserialize_arr_signed_int(serialized: list) -> np.array:
num_ele = (len(serialized) - 1) // 2
def deserialize_arr_signed_int(serialized):

deserialized_array = np.empty(num_ele, dtype=np.int64)
serialized = serialized[0]

for i in range(num_ele):
deserialized_array[i] = deserialize_signed_int(
serialized[1 + i*2: 3 + i*2])
if len(serialized) % 2 != 0:
raise ValueError("Array length must be even")

return deserialized_array
deserialized = []
for i in range(0, len(serialized), 2):
mag = serialized[i]
sign = serialized[i + 1]

if sign == 1:
mag = -mag

deserialized.append(mag)

return np.array(deserialized)

# ================= ARRAY FIXED POINT =================


def deserialize_arr_fixed_point(serialized: list, impl='FP16x16') -> np.array:
num_ele = (len(serialized) - 1) // 2
def deserialize_arr_fixed_point(serialized: list, impl='FP16x16'):

deserialized_array = np.empty(num_ele, dtype=np.float64)
serialized = serialized[0]

for i in range(num_ele):
deserialized_array[i] = deserialize_fixed_point(
serialized[1 + i*2: 3 + i*2], impl)
if len(serialized) % 2 != 0:
raise ValueError("Array length must be even")

return deserialized_array
deserialized = []
for i in range(0, len(serialized), 2):
mag = serialized[i]
sign = serialized[i + 1]

deserialized.append(deserialize_fixed_point([mag, sign], impl))

return np.array(deserialized)


# ================= TENSOR UINT =================


def deserialize_tensor_uint(serialized: list) -> np.array:
num_shape_elements = serialized[0]
shape = serialized[1:1 + num_shape_elements]
data = serialized[1 + num_shape_elements + 1:]
shape = serialized[0]
data = serialized[1]

return np.array(data, dtype=np.int64).reshape(shape)

# ================= TENSOR SIGNED INT =================


def deserialize_tensor_signed_int(serialized: list) -> np.array:
num_shape_elements = serialized[0]
shape = serialized[1:1 + num_shape_elements]
data = deserialize_arr_signed_int(
serialized[1 + num_shape_elements:])
shape = serialized[0]
data = deserialize_arr_signed_int([serialized[1]])

return data.reshape(shape)
return np.array(data, dtype=np.int64).reshape(shape)


# ================= TENSOR FIXED POINT =================


def deserialize_tensor_fixed_point(serialized: list, impl='FP16x16') -> np.array:
num_shape_elements = serialized[0]
shape = serialized[1:1 + num_shape_elements]
data = deserialize_arr_fixed_point(
serialized[1 + num_shape_elements:], impl)
shape = serialized[0]
data = deserialize_arr_fixed_point([serialized[1]], impl)

return np.array(data, dtype=np.float64).reshape(shape)

return data.reshape(shape)

# ================= TUPLE UINT =================


def deserialize_tuple_uint(serialized: list):
return np.array(serialized, dtype=np.int64)
# def deserialize_tuple_uint(serialized: list):
# return np.array(serialized[0], dtype=np.int64)


# ================= TUPLE SIGNED INT =================
# # ================= TUPLE SIGNED INT =================


def deserialize_tuple_signed_int(serialized: list):
num_ele = (len(serialized)) // 2
# def deserialize_tuple_signed_int(serialized: list):
# num_ele = (len(serialized)) // 2

deserialized_array = np.empty(num_ele, dtype=np.int64)
# deserialized_array = np.empty(num_ele, dtype=np.int64)

for i in range(num_ele):
deserialized_array[i] = deserialize_signed_int(
serialized[i*2: 3 + i*2])
# for i in range(num_ele):
# deserialized_array[i] = deserialize_signed_int(
# serialized[i*2: 3 + i*2])

return deserialized_array
# return deserialized_array

# ================= TUPLE FIXED POINT =================
# # ================= TUPLE FIXED POINT =================


def deserialize_tuple_fixed_point(serialized: list, impl='FP16x16'):
num_ele = (len(serialized)) // 2
# def deserialize_tuple_fixed_point(serialized: list, impl='FP16x16'):
# num_ele = (len(serialized)) // 2

deserialized_array = np.empty(num_ele, dtype=np.float64)
# deserialized_array = np.empty(num_ele, dtype=np.float64)

for i in range(num_ele):
deserialized_array[i] = deserialize_fixed_point(
serialized[i*2: 3 + i*2], impl)
# for i in range(num_ele):
# deserialized_array[i] = deserialize_fixed_point(
# serialized[i*2: 3 + i*2], impl)

return deserialized_array
# return deserialized_array


# ================= TUPLE TENSOR UINT =================
# # ================= TUPLE TENSOR UINT =================

def deserialize_tuple_tensor_uint(serialized: list):
return deserialize_tuple_tensor(serialized, deserialize_arr_uint)
# def deserialize_tuple_tensor_uint(serialized: list):
# return deserialize_tuple_tensor(serialized, deserialize_arr_uint)

# ================= TUPLE TENSOR SIGNED INT =================
# # ================= TUPLE TENSOR SIGNED INT =================


def deserialize_tuple_tensor_signed_int(serialized: list):
return deserialize_tuple_tensor(serialized, deserialize_arr_signed_int)
# def deserialize_tuple_tensor_signed_int(serialized: list):
# return deserialize_tuple_tensor(serialized, deserialize_arr_signed_int)

# ================= TUPLE TENSOR FIXED POINT =================
# # ================= TUPLE TENSOR FIXED POINT =================


def deserialize_tuple_tensor_fixed_point(serialized: list, impl='FP16x16'):
return deserialize_tuple_tensor(serialized, deserialize_arr_fixed_point, impl)
# def deserialize_tuple_tensor_fixed_point(serialized: list, impl='FP16x16'):
# return deserialize_tuple_tensor(serialized, deserialize_arr_fixed_point, impl)


# ================= HELPERS =================
# # ================= HELPERS =================


def extract_shape(serialized, start_index):
""" Extracts the shape part of a tensor from a serialized list. """
num_shape_elements = serialized[start_index]
shape = serialized[start_index + 1: start_index + 1 + num_shape_elements]
return shape, start_index + 1 + num_shape_elements
# def extract_shape(serialized, start_index):
# """ Extracts the shape part of a tensor from a serialized list. """
# num_shape_elements = serialized[start_index]
# shape = serialized[start_index + 1: start_index + 1 + num_shape_elements]
# return shape, start_index + 1 + num_shape_elements


def extract_data(serialized, start_index, deserialization_func, impl=None):
""" Extracts and deserializes the data part of a tensor from a serialized list. """
num_data_elements = serialized[start_index]
end_index = start_index + 1 + num_data_elements
data_serialized = serialized[start_index: end_index]
if impl:
data = deserialization_func(data_serialized, impl)
else:
data = deserialization_func(data_serialized)
return data, end_index


def deserialize_tuple_tensor(serialized, deserialization_func, impl=None):
""" Generic deserialization function for a tuple of tensors. """
deserialized_tensors = []
i = 0
while i < len(serialized):
shape, i = extract_shape(serialized, i)
data, i = extract_data(serialized, i, deserialization_func, impl)
tensor = data.reshape(shape)
deserialized_tensors.append(tensor)
return tuple(deserialized_tensors)
# def extract_data(serialized, start_index, deserialization_func, impl=None):
# """ Extracts and deserializes the data part of a tensor from a serialized list. """
# num_data_elements = serialized[start_index]
# end_index = start_index + 1 + num_data_elements
# data_serialized = serialized[start_index: end_index]
# if impl:
# data = deserialization_func(data_serialized, impl)
# else:
# data = deserialization_func(data_serialized)
# return data, end_index


# def deserialize_tuple_tensor(serialized, deserialization_func, impl=None):
# """ Generic deserialization function for a tuple of tensors. """
# deserialized_tensors = []
# i = 0
# while i < len(serialized):
# shape, i = extract_shape(serialized, i)
# data, i = extract_data(serialized, i, deserialization_func, impl)
# tensor = data.reshape(shape)
# deserialized_tensors.append(tensor)
# return tuple(deserialized_tensors)
Loading

0 comments on commit 987d083

Please sign in to comment.