diff --git a/src/awsrun/config.py b/src/awsrun/config.py index 83870c7..d26e479 100644 --- a/src/awsrun/config.py +++ b/src/awsrun/config.py @@ -338,7 +338,7 @@ def __init__(self, type_): self.type = type_ def type_check(self, obj): - return type(obj) == self.type + return type(obj) == self.type # noqa: E721 def __str__(self): return self.type.__name__ @@ -354,7 +354,7 @@ def __init__(self, pattern): self.pattern = pattern def type_check(self, obj): - if type(obj) != str: + if type(obj) != str: # noqa: E721 return False return bool(re.search(self.pattern, obj)) @@ -366,7 +366,7 @@ class IpAddress(Type): """Represents a string matching an IP address (v4 or v6).""" def type_check(self, obj): - if type(obj) != str: + if type(obj) != str: # noqa: E721 return False try: ipaddress.ip_address(obj) @@ -382,7 +382,7 @@ class IpNetwork(Type): """Represents a string matching an IP network (v4 or v6).""" def type_check(self, obj): - if type(obj) != str: + if type(obj) != str: # noqa: E721 return False try: ipaddress.ip_network(obj) @@ -398,7 +398,7 @@ class FileType(Type): """Represents a string pointing to an existing file.""" def type_check(self, obj): - if type(obj) != str: + if type(obj) != str: # noqa: E721 return False return Path(obj).exists() @@ -462,7 +462,7 @@ def __init__(self, element_type): self.element_type = element_type def type_check(self, obj): - if type(obj) != list: + if type(obj) != list: # noqa: E721 return False return all(self.element_type.type_check(e) for e in obj) @@ -485,7 +485,7 @@ def __init__(self, key_type, value_type): self.value_type = value_type def type_check(self, obj): - if type(obj) != dict: + if type(obj) != dict: # noqa: E721 return False return all(self.key_type.type_check(k) for k in obj.keys()) and all( self.value_type.type_check(v) for v in obj.values() diff --git a/src/awsrun/plugins/creds/aws.py b/src/awsrun/plugins/creds/aws.py index bb0a0a2..4a01adf 100644 --- a/src/awsrun/plugins/creds/aws.py +++ b/src/awsrun/plugins/creds/aws.py @@ -137,6 +137,7 @@ class SAML(Plugin): role: STRING* url: STRING* auth_type: ("basic" | "digest" | "ntlm") + http_method: ("GET"| "POST") http_headers: STRING: STRING no_verify: BOOLEAN @@ -175,6 +176,11 @@ class SAML(Plugin): specified, it must be one of `basic`, `digest`, or `ntlm`. The default value is `basic`. If using NTLM, username should be specified as `domain\\username`. + `http_method` + : The HTTP method to use when authenticating with the IdP. If + specified, it must be one of `GET`, `POST`. The default value + is `GET`. + `http_headers` : Additional HTTP headers to send in the request to the IdP. If specified, it must be a dictionary of `key: value` pairs, where keys and values are @@ -275,6 +281,7 @@ def instantiate(self, args): role=args.saml_role, url=cfg("url", type=URL, must_exist=True), auth=auth(args.saml_username, args.saml_password), + http_method=cfg("http_method", type=Choice("GET", "POST"), default="GET"), headers=cfg("http_headers", type=Dict(Str, Str), default={}), duration=args.saml_duration, saml_duration=args.saml_assertion_duration, @@ -458,6 +465,7 @@ class SAMLCrossAccount(AbstractCrossAccount): role: STRING* url: STRING* auth_type: ("basic" | "digest" | "ntlm") + http_method: ("GET"| "POST") http_headers: STRING: STRING no_verify: BOOLEAN @@ -503,6 +511,11 @@ class SAMLCrossAccount(AbstractCrossAccount): specified, it must be one of `basic`, `digest`, or `ntlm`. The default value is `basic`. If using NTLM, username should be specified as `domain\\username`. + `http_method` + : The HTTP method to use when authenticating with the IdP. If + specified, it must be one of `GET`, `POST`. The default value + is `GET`. + `http_headers` : Additional HTTP headers to send in the request to the IdP. If specified, it must be a dictionary of `key: value` pairs, where keys and values are diff --git a/src/awsrun/session/aws.py b/src/awsrun/session/aws.py index aff8bac..e1a19ca 100644 --- a/src/awsrun/session/aws.py +++ b/src/awsrun/session/aws.py @@ -384,6 +384,7 @@ def __init__( role, url, auth, + http_method, headers=None, duration=3600, saml_duration=300, @@ -392,6 +393,7 @@ def __init__( super().__init__(role, duration) self._url = url self._auth = auth + self._http_method = http_method self._headers = {} if headers is None else headers self._cached_saml = ExpiringValue(self._request_assertion, saml_duration) self._no_verify = no_verify @@ -414,7 +416,15 @@ def _request_assertion(self): with requests.Session() as s: s.auth = self._auth s.headers.update(self._headers) - resp = s.get(self._url, verify=not self._no_verify) + if self._http_method == "GET": + resp = s.get(self._url, verify=not self._no_verify) + else: + authData = { + "UserName": s.auth.username, + "Password": s.auth.password, + "AuthMethod": "FormsAuthentication", + } + resp = s.post(self._url, data=authData, verify=not self._no_verify) if resp.status_code == 401: raise IDPAccessDeniedException("Could not authenticate")