diff --git a/CHANGELOG.md b/CHANGELOG.md index 4186b68..fb42c9b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ - Added `WSManFaultError` which contains WSManFault specific information when receiving a 500 WSMan fault response - This contains pre-parsed values like the code, subcode, wsman fault code, wmi error code, and raw response - It can be used by the caller to implement fallback behaviour based on specific error codes +- Added public API `protocol.build_wsman_header` that can create the standard WSMan header used by the protocol + - This can be used to craft custom WSMan messages that are not supported in the existing actions +- Added public API `protocol.get_command_output_raw` + - This can be used to send a single WSMan receive request and get the output + - Unlike `protocol.get_command_output`, it will not loop until the command is done and will not catch a timeout exception ### Version 0.4.3 - Fix invalid regex escape sequences. diff --git a/winrm/protocol.py b/winrm/protocol.py index e955d47..a93d7fa 100644 --- a/winrm/protocol.py +++ b/winrm/protocol.py @@ -157,7 +157,7 @@ def open_shell( @rtype string """ req = { - "env:Envelope": self._get_soap_header( + "env:Envelope": self.build_wsman_header( resource_uri="http://schemas.microsoft.com/wbem/wsman/1/windows/shell/cmd", # NOQA action="http://schemas.xmlsoap.org/ws/2004/09/transfer/Create", ) @@ -198,13 +198,26 @@ def open_shell( return t.cast(str, next(node for node in root.findall(".//*") if node.get("Name") == "ShellId").text) # Helper method for building SOAP Header - def _get_soap_header( + def build_wsman_header( self, - action: str | None = None, - resource_uri: str | None = None, + action: str, + resource_uri: str, shell_id: str | None = None, - message_id: uuid.UUID | None = None, + message_id: str | uuid.UUID | None = None, ) -> dict[str, t.Any]: + """ + Builds the standard header needed for WSMan operations. The return + value is a dictionary that can be used by xmltodict to generate the + WSMan envelope when sending custom requests. + + @param string action: The WSMan action to perform. + @param string resource_uri: The WSMan resource URI the request is for. + @param string shell_id: The optional shell UUID the request is for. + @param string message_id: A unique message UUID, if unset a random UUID + is used. + @returns The WSMan header as a dictionary. + @rtype dict[str, t.Any] + """ if not message_id: message_id = uuid.uuid4() header: dict[str, t.Any] = { @@ -238,6 +251,11 @@ def _get_soap_header( header["env:Header"]["w:SelectorSet"] = {"w:Selector": {"@Name": "ShellId", "#text": shell_id}} return header + # For backwards compatibility with Ansible. This should not be removed + # until all supported releases of Ansible has been updated to use the new + # method. + _get_soap_header = build_wsman_header + def send_message(self, message: str) -> bytes: # TODO add message_id vs relates_to checking # TODO port error handling code @@ -311,7 +329,7 @@ def close_shell(self, shell_id: str, close_session: bool = True) -> None: try: message_id = uuid.uuid4() req = { - "env:Envelope": self._get_soap_header( + "env:Envelope": self.build_wsman_header( resource_uri="http://schemas.microsoft.com/wbem/wsman/1/windows/shell/cmd", # NOQA action="http://schemas.xmlsoap.org/ws/2004/09/transfer/Delete", shell_id=shell_id, @@ -356,7 +374,7 @@ def run_command( @rtype string """ req = { - "env:Envelope": self._get_soap_header( + "env:Envelope": self.build_wsman_header( resource_uri="http://schemas.microsoft.com/wbem/wsman/1/windows/shell/cmd", # NOQA action="http://schemas.microsoft.com/wbem/wsman/1/windows/shell/Command", # NOQA shell_id=shell_id, @@ -393,7 +411,7 @@ def cleanup_command(self, shell_id: str, command_id: str) -> None: """ message_id = uuid.uuid4() req = { - "env:Envelope": self._get_soap_header( + "env:Envelope": self.build_wsman_header( resource_uri="http://schemas.microsoft.com/wbem/wsman/1/windows/shell/cmd", # NOQA action="http://schemas.microsoft.com/wbem/wsman/1/windows/shell/Signal", # NOQA shell_id=shell_id, @@ -430,7 +448,7 @@ def send_command_input(self, shell_id: str, command_id: str, stdin_input: str | if isinstance(stdin_input, str): stdin_input = stdin_input.encode("437") req = { - "env:Envelope": self._get_soap_header( + "env:Envelope": self.build_wsman_header( resource_uri="http://schemas.microsoft.com/wbem/wsman/1/windows/shell/cmd", # NOQA action="http://schemas.microsoft.com/wbem/wsman/1/windows/shell/Send", # NOQA shell_id=shell_id, @@ -449,22 +467,22 @@ def send_command_input(self, shell_id: str, command_id: str, stdin_input: str | def get_command_output(self, shell_id: str, command_id: str) -> tuple[bytes, bytes, int]: """ - Get the Output of the given shell and command + Get the Output of the given shell and command. This will wait until the + command is finished before returning the output. + @param string shell_id: The shell id on the remote machine. See #open_shell @param string command_id: The command id on the remote machine. See #run_command - #@return [Hash] Returns a Hash with a key :exitcode and :data. - Data is an Array of Hashes where the corresponding key - # is either :stdout or :stderr. The reason it is in an Array so so - we can get the output in the order it occurs on - # the console. + @return tuple[bytes, bytes, int]: Returns a tuple with the stdout, + stderr, and the return code of the command. The stdout and stderr + value is a byte string and not a normal string. """ stdout_buffer, stderr_buffer = [], [] command_done = False while not command_done: try: - stdout, stderr, return_code, command_done = self._raw_get_command_output(shell_id, command_id) + stdout, stderr, return_code, command_done = self.get_command_output_raw(shell_id, command_id) stdout_buffer.append(stdout) stderr_buffer.append(stderr) except WinRMOperationTimeoutError: @@ -472,9 +490,25 @@ def get_command_output(self, shell_id: str, command_id: str) -> tuple[bytes, byt pass return b"".join(stdout_buffer), b"".join(stderr_buffer), return_code - def _raw_get_command_output(self, shell_id: str, command_id: str) -> tuple[bytes, bytes, int, bool]: + def get_command_output_raw(self, shell_id: str, command_id: str) -> tuple[bytes, bytes, int, bool]: + """ + Get the next available output of the given shell and command. This + will wait until the issued WSMan Receive action returns data or times + out with WinRMOperationTimeoutError. + + @param string shell_id: The shell id on the remote machine. + See #open_shell + @param string command_id: The command id on the remote machine. + See #run_command + @return tuple[bytes, bytes, int, bool]: Returns a tuple with the stdout, + stderr, the return code of the command, and whether it has finished + or not. The stdout and stderr value is a byte string and not a + normal string. + @raises WinRMOperationTimeoutError: Raised when there has been no + output from the command + """ req = { - "env:Envelope": self._get_soap_header( + "env:Envelope": self.build_wsman_header( resource_uri="http://schemas.microsoft.com/wbem/wsman/1/windows/shell/cmd", # NOQA action="http://schemas.microsoft.com/wbem/wsman/1/windows/shell/Receive", # NOQA shell_id=shell_id, @@ -488,15 +522,16 @@ def _raw_get_command_output(self, shell_id: str, command_id: str) -> tuple[bytes res = self.send_message(xmltodict.unparse(req)) root = ET.fromstring(res) stream_nodes = [node for node in root.findall(".//*") if node.tag.endswith("Stream")] - stdout = stderr = b"" + stdout = [] + stderr = [] return_code = -1 for stream_node in stream_nodes: if not stream_node.text: continue if stream_node.attrib["Name"] == "stdout": - stdout += base64.b64decode(stream_node.text.encode("ascii")) + stdout.append(base64.b64decode(stream_node.text.encode("ascii"))) elif stream_node.attrib["Name"] == "stderr": - stderr += base64.b64decode(stream_node.text.encode("ascii")) + stderr.append(base64.b64decode(stream_node.text.encode("ascii"))) # We may need to get additional output if the stream has not finished. # The CommandState will change from Running to Done like so: @@ -511,4 +546,10 @@ def _raw_get_command_output(self, shell_id: str, command_id: str) -> tuple[bytes if command_done: return_code = int(next(node for node in root.findall(".//*") if node.tag.endswith("ExitCode")).text or -1) - return stdout, stderr, return_code, command_done + return b"".join(stdout), b"".join(stderr), return_code, command_done + + # While it was meant to be private it has been treated as a public API. + # This might be removed in a future version but for now keep it as an + # alias for the now public API method 'get_command_output_raw'. + # https://github.com/search?q=_raw_get_command_output+language%3APython&type=code&l=Python + _raw_get_command_output = get_command_output_raw diff --git a/winrm/tests/test_protocol.py b/winrm/tests/test_protocol.py index 2e5d554..d7835bc 100644 --- a/winrm/tests/test_protocol.py +++ b/winrm/tests/test_protocol.py @@ -3,6 +3,17 @@ from winrm.protocol import Protocol +@pytest.mark.parametrize("func_name", ["build_wsman_header", "_get_soap_header"]) +def test_build_wsman_header(func_name, protocol_fake): + func = getattr(protocol_fake, func_name) + actual = func("my action", "resource uri", "shell id", "message id") + + assert actual["env:Header"]["a:Action"]["#text"] == "my action" + assert actual["env:Header"]["w:ResourceURI"]["#text"] == "resource uri" + assert actual["env:Header"]["a:MessageID"] == "uuid:message id" + assert actual["env:Header"]["w:SelectorSet"]["w:Selector"]["#text"] == "shell id" + + def test_open_shell_and_close_shell(protocol_fake): shell_id = protocol_fake.open_shell() assert shell_id == "11111111-1111-1111-1111-111111111113" @@ -40,6 +51,22 @@ def test_get_command_output(protocol_fake): protocol_fake.close_shell(shell_id) +@pytest.mark.parametrize("func_name", ["get_command_output_raw", "_raw_get_command_output"]) +def test_get_command_output_raw(func_name, protocol_fake): + func = getattr(protocol_fake, func_name) + shell_id = protocol_fake.open_shell() + command_id = protocol_fake.run_command(shell_id, "ipconfig", ["/all"]) + + std_out, std_err, status_code, done = func(shell_id, command_id) + assert status_code == 0 + assert b"Windows IP Configuration" in std_out + assert len(std_err) == 0 + assert done is True + + protocol_fake.cleanup_command(shell_id, command_id) + protocol_fake.close_shell(shell_id) + + def test_send_command_input(protocol_fake): shell_id = protocol_fake.open_shell() command_id = protocol_fake.run_command(shell_id, "cmd")