diff --git a/llm_sandbox/base.py b/llm_sandbox/base.py index 46974dd..e7404ad 100644 --- a/llm_sandbox/base.py +++ b/llm_sandbox/base.py @@ -3,22 +3,23 @@ class ConsoleOutput: - def __init__(self, text: str): + def __init__(self, exit_code: Optional[int], text: str): + self._exit_code = exit_code self._text = text + @property + def exit_code(self): + return self._exit_code + @property def text(self): return self._text def __str__(self): - return f"ConsoleOutput(text={self.text})" + return f"ConsoleOutput(text={self.text}, exit_code={self._exit_code})" class KubernetesConsoleOutput(ConsoleOutput): - def __init__(self, exit_code: int, text: str): - super().__init__(text) - self.exit_code = exit_code - def __str__(self): return f"KubernetesConsoleOutput(text={self.text}, exit_code={self.exit_code})" diff --git a/llm_sandbox/docker.py b/llm_sandbox/docker.py index 4ed1811..c266baa 100644 --- a/llm_sandbox/docker.py +++ b/llm_sandbox/docker.py @@ -34,6 +34,7 @@ def __init__( commit_container: bool = True, verbose: bool = False, mounts: Optional[list[Mount]] = None, + stream: bool = True, container_configs: Optional[dict] = None, ): """ @@ -46,6 +47,7 @@ def __init__( :param commit_container: if True, the Docker container will be commited after the session ends :param verbose: if True, print messages :param mounts: List of mounts to be mounted to the container + :param stream: if True, the output will be streamed (enabling this option prevents obtaining an exit code of run command) :param container_configs: Additional configurations for the container, i.e. resources limits (cpu_count, mem_limit), etc. """ super().__init__(lang, verbose) @@ -80,6 +82,7 @@ def __init__( self.is_create_template: bool = False self.verbose = verbose self.mounts = mounts + self.stream = stream self.container_configs = container_configs def open(self): @@ -196,7 +199,7 @@ def run(self, code: str, libraries: Optional[List] = None) -> ConsoleOutput: self.copy_to_runtime(code_file, code_dest_file) - output = ConsoleOutput("") + output = ConsoleOutput(0, "") commands = get_code_execution_command(self.lang, code_dest_file) for command in commands: if self.lang == SupportedLanguage.GO: @@ -263,21 +266,24 @@ def execute_command( if workdir: exit_code, exec_log = self.container.exec_run( - command, stream=True, tty=True, workdir=workdir + command, stream=self.stream, tty=True, workdir=workdir ) else: exit_code, exec_log = self.container.exec_run( - command, stream=True, tty=True + command, stream=self.stream, tty=True ) output = "" if self.verbose: print("Output:", end=" ") + if not self.stream: + exec_log = [exec_log] + for chunk in exec_log: chunk_str = chunk.decode("utf-8") output += chunk_str if self.verbose: print(chunk_str, end="") - return ConsoleOutput(output) + return ConsoleOutput(exit_code, output) diff --git a/llm_sandbox/kubernetes.py b/llm_sandbox/kubernetes.py index 29eef68..96f2d93 100644 --- a/llm_sandbox/kubernetes.py +++ b/llm_sandbox/kubernetes.py @@ -1,8 +1,10 @@ import io import os +import tempfile import time import uuid import tarfile +from pathlib import Path from typing import List, Optional from kubernetes import client as k8s_client, config @@ -27,6 +29,7 @@ def __init__( verbose: bool = False, kube_namespace: Optional[str] = "default", env_vars: Optional[dict] = None, + pod_manifest: Optional[dict] = None, ): """ Create a new sandbox session @@ -37,6 +40,8 @@ def __init__( :param keep_template: if True, the image and container will not be removed after the session ends :param verbose: if True, print messages :param kube_namespace: Kubernetes namespace to use, default is 'default' + :param env_vars: Environment variables to use + :param pod_manifest: Pod manifest to use (ignores other settings: `image`, `kube_namespace` and `env_vars`) """ super().__init__(lang, verbose) if lang not in SupportedLanguageValues: @@ -62,11 +67,10 @@ def __init__( self.keep_template = keep_template self.container = None self.env_vars = env_vars + self.pod_manifest = pod_manifest or self._default_pod_manifest() + self._reconfigure_with_pod_manifest() - def open(self): - self._create_kubernetes_pod() - - def _create_kubernetes_pod(self): + def _default_pod_manifest(self) -> dict: pod_manifest = { "apiVersion": "v1", "kind": "Pod", @@ -83,11 +87,23 @@ def _create_kubernetes_pod(self): } # Add environment variables if provided if self.env_vars: - pod_manifest["spec"]["containers"][0]["env"] = [ + pod_manifest["spec"]["containers"][0]["env"] = [ # type: ignore[index] {"name": key, "value": value} for key, value in self.env_vars.items() ] + return pod_manifest + + def _reconfigure_with_pod_manifest(self): + self.pod_name = self.pod_manifest.get("metadata", {}).get("name", self.pod_name) + self.kube_namespace = self.pod_manifest.get("metadata", {}).get( + "namespace", self.kube_namespace + ) + + def open(self): + self._create_kubernetes_pod() + + def _create_kubernetes_pod(self): self.client.create_namespaced_pod( - namespace=self.kube_namespace, body=pod_manifest + namespace=self.kube_namespace, body=self.pod_manifest ) while True: @@ -142,16 +158,18 @@ def run(self, code: str, libraries: Optional[List] = None) -> ConsoleOutput: f"Failed to install library {library}: {output}" ) - code_file = f"/tmp/code.{get_code_file_extension(self.lang)}" + code_file_name = f"code.{get_code_file_extension(self.lang)}" if self.lang == SupportedLanguage.GO: code_dest_file = "/example/code.go" else: - code_dest_file = code_file + code_dest_file = f"/tmp/{code_file_name}" - with open(code_file, "w") as f: - f.write(code) + with tempfile.TemporaryDirectory() as tmp_dir: + code_file = Path(tmp_dir) / code_file_name + with open(code_file, "w") as f: + f.write(code) + self.copy_to_runtime(str(code_file), code_dest_file) - self.copy_to_runtime(code_file, code_dest_file) commands = get_code_execution_command(self.lang, code_dest_file) output = KubernetesConsoleOutput(0, "") @@ -164,7 +182,7 @@ def run(self, code: str, libraries: Optional[List] = None) -> ConsoleOutput: if output.exit_code != 0: break - return ConsoleOutput(output.text) + return ConsoleOutput(output.exit_code, output.text) def copy_to_runtime(self, src: str, dest: str): if not self.container: diff --git a/llm_sandbox/micromamba.py b/llm_sandbox/micromamba.py index 9228952..3bf236f 100644 --- a/llm_sandbox/micromamba.py +++ b/llm_sandbox/micromamba.py @@ -68,4 +68,4 @@ def execute_command( if self.verbose: print(chunk_str, end="") - return ConsoleOutput(output) + return ConsoleOutput(exit_code, output) diff --git a/llm_sandbox/podman.py b/llm_sandbox/podman.py index 12e543e..f70bba4 100644 --- a/llm_sandbox/podman.py +++ b/llm_sandbox/podman.py @@ -218,7 +218,7 @@ def run(self, code: str, libraries: Optional[List] = None) -> ConsoleOutput: self.copy_to_runtime(code_file, code_dest_file) - output = ConsoleOutput("") + output = ConsoleOutput(0, "") commands = get_code_execution_command(self.lang, code_dest_file) for command in commands: if self.lang == SupportedLanguage.GO: @@ -304,4 +304,4 @@ def execute_command( if self.verbose: print(output) - return ConsoleOutput(output) + return ConsoleOutput(exit_code, output) diff --git a/tests/test_session.py b/tests/test_session.py index 4bc6f21..bbaa754 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -124,6 +124,18 @@ def test_execute_empty_command(self): with self.assertRaises(ValueError): self.session.execute_command("") + def test_execute_failing_command(self): + mock_container = MagicMock() + self.session.container = mock_container + + command = "exit 1" + mock_container.exec_run.return_value = (1, iter([])) + + output = self.session.execute_command(command) + mock_container.exec_run.assert_called_with(command, stream=True, tty=True) + self.assertEqual(output.exit_code, 1) + self.assertEqual(output.text, "") + if __name__ == "__main__": unittest.main() diff --git a/tests/test_session_kubernetes.py b/tests/test_session_kubernetes.py new file mode 100644 index 0000000..36d429c --- /dev/null +++ b/tests/test_session_kubernetes.py @@ -0,0 +1,64 @@ +import unittest +from unittest.mock import patch, MagicMock +from llm_sandbox.kubernetes import SandboxKubernetesSession + + +class TestSandboxKubernetesSession(unittest.TestCase): + @patch("kubernetes.config.load_kube_config") + def setUp(self, mock_kube_config): + self.image = "python:3.9.19-bullseye" + self.dockerfile = None + self.lang = "python" + self.keep_template = False + self.verbose = False + + self.session = SandboxKubernetesSession( + image=self.image, + dockerfile=self.dockerfile, + lang=self.lang, + keep_template=self.keep_template, + verbose=self.verbose, + ) + + @patch("kubernetes.config.load_kube_config") + def test_with_pod_manifest(self, mock_kube_config): + pod_manifest = { + "apiVersion": "v1", + "kind": "Pod", + "metadata": { + "name": "test", + "namespace": "test", + "labels": {"app": "sandbox"}, + }, + "spec": { + "containers": [ + { + "name": "sandbox-container", + "image": "test", + "tty": True, + "volumeMounts": { + "name": "tmp", + "mountPath": "/tmp", + }, + } + ], + "volumes": [{"name": "tmp", "emptyDir": {"sizeLimit": "5Gi"}}], + }, + } + self.session = SandboxKubernetesSession( + image=self.image, + dockerfile=self.dockerfile, + lang=self.lang, + keep_template=self.keep_template, + verbose=self.verbose, + pod_manifest=pod_manifest, + ) + + self.session.client = MagicMock() + self.session.client.read_namespaced_pod.return_value.status.phase = "Running" + self.session.open() + + self.session.client.create_namespaced_pod.assert_called_with( + namespace="test", + body=pod_manifest, + ) diff --git a/tests/test_session_podman.py b/tests/test_session_podman.py index c9cf651..430b70b 100644 --- a/tests/test_session_podman.py +++ b/tests/test_session_podman.py @@ -149,6 +149,19 @@ def test_execute_empty_command(self): with self.assertRaises(ValueError): self.session.execute_command("") + def test_execute_failing_command(self): + mock_container = MagicMock() + self.session.container = mock_container + + command = "exit 1" + mock_container.exec_run.return_value = (1, iter([])) + + output = self.session.execute_command(command) + + mock_container.exec_run.assert_called_with(command, stream=True, tty=True) + self.assertEqual(output.exit_code, 1) + self.assertEqual(output.text, "") + if __name__ == "__main__": unittest.main()