Skip to content

Commit

Permalink
Merge pull request #33 from Eventual-Inc/connect
Browse files Browse the repository at this point in the history
feat: Add connect command
  • Loading branch information
raunakab authored Jan 15, 2025
2 parents f1c4978 + b4f8a74 commit 132ef1d
Showing 1 changed file with 89 additions and 58 deletions.
147 changes: 89 additions & 58 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ use std::{
#[cfg(not(test))]
use anyhow::bail;
use aws_config::{BehaviorVersion, Region};
use aws_sdk_ec2::types::InstanceStateName;
use aws_sdk_ec2::Client;
use aws_sdk_ec2::{types::InstanceStateName, Client};
use clap::{Parser, Subcommand};
use comfy_table::{
modifiers, presets, Attribute, Cell, CellAlignment, Color, ContentArrangement, Table,
Expand All @@ -23,10 +22,10 @@ use serde::{Deserialize, Serialize};
use tempdir::TempDir;
use tokio::{
fs,
io::{AsyncReadExt, AsyncWriteExt, BufReader},
io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
process::{Child, Command},
time::timeout,
};
use tokio::{io::AsyncBufReadExt, time::timeout};

type StrRef = Arc<str>;
type PathRef = Arc<Path>;
Expand Down Expand Up @@ -65,10 +64,14 @@ enum SubCommand {

/// Submit a job to the Ray cluster.
///
/// The configurations of the job should be placed inside of your daft-launcher configuration
/// file.
/// The configurations of the job should be placed inside of your
/// daft-launcher configuration file.
Submit(Submit),

/// Establish an ssh port-forward connection from your local machine to the
/// Ray cluster.
Connect(Connect),

/// Spin down a given cluster and put the nodes to "sleep".
///
/// This will *not* delete the nodes, only stop them. The nodes can be
Expand Down Expand Up @@ -116,6 +119,16 @@ struct Submit {
config_path: ConfigPath,
}

#[derive(Debug, Parser, Clone, PartialEq, Eq)]
struct Connect {
/// The local port to connect to the remote Ray cluster.
#[arg(long, default_value = "8265")]
port: u16,

#[clap(flatten)]
config_path: ConfigPath,
}

#[derive(Debug, Parser, Clone, PartialEq, Eq)]
struct ConfigPath {
/// Path to configuration file.
Expand Down Expand Up @@ -336,10 +349,18 @@ async fn read_and_convert(
let key_name = daft_config
.setup
.ssh_private_key
.clone()
.file_stem()
.ok_or_else(|| anyhow::anyhow!(""))?
.ok_or_else(|| {
anyhow::anyhow!(r#"Private key doesn't have a name of the format "name.ext""#)
})?
.to_str()
.ok_or_else(|| anyhow::anyhow!(""))?
.ok_or_else(|| {
anyhow::anyhow!(
"The file {:?} does not a valid UTF-8 name",
daft_config.setup.ssh_private_key,
)
})?
.into();
let iam_instance_profile = daft_config
.setup
Expand Down Expand Up @@ -649,25 +670,6 @@ async fn get_region(region: Option<StrRef>, config: impl AsRef<Path>) -> anyhow:
})
}

async fn start_ssh_port_forward(
user: &str,
addr: Ipv4Addr,
ssh_private_key: &Path,
) -> anyhow::Result<Child> {
let child = Command::new("ssh")
.arg("-N")
.arg("-i")
.arg(ssh_private_key)
.arg("-L")
.arg("8265:localhost:8265")
.arg(format!("{user}@{addr}"))
.arg("-v")
.stderr(Stdio::piped())
.kill_on_drop(true)
.spawn()?;
Ok(child)
}

async fn get_head_node_ip(ray_path: impl AsRef<Path>) -> anyhow::Result<Ipv4Addr> {
let mut ray_command = Command::new("ray")
.arg("get-head-ip")
Expand Down Expand Up @@ -700,6 +702,54 @@ async fn get_head_node_ip(ray_path: impl AsRef<Path>) -> anyhow::Result<Ipv4Addr
Ok(addr)
}

async fn establish_ssh_portforward(
ray_path: impl AsRef<Path>,
daft_config: &DaftConfig,
port: Option<u16>,
) -> anyhow::Result<Child> {
let user = daft_config.setup.ssh_user.as_ref();
let addr = get_head_node_ip(ray_path).await?;
let port = port.unwrap_or(8265);
let mut child = Command::new("ssh")
.arg("-N")
.arg("-i")
.arg(daft_config.setup.ssh_private_key.as_ref())
.arg("-L")
.arg(format!("{port}:localhost:8265"))
.arg(format!("{user}@{addr}"))
.arg("-v")
.stderr(Stdio::piped())
.kill_on_drop(true)
.spawn()?;

// We wait for the ssh port-forwarding process to write a specific string to the
// output.
//
// This is a little hacky (and maybe even incorrect across platforms) since we
// are just parsing the output and observing if a specific string has been
// printed. It may be incorrect across platforms because the SSH standard
// does *not* specify a standard "success-message" to printout if the ssh
// port-forward was successful.
timeout(Duration::from_secs(5), {
let stderr = child.stderr.take().expect("stderr must exist");
async move {
let mut lines = BufReader::new(stderr).lines();
loop {
let Some(line) = lines.next_line().await? else {
anyhow::bail!("Failed to establish ssh port-forward to {addr}");
};
if line.starts_with(format!("Authenticated to {addr}").as_str()) {
break Ok(());
}
}
}
})
.await
.map_err(|_| anyhow::anyhow!("Establishing an ssh port-forward to {addr} timed out"))??;

Ok(child)
}

async fn run(daft_launcher: DaftLauncher) -> anyhow::Result<()> {
match daft_launcher.sub_command {
SubCommand::Init(Init { path }) => {
Expand Down Expand Up @@ -752,37 +802,7 @@ async fn run(daft_launcher: DaftLauncher) -> anyhow::Result<()> {

let (_temp_dir, ray_path) = create_temp_ray_file()?;
write_ray_config(ray_config, &ray_path).await?;
let addr = get_head_node_ip(ray_path).await?;
let mut child = start_ssh_port_forward(
daft_config.setup.ssh_user.as_ref(),
addr,
daft_config.setup.ssh_private_key.as_ref(),
)
.await?;

// We wait for the ssh port-forwarding process to write a specific string to the
// output.
//
// This is a little hacky (and maybe even incorrect across platforms) since we are just
// parsing the output and observing if a specific string has been printed. It may be
// incorrect across platforms because the SSH standard does *not* specify a standard
// "success-message" to printout if the ssh port-forward was successful.
timeout(Duration::from_secs(5), async move {
let mut lines =
BufReader::new(child.stderr.take().expect("stderr must exist")).lines();
loop {
let Some(line) = lines.next_line().await? else {
anyhow::bail!("Failed to establish ssh port-forward to {addr}");
};
if line.starts_with(format!("Authenticated to {addr}").as_str()) {
break Ok(());
}
}
})
.await
.map_err(|_| {
anyhow::anyhow!("Establishing an ssh port-forward to {addr} timed out")
})??;
let _child = establish_ssh_portforward(ray_path, &daft_config, None).await?;

let exit_status = Command::new("ray")
.env("PYTHONUNBUFFERED", "1")
Expand All @@ -798,6 +818,17 @@ async fn run(daft_launcher: DaftLauncher) -> anyhow::Result<()> {
anyhow::bail!("Failed to submit job to the ray cluster");
};
}
SubCommand::Connect(Connect { port, config_path }) => {
let (daft_config, ray_config) = read_and_convert(&config_path.config, None).await?;
assert_is_logged_in_with_aws().await?;

let (_temp_dir, ray_path) = create_temp_ray_file()?;
write_ray_config(ray_config, &ray_path).await?;
let _ = establish_ssh_portforward(ray_path, &daft_config, Some(port))
.await?
.wait_with_output()
.await?;
}
SubCommand::Stop(ConfigPath { config }) => {
let (_, ray_config) = read_and_convert(&config, Some(TeardownBehaviour::Stop)).await?;
assert_is_logged_in_with_aws().await?;
Expand Down

0 comments on commit 132ef1d

Please sign in to comment.