Skip to content

Commit

Permalink
Add functionality for cluster.kill.
Browse files Browse the repository at this point in the history
  • Loading branch information
rohinb2 committed Dec 12, 2024
1 parent 2fc71b6 commit 26b26e1
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 0 deletions.
17 changes: 17 additions & 0 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1786,6 +1786,23 @@ def run_bash_over_ssh(

return return_codes

def kill_process(
self,
process: str,
):
"""Kill a process on the cluster.
Args:
process (str): Process to kill.
Example:
>>> rh.cluster("rh-cpu").kill("ray")
"""
if self.on_this_cluster():
obj_store.delete_servlet_contents(process)
else:
self.client.kill_process(process)

def run(
self,
commands: Union[str, List[str]],
Expand Down
8 changes: 8 additions & 0 deletions runhouse/servers/http/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
GetObjectParams,
handle_response,
InstallPackageParams,
KillProcessParams,
LogsParams,
OutputType,
PutObjectParams,
Expand Down Expand Up @@ -837,6 +838,13 @@ def set_process_env_vars(
).model_dump(),
)

def kill_process(self, process_name: str):
return self.request_json(
"/kill_process",
req_type="post",
json_dict=KillProcessParams(process_name=process_name).model_dump(),
)

def install_package(self, package: "Package", conda_env_name: Optional[str] = None):
return self.request_json(
"/install_package",
Expand Down
12 changes: 12 additions & 0 deletions runhouse/servers/http/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
get_token_from_request,
handle_exception_response,
InstallPackageParams,
KillProcessParams,
LogsParams,
OutputType,
PutObjectParams,
Expand Down Expand Up @@ -427,6 +428,17 @@ async def create_process(request: Request, params: CreateProcessParams):
e, traceback.format_exc(), from_http_server=True
)

@staticmethod
@app.post("/kill_process")
@validate_cluster_access
async def kill_process(request: Request, params: KillProcessParams):
try:
await obj_store.adelete_servlet_contents(params.process_name)
except Exception as e:
return handle_exception_response(
e, traceback.format_exc(), from_http_server=True
)

@staticmethod
@app.post("/process_env_vars")
@validate_cluster_access
Expand Down
4 changes: 4 additions & 0 deletions runhouse/servers/http/http_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ class GetObjectParams(BaseModel):
remote: Optional[bool] = False


class KillProcessParams(BaseModel):
process_name: str


class LogsParams(BaseModel):
run_name: str
node_ip: Optional[str] = None
Expand Down
11 changes: 11 additions & 0 deletions tests/test_resources/test_clusters/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,3 +1228,14 @@ def test_cluster_run_bash_in_process(self, cluster):
res = cluster.run_bash("echo hello", process=process)
assert res[0][0] == 0
assert res[0][1].strip() == "hello"

@pytest.mark.level("local")
@pytest.mark.clustertest
def test_cluster_kill_process(self, cluster):
process = cluster.ensure_process_created(name="new_test_process")
assert process in cluster.list_processes()
cluster.put(key="new_key", obj="val", process=process)
assert cluster.get("new_key") == "val"
cluster.kill_process(process)
assert cluster.get("new_key") is None
assert process not in cluster.list_processes()

0 comments on commit 26b26e1

Please sign in to comment.