Skip to content

Commit

Permalink
Fix auth query password caching segfault (#536)
Browse files Browse the repository at this point in the history
  • Loading branch information
reshke authored Oct 26, 2023
1 parent d95e865 commit 5faba9b
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 31 deletions.
15 changes: 11 additions & 4 deletions sources/auth_query.c
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,21 @@ int od_auth_query(od_client_t *client, char *peer)

if (value->data == NULL) {
/* one-time initialize */
value->data = malloc(sizeof(od_auth_cache_value_t));
value->len = sizeof(od_auth_cache_value_t);
value->data = malloc(value->len);
/* OOM */
if (value->data == NULL) {
goto error;
}
memset(((od_auth_cache_value_t *)(value->data)), 0, value->len);
}

cache_value = (od_auth_cache_value_t *)value->data;

current_time = machine_time_us();

if (cache_value != NULL
/* password cached for 10 sec */
&& current_time - cache_value->timestamp < 10 * interval_usec) {
if (/* password cached for 10 sec */
current_time - cache_value->timestamp < 10 * interval_usec) {
od_debug(&instance->logger, "auth_query", NULL, NULL,
"reusing cached password for user %.*s",
user->name_len, user->name);
Expand Down Expand Up @@ -246,6 +250,9 @@ int od_auth_query(od_client_t *client, char *peer)
if (cache_value->passwd != NULL) {
/* drop previous value */
free(cache_value->passwd);

// there should be cache_value->passwd = NULL for sanity
// but this is meaninigless sinse we assing new value just below
}
cache_value->passwd_len = password->password_len;
cache_value->passwd = malloc(password->password_len);
Expand Down
7 changes: 4 additions & 3 deletions sources/scram.c
Original file line number Diff line number Diff line change
Expand Up @@ -844,11 +844,12 @@ od_retcode_t od_scram_verify_client_proof(od_scram_state_t *scram_state,
for (int i = 0; i < OD_SCRAM_MAX_KEY_LEN; i++)
client_key[i] = client_proof[i] ^ client_signature[i];

od_scram_H(client_key, OD_SCRAM_MAX_KEY_LEN, client_stored_key, &errstr);
od_scram_H(client_key, OD_SCRAM_MAX_KEY_LEN, client_stored_key,
&errstr);
od_scram_HMAC_free(ctx);

if (memcmp(client_stored_key, scram_state->stored_key, OD_SCRAM_MAX_KEY_LEN) !=
0)
if (memcmp(client_stored_key, scram_state->stored_key,
OD_SCRAM_MAX_KEY_LEN) != 0)
return NOT_OK_RESPONSE;

return OK_RESPONSE;
Expand Down
49 changes: 25 additions & 24 deletions sources/scram.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,55 +63,56 @@ typedef struct pg_hmac_ctx od_scram_ctx_t;

#endif


#if PG_VERSION_NUM >= 160000

#define od_scram_ServerKey(salted_password, result, errstr) \
scram_ServerKey(salted_password, PG_SHA256, SCRAM_SHA_256_KEY_LEN, result, errstr)
#define od_scram_ServerKey(salted_password, result, errstr) \
scram_ServerKey(salted_password, PG_SHA256, SCRAM_SHA_256_KEY_LEN, \
result, errstr)

#define od_scram_SaltedPassword(password, salt, saltlen, iterations, result, \
errstr) \
scram_SaltedPassword(password, PG_SHA256, SCRAM_SHA_256_KEY_LEN, salt, saltlen, iterations, result, \
errstr)
#define od_scram_SaltedPassword(password, salt, saltlen, iterations, result, \
errstr) \
scram_SaltedPassword(password, PG_SHA256, SCRAM_SHA_256_KEY_LEN, salt, \
saltlen, iterations, result, errstr)

# define od_scram_H(input, len, result, errstr) \
#define od_scram_H(input, len, result, errstr) \
scram_H(input, PG_SHA256, SCRAM_SHA_256_KEY_LEN, result, errstr)

# define od_scram_ClientKey(salted_password, result, errstr) \
scram_ClientKey(salted_password, PG_SHA256, SCRAM_SHA_256_KEY_LEN, result, errstr)
#define od_scram_ClientKey(salted_password, result, errstr) \
scram_ClientKey(salted_password, PG_SHA256, SCRAM_SHA_256_KEY_LEN, \
result, errstr)

#else
#else

# if PG_VERSION_NUM >= 150000
# define od_scram_ServerKey(salted_password, result, errstr) \
scram_ServerKey(salted_password, result, errstr)
#if PG_VERSION_NUM >= 150000
#define od_scram_ServerKey(salted_password, result, errstr) \
scram_ServerKey(salted_password, result, errstr)

# define od_scram_SaltedPassword(password, salt, saltlen, iterations, result, \
#define od_scram_SaltedPassword(password, salt, saltlen, iterations, result, \
errstr) \
scram_SaltedPassword(password, salt, saltlen, iterations, result, \
errstr)

# define od_scram_H(input, len, result, errstr) \
#define od_scram_H(input, len, result, errstr) \
scram_H(input, len, result, errstr)

# define od_scram_ClientKey(salted_password, result, errstr) \
#define od_scram_ClientKey(salted_password, result, errstr) \
scram_ClientKey(salted_password, result, errstr)

# else
#else

# define od_scram_ServerKey(salted_password, result, errstr) \
scram_ServerKey(salted_password, result)
#define od_scram_ServerKey(salted_password, result, errstr) \
scram_ServerKey(salted_password, result)

# define od_scram_SaltedPassword(password, salt, saltlen, iterations, result, \
#define od_scram_SaltedPassword(password, salt, saltlen, iterations, result, \
errstr) \
scram_SaltedPassword(password, salt, saltlen, iterations, result)

# define od_scram_H(input, len, result, errstr) scram_H(input, len, result)
#define od_scram_H(input, len, result, errstr) scram_H(input, len, result)

# define od_scram_ClientKey(salted_password, result, errstr) \
#define od_scram_ClientKey(salted_password, result, errstr) \
scram_ClientKey(salted_password, result)

# endif
#endif
#endif

typedef struct od_scram_state od_scram_state_t;
Expand Down

0 comments on commit 5faba9b

Please sign in to comment.