diff --git a/lib/duo.c b/lib/duo.c index 9f8ffb5..05ffd6b 100644 --- a/lib/duo.c +++ b/lib/duo.c @@ -343,6 +343,47 @@ _duo_json_response(struct duo_ctx *ctx) { return code; } +int +_duo_https_exchange(struct duo_ctx *ctx, const char *method, const char *uri, int msecs, int *code) +{ + const int max_int_digits = (241 * sizeof(int) / 100 + 1); + const int max_backoff_wait_secs = 32; + const int initial_backof_wait_secs = 1; + const int backoff_factor = 2; + + static const char fmt[] = "Rate-limiting response received from server. Waiting for %d seconds before retrying."; + char msg[(sizeof fmt) + max_int_digits]; + int wait_secs = initial_backof_wait_secs; + + while (1) { + HTTPScode rc; + time_t retry_after; + + rc = https_send(ctx->https, method, uri, + ctx->argc, ctx->argv, ctx->ikey, ctx->skey, ctx->useragent); + if (rc != HTTPS_OK) + return rc; + rc = https_recv(ctx->https, code, &ctx->body, &ctx->body_len, &retry_after, msecs); + if (retry_after != (time_t)-1) + wait_secs = retry_after - time(NULL); + + if (rc != HTTPS_OK || *code != 429 || wait_secs > max_backoff_wait_secs) + return rc; + + struct timespec timeout = { + .tv_sec = wait_secs, + .tv_nsec = (float)rand() / RAND_MAX * 1000000000 + }; + + snprintf(msg, sizeof(msg), fmt, timeout.tv_sec); + if (ctx->conv_status) + ctx->conv_status(NULL, msg); + nanosleep(&timeout, NULL); + if (retry_after == (time_t)-1) + wait_secs *= backoff_factor; + } +} + static duo_code_t duo_call(struct duo_ctx *ctx, const char *method, const char *uri, int msecs) { @@ -361,12 +402,8 @@ duo_call(struct duo_ctx *ctx, const char *method, const char *uri, int msecs) } break; } - if ((err = https_send(ctx->https, method, uri, - ctx->argc, ctx->argv, ctx->ikey, ctx->skey, ctx->useragent)) == HTTPS_OK && - (err = https_recv(ctx->https, &code, - &ctx->body, &ctx->body_len, msecs)) == HTTPS_OK) { + if (_duo_https_exchange(ctx, method, uri, msecs, &code) == HTTPS_OK) break; - } https_close(&ctx->https); } duo_reset(ctx); diff --git a/lib/https.c b/lib/https.c index 22027c9..f4f44c8 100644 --- a/lib/https.c +++ b/lib/https.c @@ -57,6 +57,13 @@ struct https_ctx { struct https_ctx ctx; +typedef enum +{ + CB_NONE = 0, /* First callback*/ + CB_KEY, /* Last was key */ + CB_VAL /* Last was value */ +} callback_status_t; + struct https_request { BIO *cbio; BIO *body; @@ -70,6 +77,14 @@ struct https_request { int sigpipe_ignored; struct sigaction old_sigpipe; + + time_t retry_after; + + char *value; + size_t value_size; + char* key; /* current header name */ + size_t key_size; /* size of header name */ + callback_status_t last_cb; }; static int @@ -80,15 +95,92 @@ __on_body(http_parser *p, const char *buf, size_t len) return (BIO_write(req->body, buf, len) != len); } +time_t +_parse_retry_after(const char *header_value) +{ + if (header_value == NULL) { + return (time_t)-1; + } + + /* Try to parse as an integer (delay in seconds) */ + char *endptr; + long delay_seconds = strtol(header_value, &endptr, 10); + if (*endptr == '\0') { + return time(NULL) + delay_seconds; + } + + /* Try to parse as a date */ + struct tm tm; + memset(&tm, 0, sizeof(struct tm)); + if (strptime(header_value, "%a, %d %b %Y %H:%M:%S %Z", &tm) != NULL) { + return mktime(&tm); + } + + return (time_t)-1; +} + static int __on_message_complete(http_parser *p) { struct https_request *req = (struct https_request *)p->data; + req->retry_after = _parse_retry_after(req->value); + + free(req->value); + req->value = NULL; + req->value_size = 0; + free(req->key); + req->key = NULL; + req->key_size = 0; + req->last_cb = CB_NONE; + req->done = 1; return (0); } +static const char retry_after_header[] = "Retry-After"; +static const char x_retry_after_header[] = "X-Retry-After"; + +static int +__on_header_field(http_parser* p, const char* at, size_t length) +{ + struct https_request *client = p->data; + + if (client->last_cb == CB_VAL) + client->key_size = 0; + + client->key = realloc(client->key, client->key_size + length + 1); + memcpy(client->key + client->key_size, at, length); + client->key_size += length; + client->key[client->key_size] = 0; + + client->last_cb = CB_KEY; + + return 0; +} + +static int +__on_header_value(http_parser* p, const char* at, size_t length) +{ + struct https_request *client = p->data; + + if (strcasecmp(client->key, retry_after_header) == 0 + || strcasecmp(client->key, x_retry_after_header) == 0) + { + if (client->last_cb != CB_VAL) + client->value_size = 0; + + client->value = realloc(client->value, client->value_size + length + 1); + memcpy(client->value + client->value_size, at, length); + client->value_size += length; + client->value[client->value_size] = 0; + } + + client->last_cb = CB_VAL; + + return 0; +} + static const char * _SSL_strerror(void) { @@ -504,6 +596,8 @@ https_init(const char *cafile, const char *http_proxy) /* Set HTTP parser callbacks */ ctx.parse_settings.on_body = __on_body; ctx.parse_settings.on_message_complete = __on_message_complete; + ctx.parse_settings.on_header_field = __on_header_field; + ctx.parse_settings.on_header_value = __on_header_value; return (0); } @@ -774,7 +868,7 @@ https_send(struct https_request *req, const char *method, const char *uri, HTTPScode https_recv(struct https_request *req, int *code, const char **body, int *len, - int msecs) + time_t *retry_after, int msecs) { int n, err; @@ -799,6 +893,8 @@ https_recv(struct https_request *req, int *code, const char **body, int *len, } *len = BIO_get_mem_data(req->body, (char **)body); *code = req->parser->status_code; + if (retry_after) + *retry_after = req->retry_after; return (HTTPS_OK); } diff --git a/lib/https.h b/lib/https.h index 90e0515..e5037a6 100644 --- a/lib/https.h +++ b/lib/https.h @@ -47,6 +47,7 @@ HTTPScode https_recv( int *code, const char **body, int *length, + time_t *retry_after, int msecs ); diff --git a/login_duo/login_duo.c b/login_duo/login_duo.c index 2db25bb..644bd1e 100644 --- a/login_duo/login_duo.c +++ b/login_duo/login_duo.c @@ -46,6 +46,9 @@ struct login_ctx { uid_t uid; }; +static void +die(const char *fmt, ...) __attribute__((noreturn)); + static void die(const char *fmt, ...) { diff --git a/tests/common_suites.py b/tests/common_suites.py index f2fa0e2..f35ba02 100644 --- a/tests/common_suites.py +++ b/tests/common_suites.py @@ -9,6 +9,7 @@ import os import subprocess +import time import unittest import sys @@ -325,6 +326,33 @@ def test_preauth_allow_bad_response(self): "preauth-allow-bad_response", "JSON missing valid 'status'" ) + def test_preauth_allow_retry_after(self): + start_time = time.time() + self.check_preauth_state( + "retry-after-3-preauth-allow", "preauth-allowed", prefix="Skipped" + ) + execution_time = time.time() - start_time + # 3.x seconds executed twice + self.assertGreater(execution_time, 6) + + def test_preauth_allow_retry_after_date(self): + start_time = time.time() + self.check_preauth_state( + "retry-after-date-preauth-allow", "preauth-allowed", prefix="Skipped" + ) + execution_time = time.time() - start_time + # 3.x seconds executed twice + self.assertGreater(execution_time, 6) + + def test_preauth_allow_rate_limited(self): + start_time = time.time() + self.check_preauth_state( + "rate-limited-preauth-allow", "preauth-allowed", prefix="Skipped" + ) + execution_time = time.time() - start_time + # 1.x seconds + 2.x seconds executed twice + self.assertGreater(execution_time, 6) + class Hosts(CommonTestCase): def run(self, result=None): with MockDuo(NORMAL_CERT): @@ -538,7 +566,7 @@ def test_configuration_with_extra_space(self): ) class Interactive(CommonTestCase): - PROMPT_REGEX = ".* or option \(1-4\): $" + PROMPT_REGEX = ".* or option \\(1-4\\): $" PROMPT_TEXT = [ "Duo login for foobar", "Choose or lose:", diff --git a/tests/mockduo.py b/tests/mockduo.py index 93bdc68..5e4c3c9 100755 --- a/tests/mockduo.py +++ b/tests/mockduo.py @@ -58,6 +58,11 @@ class MockDuoHandler(BaseHTTPRequestHandler): server_version = "MockDuo/1.0" protocol_version = "HTTP/1.1" + def __init__(self, *args, **kwargs): + self._rl_req_clock = 0 + self._rl_req_num = 0 + super().__init__(*args, **kwargs) + def _verify_sig(self): authz = base64.b64decode(self.headers["Authorization"].split()[1]).decode( "utf-8" @@ -119,9 +124,12 @@ def _get_tx_response(self, txid, is_async): time.sleep(int(secs)) return rsp - def _send(self, code, buf=b""): + def _send(self, code, buf=b"", headers=None): self.send_response(code) self.send_header("Content-length", str(len(buf))) + if headers: + for key, value in headers.items(): + self.send_header(key, value) if buf: self.send_header("Content-type", "application/json") self.end_headers() @@ -230,6 +238,30 @@ def do_POST(self): ret["response"] = {"result": "enroll", "status": "please enroll"} elif self.args["user"] == "bad-json": buf = b"" + elif self.args["user"] == "retry-after-3-preauth-allow": + if self._rl_req_num == 0: + self._rl_req_num = 1 + return self._send(429, headers={"X-Retry-After": "3"}) + else: + self._rl_req_num = 0 + ret["response"] = {"result": "allow", "status": "preauth-allowed"} + elif self.args["user"] == "retry-after-date-preauth-allow": + if self._rl_req_num == 0: + self._rl_req_num = 1 + timestr = time.strftime("%a, %d %b %Y %H:%M:%S %Z", time.localtime(time.time()+3)) + return self._send(429, headers={"Retry-After": timestr}) + else: + self._rl_req_num = 0 + ret["response"] = {"result": "allow", "status": "preauth-allowed"} + elif self.args["user"] == "rate-limited-preauth-allow": + if self._rl_req_num in [0,1]: + self._rl_req_num += 1 + return self._send(429) + elif self._rl_req_num == 2: + self._rl_req_num = 0 + ret["response"] = {"result": "allow", "status": "preauth-allowed"} + else: + return self._send(500, "Wrong timeout") else: ret["response"] = { "result": "auth", diff --git a/tests/test_login_duo.py b/tests/test_login_duo.py index 2d41edf..3f8fdeb 100755 --- a/tests/test_login_duo.py +++ b/tests/test_login_duo.py @@ -237,7 +237,7 @@ def test_help_output(self): def test_version_output(self): """Check version output""" result = login_duo(["-v"]) - self.assertRegex(result["stderr"][0], "login_duo \d+\.\d+.\d+") + self.assertRegex(result["stderr"][0], "login_duo \\d+\\.\\d+.\\d+") class TestLoginDuoEnv(CommonSuites.Env): diff --git a/tests/test_pam_duo.py b/tests/test_pam_duo.py index c3ba640..75ff093 100755 --- a/tests/test_pam_duo.py +++ b/tests/test_pam_duo.py @@ -84,7 +84,7 @@ def pam_duo_interactive(args, env={}, timeout=2): return process -def pam_duo(args, env={}, timeout=2): +def pam_duo(args, env={}, timeout=10): pam_duo_path = [os.path.join(TESTDIR, "testpam.py")] # we don't want to accidentally grab these from the calling environment excluded_keys = ["SSH_CONNECTION", "FALLBACK", "UID", "http_proxy", "TIMEOUT"]