Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gracefully handle Auth API rate-limiting #306

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 42 additions & 5 deletions lib/duo.c
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,47 @@ _duo_json_response(struct duo_ctx *ctx) {
return code;
}

int
_duo_https_exchange(struct duo_ctx *ctx, const char *method, const char *uri, int msecs, int *code)
{
const int max_int_digits = (241 * sizeof(int) / 100 + 1);
const int max_backoff_wait_secs = 32;
const int initial_backof_wait_secs = 1;
const int backoff_factor = 2;

static const char fmt[] = "Rate-limiting response received from server. Waiting for %d seconds before retrying.";
char msg[(sizeof fmt) + max_int_digits];
int wait_secs = initial_backof_wait_secs;

while (1) {
HTTPScode rc;
time_t retry_after;

rc = https_send(ctx->https, method, uri,
ctx->argc, ctx->argv, ctx->ikey, ctx->skey, ctx->useragent);
if (rc != HTTPS_OK)
return rc;
rc = https_recv(ctx->https, code, &ctx->body, &ctx->body_len, &retry_after, msecs);
if (retry_after != (time_t)-1)
wait_secs = retry_after - time(NULL);

if (rc != HTTPS_OK || *code != 429 || wait_secs > max_backoff_wait_secs)
return rc;

struct timespec timeout = {
.tv_sec = wait_secs,
.tv_nsec = (float)rand() / RAND_MAX * 1000000000
};

snprintf(msg, sizeof(msg), fmt, timeout.tv_sec);
if (ctx->conv_status)
ctx->conv_status(NULL, msg);
nanosleep(&timeout, NULL);
if (retry_after == (time_t)-1)
wait_secs *= backoff_factor;
}
}

