Skip to content

Commit

Permalink
Support specifying internal_ip for SSH fleet hosts (#2056)
Browse files Browse the repository at this point in the history
* Support specifying internal_ip for SSH fleet hosts

* Validate internal_ip

* Handle client backward compatibility

* Remove extra space
  • Loading branch information
r4victor authored Dec 4, 2024
1 parent 3193792 commit c10b1fe
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 19 deletions.
29 changes: 28 additions & 1 deletion src/dstack/_internal/core/models/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,29 @@ class SSHHostParams(CoreModel):
identity_file: Annotated[
Optional[str], Field(description="The private key to use for this host")
] = None
internal_ip: Annotated[
Optional[str],
Field(
description=(
"The internal IP of the host used for communication inside the cluster."
" If not specified, `dstack` will use the IP address from `network` or from the first found internal network."
)
),
] = None
ssh_key: Optional[SSHKey] = None

@validator("internal_ip")
def validate_internal_ip(cls, value):
if value is None:
return value
try:
internal_ip = ipaddress.ip_address(value)
except ValueError as e:
raise ValueError("Invalid IP address") from e
if not internal_ip.is_private:
raise ValueError("IP address is not private")
return value


class SSHParams(CoreModel):
user: Annotated[Optional[str], Field(description="The user to log in with on all hosts")] = (
Expand All @@ -70,7 +91,13 @@ class SSHParams(CoreModel):
]
network: Annotated[
Optional[str],
Field(description="The network address for cluster setup in the format `<ip>/<netmask>`"),
Field(
description=(
"The network address for cluster setup in the format `<ip>/<netmask>`."
" `dstack` will use IP addresses from this network for communication between hosts."
" If not specified, `dstack` will use IPs from the first found internal network."
)
),
]

@validator("network")
Expand Down
29 changes: 24 additions & 5 deletions src/dstack/_internal/server/background/tasks/process_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
from dstack._internal.server.utils.common import run_async
from dstack._internal.utils.common import get_current_datetime
from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.network import get_ip_from_network
from dstack._internal.utils.network import get_ip_from_network, is_ip_among_addresses
from dstack._internal.utils.ssh import (
pkey_from_str,
)
Expand Down Expand Up @@ -290,16 +290,20 @@ async def _add_remote(instance: InstanceModel) -> None:

instance_type = host_info_to_instance_type(host_info)
instance_network = None
internal_ip = None
try:
default_jpd = JobProvisioningData.__response__.parse_raw(instance.job_provisioning_data)
instance_network = default_jpd.instance_network
internal_ip = default_jpd.internal_ip
except ValidationError:
pass

internal_ip = get_ip_from_network(
network=instance_network,
addresses=host_info.get("addresses", []),
)
host_network_addresses = host_info.get("addresses", [])
if internal_ip is None:
internal_ip = get_ip_from_network(
network=instance_network,
addresses=host_network_addresses,
)
if instance_network is not None and internal_ip is None:
instance.status = InstanceStatus.TERMINATED
instance.termination_reason = "Failed to locate internal IP address on the given network"
Expand All @@ -312,6 +316,21 @@ async def _add_remote(instance: InstanceModel) -> None:
},
)
return
if internal_ip is not None:
if not is_ip_among_addresses(ip_address=internal_ip, addresses=host_network_addresses):
instance.status = InstanceStatus.TERMINATED
instance.termination_reason = (
"Specified internal IP not found among instance interfaces"
)
logger.warning(
"Failed to add instance %s: specified internal IP not found among instance interfaces",
instance.name,
extra={
"instance_name": instance.name,
"instance_status": InstanceStatus.TERMINATED.value,
},
)
return

