From eb848a55e10910eb7ad3bb5896bbceaafdac7e28 Mon Sep 17 00:00:00 2001 From: Zhang Lei Date: Thu, 29 Aug 2024 14:36:18 +0800 Subject: [PATCH] refactor(interactive): Add Encoder/Decoder for Python SDK (#4191) Implement the Encoder/Decoder for the Interactive Python SDK, enabling users to customize query parameters and deserialize the output. --- .../graph_db/database/graph_db_session.h | 2 +- .../interactive/client/utils/Encoder.java | 5 ++ .../python/gs_interactive/client/session.py | 16 ++-- .../sdk/python/gs_interactive/client/utils.py | 73 ++++++++++++++++++- .../python/gs_interactive/tests/test_utils.py | 2 +- flex/tests/interactive/test_call_proc.py | 23 +++--- 6 files changed, 100 insertions(+), 21 deletions(-) diff --git a/flex/engines/graph_db/database/graph_db_session.h b/flex/engines/graph_db/database/graph_db_session.h index 15c374abfdc2..158511299e10 100644 --- a/flex/engines/graph_db/database/graph_db_session.h +++ b/flex/engines/graph_db/database/graph_db_session.h @@ -139,7 +139,7 @@ class GraphDBSession { inline Result> parse_query_type( const std::string& input) { const char* str_data = input.data(); - VLOG(10) << "parse query type for " << input; + VLOG(10) << "parse query type for " << input << " size: " << input.size(); char input_tag = input.back(); VLOG(10) << "input tag: " << static_cast(input_tag); size_t len = input.size(); diff --git a/flex/interactive/sdk/java/src/main/java/com/alibaba/graphscope/interactive/client/utils/Encoder.java b/flex/interactive/sdk/java/src/main/java/com/alibaba/graphscope/interactive/client/utils/Encoder.java index bcaceedc47a9..3324cb2faa9d 100644 --- a/flex/interactive/sdk/java/src/main/java/com/alibaba/graphscope/interactive/client/utils/Encoder.java +++ b/flex/interactive/sdk/java/src/main/java/com/alibaba/graphscope/interactive/client/utils/Encoder.java @@ -87,6 +87,11 @@ public void put_bytes(byte[] bytes) { this.loc = serialize_bytes(this.bs, this.loc, bytes); } + public void put_string(String value) { + byte[] bytes = value.getBytes(); + this.put_bytes(bytes); + } + byte[] bs; int loc; } diff --git a/flex/interactive/sdk/python/gs_interactive/client/session.py b/flex/interactive/sdk/python/gs_interactive/client/session.py index 1ab3671dd6aa..d72da741061b 100644 --- a/flex/interactive/sdk/python/gs_interactive/client/session.py +++ b/flex/interactive/sdk/python/gs_interactive/client/session.py @@ -212,11 +212,11 @@ def call_procedure_current(self, params: QueryRequest) -> Result[CollectiveResul raise NotImplementedError @abstractmethod - def call_procedure_raw(self, graph_id: StrictStr, params: str) -> Result[str]: + def call_procedure_raw(self, graph_id: StrictStr, params: bytes) -> Result[str]: raise NotImplementedError @abstractmethod - def call_procedure_current_raw(self, params: str) -> Result[str]: + def call_procedure_current_raw(self, params: bytes) -> Result[str]: raise NotImplementedError @@ -582,7 +582,7 @@ def call_procedure( # Here we add byte of value 1 to denote the input format is in json format response = self._query_api.call_proc_with_http_info( graph_id = graph_id, - body=append_format_byte(params.to_json(), InputFormat.CYPHER_JSON) + body=append_format_byte(params.to_json().encode(), InputFormat.CYPHER_JSON) ) result = CollectiveResults() if response.status_code == 200: @@ -598,7 +598,7 @@ def call_procedure_current(self, params: QueryRequest) -> Result[CollectiveResul # gs_interactive currently support four type of inputformat, see flex/engines/graph_db/graph_db_session.h # Here we add byte of value 1 to denote the input format is in json format response = self._query_api.call_proc_current_with_http_info( - body = append_format_byte(params.to_json(), InputFormat.CYPHER_JSON) + body = append_format_byte(params.to_json().encode(), InputFormat.CYPHER_JSON) ) result = CollectiveResults() if response.status_code == 200: @@ -609,25 +609,25 @@ def call_procedure_current(self, params: QueryRequest) -> Result[CollectiveResul except Exception as e: return Result.from_exception(e) - def call_procedure_raw(self, graph_id: StrictStr, params: str) -> Result[str]: + def call_procedure_raw(self, graph_id: StrictStr, params: bytes) -> Result[str]: graph_id = self.ensure_param_str("graph_id", graph_id) try: # gs_interactive currently support four type of inputformat, see flex/engines/graph_db/graph_db_session.h # Here we add byte of value 1 to denote the input format is in encoder/decoder format response = self._query_api.call_proc_with_http_info( graph_id = graph_id, - body = append_format_byte(params, InputFormat.CPP_ENCODER) + body = append_format_byte(params.encode(), InputFormat.CPP_ENCODER) ) return Result.from_response(response) except Exception as e: return Result.from_exception(e) - def call_procedure_current_raw(self, params: str) -> Result[str]: + def call_procedure_current_raw(self, params: bytes) -> Result[str]: try: # gs_interactive currently support four type of inputformat, see flex/engines/graph_db/graph_db_session.h # Here we add byte of value 1 to denote the input format is in encoder/decoder format response = self._query_api.call_proc_current_with_http_info( - body = append_format_byte(params, InputFormat.CPP_ENCODER) + body = append_format_byte(params.encode(), InputFormat.CPP_ENCODER) ) return Result.from_response(response) except Exception as e: diff --git a/flex/interactive/sdk/python/gs_interactive/client/utils.py b/flex/interactive/sdk/python/gs_interactive/client/utils.py index 73944deeb74f..abdc78b4c39f 100644 --- a/flex/interactive/sdk/python/gs_interactive/client/utils.py +++ b/flex/interactive/sdk/python/gs_interactive/client/utils.py @@ -17,6 +17,77 @@ # from enum import Enum + +class Encoder: + def __init__(self, endian = 'little') -> None: + self.byte_array = bytearray() + self.endian = endian + + def put_int(self, value: int): + # put the value in big endian, 4 bytes + self.byte_array.extend(value.to_bytes(4, byteorder=self.endian)) + + def put_long(self, value: int): + self.byte_array.extend(value.to_bytes(8, byteorder=self.endian)) + + def put_string(self, value: str): + self.put_int(len(value)) + self.byte_array.extend(value.encode('utf-8')) + + def put_byte(self, value: int): + self.byte_array.extend(value.to_bytes(1, byteorder=self.endian)) + + def put_bytes(self, value: bytes): + self.byte_array.extend(value) + + def put_double(self, value: float): + self.byte_array.extend(value.to_bytes(8, byteorder=self.endian)) + + def get_bytes(self): + # return bytes not bytearray + return bytes(self.byte_array) + +class Decoder: + def __init__(self, byte_array: bytearray,endian = 'little') -> None: + self.byte_array = byte_array + self.index = 0 + self.endian = endian + + def get_int(self) -> int: + value = int.from_bytes(self.byte_array[self.index:self.index+4], byteorder=self.endian) + self.index += 4 + return value + + def get_long(self) -> int: + value = int.from_bytes(self.byte_array[self.index:self.index+8], byteorder=self.endian) + self.index += 8 + return value + + def get_double(self) -> float: + value = float.from_bytes(self.byte_array[self.index:self.index+8], byteorder=self.endian) + self.index += 8 + return value + + def get_byte(self) -> int: + value = int.from_bytes(self.byte_array[self.index:self.index+1], byteorder=self.endian) + self.index += 1 + return value + + def get_bytes(self, length: int) -> bytes: + value = self.byte_array[self.index:self.index+length] + self.index += length + return value + + def get_string(self) -> str: + length = self.get_int() + value = self.byte_array[self.index:self.index+length].decode('utf-8') + self.index += length + return value + + def is_empty(self) -> bool: + return self.index == len(self.byte_array) + + class InputFormat(Enum): CPP_ENCODER = 0 # raw bytes encoded by encoder/decoder CYPHER_JSON = 1 # json format string @@ -27,6 +98,6 @@ def append_format_byte(input: bytes, input_format: InputFormat): """ Append a byte to the end of the input string to denote the input format """ - new_bytes = str.encode(input) + bytes([input_format.value]) + new_bytes = input + bytes([input_format.value]) return new_bytes \ No newline at end of file diff --git a/flex/interactive/sdk/python/gs_interactive/tests/test_utils.py b/flex/interactive/sdk/python/gs_interactive/tests/test_utils.py index 5e926af93ca5..a3c84a8cd936 100644 --- a/flex/interactive/sdk/python/gs_interactive/tests/test_utils.py +++ b/flex/interactive/sdk/python/gs_interactive/tests/test_utils.py @@ -38,6 +38,6 @@ def tearDown(self): def test_append_format_byte(self): input = "hello" - new_bytes = append_format_byte(input, input_format=InputFormat.CPP_ENCODER) + new_bytes = append_format_byte(input.encode(), input_format=InputFormat.CPP_ENCODER) self.assertEqual(new_bytes, b'hello\x00') self.assertEqual(len(new_bytes), len(input) + 1) diff --git a/flex/tests/interactive/test_call_proc.py b/flex/tests/interactive/test_call_proc.py index d173d222a70e..ed67819fefc8 100644 --- a/flex/tests/interactive/test_call_proc.py +++ b/flex/tests/interactive/test_call_proc.py @@ -23,6 +23,7 @@ sys.path.append("../../interactive/sdk/python") from gs_interactive.client.driver import Driver +from gs_interactive.client.utils import Encoder, Decoder from gs_interactive.models.base_edge_type_vertex_type_pair_relations_inner import ( BaseEdgeTypeVertexTypePairRelationsInner, ) @@ -98,27 +99,29 @@ def callProcedureWithJsonFormat(self, graph_id : str): def callProcedureWithEncoder(self, graph_id : str): # count_vertex_num, should be with id 1 # construct a byte array with bytes: 0x01 - params = chr(1) - resp = self._sess.call_procedure_raw(graph_id, params) + encoder = Encoder() + encoder.put_byte(1) + resp = self._sess.call_procedure_raw(graph_id, encoder.get_bytes()) if not resp.is_ok(): print("call count_vertex_num failed: ", resp.get_status_message()) exit(1) # plus_one, should be with id 2 - # construct a byte array with bytes: the 4 bytes of integer 1, and a byte 2 - value = 1 - byte_string = value.to_bytes(4, byteorder=sys.byteorder) + bytes([2]) - # byte_string = bytes([1,0,0,0,2]) # 4 bytes of integer 1, and a byte 3 - params = byte_string.decode('utf-8') - resp = self._sess.call_procedure_raw(graph_id, params) + encoder2 = Encoder() + encoder2.put_int(1) # The input value 1 + encoder2.put_byte(2) # The procedure id + resp = self._sess.call_procedure_raw(graph_id, encoder2.get_bytes()) if not resp.is_ok(): print("call plus_one failed: ", resp.get_status_message()) exit(1) res = resp.get_value() assert len(res) == 4 # the four byte represent a integer - res = int.from_bytes(res, byteorder=sys.byteorder) - assert(res == 2) + decoder = Decoder(res) + value = decoder.get_int() + print("call plus_one result: ", value) + assert(value == 2) + assert(decoder.is_empty()) if __name__ == "__main__": #parse command line args