static duo_code_t
duo_call(struct duo_ctx *ctx, const char *method, const char *uri, int msecs)
{
Expand All @@ -361,12 +402,8 @@ duo_call(struct duo_ctx *ctx, const char *method, const char *uri, int msecs)
}
break;
}
if ((err = https_send(ctx->https, method, uri,
ctx->argc, ctx->argv, ctx->ikey, ctx->skey, ctx->useragent)) == HTTPS_OK &&
(err = https_recv(ctx->https, &code,
&ctx->body, &ctx->body_len, msecs)) == HTTPS_OK) {
if (_duo_https_exchange(ctx, method, uri, msecs, &code) == HTTPS_OK)
break;
}
https_close(&ctx->https);
}
duo_reset(ctx);
Expand Down
98 changes: 97 additions & 1 deletion lib/https.c
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ struct https_ctx {

struct https_ctx ctx;

typedef enum
{
CB_NONE = 0, /* First callback*/
CB_KEY, /* Last was key */
CB_VAL /* Last was value */
} callback_status_t;

struct https_request {
BIO *cbio;
BIO *body;
Expand All @@ -70,6 +77,14 @@ struct https_request {

int sigpipe_ignored;
struct sigaction old_sigpipe;

time_t retry_after;
AaronAtDuo marked this conversation as resolved.
Show resolved Hide resolved

char *value;
size_t value_size;
char* key; /* current header name */
size_t key_size; /* size of header name */
callback_status_t last_cb;
};

static int
Expand All @@ -80,15 +95,92 @@ __on_body(http_parser *p, const char *buf, size_t len)
return (BIO_write(req->body, buf, len) != len);
}

time_t
_parse_retry_after(const char *header_value)
{
if (header_value == NULL) {
return (time_t)-1;
}

/* Try to parse as an integer (delay in seconds) */
char *endptr;
long delay_seconds = strtol(header_value, &endptr, 10);
if (*endptr == '\0') {
return time(NULL) + delay_seconds;
}

/* Try to parse as a date */
struct tm tm;
memset(&tm, 0, sizeof(struct tm));
if (strptime(header_value, "%a, %d %b %Y %H:%M:%S %Z", &tm) != NULL) {
return mktime(&tm);
}

return (time_t)-1;
}

static int
__on_message_complete(http_parser *p)
{
struct https_request *req = (struct https_request *)p->data;

req->retry_after = _parse_retry_after(req->value);

free(req->value);
req->value = NULL;
req->value_size = 0;
free(req->key);
req->key = NULL;
req->key_size = 0;
req->last_cb = CB_NONE;

req->done = 1;
return (0);
}

static const char retry_after_header[] = "Retry-After";
static const char x_retry_after_header[] = "X-Retry-After";

static int
__on_header_field(http_parser* p, const char* at, size_t length)
{
struct https_request *client = p->data;

if (client->last_cb == CB_VAL)
client->key_size = 0;

client->key = realloc(client->key, client->key_size + length + 1);
memcpy(client->key + client->key_size, at, length);
client->key_size += length;
client->key[client->key_size] = 0;

client->last_cb = CB_KEY;

return 0;
}

static int
__on_header_value(http_parser* p, const char* at, size_t length)
{
struct https_request *client = p->data;

if (strcasecmp(client->key, retry_after_header) == 0
|| strcasecmp(client->key, x_retry_after_header) == 0)
{
if (client->last_cb != CB_VAL)
client->value_size = 0;

client->value = realloc(client->value, client->value_size + length + 1);
memcpy(client->value + client->value_size, at, length);
client->value_size += length;
client->value[client->value_size] = 0;
}

client->last_cb = CB_VAL;

return 0;
}

static const char *
_SSL_strerror(void)
{
Expand Down Expand Up @@ -504,6 +596,8 @@ https_init(const char *cafile, const char *http_proxy)
/* Set HTTP parser callbacks */
ctx.parse_settings.on_body = __on_body;
ctx.parse_settings.on_message_complete = __on_message_complete;
ctx.parse_settings.on_header_field = __on_header_field;
ctx.parse_settings.on_header_value = __on_header_value;

return (0);
}
Expand Down Expand Up @@ -774,7 +868,7 @@ https_send(struct https_request *req, const char *method, const char *uri,

HTTPScode
https_recv(struct https_request *req, int *code, const char **body, int *len,
int msecs)
time_t *retry_after, int msecs)
{
int n, err;

Expand All @@ -799,6 +893,8 @@ https_recv(struct https_request *req, int *code, const char **body, int *len,
}
*len = BIO_get_mem_data(req->body, (char **)body);
*code = req->parser->status_code;
if (retry_after)
*retry_after = req->retry_after;

return (HTTPS_OK);
}
Expand Down
1 change: 1 addition & 0 deletions lib/https.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ HTTPScode https_recv(
int *code,
const char **body,
int *length,
time_t *retry_after,
int msecs
);

Expand Down
3 changes: 3 additions & 0 deletions login_duo/login_duo.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ struct login_ctx {
uid_t uid;
};

static void
die(const char *fmt, ...) __attribute__((noreturn));

static void
die(const char *fmt, ...)
{
Expand Down
30 changes: 29 additions & 1 deletion tests/common_suites.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import os
import subprocess
import time
import unittest
import sys

Expand Down Expand Up @@ -325,6 +326,33 @@ def test_preauth_allow_bad_response(self):
"preauth-allow-bad_response", "JSON missing valid 'status'"
)

def test_preauth_allow_retry_after(self):
start_time = time.time()
self.check_preauth_state(
"retry-after-3-preauth-allow", "preauth-allowed", prefix="Skipped"
)
execution_time = time.time() - start_time
# 3.x seconds executed twice
self.assertGreater(execution_time, 6)

def test_preauth_allow_retry_after_date(self):
start_time = time.time()
self.check_preauth_state(
"retry-after-date-preauth-allow", "preauth-allowed", prefix="Skipped"
)
execution_time = time.time() - start_time
# 3.x seconds executed twice
self.assertGreater(execution_time, 6)

def test_preauth_allow_rate_limited(self):
start_time = time.time()
self.check_preauth_state(
"rate-limited-preauth-allow", "preauth-allowed", prefix="Skipped"
)
execution_time = time.time() - start_time
# 1.x seconds + 2.x seconds executed twice
self.assertGreater(execution_time, 6)

class Hosts(CommonTestCase):
def run(self, result=None):
with MockDuo(NORMAL_CERT):
Expand Down Expand Up @@ -538,7 +566,7 @@ def test_configuration_with_extra_space(self):
)

