Skip to content

Commit

Permalink
[pri-885] Feature/ux and multi node (#13)
Browse files Browse the repository at this point in the history
* Fix duplicate api key input requested
* Add loading on pod termination, add mutli node availability, group common providers
* Allow to pick ssh connection when multiple conns exist
* Fix issue in short id creation missing gpu count
* Minor cleanup
  • Loading branch information
JannikSt authored Jan 8, 2025
1 parent 9cd4b71 commit adc781c
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 25 deletions.
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## [0.2.0] - 2025-01-08
### Added
- Added support for grouping similar GPUs
- Add support for multi-node cluster creation

## [0.1.0] - 2025-01-05
### Added
- Initial CLI implementation
- Initial CLI implementation

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "prime-cli"
version = "0.1.3"
version = "0.2.0"
description = "Prime Intellect CLI"
readme = "README.md"
requires-python = ">=3.8"
Expand Down
32 changes: 26 additions & 6 deletions src/prime_cli/api/availability.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,16 @@ def get(
gpu_type: Optional[str] = None,
) -> Dict[str, List[GPUAvailability]]:
"""
Get GPU availability information.
Get both single GPU and cluster availability information.
Args:
regions: Optional list of regions to filter by
gpu_count: Optional number of GPUs to filter by
gpu_type: Optional GPU type to filter by
Returns:
Dictionary mapping GPU types to lists of availability information,
combining both single GPU and cluster availability
"""
params = {}
if regions:
Expand All @@ -99,9 +108,20 @@ def get(
if gpu_type:
params["gpu_type"] = gpu_type

response = self.client.get("/availability", params=params)
# Get both single GPU and cluster availability
single_response = self.client.get("/availability", params=params)
cluster_response = self.client.get("/availability/clusters", params=params)

combined: Dict[str, List[GPUAvailability]] = {}
for gpu_type, gpus in single_response.items():
if gpu_type is not None:
combined[gpu_type] = [GPUAvailability(**gpu) for gpu in gpus]

for gpu_type, gpus in cluster_response.items():
if gpu_type is not None:
if gpu_type in combined:
combined[gpu_type].extend([GPUAvailability(**gpu) for gpu in gpus])
else:
combined[gpu_type] = [GPUAvailability(**gpu) for gpu in gpus]

return {
gpu_type: [GPUAvailability(**gpu) for gpu in gpus]
for gpu_type, gpus in response.items()
}
return combined
52 changes: 41 additions & 11 deletions src/prime_cli/commands/availability.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def list(
socket: Optional[str] = typer.Option(
None, help="Filter by socket type (e.g., PCIe, SXM5, SXM4)"
),
group_similar: bool = typer.Option(
True, help="Group similar configurations from same provider"
),
) -> None:
"""List available GPU resources"""
try:
Expand Down Expand Up @@ -97,11 +100,7 @@ def list(
gpu.stock_status, "white"
)

location = (
f"{gpu.country or 'N/A'} - {gpu.data_center or 'N/A'}"
if gpu.country or gpu.data_center
else "N/A"
)
location = f"{gpu.country or 'N/A'}"

short_id = generate_short_id(gpu)
gpu_data = {
Expand All @@ -125,12 +124,43 @@ def list(
# Sort by price and remove duplicates based on short_id
seen_ids = set()
filtered_gpus: List[Dict[str, Any]] = []
for gpu_config in sorted(
all_gpus, key=lambda x: (x["price_value"], x["short_id"])
):
if gpu_config["short_id"] not in seen_ids:
seen_ids.add(gpu_config["short_id"])
filtered_gpus.append(gpu_config)

if group_similar:
grouped_gpus: Dict[str, List[Dict[str, Any]]] = {}
for gpu_config in sorted(
all_gpus, key=lambda x: (x["price_value"], x["short_id"])
):
key = (
f"{gpu_config['provider']}_{gpu_config['gpu_type']}_{gpu_config['gpu_count']}_"
f"{gpu_config['socket']}_{gpu_config['location']}_{gpu_config['security']}_{gpu_config['price']}"
)
if key not in grouped_gpus:
grouped_gpus[key] = []
grouped_gpus[key].append(gpu_config)

# For each group, select representative configuration
for group in grouped_gpus.values():
if len(group) > 1:
# Use first ID but show ranges for variable specs
base = group[0].copy()
min_vcpu = min(g["vcpu"] for g in group)
max_vcpu = max(g["vcpu"] for g in group)
min_mem = min(g["memory"] for g in group)
max_mem = max(g["memory"] for g in group)
vcpu_range = f"{min_vcpu}-{max_vcpu}"
memory_range = f"{min_mem}-{max_mem}"
base["vcpu"] = vcpu_range
base["memory"] = memory_range
filtered_gpus.append(base)
else:
filtered_gpus.append(group[0])
else:
for gpu_config in sorted(
all_gpus, key=lambda x: (x["price_value"], x["short_id"])
):
if gpu_config["short_id"] not in seen_ids:
seen_ids.add(gpu_config["short_id"])
filtered_gpus.append(gpu_config)

for gpu_entry in filtered_gpus:
table.add_row(
Expand Down
36 changes: 31 additions & 5 deletions src/prime_cli/commands/pods.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,8 +582,9 @@ def terminate(pod_id: str) -> None:
console.print("Termination cancelled")
raise typer.Exit(0)

# Delete the pod
pods_client.delete(pod_id)
with console.status("[bold blue]Terminating pod...", spinner="dots"):
pods_client.delete(pod_id)

console.print(f"[green]Successfully terminated pod {pod_id}[/green]")

except APIError as e:
Expand Down Expand Up @@ -620,13 +621,38 @@ def ssh(pod_id: str) -> None:
if not os.path.exists(ssh_key_path):
console.print(f"[red]SSH key not found at {ssh_key_path}[/red]")
raise typer.Exit(1)

ssh_conn = status.ssh_connection
# Handle ssh_conn being either a string or list of strings
connection_str: str
connections: List[str] = []
if isinstance(ssh_conn, List):
connection_str = ssh_conn[0] if ssh_conn else ""
# Filter out None values and convert to strings
connections = [str(conn) for conn in ssh_conn if conn is not None]
else:
connections = [str(ssh_conn)] if ssh_conn else []

if not connections:
console.print("[red]No valid SSH connections available[/red]")
raise typer.Exit(1)

# If multiple connections available, let user choose
connection_str: str
if len(connections) > 1:
console.print("\nMultiple nodes available. Please select one:")
for idx, conn in enumerate(connections):
console.print(f"[blue]{idx + 1}[/blue]) {conn}")

choice = typer.prompt(
"Enter node number", type=int, default=1, show_default=False
)

if choice < 1 or choice > len(connections):
console.print("[red]Invalid selection[/red]")
raise typer.Exit(1)

connection_str = connections[choice - 1]
else:
connection_str = str(ssh_conn) if ssh_conn else ""
connection_str = connections[0]

connection_parts = connection_str.split(" -p ")
host = connection_parts[0]
Expand Down
3 changes: 2 additions & 1 deletion src/prime_cli/helper/short_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def generate_short_id(gpu_config: GPUAvailability) -> str:
f"{location}-"
f"{gpu_config.provider or 'N/A'}-"
f"{gpu_config.memory.default_count}-"
f"{gpu_config.vcpu.default_count}"
f"{gpu_config.vcpu.default_count}-"
f"{gpu_config.gpu_count}"
)
return hashlib.md5(config_str.encode()).hexdigest()[:6]

0 comments on commit adc781c

Please sign in to comment.