Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor osiris deserializer #9

Merged
merged 4 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 72 additions & 156 deletions osiris/cairo/serde/deserialize.py
Original file line number Diff line number Diff line change
@@ -1,172 +1,88 @@
import json

import numpy as np

from .utils import felt_to_int, from_fp


def deserializer(serialized: str, dtype: str):
# Check if the serialized data is a string and needs conversion
if isinstance(serialized, str):
serialized = convert_data(serialized)

# Function to deserialize individual elements within a tuple
def deserialize_element(element, element_type):
if element_type in ("u8", "u16", "u32", "u64", "u128", "i8", "i16", "i32", "i64", "i128"):
return deserialize_int(element)
elif element_type.startswith("FP"):
return deserialize_fixed_point(element, element_type)
elif element_type.startswith("Span<") and element_type.endswith(">"):
inner_type = element_type[5:-1]
if inner_type.startswith("FP"):
return deserialize_arr_fixed_point(element, inner_type)
else:
return deserialize_arr_int(element)
elif element_type.startswith("Tensor<") and element_type.endswith(">"):
inner_type = element_type[7:-1]
if inner_type.startswith("FP"):
return deserialize_tensor_fixed_point(element, inner_type)
else:
return deserialize_tensor_int(element)
elif element_type.startswith("(") and element_type.endswith(")"):
# Recursive call for nested tuples
return deserializer(element, element_type)
else:
raise ValueError(f"Unsupported data type: {element_type}")

# Handle tuple data type
if dtype.startswith("(") and dtype.endswith(")"):
types = dtype[1:-1].split(", ")
deserialized_elements = []
i = 0 # Initialize loop counter

while i < len(serialized):
ele_type = types[len(deserialized_elements)]

if ele_type.startswith("Tensor<"):
# For Tensors, take two elements from serialized (shape and data)
ele = serialized[i:i+2]
i += 2
else:
# For other types, take one element
ele = serialized[i]
i += 1

if ele_type.startswith("Tensor<"):
deserialized_elements.append(
deserialize_element(ele, ele_type))
else:
deserialized_elements.append(
deserialize_element([ele], ele_type))

if len(deserialized_elements) != len(types):
raise ValueError(
"Serialized data length does not match tuple length")

return tuple(deserialized_elements)

else:
return deserialize_element(serialized, dtype)


def parse_return_value(return_value):
"""
Parse a ReturnValue dictionary to extract the integer value or recursively parse an array of ReturnValues (cf: OrionRunner ReturnValues).
"""
if 'Int' in return_value:
# Convert hexadecimal string to integer
return int(return_value['Int'], 16)
elif 'Array' in return_value:
# Recursively parse each item in the array
return [parse_return_value(item) for item in return_value['Array']]
else:
raise ValueError("Invalid ReturnValue format")


def convert_data(data):
"""
Convert the given JSON-like data structure to the desired format.
"""
parsed_data = json.loads(data)
result = []
for item in parsed_data:
# Parse each item based on its keys
if 'Array' in item:
# Process array items
result.append(parse_return_value(item))
elif 'Int' in item:
# Process single int items
result.append(parse_return_value(item))
else:
raise ValueError("Invalid data format")
return result


# ================= INT =================


def deserialize_int(serialized: list) -> np.int64:
return np.int64(felt_to_int(serialized[0]))


# ================= FIXED POINT =================


def deserialize_fixed_point(serialized: list, impl='FP16x16') -> np.float64:
serialized_mag = from_fp(serialized[0], impl)
serialized_sign = serialized[1]

deserialized = serialized_mag if serialized_sign == 0 else -serialized_mag
return np.float64(deserialized)
from osiris.cairo.serde.utils import felt_to_int, from_fp


# ================= ARRAY INT =================
def deserializer(serialized, dtype):

if dtype in ["u8", "u16", "u32", "u64", "u128", "i8", "i16", "i32", "i64", "i128"]:
return felt_to_int(int(serialized))

def deserialize_arr_int(serialized):
elif dtype.startswith("FP"):
return deserialize_fp(serialized)

serialized = serialized[0]
elif dtype.startswith('Span<'):
return deserialize_span(serialized, dtype)

deserialized = []
for ele in serialized:
deserialized.append(felt_to_int(ele))
elif dtype.startswith('Tensor<'):
return deserialize_tensor(serialized, dtype)

return np.array(deserialized)

# ================= ARRAY FIXED POINT =================