class Interactive(CommonTestCase):
PROMPT_REGEX = ".* or option \(1-4\): $"
PROMPT_REGEX = ".* or option \\(1-4\\): $"
PROMPT_TEXT = [
"Duo login for foobar",
"Choose or lose:",
Expand Down
34 changes: 33 additions & 1 deletion tests/mockduo.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ class MockDuoHandler(BaseHTTPRequestHandler):
server_version = "MockDuo/1.0"
protocol_version = "HTTP/1.1"

def __init__(self, *args, **kwargs):
self._rl_req_clock = 0
self._rl_req_num = 0
super().__init__(*args, **kwargs)

def _verify_sig(self):
authz = base64.b64decode(self.headers["Authorization"].split()[1]).decode(
"utf-8"
Expand Down Expand Up @@ -119,9 +124,12 @@ def _get_tx_response(self, txid, is_async):
time.sleep(int(secs))
return rsp

def _send(self, code, buf=b""):
def _send(self, code, buf=b"", headers=None):
self.send_response(code)
self.send_header("Content-length", str(len(buf)))
if headers:
for key, value in headers.items():
self.send_header(key, value)
if buf:
self.send_header("Content-type", "application/json")
self.end_headers()
Expand Down Expand Up @@ -230,6 +238,30 @@ def do_POST(self):
ret["response"] = {"result": "enroll", "status": "please enroll"}
elif self.args["user"] == "bad-json":
buf = b""
elif self.args["user"] == "retry-after-3-preauth-allow":
if self._rl_req_num == 0:
self._rl_req_num = 1
return self._send(429, headers={"X-Retry-After": "3"})
else:
self._rl_req_num = 0
ret["response"] = {"result": "allow", "status": "preauth-allowed"}
elif self.args["user"] == "retry-after-date-preauth-allow":
if self._rl_req_num == 0:
self._rl_req_num = 1
timestr = time.strftime("%a, %d %b %Y %H:%M:%S %Z", time.localtime(time.time()+3))
return self._send(429, headers={"Retry-After": timestr})
else:
self._rl_req_num = 0
ret["response"] = {"result": "allow", "status": "preauth-allowed"}
elif self.args["user"] == "rate-limited-preauth-allow":
if self._rl_req_num in [0,1]:
self._rl_req_num += 1
return self._send(429)
elif self._rl_req_num == 2:
self._rl_req_num = 0
ret["response"] = {"result": "allow", "status": "preauth-allowed"}
else:
return self._send(500, "Wrong timeout")
else:
ret["response"] = {
"result": "auth",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_login_duo.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def test_help_output(self):
def test_version_output(self):
"""Check version output"""
result = login_duo(["-v"])
self.assertRegex(result["stderr"][0], "login_duo \d+\.\d+.\d+")
self.assertRegex(result["stderr"][0], "login_duo \\d+\\.\\d+.\\d+")


class TestLoginDuoEnv(CommonSuites.Env):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pam_duo.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def pam_duo_interactive(args, env={}, timeout=2):
return process


def pam_duo(args, env={}, timeout=2):
def pam_duo(args, env={}, timeout=10):
pam_duo_path = [os.path.join(TESTDIR, "testpam.py")]
# we don't want to accidentally grab these from the calling environment
excluded_keys = ["SSH_CONNECTION", "FALLBACK", "UID", "http_proxy", "TIMEOUT"]
Expand Down
Loading