Skip to content

Commit

Permalink
Fix various typing errors
Browse files Browse the repository at this point in the history
  • Loading branch information
adanaja committed Oct 7, 2024
1 parent 89f69e6 commit 2a49b04
Show file tree
Hide file tree
Showing 11 changed files with 63 additions and 36 deletions.
21 changes: 13 additions & 8 deletions src/e3/aws/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from typing import Any, TypedDict, Callable
import botocore.client
import botocore.stub
from datetime import datetime

class AWSCredentials(TypedDict, total=False):
"""Annotate a dict containing AWS credentials.
Expand All @@ -44,7 +45,7 @@ class AWSCredentials(TypedDict, total=False):
AccessKeyId: str
SecretAccessKey: str
SessionToken: str
Expiration: str
Expiration: datetime


class AWSSessionRunError(E3Error):
Expand Down Expand Up @@ -372,7 +373,7 @@ def wrapper(*args, **kwargs):
return decorator


def assume_profile_main():
def assume_profile_main() -> None:
"""Generate shell commands to set credentials for a profile."""
argument_parser = argparse.ArgumentParser()
argument_parser.add_argument(
Expand Down Expand Up @@ -416,7 +417,7 @@ def assume_profile_main():
print(f"export {k}={v}")


def assume_role_main():
def assume_role_main() -> None:
"""Generate shell commands to set credentials for a role."""
argument_parser = argparse.ArgumentParser()
argument_parser.add_argument(
Expand Down Expand Up @@ -449,14 +450,18 @@ def assume_role_main():
credentials = s.assume_role_get_credentials(
args.role_arn, role_session_name, session_duration=session_duration
)
credentials["Expiration"] = credentials["Expiration"].timestamp()
credentials_float = credentials | {
"Expiration": credentials["Expiration"].timestamp()
}
if args.json:
print(json.dumps(credentials))
print(json.dumps(credentials_float))
else:
credentials = {
key_to_envvar[k]: v for k, v in credentials.items() if k in key_to_envvar
credentials_float = {
key_to_envvar[k]: v
for k, v in credentials_float.items()
if k in key_to_envvar
}
for k, v in credentials.items():
for k, v in credentials_float.items():
print(f"export {k}={v}")


Expand Down
4 changes: 2 additions & 2 deletions src/e3/aws/pricing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from e3.aws.util import get_region_name

if TYPE_CHECKING:
from typing import Any
from typing import Any, Union
import botocore

_CacheKey = tuple[str | None, str | None, str | None]
_CacheKey = tuple[Union[str, None], Union[str, None], Union[str, None]]

# This is only to avoid repeating the type everywhere
PriceInformation = dict[str, Any]
Expand Down
6 changes: 5 additions & 1 deletion tests/tests_e3_aws/assume_profile_main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ def get_frozen_credentials(self) -> ReadOnlyCredentials:
"json,expected_output",
[(False, EXPECTED_DEFAULT_OUTPUT), (True, EXPECTED_JSON_OUTPUT)],
)
def test_assume_profile_main_json(json: bool, expected_output: str, capfd):
def test_assume_profile_main_json(
json: bool,
expected_output: str,
capfd: pytest.CaptureFixture[str],
) -> None:
"""Test the credentials returned by assume_profile_main."""
with (
mock.patch(
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_e3_aws/dynamodb/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_update_item(client: DynamoDB) -> None:
client.update_item(
item=customers[0],
table_name=TABLE_NAME,
keys=PRIMARY_KEYS,
keys=("name", "S"),
data={"age": 33},
)

Expand All @@ -138,7 +138,7 @@ def test_update_item_condition(client: DynamoDB) -> None:
client.update_item(
item=customers[0],
table_name=TABLE_NAME,
keys=PRIMARY_KEYS,
keys=("name", "S"),
data={"age": 33},
condition_expression="attribute_exists(#n) AND #a = :a",
expression_attribute_names={"#n": "name", "#a": "age"},
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_e3_aws/pricing/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@
{"Field": "capacitystatus", "Type": "TERM_MATCH", "Value": "Used"},
{"Field": "preInstalledSw", "Type": "TERM_MATCH", "Value": "NA"},
{"Field": "tenancy", "Type": "TERM_MATCH", "Value": "shared"},
]
+ GET_PRODUCTS_PARAMS["Filters"],
*GET_PRODUCTS_PARAMS["Filters"],
],
}


Expand Down
5 changes: 4 additions & 1 deletion tests/tests_e3_aws/troposphere/apigateway/apigateway_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from typing import Any, cast
import json
import os
import pytest
Expand Down Expand Up @@ -283,7 +284,9 @@
},
"TestapiIntegration": {
"Properties": {
**EXPECTED_TEMPLATE["TestapiIntegration"]["Properties"],
**cast(
dict[str, Any], EXPECTED_TEMPLATE["TestapiIntegration"]["Properties"]
),
"IntegrationUri": "arn:aws:lambda:eu-west-1:123456789012:function:"
"mypylambda:${stageVariables.lambdaAlias}",
},
Expand Down
35 changes: 22 additions & 13 deletions tests/tests_e3_aws/troposphere/awslambda/awslambda_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@
from e3.pytest import require_tool

if TYPE_CHECKING:
from typing import Iterable
from typing import Iterable, Callable
from flask import Application, Response
from pathlib import Path


SOURCE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "source_dir")
Expand Down Expand Up @@ -414,7 +415,7 @@ def test_pyfunction(stack: Stack) -> None:
assert stack.export()["Resources"] == EXPECTED_PYFUNCTION_TEMPLATE


def test_pyfunction_with_requirements(tmp_path, stack: Stack) -> None:
def test_pyfunction_with_requirements(tmp_path: Path, stack: Stack) -> None:
"""Test PyFunction creation."""
stack.s3_bucket = "cfn_bucket"
stack.s3_key = "templates/"
Expand Down Expand Up @@ -472,7 +473,7 @@ def test_pyfunction_policy_document(stack: Stack) -> None:
@pytest.mark.skip(
reason="This test does not work in GitLab CI jobs. Disable it for now.",
)
def test_docker_function(stack: Stack, has_docker: pytest.Fixture) -> None:
def test_docker_function(stack: Stack, has_docker: Callable) -> None:
"""Test adding docker function to stack."""
aws_env = AWSEnv(regions=["us-east-1"], stub=True)
stubber_ecr = aws_env.stub("ecr")
Expand Down Expand Up @@ -607,10 +608,14 @@ def test_autoversion_default(stack: Stack, simple_lambda_function: PyFunction) -
stack.add(auto_version)
print(stack.export()["Resources"])
assert stack.export()["Resources"] == EXPECTED_AUTOVERSION_DEFAULT_TEMPLATE
assert auto_version.get_version(1).name == "mypylambdaVersion1"
assert auto_version.get_version(2).name == "mypylambdaVersion2"
assert auto_version.previous.name == "mypylambdaVersion1"
assert auto_version.latest.name == "mypylambdaVersion2"
assert (
version := auto_version.get_version(1)
) and version.name == "mypylambdaVersion1"
assert (
version := auto_version.get_version(2)
) and version.name == "mypylambdaVersion2"
assert (version := auto_version.previous) and version.name == "mypylambdaVersion1"
assert (version := auto_version.latest) and version.name == "mypylambdaVersion2"


def test_autoversion_single(stack: Stack, simple_lambda_function: PyFunction) -> None:
Expand Down Expand Up @@ -641,10 +646,14 @@ def test_autoversion(stack: Stack, simple_lambda_function: PyFunction) -> None:
stack.add(auto_version)
print(stack.export()["Resources"])
assert stack.export()["Resources"] == EXPECTED_AUTOVERSION_TEMPLATE
assert auto_version.get_version(2).name == "mypylambdaVersion2"
assert auto_version.get_version(3).name == "mypylambdaVersion3"
assert auto_version.previous.name == "mypylambdaVersion2"
assert auto_version.latest.name == "mypylambdaVersion3"
assert (
version := auto_version.get_version(2)
) and version.name == "mypylambdaVersion2"
assert (
version := auto_version.get_version(3)
) and version.name == "mypylambdaVersion3"
assert (version := auto_version.previous) and version.name == "mypylambdaVersion2"
assert (version := auto_version.latest) and version.name == "mypylambdaVersion3"


def test_bluegreenaliases_default(
Expand Down Expand Up @@ -798,7 +807,7 @@ def get_base64_response() -> Response:
yield app


def test_text_response(base64_response_server: Application):
def test_text_response(base64_response_server: Application) -> None:
"""Query a route sending back a plain text response."""
with open(
os.path.join(
Expand All @@ -815,7 +824,7 @@ def test_text_response(base64_response_server: Application):
assert response["body"] == b"world"


def test_base64_response(base64_response_server: Application):
def test_base64_response(base64_response_server: Application) -> None:
"""Query a route sending back a base64 encoded response."""
with open(
os.path.join(
Expand Down
7 changes: 4 additions & 3 deletions tests/tests_e3_aws/troposphere/cfn_project_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@


if TYPE_CHECKING:
from e3.aws.troposphere import Stack
import pytest
from e3.aws.cfn import Stack


TEST_DIR = os.path.dirname(os.path.abspath(__file__))
Expand All @@ -20,7 +21,7 @@
class MyCFNProject(CFNProjectMain):
"""Provide CLI to manage MyCFNProject."""

def create_stack(self) -> list[Stack]:
def create_stack(self) -> Stack | list[Stack]:
"""Return MyCFNProject stack."""
self.add(
(
Expand All @@ -34,7 +35,7 @@ def create_stack(self) -> list[Stack]:
return self.stack


def test_cfn_project_main(capfd) -> None:
def test_cfn_project_main(capfd: pytest.CaptureFixture[str]) -> None:
"""Test CFNProjectMain."""
aws_env = AWSEnv(regions=["eu-west-1"], stub=True)
test = MyCFNProject(
Expand Down
6 changes: 4 additions & 2 deletions tests/tests_e3_aws/troposphere/cloudwatch/cloudwatch_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations

from typing import Any, cast
from troposphere import Ref
from e3.aws.troposphere import Stack
from e3.aws.troposphere.cloudwatch import Alarm
Expand All @@ -22,7 +22,9 @@
EXPECTED_ALARM_TEMPLATE = {
"Myalarm": {
"Properties": {
**EXPECTED_ALARM_DEFAULT_TEMPLATE["Myalarm"]["Properties"],
**cast(
dict[str, Any], EXPECTED_ALARM_DEFAULT_TEMPLATE["Myalarm"]["Properties"]
),
**{
"AlarmActions": ["StrAction", {"Ref": "RefAction"}],
"Dimensions": [
Expand Down
5 changes: 4 additions & 1 deletion tests/tests_e3_aws/troposphere/dynamodb/dynamodb_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from typing import Any, cast
import os
import json
from troposphere import Ref
Expand Down Expand Up @@ -35,7 +36,9 @@
EXPECTED_TABLE_TEMPLATE = {
"Mytable": {
"Properties": {
**EXPECTED_TABLE_DEFAULT_TEMPLATE["Mytable"]["Properties"],
**cast(
dict[str, Any], EXPECTED_TABLE_DEFAULT_TEMPLATE["Mytable"]["Properties"]
),
**{
"Tags": [{"Key": "tagkey", "Value": "tagvalue"}],
"TimeToLiveSpecification": {
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_e3_aws/troposphere/iam/iam_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_trust_roles(stack: Stack) -> None:
name="TestRole",
description="TestRole description",
trust=Trust(
roles=[(123456789012, "OtherRole")], actions=["sts:SetSourceIdentity"]
roles=[("123456789012", "OtherRole")], actions=["sts:SetSourceIdentity"]
),
)
)
Expand Down

0 comments on commit 2a49b04

Please sign in to comment.