def deserialize_arr_fixed_point(serialized: list, impl='FP16x16'):

serialized = serialized[0]

if len(serialized) % 2 != 0:
raise ValueError("Array length must be even")

deserialized = []
for i in range(0, len(serialized), 2):
mag = serialized[i]
sign = serialized[i + 1]

deserialized.append(deserialize_fixed_point([mag, sign], impl))

return np.array(deserialized)


# ================= TENSOR INT =================
elif dtype.startswith('('): # Tuple
return deserialize_tuple(serialized, dtype)

else:
raise ValueError(f"Unknown data type: {dtype}")

def deserialize_tensor_int(serialized: list) -> np.array:
shape = serialized[0]
data = deserialize_arr_int([serialized[1]])

return np.array(data, dtype=np.int64).reshape(shape)
def deserialize_fp(serialized):
parts = serialized.split()
value = from_fp(int(parts[0]))
if len(parts) > 1 and parts[1] == '1': # Check for negative sign
value = -value
return value


# ================= TENSOR FIXED POINT =================
def deserialize_span(serialized, dtype):
inner_type = dtype[5:-1]
elements = serialized[1:-1].split()
if inner_type.startswith("FP"):
# For fixed point, elements consist of two parts (value and sign)
deserialized_elements = [deserializer(' '.join(elements[i:i + 2]), inner_type)
for i in range(0, len(elements), 2)]
return np.array(deserialized_elements, dtype=np.float64)
else:
return np.array([deserializer(e, inner_type) for e in elements], dtype=np.int64)

def deserialize_tensor_fixed_point(serialized: list, impl='FP16x16') -> np.array:
shape = serialized[0]
data = deserialize_arr_fixed_point([serialized[1]], impl)

return np.array(data, dtype=np.float64).reshape(shape)
def deserialize_tensor(serialized, dtype):
inner_type = dtype[7:-1]
parts = serialized.split('] [')
dims = [int(d) for d in parts[0][1:].split()]
values = parts[1][:-1].split()
if inner_type.startswith("FP"):
tensor_data = np.array([deserializer(' '.join(values[i:i + 2]), inner_type)
for i in range(0, len(values), 2)])
else:
tensor_data = np.array(
[deserializer(v, inner_type) for v in values])
return tensor_data.reshape(dims)


def deserialize_tuple(serialized, dtype):
types = dtype[1:-1].split(', ')
if 'Tensor' in types[0]:
tensor_end = find_nth_occurrence(serialized, ']', 2)
depth = 1
for i in range(tensor_end, len(serialized)):
if serialized[i] == '[':
depth += 1
elif serialized[i] == ']':
depth -= 1
if depth == 0:
tensor_end = i + 1
break
part1 = deserializer(serialized[:tensor_end].strip(), types[0])
part2 = deserializer(serialized[tensor_end:].strip(), types[1])
else:
split_index = serialized.find(']') + 2
part1 = deserializer(serialized[:split_index].strip(), types[0])
part2 = deserializer(serialized[split_index:].strip(), types[1])
return part1, part2


def find_nth_occurrence(string, sub_string, n):
start_index = string.find(sub_string)
while start_index >= 0 and n > 1:
start_index = string.find(sub_string, start_index + 1)
n -= 1
return start_index
29 changes: 14 additions & 15 deletions tests/test_deserialize.py
Original file line number Diff line number Diff line change
@@ -1,88 +1,87 @@
import numpy as np
import numpy.testing as npt
import pytest
from math import isclose

from osiris.cairo.serde.deserialize import *


def test_deserialize_int():
serialized = '[{"Int":"2A"}]'
serialized = '42'
deserialized = deserializer(serialized, 'u32')
assert deserialized == 42

serialized = '[{"Int":"800000000000010FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFD7"}]'
serialized = '3618502788666131213697322783095070105623107215331596699973092056135872020439'
deserialized = deserializer(serialized, 'i32')
assert deserialized == -42


def test_deserialize_fp():
serialized = '[{"Int":"2A6B85"}, {"Int":"0"}]'
serialized = '2780037 0'
deserialized = deserializer(serialized, 'FP16x16')
assert isclose(deserialized, 42.42, rel_tol=1e-7)

serialized = '[{"Int":"2A6B85"}, {"Int":"1"}]'
serialized = '2780037 1'
deserialized = deserializer(serialized, 'FP16x16')
assert isclose(deserialized, -42.42, rel_tol=1e-7)