region = instance.region
jpd = JobProvisioningData(
Expand Down
15 changes: 15 additions & 0 deletions src/dstack/_internal/server/services/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,11 +402,13 @@ async def create_fleet_ssh_instance_model(
ssh_user = ssh_params.user
ssh_key = ssh_params.ssh_key
port = ssh_params.port
internal_ip = None
else:
hostname = host.hostname
ssh_user = host.user or ssh_params.user
ssh_key = host.ssh_key or ssh_params.ssh_key
port = host.port or ssh_params.port
internal_ip = host.internal_ip

if ssh_user is None or ssh_key is None:
# This should not be reachable but checked by fleet spec validation
Expand All @@ -422,6 +424,7 @@ async def create_fleet_ssh_instance_model(
ssh_user=ssh_user,
ssh_keys=[ssh_key],
env=env,
internal_ip=internal_ip,
instance_network=ssh_params.network,
port=port or 22,
)
Expand Down Expand Up @@ -678,6 +681,7 @@ def _validate_fleet_spec(spec: FleetSpec):
for host in spec.configuration.ssh_config.hosts:
if is_core_model_instance(host, SSHHostParams) and host.ssh_key is not None:
_validate_ssh_key(host.ssh_key)
_validate_internal_ips(spec.configuration.ssh_config)


def _validate_all_ssh_params_specified(ssh_config: SSHParams):
Expand Down Expand Up @@ -706,6 +710,17 @@ def _validate_ssh_key(ssh_key: SSHKey):
)


def _validate_internal_ips(ssh_config: SSHParams):
internal_ips_num = 0
for host in ssh_config.hosts:
if not isinstance(host, str) and host.internal_ip is not None:
internal_ips_num += 1
if internal_ips_num != 0 and internal_ips_num != len(ssh_config.hosts):
raise ServerClientError("internal_ip must be specified for all hosts")
if internal_ips_num > 0 and ssh_config.network is not None:
raise ServerClientError("internal_ip is mutually exclusive with network")


def _get_fleet_nodes_to_provision(spec: FleetSpec) -> int:
if spec.configuration.nodes is None or spec.configuration.nodes.min is None:
return 0
Expand Down
3 changes: 2 additions & 1 deletion src/dstack/_internal/server/services/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,7 @@ async def create_ssh_instance_model(
pool: PoolModel,
instance_name: str,
instance_num: int,
internal_ip: Optional[str],
instance_network: Optional[str],
region: Optional[str],
host: str,
Expand All @@ -676,7 +677,7 @@ async def create_ssh_instance_model(
instance_id=instance_name,
hostname=host,
region=host_region,
internal_ip=None,
internal_ip=internal_ip,
instance_network=instance_network,
price=0,
username=ssh_user,
Expand Down
18 changes: 17 additions & 1 deletion src/dstack/_internal/utils/network.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import ipaddress
from typing import Optional, Sequence
from typing import List, Optional, Sequence


def get_ip_from_network(network: Optional[str], addresses: Sequence[str]) -> Optional[str]:
Expand Down Expand Up @@ -32,3 +32,19 @@ def get_ip_from_network(network: Optional[str], addresses: Sequence[str]) -> Opt
# return any ipv4
internal_ip = str(ip_addresses[0]) if ip_addresses else None
return internal_ip


def is_ip_among_addresses(ip_address: str, addresses: Sequence[str]) -> bool:
ip_addresses = get_ips_from_addresses(addresses)
return ip_address in ip_addresses


def get_ips_from_addresses(addresses: Sequence[str]) -> List[str]:
ip_addresses = []
for address in addresses:
try:
interface = ipaddress.IPv4Interface(address)
ip_addresses.append(interface.ip)
except ipaddress.AddressValueError:
continue
return [str(ip) for ip in ip_addresses]
30 changes: 19 additions & 11 deletions src/dstack/api/server/_fleets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional

from pydantic import parse_obj_as

Expand Down Expand Up @@ -29,11 +29,7 @@ def get_plan(
spec: FleetSpec,
) -> FleetPlan:
body = GetFleetPlanRequest(spec=spec)
body_json = body.json()
if spec.configuration_path is None:
# Handle old server versions that do not accept configuration_path
# TODO: Can be removed in 0.19
body_json = body.json(exclude={"spec": {"configuration_path"}})
body_json = body.json(exclude=_get_fleet_spec_excludes(spec))
resp = self._request(f"/api/project/{project_name}/fleets/get_plan", body=body_json)
return parse_obj_as(FleetPlan.__response__, resp.json())

Expand All @@ -43,11 +39,7 @@ def create(
spec: FleetSpec,
) -> Fleet:
body = CreateFleetRequest(spec=spec)
body_json = body.json()
if spec.configuration_path is None:
# Handle old server versions that do not accept configuration_path
# TODO: Can be removed in 0.19
body_json = body.json(exclude={"spec": {"configuration_path"}})
body_json = body.json(exclude=_get_fleet_spec_excludes(spec))
resp = self._request(f"/api/project/{project_name}/fleets/create", body=body_json)
return parse_obj_as(Fleet.__response__, resp.json())

Expand All @@ -58,3 +50,19 @@ def delete(self, project_name: str, names: List[str]) -> None:
def delete_instances(self, project_name: str, name: str, instance_nums: List[int]) -> None:
body = DeleteFleetInstancesRequest(name=name, instance_nums=instance_nums)
self._request(f"/api/project/{project_name}/fleets/delete_instances", body=body.json())


def _get_fleet_spec_excludes(fleet_spec: FleetSpec) -> Optional[dict]:
exclude = {}
# TODO: Can be removed in 0.19
if fleet_spec.configuration_path is None:
exclude["spec"] = {"configuration_path"}
if fleet_spec.configuration.ssh_config is not None:
if all(
isinstance(h, str) or h.internal_ip is None
for h in fleet_spec.configuration.ssh_config.hosts
):
exclude["spec"] = {
"configuration": {"ssh_config": {"hosts": {"__all__": {"internal_ip"}}}}
}
return exclude or None

0 comments on commit c10b1fe

Please sign in to comment.