diff --git a/pyk/src/pyk/kore/rpc.py b/pyk/src/pyk/kore/rpc.py index c23dbf7bd61..242dc7a2438 100644 --- a/pyk/src/pyk/kore/rpc.py +++ b/pyk/src/pyk/kore/rpc.py @@ -56,8 +56,52 @@ def __init__(self, message: str, code: int, data: Any = None): class Transport(ContextManager['Transport'], ABC): + _bug_report: BugReport | None + _bug_report_id: str | None + + def __init__(self, bug_report_id: str | None = None, bug_report: BugReport | None = None) -> None: + if (bug_report_id is None and bug_report is not None) or (bug_report_id is not None and bug_report is None): + raise ValueError('bug_report and bug_report_id must be passed together.') + self._bug_report_id = bug_report_id + self._bug_report = bug_report + + def request(self, req: str, request_id: int, method_name: str) -> str: + base_name = self._bug_report_id if self._bug_report_id is not None else 'kore_rpc' + req_name = f'{base_name}/{id(self)}/{request_id:03}' + if self._bug_report: + bug_report_request = f'{req_name}_request.json' + self._bug_report.add_file_contents(req, Path(bug_report_request)) + self._bug_report.add_command(self._command(req_name, bug_report_request)) + + server_addr = self._description() + _LOGGER.info(f'Sending request to {server_addr}: {request_id} - {method_name}') + _LOGGER.debug(f'Sending request to {server_addr}: {req}') + resp = self._request(req) + _LOGGER.info(f'Received response from {server_addr}: {request_id} - {method_name}') + _LOGGER.debug(f'Received response from {server_addr}: {resp}') + + if self._bug_report: + bug_report_response = f'{req_name}_response.json' + self._bug_report.add_file_contents(resp, Path(bug_report_response)) + self._bug_report.add_command( + [ + 'diff', + '-b', + '-s', + f'{req_name}_actual.json', + f'{req_name}_response.json', + ] + ) + return resp + + @abstractmethod + def _command(self, req_name: str, bug_report_request: str) -> list[str]: ... + + @abstractmethod + def _request(self, req: str) -> str: ... + @abstractmethod - def request(self, req: str) -> str: ... + def _description(self) -> str: ... def __enter__(self) -> Transport: return self @@ -68,12 +112,6 @@ def __exit__(self, *args: Any) -> None: @abstractmethod def close(self) -> None: ... - @abstractmethod - def command(self, bug_report_id: str, old_id: int, bug_report_request: str) -> list[str]: ... - - @abstractmethod - def description(self) -> str: ... - class TransportType(Enum): SINGLE_SOCKET = auto() @@ -87,7 +125,16 @@ class SingleSocketTransport(Transport): _sock: socket.socket _file: TextIO - def __init__(self, host: str, port: int, *, timeout: int | None = None): + def __init__( + self, + host: str, + port: int, + *, + timeout: int | None = None, + bug_report_id: str | None = None, + bug_report: BugReport | None = None, + ): + super().__init__(bug_report_id, bug_report) self._host = host self._port = port self._sock = self._create_connection(host, port, timeout) @@ -117,7 +164,7 @@ def close(self) -> None: self._file.close() self._sock.close() - def command(self, bug_report_id: str, old_id: int, bug_report_request: str) -> list[str]: + def _command(self, req_name: str, bug_report_request: str) -> list[str]: return [ 'cat', bug_report_request, @@ -127,16 +174,16 @@ def command(self, bug_report_id: str, old_id: int, bug_report_request: str) -> l self._host, str(self._port), '>', - f'rpc_{bug_report_id}/{old_id:03}_actual.json', + f'{req_name}_actual.json', ] - def request(self, req: str) -> str: + def _request(self, req: str) -> str: self._sock.sendall(req.encode()) - server_addr = self.description() + server_addr = self._description() _LOGGER.debug(f'Waiting for response from {server_addr}...') return self._file.readline().rstrip() - def description(self) -> str: + def _description(self) -> str: return f'{self._host}:{self._port}' @@ -146,7 +193,16 @@ class HttpTransport(Transport): _port: int _timeout: int | None - def __init__(self, host: str, port: int, *, timeout: int | None = None): + def __init__( + self, + host: str, + port: int, + *, + timeout: int | None = None, + bug_report_id: str | None = None, + bug_report: BugReport | None = None, + ): + super().__init__(bug_report_id, bug_report) self._host = host self._port = port self._timeout = timeout @@ -154,7 +210,7 @@ def __init__(self, host: str, port: int, *, timeout: int | None = None): def close(self) -> None: pass - def command(self, bug_report_id: str, old_id: int, bug_report_request: str) -> list[str]: + def _command(self, req_name: str, bug_report_request: str) -> list[str]: return [ 'curl', '-X', @@ -165,20 +221,20 @@ def command(self, bug_report_id: str, old_id: int, bug_report_request: str) -> l '@' + bug_report_request, 'http://' + self._host + ':' + str(self._port), '>', - f'rpc_{bug_report_id}/{old_id:03}_actual.json', + f'{req_name}_actual.json', ] - def request(self, req: str) -> str: + def _request(self, req: str) -> str: connection = http.client.HTTPConnection(self._host, self._port, timeout=self._timeout) connection.request('POST', '/', body=req, headers={'Content-Type': 'application/json'}) - server_addr = self.description() + server_addr = self._description() _LOGGER.debug(f'Waiting for response from {server_addr}...') response = connection.getresponse() if response.status != 200: raise JsonRpcError('Internal server error', -32603) return response.read().decode() - def description(self) -> str: + def _description(self) -> str: return f'{self._host}:{self._port}' @@ -258,8 +314,6 @@ class JsonRpcClient(ContextManager['JsonRpcClient']): _transport: Transport _req_id: int - _bug_report: BugReport | None - _bug_report_id: str def __init__( self, @@ -272,14 +326,16 @@ def __init__( transport: TransportType = TransportType.SINGLE_SOCKET, ): if transport is TransportType.SINGLE_SOCKET: - self._transport = SingleSocketTransport(host, port, timeout=timeout) + self._transport = SingleSocketTransport( + host, port, timeout=timeout, bug_report=bug_report, bug_report_id=bug_report_id + ) elif transport is TransportType.HTTP: - self._transport = HttpTransport(host, port, timeout=timeout) + self._transport = HttpTransport( + host, port, timeout=timeout, bug_report=bug_report, bug_report_id=bug_report_id + ) else: raise AssertionError() self._req_id = 1 - self._bug_report = bug_report - self._bug_report_id = bug_report_id if bug_report_id is not None else str(id(self)) def __enter__(self) -> JsonRpcClient: return self @@ -301,38 +357,15 @@ def request(self, method: str, **params: Any) -> dict[str, Any]: 'params': params, } - server_addr = self._transport.description() - _LOGGER.info(f'Sending request to {server_addr}: {old_id} - {method}') req = json.dumps(payload) - if self._bug_report: - bug_report_request = f'rpc_{self._bug_report_id}/{old_id:03}_request.json' - self._bug_report.add_file_contents(req, Path(bug_report_request)) - self._bug_report.add_command(self._transport.command(self._bug_report_id, old_id, bug_report_request)) - - _LOGGER.debug(f'Sending request to {server_addr}: {req}') - resp = self._transport.request(req) + resp = self._transport.request(req, old_id, method) if not resp: raise RuntimeError('Empty response received') - _LOGGER.debug(f'Received response from {server_addr}: {resp}') - - if self._bug_report: - bug_report_response = f'rpc_{self._bug_report_id}/{old_id:03}_response.json' - self._bug_report.add_file_contents(resp, Path(bug_report_response)) - self._bug_report.add_command( - [ - 'diff', - '-b', - '-s', - f'rpc_{self._bug_report_id}/{old_id:03}_actual.json', - f'rpc_{self._bug_report_id}/{old_id:03}_response.json', - ] - ) data = json.loads(resp) self._check(data) assert data['id'] == old_id - _LOGGER.info(f'Received response from {server_addr}: {old_id} - {method}') return data['result'] @staticmethod diff --git a/pyk/src/tests/unit/kore/test_rpc_client.py b/pyk/src/tests/unit/kore/test_rpc_client.py index bb054849e3a..cf69f1c5bf7 100644 --- a/pyk/src/tests/unit/kore/test_rpc_client.py +++ b/pyk/src/tests/unit/kore/test_rpc_client.py @@ -50,7 +50,7 @@ def transport(mock: Mock) -> MockTransport: @pytest.fixture def kore_client(mock: Mock, mock_class: Mock) -> Iterator[KoreClient]: # noqa: N803 client = KoreClient('localhost', 3000) - mock_class.assert_called_with('localhost', 3000, timeout=None) + mock_class.assert_called_with('localhost', 3000, timeout=None, bug_report=None, bug_report_id=None) assert client._client._default_client._transport == mock yield client client.close()