def test_deserialize_array_int():
serialized = '[{"Array": [{"Int": "0x1"}, {"Int": "0x2"}]}]'
serialized = '[1 2]'
deserialized = deserializer(serialized, 'Span<u32>')
assert np.array_equal(deserialized, np.array([1, 2], dtype=np.int64))

serialized = '[{"Array": [{"Int": "2A"}, {"Int": "800000000000010FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFD7"}]}]'
serialized = '[42 3618502788666131213697322783095070105623107215331596699973092056135872020439]'
deserialized = deserializer(serialized, 'Span<i32>')
assert np.array_equal(deserialized, np.array([42, -42], dtype=np.int64))


def test_deserialize_arr_fixed_point():
serialized = '[{"Array": [{"Int": "2A6B85"}, {"Int": "0"}, {"Int": "2A6B85"}, {"Int": "0x1"}]}]'
serialized = '[2780037 0 2780037 1]'
deserialized = deserializer(serialized, 'Span<FP16x16>')
expected = np.array([42.42, -42.42], dtype=np.float64)
assert np.all(np.isclose(deserialized, expected, atol=1e-7))


def test_deserialize_tensor_int():
serialized = '[{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "0x1"}, {"Int": "0x2"}, {"Int": "0x3"}, {"Int": "0x4"}]}]'
serialized = '[2 2] [1 2 3 4]'
deserialized = deserializer(serialized, 'Tensor<i32>')
assert np.array_equal(deserialized, np.array(
([1, 2], [3, 4]), dtype=np.int64))

serialized = '[{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "2A"}, {"Int": "2A"},{"Int": "800000000000010FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFD7"}, {"Int": "800000000000010FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFD7"}]}]'
serialized = '[2 2] [42 42 3618502788666131213697322783095070105623107215331596699973092056135872020439 3618502788666131213697322783095070105623107215331596699973092056135872020439]'
deserialized = deserializer(serialized, 'Tensor<i32>')
assert np.array_equal(deserialized, np.array([[42, 42], [-42, -42]]))


def test_deserialize_tensor_fixed_point():
serialized = '[{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x1"}, {"Int": "2A6B85"}, {"Int": "0x1"}]}]'
serialized = '[2 2] [2780037 0 2780037 0 2780037 1 2780037 1]'
expected_array = np.array([[42.42, 42.42], [-42.42, -42.42]])
deserialized = deserializer(serialized, 'Tensor<FP16x16>')
assert np.allclose(deserialized, expected_array, atol=1e-7)


def test_deserialize_tuple_int():
serialized = '[{"Int":"0x1"},{"Int":"0x3"}]'
serialized = '1 3'
deserialized = deserializer(serialized, '(u32, u32)')
assert deserialized == (1, 3)


def test_deserialize_tuple_span():
serialized = '[{"Array":[{"Int":"0x1"},{"Int":"0x2"}]},{"Int":"0x3"}]'
serialized = '[1 2] 3'
deserialized = deserializer(serialized, '(Span<u32>, u32)')
expected = (np.array([1, 2]), 3)
npt.assert_array_equal(deserialized[0], expected[0])
assert deserialized[1] == expected[1]


def test_deserialize_tuple_span_tensor_fp():
serialized = '[{"Array":[{"Int":"0x1"},{"Int":"0x2"}]},{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x1"}, {"Int": "2A6B85"}, {"Int": "0x1"}]}]'
serialized = '[1 2] [2 2] [2780037 0 2780037 0 2780037 1 2780037 1]'
deserialized = deserializer(serialized, '(Span<u32>, Tensor<FP16x16>)')
expected = (np.array([1, 2]), np.array([[42.42, 42.42], [-42.42, -42.42]]))
npt.assert_array_equal(deserialized[0], expected[0])
assert np.allclose(deserialized[1], expected[1], atol=1e-7)

serialized = '[{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x1"}, {"Int": "2A6B85"}, {"Int": "0x1"}]}, {"Array":[{"Int":"0x1"},{"Int":"0x2"}]}]'
serialized = '[2 2] [2780037 0 2780037 0 2780037 1 2780037 1] [1 2]'
deserialized = deserializer(serialized, '(Tensor<FP16x16>, Span<u32>)')
expected = (np.array([[42.42, 42.42], [-42.42, -42.42]]), np.array([1, 2]))
assert np.allclose(deserialized[0], expected[0], atol=1e-7)
Expand Down
Loading