diff --git a/dashboard/modules/job/tests/test_cli.py b/dashboard/modules/job/tests/test_cli.py index d5451405e1a64..71694d09bc16a 100644 --- a/dashboard/modules/job/tests/test_cli.py +++ b/dashboard/modules/job/tests/test_cli.py @@ -79,22 +79,30 @@ def set_env_var(key: str, val: Optional[str] = None): os.environ[key] = old_val +def check_exit_code(result, exit_code): + assert result.exit_code == exit_code, result.output + + def _job_cli_group_test_address(mock_sdk_client, cmd, *args): runner = CliRunner() + create_cluster_if_needed = True if cmd == "submit" else False # Test passing address via command line. result = runner.invoke(job_cli_group, [cmd, "--address=arg_addr", *args]) - assert mock_sdk_client.called_with("arg_addr") - assert result.exit_code == 0 + mock_sdk_client.assert_called_with("arg_addr", create_cluster_if_needed) + with pytest.raises(AssertionError): + mock_sdk_client.assert_called_with("some_other_addr", True) + check_exit_code(result, 0) # Test passing address via env var. with set_env_var("RAY_ADDRESS", "env_addr"): result = runner.invoke(job_cli_group, [cmd, *args]) - assert result.exit_code == 0 - assert mock_sdk_client.called_with("env_addr") + check_exit_code(result, 0) + # RAY_ADDRESS is read inside the SDK client. + mock_sdk_client.assert_called_with(None, create_cluster_if_needed) # Test passing no address. result = runner.invoke(job_cli_group, [cmd, *args]) - assert result.exit_code == 0 - assert mock_sdk_client.called_with(None) + check_exit_code(result, 0) + mock_sdk_client.assert_called_with(None, create_cluster_if_needed) class TestList: @@ -108,16 +116,16 @@ def test_list(self, mock_sdk_client): job_cli_group, ["list"], ) - assert result.exit_code == 0 + check_exit_code(result, 0) result = runner.invoke(job_cli_group, ["submit", "--", "echo hello"]) - assert result.exit_code == 0 + check_exit_code(result, 0) result = runner.invoke( job_cli_group, ["list"], ) - assert result.exit_code == 0 + check_exit_code(result, 0) class TestSubmit: @@ -130,21 +138,31 @@ def test_working_dir(self, mock_sdk_client): with set_env_var("RAY_ADDRESS", "env_addr"): result = runner.invoke(job_cli_group, ["submit", "--", "echo hello"]) - assert result.exit_code == 0 - assert mock_client_instance.called_with(runtime_env={}) + check_exit_code(result, 0) + mock_client_instance.submit_job.assert_called_with( + entrypoint='"echo hello"', submission_id=None, runtime_env={} + ) result = runner.invoke( job_cli_group, - ["submit", "--", "--working-dir", "blah", "--", "echo hello"], + ["submit", "--working-dir", "blah", "--", "echo hello"], + ) + check_exit_code(result, 0) + mock_client_instance.submit_job.assert_called_with( + entrypoint='"echo hello"', + submission_id=None, + runtime_env={"working_dir": "blah"}, ) - assert result.exit_code == 0 - assert mock_client_instance.called_with(runtime_env={"working_dir": "blah"}) result = runner.invoke( - job_cli_group, ["submit", "--", "--working-dir='.'", "--", "echo hello"] + job_cli_group, ["submit", "--working-dir='.'", "--", "echo hello"] + ) + check_exit_code(result, 0) + mock_client_instance.submit_job.assert_called_with( + entrypoint='"echo hello"', + submission_id=None, + runtime_env={"working_dir": "'.'"}, ) - assert result.exit_code == 0 - assert mock_client_instance.called_with(runtime_env={"working_dir": "."}) def test_runtime_env(self, mock_sdk_client, runtime_env_formats): runner = CliRunner() @@ -156,16 +174,20 @@ def test_runtime_env(self, mock_sdk_client, runtime_env_formats): result = runner.invoke( job_cli_group, ["submit", "--runtime-env", env_yaml, "--", "echo hello"] ) - assert result.exit_code == 0 - assert mock_client_instance.called_with(runtime_env=env_dict) + check_exit_code(result, 0) + mock_client_instance.submit_job.assert_called_with( + entrypoint='"echo hello"', submission_id=None, runtime_env=env_dict + ) # Test passing via json. result = runner.invoke( job_cli_group, ["submit", "--runtime-env-json", env_json, "--", "echo hello"], ) - assert result.exit_code == 0 - assert mock_client_instance.called_with(runtime_env=env_dict) + check_exit_code(result, 0) + mock_client_instance.submit_job.assert_called_with( + entrypoint='"echo hello"', submission_id=None, runtime_env=env_dict + ) # Test passing both throws an error. result = runner.invoke( @@ -180,7 +202,7 @@ def test_runtime_env(self, mock_sdk_client, runtime_env_formats): "echo hello", ], ) - assert result.exit_code == 1 + check_exit_code(result, 1) assert "Only one of" in str(result.exception) # Test overriding working_dir. @@ -197,8 +219,10 @@ def test_runtime_env(self, mock_sdk_client, runtime_env_formats): "echo hello", ], ) - assert result.exit_code == 0 - assert mock_client_instance.called_with(runtime_env=env_dict) + check_exit_code(result, 0) + mock_client_instance.submit_job.assert_called_with( + entrypoint='"echo hello"', submission_id=None, runtime_env=env_dict + ) result = runner.invoke( job_cli_group, @@ -212,8 +236,10 @@ def test_runtime_env(self, mock_sdk_client, runtime_env_formats): "echo hello", ], ) - assert result.exit_code == 0 - assert mock_client_instance.called_with(runtime_env=env_dict) + check_exit_code(result, 0) + mock_client_instance.submit_job.assert_called_with( + entrypoint='"echo hello"', submission_id=None, runtime_env=env_dict + ) def test_job_id(self, mock_sdk_client): runner = CliRunner() @@ -221,15 +247,19 @@ def test_job_id(self, mock_sdk_client): with set_env_var("RAY_ADDRESS", "env_addr"): result = runner.invoke(job_cli_group, ["submit", "--", "echo hello"]) - assert result.exit_code == 0 - assert mock_client_instance.called_with(submission_id=None) + check_exit_code(result, 0) + mock_client_instance.submit_job.assert_called_with( + entrypoint='"echo hello"', submission_id=None, runtime_env={} + ) result = runner.invoke( job_cli_group, - ["submit", "--", "--submission-id=my_job_id", "echo hello"], + ["submit", "--submission-id=my_job_id", "--", "echo hello"], + ) + check_exit_code(result, 0) + mock_client_instance.submit_job.assert_called_with( + entrypoint='"echo hello"', submission_id="my_job_id", runtime_env={} ) - assert result.exit_code == 0 - assert mock_client_instance.called_with(submission_id="my_job_id") if __name__ == "__main__":