diff --git a/runhouse/resources/hardware/cluster.py b/runhouse/resources/hardware/cluster.py index 5299228c1..513ccdeed 100644 --- a/runhouse/resources/hardware/cluster.py +++ b/runhouse/resources/hardware/cluster.py @@ -1786,6 +1786,23 @@ def run_bash_over_ssh( return return_codes + def kill_process( + self, + process: str, + ): + """Kill a process on the cluster. + + Args: + process (str): Process to kill. + + Example: + >>> rh.cluster("rh-cpu").kill("ray") + """ + if self.on_this_cluster(): + obj_store.delete_servlet_contents(process) + else: + self.client.kill_process(process) + def run( self, commands: Union[str, List[str]], diff --git a/runhouse/servers/http/http_client.py b/runhouse/servers/http/http_client.py index 664dbf1b2..4c01bcc35 100644 --- a/runhouse/servers/http/http_client.py +++ b/runhouse/servers/http/http_client.py @@ -30,6 +30,7 @@ GetObjectParams, handle_response, InstallPackageParams, + KillProcessParams, LogsParams, OutputType, PutObjectParams, @@ -837,6 +838,13 @@ def set_process_env_vars( ).model_dump(), ) + def kill_process(self, process_name: str): + return self.request_json( + "/kill_process", + req_type="post", + json_dict=KillProcessParams(process_name=process_name).model_dump(), + ) + def install_package(self, package: "Package", conda_env_name: Optional[str] = None): return self.request_json( "/install_package", diff --git a/runhouse/servers/http/http_server.py b/runhouse/servers/http/http_server.py index 73cb9ed54..4203eb8cd 100644 --- a/runhouse/servers/http/http_server.py +++ b/runhouse/servers/http/http_server.py @@ -51,6 +51,7 @@ get_token_from_request, handle_exception_response, InstallPackageParams, + KillProcessParams, LogsParams, OutputType, PutObjectParams, @@ -427,6 +428,17 @@ async def create_process(request: Request, params: CreateProcessParams): e, traceback.format_exc(), from_http_server=True ) + @staticmethod + @app.post("/kill_process") + @validate_cluster_access + async def kill_process(request: Request, params: KillProcessParams): + try: + await obj_store.adelete_servlet_contents(params.process_name) + except Exception as e: + return handle_exception_response( + e, traceback.format_exc(), from_http_server=True + ) + @staticmethod @app.post("/process_env_vars") @validate_cluster_access diff --git a/runhouse/servers/http/http_utils.py b/runhouse/servers/http/http_utils.py index 28a1c1585..632565fd1 100644 --- a/runhouse/servers/http/http_utils.py +++ b/runhouse/servers/http/http_utils.py @@ -87,6 +87,10 @@ class GetObjectParams(BaseModel): remote: Optional[bool] = False +class KillProcessParams(BaseModel): + process_name: str + + class LogsParams(BaseModel): run_name: str node_ip: Optional[str] = None diff --git a/tests/test_resources/test_clusters/test_cluster.py b/tests/test_resources/test_clusters/test_cluster.py index ac3f1f59b..4a4f17fa4 100644 --- a/tests/test_resources/test_clusters/test_cluster.py +++ b/tests/test_resources/test_clusters/test_cluster.py @@ -1228,3 +1228,14 @@ def test_cluster_run_bash_in_process(self, cluster): res = cluster.run_bash("echo hello", process=process) assert res[0][0] == 0 assert res[0][1].strip() == "hello" + + @pytest.mark.level("local") + @pytest.mark.clustertest + def test_cluster_kill_process(self, cluster): + process = cluster.ensure_process_created(name="new_test_process") + assert process in cluster.list_processes() + cluster.put(key="new_key", obj="val", process=process) + assert cluster.get("new_key") == "val" + cluster.kill_process(process) + assert cluster.get("new_key") is None + assert process not in cluster.list_processes()