Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functionality for cluster.kill. #1588

Merged
merged 1 commit into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1786,6 +1786,23 @@ def run_bash_over_ssh(

return return_codes

def kill_process(
self,
process: str,
):
"""Kill a process on the cluster.

Args:
process (str): Process to kill.

Example:
>>> rh.cluster("rh-cpu").kill("ray")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - maybe use a diff string (that's not also a package) for the example?

"""
if self.on_this_cluster():
obj_store.delete_servlet_contents(process)
else:
self.client.kill_process(process)

def run(
self,
commands: Union[str, List[str]],
Expand Down
8 changes: 8 additions & 0 deletions runhouse/servers/http/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
GetObjectParams,
handle_response,
InstallPackageParams,
KillProcessParams,
LogsParams,
OutputType,
PutObjectParams,
Expand Down Expand Up @@ -837,6 +838,13 @@ def set_process_env_vars(
).model_dump(),
)

def kill_process(self, process_name: str):
return self.request_json(
"/kill_process",
req_type="post",
json_dict=KillProcessParams(process_name=process_name).model_dump(),
)

def install_package(self, package: "Package", conda_env_name: Optional[str] = None):
return self.request_json(
"/install_package",
Expand Down
12 changes: 12 additions & 0 deletions runhouse/servers/http/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
get_token_from_request,
handle_exception_response,
InstallPackageParams,
KillProcessParams,
LogsParams,
OutputType,
PutObjectParams,
Expand Down Expand Up @@ -427,6 +428,17 @@ async def create_process(request: Request, params: CreateProcessParams):
e, traceback.format_exc(), from_http_server=True
)

@staticmethod
@app.post("/kill_process")
@validate_cluster_access
async def kill_process(request: Request, params: KillProcessParams):
try:
await obj_store.adelete_servlet_contents(params.process_name)
except Exception as e:
return handle_exception_response(
e, traceback.format_exc(), from_http_server=True
)

@staticmethod
@app.post("/process_env_vars")
@validate_cluster_access
Expand Down
4 changes: 4 additions & 0 deletions runhouse/servers/http/http_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ class GetObjectParams(BaseModel):
remote: Optional[bool] = False


class KillProcessParams(BaseModel):
process_name: str


class LogsParams(BaseModel):
run_name: str
node_ip: Optional[str] = None
Expand Down
11 changes: 11 additions & 0 deletions tests/test_resources/test_clusters/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,3 +1228,14 @@ def test_cluster_run_bash_in_process(self, cluster):
res = cluster.run_bash("echo hello", process=process)
assert res[0][0] == 0
assert res[0][1].strip() == "hello"

@pytest.mark.level("local")
@pytest.mark.clustertest
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's 2 existing tests in test_cluster depending on this feature that I marked as skip - can you address those too? either update or remove them if this test covers it (marked with @pytest.mark.skip("pending cluster.kill functionality"))

def test_cluster_kill_process(self, cluster):
process = cluster.ensure_process_created(name="new_test_process")
assert process in cluster.list_processes()
cluster.put(key="new_key", obj="val", process=process)
assert cluster.get("new_key") == "val"
cluster.kill_process(process)
assert cluster.get("new_key") is None
assert process not in cluster.list_processes()
Loading