Skip to content

Commit

Permalink
aws-backend: add support for running instances on multiple subnets (#700
Browse files Browse the repository at this point in the history
)

Update the AWS backend (and instance config) to add the ability to run
instance on a set of subnets (and
so availability zone).

The main use-case for this is to allow instance to come up in multiple
availability zones. A user would create one (or more) subnets in an
availability zone and the list the subnets in the `subnet_configs`
section. This has two main advantages:
- There is some crude load balancing (subnets are picked at random from
the list)
- If an instance fails to start we can retry in another subnet. This
will help in cases where an AZ has been starved of all the instance and
instance not longer start.

To support this (without breaking existing configs) I added a new
`subnet_configs` section to the instance config. This is mutually
exclusive to the existing `subnet_id` section.
If `subnet_configs` is set, then the backend will choose one of the
subnets at random and try to bring up the instance there. If this fails.
then another subnet will be tried, until one works. If none of the
subnets work, an the exception will be re-raised.
  • Loading branch information
timbrown5 authored Jan 29, 2025
1 parent 8969d71 commit d3feeae
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 13 deletions.
72 changes: 64 additions & 8 deletions runner_manager/backend/aws.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import List, Literal, Optional
from copy import deepcopy
from random import shuffle
from typing import List, Literal, Optional, Sequence

from boto3 import client
from botocore.exceptions import ClientError
Expand All @@ -14,6 +16,7 @@
AWSConfig,
AwsInstance,
AWSInstanceConfig,
AwsSubnetListConfig,
Backends,
)
from runner_manager.models.runner import Runner
Expand All @@ -31,15 +34,68 @@ def client(self) -> EC2Client:

def create(self, runner: Runner) -> Runner:
"""Create a runner."""
instance_resource: AwsInstance = self.instance_config.configure_instance(runner)
try:
instance = self.client.run_instances(**instance_resource)
runner.instance_id = instance["Instances"][0]["InstanceId"]
except Exception as e:
log.error(e)
raise e
if self.instance_config.subnet_id and self.instance_config.subnet_configs:
raise Exception(
"Instance config contains both subnet_id and subnet_configs, only one allowed."
)
if len(self.instance_config.subnet_configs) > 0:
runner = self._create_from_subnet_config(
runner, self.instance_config.subnet_configs
)
log.warn(f"Instance id: {runner.instance_id}")
else:
instance_resource: AwsInstance = self.instance_config.configure_instance(
runner
)
try:
runner = self._create(runner, instance_resource)
log.warn(f"Instance id: {runner.instance_id}")
except Exception as e:
log.error(e)
raise e
return super().create(runner)

def _create_from_subnet_config(
self, runner: Runner, subnet_configs: Sequence[AwsSubnetListConfig]
) -> Runner:
# Randomize the order of the Subnets - very coarse load balancing.
# TODO: Skip subnets that have failed recently. Maybe with an increasing backoff.
order = list(range(len(subnet_configs)))
shuffle(order)
for idx, i in enumerate(order):
subnet_config = subnet_configs[i]
try:
# Copy the object to avoid modifying the object we were passed.
count = self.instance_config.max_count - self.instance_config.min_count
log.info(
f"Trying to launch {count} containers on subnet {subnet_config['subnet_id']}"
)
concrete_instance_config = deepcopy(self.instance_config)
concrete_instance_config.subnet_id = subnet_config["subnet_id"]
subnet_security_groups = subnet_config.get("security_group_ids", [])
if subnet_security_groups:
security_groups = list(concrete_instance_config.security_group_ids)
security_groups += subnet_security_groups
concrete_instance_config.security_group_ids = security_groups
instance_resource: AwsInstance = (
concrete_instance_config.configure_instance(runner)
)
return self._create(runner, instance_resource)
except Exception as e:
log.warn(
f"Creating instance in subnet {subnet_config['subnet_id']} failed with '{e}'. Retrying with another subnet."
)
if idx >= len(order) - 1:
raise e
return runner

def _create(self, runner: Runner, instance_resource: AwsInstance) -> Runner:
instance = self.client.run_instances(**instance_resource)
# Allow this to raise exception as we don't want to track an instance that
# doesn't have an instance ID.
runner.instance_id = instance["Instances"][0]["InstanceId"] # type: ignore
return runner

def delete(self, runner: Runner):
"""Delete a runner."""
if runner.instance_id:
Expand Down
13 changes: 11 additions & 2 deletions runner_manager/models/backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from pathlib import Path
from string import Template
from typing import Dict, List, Literal, Optional, Sequence, TypedDict
from typing import Dict, List, Literal, NotRequired, Optional, Sequence, TypedDict

