Skip to content

Commit

Permalink
refactor(interactive): Add Encoder/Decoder for Python SDK (#4191)
Browse files Browse the repository at this point in the history
Implement the Encoder/Decoder for the Interactive Python SDK, enabling
users to customize query parameters and deserialize the output.
zhanglei1949 authored Aug 29, 2024
1 parent 387e33b commit eb848a5
Showing 6 changed files with 100 additions and 21 deletions.
2 changes: 1 addition & 1 deletion flex/engines/graph_db/database/graph_db_session.h
Original file line number Diff line number Diff line change
@@ -139,7 +139,7 @@ class GraphDBSession {
inline Result<std::pair<uint8_t, std::string_view>> parse_query_type(
const std::string& input) {
const char* str_data = input.data();
VLOG(10) << "parse query type for " << input;
VLOG(10) << "parse query type for " << input << " size: " << input.size();
char input_tag = input.back();
VLOG(10) << "input tag: " << static_cast<int>(input_tag);
size_t len = input.size();
Original file line number Diff line number Diff line change
@@ -87,6 +87,11 @@ public void put_bytes(byte[] bytes) {
this.loc = serialize_bytes(this.bs, this.loc, bytes);
}

public void put_string(String value) {
byte[] bytes = value.getBytes();
this.put_bytes(bytes);
}

byte[] bs;
int loc;
}
16 changes: 8 additions & 8 deletions flex/interactive/sdk/python/gs_interactive/client/session.py
Original file line number Diff line number Diff line change
@@ -212,11 +212,11 @@ def call_procedure_current(self, params: QueryRequest) -> Result[CollectiveResul
raise NotImplementedError

@abstractmethod
def call_procedure_raw(self, graph_id: StrictStr, params: str) -> Result[str]:
def call_procedure_raw(self, graph_id: StrictStr, params: bytes) -> Result[str]:
raise NotImplementedError

@abstractmethod
def call_procedure_current_raw(self, params: str) -> Result[str]:
def call_procedure_current_raw(self, params: bytes) -> Result[str]:
raise NotImplementedError


@@ -582,7 +582,7 @@ def call_procedure(
# Here we add byte of value 1 to denote the input format is in json format
response = self._query_api.call_proc_with_http_info(
graph_id = graph_id,
body=append_format_byte(params.to_json(), InputFormat.CYPHER_JSON)
body=append_format_byte(params.to_json().encode(), InputFormat.CYPHER_JSON)
)
result = CollectiveResults()
if response.status_code == 200:
@@ -598,7 +598,7 @@ def call_procedure_current(self, params: QueryRequest) -> Result[CollectiveResul
# gs_interactive currently support four type of inputformat, see flex/engines/graph_db/graph_db_session.h
# Here we add byte of value 1 to denote the input format is in json format
response = self._query_api.call_proc_current_with_http_info(
body = append_format_byte(params.to_json(), InputFormat.CYPHER_JSON)
body = append_format_byte(params.to_json().encode(), InputFormat.CYPHER_JSON)
)
result = CollectiveResults()
if response.status_code == 200:
@@ -609,25 +609,25 @@ def call_procedure_current(self, params: QueryRequest) -> Result[CollectiveResul
except Exception as e:
return Result.from_exception(e)

def call_procedure_raw(self, graph_id: StrictStr, params: str) -> Result[str]:
def call_procedure_raw(self, graph_id: StrictStr, params: bytes) -> Result[str]:
graph_id = self.ensure_param_str("graph_id", graph_id)
try:
# gs_interactive currently support four type of inputformat, see flex/engines/graph_db/graph_db_session.h
# Here we add byte of value 1 to denote the input format is in encoder/decoder format
response = self._query_api.call_proc_with_http_info(
graph_id = graph_id,
body = append_format_byte(params, InputFormat.CPP_ENCODER)
body = append_format_byte(params.encode(), InputFormat.CPP_ENCODER)
)
return Result.from_response(response)
except Exception as e:
return Result.from_exception(e)

def call_procedure_current_raw(self, params: str) -> Result[str]:
def call_procedure_current_raw(self, params: bytes) -> Result[str]:
try:
# gs_interactive currently support four type of inputformat, see flex/engines/graph_db/graph_db_session.h
# Here we add byte of value 1 to denote the input format is in encoder/decoder format
response = self._query_api.call_proc_current_with_http_info(
body = append_format_byte(params, InputFormat.CPP_ENCODER)
body = append_format_byte(params.encode(), InputFormat.CPP_ENCODER)
)
return Result.from_response(response)
except Exception as e:
73 changes: 72 additions & 1 deletion flex/interactive/sdk/python/gs_interactive/client/utils.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,77 @@
#
from enum import Enum


class Encoder:
def __init__(self, endian = 'little') -> None:
self.byte_array = bytearray()
self.endian = endian

def put_int(self, value: int):
# put the value in big endian, 4 bytes
self.byte_array.extend(value.to_bytes(4, byteorder=self.endian))

def put_long(self, value: int):
self.byte_array.extend(value.to_bytes(8, byteorder=self.endian))

def put_string(self, value: str):
self.put_int(len(value))
self.byte_array.extend(value.encode('utf-8'))

def put_byte(self, value: int):
self.byte_array.extend(value.to_bytes(1, byteorder=self.endian))

def put_bytes(self, value: bytes):
self.byte_array.extend(value)

def put_double(self, value: float):
self.byte_array.extend(value.to_bytes(8, byteorder=self.endian))

def get_bytes(self):
# return bytes not bytearray
return bytes(self.byte_array)

class Decoder:
def __init__(self, byte_array: bytearray,endian = 'little') -> None:
self.byte_array = byte_array
self.index = 0
self.endian = endian

def get_int(self) -> int:
value = int.from_bytes(self.byte_array[self.index:self.index+4], byteorder=self.endian)
self.index += 4
return value

def get_long(self) -> int:
value = int.from_bytes(self.byte_array[self.index:self.index+8], byteorder=self.endian)
self.index += 8
return value

def get_double(self) -> float:
value = float.from_bytes(self.byte_array[self.index:self.index+8], byteorder=self.endian)
self.index += 8
return value

def get_byte(self) -> int:
value = int.from_bytes(self.byte_array[self.index:self.index+1], byteorder=self.endian)
self.index += 1
return value

def get_bytes(self, length: int) -> bytes:
value = self.byte_array[self.index:self.index+length]
self.index += length
return value

def get_string(self) -> str:
length = self.get_int()
value = self.byte_array[self.index:self.index+length].decode('utf-8')
self.index += length
return value

def is_empty(self) -> bool:
return self.index == len(self.byte_array)


class InputFormat(Enum):
CPP_ENCODER = 0 # raw bytes encoded by encoder/decoder
CYPHER_JSON = 1 # json format string
@@ -27,6 +98,6 @@ def append_format_byte(input: bytes, input_format: InputFormat):
"""
Append a byte to the end of the input string to denote the input format
"""
new_bytes = str.encode(input) + bytes([input_format.value])
new_bytes = input + bytes([input_format.value])
return new_bytes

Original file line number Diff line number Diff line change
@@ -38,6 +38,6 @@ def tearDown(self):

def test_append_format_byte(self):
input = "hello"
new_bytes = append_format_byte(input, input_format=InputFormat.CPP_ENCODER)
new_bytes = append_format_byte(input.encode(), input_format=InputFormat.CPP_ENCODER)
self.assertEqual(new_bytes, b'hello\x00')
self.assertEqual(len(new_bytes), len(input) + 1)
23 changes: 13 additions & 10 deletions flex/tests/interactive/test_call_proc.py
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@
sys.path.append("../../interactive/sdk/python")

from gs_interactive.client.driver import Driver
from gs_interactive.client.utils import Encoder, Decoder
from gs_interactive.models.base_edge_type_vertex_type_pair_relations_inner import (
BaseEdgeTypeVertexTypePairRelationsInner,
)
@@ -98,27 +99,29 @@ def callProcedureWithJsonFormat(self, graph_id : str):
def callProcedureWithEncoder(self, graph_id : str):
# count_vertex_num, should be with id 1
# construct a byte array with bytes: 0x01
params = chr(1)
resp = self._sess.call_procedure_raw(graph_id, params)
encoder = Encoder()
encoder.put_byte(1)
resp = self._sess.call_procedure_raw(graph_id, encoder.get_bytes())
if not resp.is_ok():
print("call count_vertex_num failed: ", resp.get_status_message())
exit(1)

# plus_one, should be with id 2
# construct a byte array with bytes: the 4 bytes of integer 1, and a byte 2
value = 1
byte_string = value.to_bytes(4, byteorder=sys.byteorder) + bytes([2])
# byte_string = bytes([1,0,0,0,2]) # 4 bytes of integer 1, and a byte 3
params = byte_string.decode('utf-8')
resp = self._sess.call_procedure_raw(graph_id, params)
encoder2 = Encoder()
encoder2.put_int(1) # The input value 1
encoder2.put_byte(2) # The procedure id
resp = self._sess.call_procedure_raw(graph_id, encoder2.get_bytes())
if not resp.is_ok():
print("call plus_one failed: ", resp.get_status_message())
exit(1)
res = resp.get_value()
assert len(res) == 4
# the four byte represent a integer
res = int.from_bytes(res, byteorder=sys.byteorder)
assert(res == 2)
decoder = Decoder(res)
value = decoder.get_int()
print("call plus_one result: ", value)
assert(value == 2)
assert(decoder.is_empty())

if __name__ == "__main__":
#parse command line args

0 comments on commit eb848a5

Please sign in to comment.