diff --git a/runner_manager/backend/aws.py b/runner_manager/backend/aws.py index 8f3f53f8..1d6d9c6e 100644 --- a/runner_manager/backend/aws.py +++ b/runner_manager/backend/aws.py @@ -1,4 +1,6 @@ -from typing import List, Literal, Optional +from copy import deepcopy +from random import shuffle +from typing import List, Literal, Optional, Sequence from boto3 import client from botocore.exceptions import ClientError @@ -14,6 +16,7 @@ AWSConfig, AwsInstance, AWSInstanceConfig, + AwsSubnetListConfig, Backends, ) from runner_manager.models.runner import Runner @@ -31,15 +34,68 @@ def client(self) -> EC2Client: def create(self, runner: Runner) -> Runner: """Create a runner.""" - instance_resource: AwsInstance = self.instance_config.configure_instance(runner) - try: - instance = self.client.run_instances(**instance_resource) - runner.instance_id = instance["Instances"][0]["InstanceId"] - except Exception as e: - log.error(e) - raise e + if self.instance_config.subnet_id and self.instance_config.subnet_configs: + raise Exception( + "Instance config contains both subnet_id and subnet_configs, only one allowed." + ) + if len(self.instance_config.subnet_configs) > 0: + runner = self._create_from_subnet_config( + runner, self.instance_config.subnet_configs + ) + log.warn(f"Instance id: {runner.instance_id}") + else: + instance_resource: AwsInstance = self.instance_config.configure_instance( + runner + ) + try: + runner = self._create(runner, instance_resource) + log.warn(f"Instance id: {runner.instance_id}") + except Exception as e: + log.error(e) + raise e return super().create(runner) + def _create_from_subnet_config( + self, runner: Runner, subnet_configs: Sequence[AwsSubnetListConfig] + ) -> Runner: + # Randomize the order of the Subnets - very coarse load balancing. + # TODO: Skip subnets that have failed recently. Maybe with an increasing backoff. + order = list(range(len(subnet_configs))) + shuffle(order) + for idx, i in enumerate(order): + subnet_config = subnet_configs[i] + try: + # Copy the object to avoid modifying the object we were passed. + count = self.instance_config.max_count - self.instance_config.min_count + log.info( + f"Trying to launch {count} containers on subnet {subnet_config['subnet_id']}" + ) + concrete_instance_config = deepcopy(self.instance_config) + concrete_instance_config.subnet_id = subnet_config["subnet_id"] + subnet_security_groups = subnet_config.get("security_group_ids", []) + if subnet_security_groups: + security_groups = list(concrete_instance_config.security_group_ids) + security_groups += subnet_security_groups + concrete_instance_config.security_group_ids = security_groups + instance_resource: AwsInstance = ( + concrete_instance_config.configure_instance(runner) + ) + return self._create(runner, instance_resource) + except Exception as e: + log.warn( + f"Creating instance in subnet {subnet_config['subnet_id']} failed with '{e}'. Retrying with another subnet." + ) + if idx >= len(order) - 1: + raise e + return runner + + def _create(self, runner: Runner, instance_resource: AwsInstance) -> Runner: + instance = self.client.run_instances(**instance_resource) + # Allow this to raise exception as we don't want to track an instance that + # doesn't have an instance ID. + runner.instance_id = instance["Instances"][0]["InstanceId"] # type: ignore + return runner + def delete(self, runner: Runner): """Delete a runner.""" if runner.instance_id: diff --git a/runner_manager/models/backend.py b/runner_manager/models/backend.py index 29c7e714..98bb60f1 100644 --- a/runner_manager/models/backend.py +++ b/runner_manager/models/backend.py @@ -1,7 +1,7 @@ from enum import Enum from pathlib import Path from string import Template -from typing import Dict, List, Literal, Optional, Sequence, TypedDict +from typing import Dict, List, Literal, NotRequired, Optional, Sequence, TypedDict from mypy_boto3_ec2.literals import ( InstanceMetadataTagsStateType, @@ -134,6 +134,14 @@ class AWSConfig(BackendConfig): region: str = "us-west-2" +AwsSubnetListConfig = TypedDict( + "AwsSubnetListConfig", + { + "subnet_id": str, + "security_group_ids": NotRequired[Sequence[str]], + }, +) + AwsInstance = TypedDict( "AwsInstance", { @@ -157,7 +165,7 @@ class AWSInstanceConfig(InstanceConfig): image: str = "ami-0735c191cf914754d" # Ubuntu 22.04 for us-west-2 instance_type: InstanceTypeType = "t3.micro" - subnet_id: str + subnet_id: str = "" security_group_ids: Sequence[str] = [] max_count: int = 1 min_count: int = 1 @@ -167,6 +175,7 @@ class AWSInstanceConfig(InstanceConfig): disk_size_gb: int = 20 iam_instance_profile_arn: str = "" instance_metadata_tags: InstanceMetadataTagsStateType = "disabled" + subnet_configs: Sequence[AwsSubnetListConfig] = [] def configure_instance(self, runner: Runner) -> AwsInstance: """Configure instance.""" diff --git a/tests/unit/backend/test_aws.py b/tests/unit/backend/test_aws.py index b5633411..2458b4e7 100644 --- a/tests/unit/backend/test_aws.py +++ b/tests/unit/backend/test_aws.py @@ -1,4 +1,5 @@ import os +from unittest.mock import patch from mypy_boto3_ec2.type_defs import TagTypeDef from pytest import fixture, mark, raises @@ -36,10 +37,68 @@ def aws_group(settings) -> RunnerGroup: return runner_group +@fixture() +def aws_multi_subnet_group(settings) -> RunnerGroup: + config = AWSConfig() + subnet_id = os.getenv("AWS_SUBNET_ID", "") + runner_group: RunnerGroup = RunnerGroup( + id=3, + name="default", + organization="test", + manager=settings.name, + backend=AWSBackend( + name=Backends.aws, + config=config, + instance_config=AWSInstanceConfig( + subnet_configs=[ + { + "subnet_id": subnet_id, + "security_group_ids": [], + } + ] + ), + ), + labels=[ + "label", + ], + ) + return runner_group + + +@fixture() +def aws_multi_subnet_group_invalid_subnets(settings) -> RunnerGroup: + config = AWSConfig() + runner_group: RunnerGroup = RunnerGroup( + id=3, + name="default", + organization="test", + manager=settings.name, + backend=AWSBackend( + name=Backends.aws, + config=config, + instance_config=AWSInstanceConfig( + subnet_configs=[ + { + "subnet_id": "does-not-exist", + }, + { + "subnet_id": "also-does-not-exist", + }, + ] + ), + ), + labels=[ + "label", + ], + ) + return runner_group + + @fixture() def aws_runner(runner: Runner, aws_group: RunnerGroup) -> Runner: # Cleanup and return a runner for testing aws_group.backend.delete(runner) + runner.instance_id = None return runner @@ -70,7 +129,10 @@ def test_aws_instance_config(runner: Runner): assert instance["TagSpecifications"][1]["ResourceType"] == "volume" -@mark.skipif(not os.getenv("AWS_ACCESS_KEY_ID"), reason="AWS credentials not found") +@mark.skipif( + not os.getenv("AWS_ACCESS_KEY_ID") and not os.getenv("AWS_PROFILE"), + reason="AWS credentials not found", +) def test_create_delete(aws_runner, aws_group): runner = aws_group.backend.create(aws_runner) assert runner.instance_id is not None @@ -81,7 +143,10 @@ def test_create_delete(aws_runner, aws_group): Runner.find(Runner.instance_id == runner.instance_id).first() -@mark.skipif(not os.getenv("AWS_ACCESS_KEY_ID"), reason="AWS credentials not found") +@mark.skipif( + not os.getenv("AWS_ACCESS_KEY_ID") and not os.getenv("AWS_PROFILE"), + reason="AWS credentials not found", +) def test_list(aws_runner, aws_group): runner = aws_group.backend.create(aws_runner) runners = aws_group.backend.list() @@ -91,7 +156,10 @@ def test_list(aws_runner, aws_group): aws_group.backend.get(runner.instance_id) -@mark.skipif(not os.getenv("AWS_ACCESS_KEY_ID"), reason="AWS credentials not found") +@mark.skipif( + not os.getenv("AWS_ACCESS_KEY_ID") and not os.getenv("AWS_PROFILE"), + reason="AWS credentials not found", +) def test_update(aws_runner, aws_group): runner = aws_group.backend.create(aws_runner) runner.labels = [RunnerLabel(name="test", type="custom")] @@ -100,3 +168,36 @@ def test_update(aws_runner, aws_group): aws_group.backend.delete(runner) with raises(NotFoundError): aws_group.backend.get(runner.instance_id) + + +@mark.skipif( + not os.getenv("AWS_ACCESS_KEY_ID") and not os.getenv("AWS_PROFILE"), + reason="AWS credentials not found", +) +def test_create_delete_multi_subnet(aws_runner, aws_multi_subnet_group): + runner = aws_multi_subnet_group.backend.create(aws_runner) + print(f"{runner.instance_id}") + assert runner.instance_id is not None + assert runner.backend == "aws" + assert Runner.find(Runner.instance_id == runner.instance_id).first() == runner + aws_multi_subnet_group.backend.delete(runner) + with raises(NotFoundError): + Runner.find(Runner.instance_id == runner.instance_id).first() + + +@mark.skipif( + not os.getenv("AWS_ACCESS_KEY_ID") and not os.getenv("AWS_PROFILE"), + reason="AWS credentials not found", +) +def test_create_delete_multi_subnet_invalid_subnets( + aws_runner, aws_multi_subnet_group_invalid_subnets +): + with patch.object( + AWSBackend, + "_create", + wraps=aws_multi_subnet_group_invalid_subnets.backend._create, + ) as mock: + with raises(Exception): + aws_multi_subnet_group_invalid_subnets.backend.create(aws_runner) + # Check that the code tries once for each subnet. + assert mock.call_count == 2