from mypy_boto3_ec2.literals import (
InstanceMetadataTagsStateType,
Expand Down Expand Up @@ -134,6 +134,14 @@ class AWSConfig(BackendConfig):
region: str = "us-west-2"


AwsSubnetListConfig = TypedDict(
"AwsSubnetListConfig",
{
"subnet_id": str,
"security_group_ids": NotRequired[Sequence[str]],
},
)

AwsInstance = TypedDict(
"AwsInstance",
{
Expand All @@ -157,7 +165,7 @@ class AWSInstanceConfig(InstanceConfig):

image: str = "ami-0735c191cf914754d" # Ubuntu 22.04 for us-west-2
instance_type: InstanceTypeType = "t3.micro"
subnet_id: str
subnet_id: str = ""
security_group_ids: Sequence[str] = []
max_count: int = 1
min_count: int = 1
Expand All @@ -167,6 +175,7 @@ class AWSInstanceConfig(InstanceConfig):
disk_size_gb: int = 20
iam_instance_profile_arn: str = ""
instance_metadata_tags: InstanceMetadataTagsStateType = "disabled"
subnet_configs: Sequence[AwsSubnetListConfig] = []

def configure_instance(self, runner: Runner) -> AwsInstance:
"""Configure instance."""
Expand Down
107 changes: 104 additions & 3 deletions tests/unit/backend/test_aws.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from unittest.mock import patch

from mypy_boto3_ec2.type_defs import TagTypeDef
from pytest import fixture, mark, raises
Expand Down Expand Up @@ -36,10 +37,68 @@ def aws_group(settings) -> RunnerGroup:
return runner_group


@fixture()
def aws_multi_subnet_group(settings) -> RunnerGroup:
config = AWSConfig()
subnet_id = os.getenv("AWS_SUBNET_ID", "")
runner_group: RunnerGroup = RunnerGroup(
id=3,
name="default",
organization="test",
manager=settings.name,
backend=AWSBackend(
name=Backends.aws,
config=config,
instance_config=AWSInstanceConfig(
subnet_configs=[
{
"subnet_id": subnet_id,
"security_group_ids": [],
}
]
),
),
labels=[
"label",
],
)
return runner_group


@fixture()
def aws_multi_subnet_group_invalid_subnets(settings) -> RunnerGroup:
config = AWSConfig()
runner_group: RunnerGroup = RunnerGroup(
id=3,
name="default",
organization="test",
manager=settings.name,
backend=AWSBackend(
name=Backends.aws,
config=config,
instance_config=AWSInstanceConfig(
subnet_configs=[
{
"subnet_id": "does-not-exist",
},
{
"subnet_id": "also-does-not-exist",
},
]
),
),
labels=[
"label",
],
)
return runner_group


@fixture()
def aws_runner(runner: Runner, aws_group: RunnerGroup) -> Runner:
# Cleanup and return a runner for testing
aws_group.backend.delete(runner)
runner.instance_id = None
return runner


Expand Down Expand Up @@ -70,7 +129,10 @@ def test_aws_instance_config(runner: Runner):
assert instance["TagSpecifications"][1]["ResourceType"] == "volume"


@mark.skipif(not os.getenv("AWS_ACCESS_KEY_ID"), reason="AWS credentials not found")
@mark.skipif(
not os.getenv("AWS_ACCESS_KEY_ID") and not os.getenv("AWS_PROFILE"),
reason="AWS credentials not found",
)
def test_create_delete(aws_runner, aws_group):
runner = aws_group.backend.create(aws_runner)
assert runner.instance_id is not None
Expand All @@ -81,7 +143,10 @@ def test_create_delete(aws_runner, aws_group):
Runner.find(Runner.instance_id == runner.instance_id).first()


@mark.skipif(not os.getenv("AWS_ACCESS_KEY_ID"), reason="AWS credentials not found")
@mark.skipif(
not os.getenv("AWS_ACCESS_KEY_ID") and not os.getenv("AWS_PROFILE"),
reason="AWS credentials not found",
)
def test_list(aws_runner, aws_group):
runner = aws_group.backend.create(aws_runner)
runners = aws_group.backend.list()
Expand All @@ -91,7 +156,10 @@ def test_list(aws_runner, aws_group):
aws_group.backend.get(runner.instance_id)


@mark.skipif(not os.getenv("AWS_ACCESS_KEY_ID"), reason="AWS credentials not found")
@mark.skipif(
not os.getenv("AWS_ACCESS_KEY_ID") and not os.getenv("AWS_PROFILE"),
reason="AWS credentials not found",
)
def test_update(aws_runner, aws_group):
runner = aws_group.backend.create(aws_runner)
runner.labels = [RunnerLabel(name="test", type="custom")]
Expand All @@ -100,3 +168,36 @@ def test_update(aws_runner, aws_group):
aws_group.backend.delete(runner)
with raises(NotFoundError):
aws_group.backend.get(runner.instance_id)


@mark.skipif(
not os.getenv("AWS_ACCESS_KEY_ID") and not os.getenv("AWS_PROFILE"),
reason="AWS credentials not found",
)
def test_create_delete_multi_subnet(aws_runner, aws_multi_subnet_group):
runner = aws_multi_subnet_group.backend.create(aws_runner)
print(f"{runner.instance_id}")
assert runner.instance_id is not None
assert runner.backend == "aws"
assert Runner.find(Runner.instance_id == runner.instance_id).first() == runner
aws_multi_subnet_group.backend.delete(runner)
with raises(NotFoundError):
Runner.find(Runner.instance_id == runner.instance_id).first()


@mark.skipif(
not os.getenv("AWS_ACCESS_KEY_ID") and not os.getenv("AWS_PROFILE"),
reason="AWS credentials not found",
)
def test_create_delete_multi_subnet_invalid_subnets(
aws_runner, aws_multi_subnet_group_invalid_subnets
):
with patch.object(
AWSBackend,
"_create",
wraps=aws_multi_subnet_group_invalid_subnets.backend._create,
) as mock:
with raises(Exception):
aws_multi_subnet_group_invalid_subnets.backend.create(aws_runner)
# Check that the code tries once for each subnet.
assert mock.call_count == 2

0 comments on commit d3feeae

Please sign in to comment.