Skip to content

Commit

Permalink
Fix sccram compile with pg >= 16
Browse files Browse the repository at this point in the history
  • Loading branch information
reshke committed Oct 16, 2023
1 parent abf4e36 commit 6d02a9c
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 44 deletions.
62 changes: 31 additions & 31 deletions sources/scram.c
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ int od_scram_parse_verifier(od_scram_state_t *scram_state, char *verifier)

int stored_key_len = od_b64_decode(stored_key_raw, stored_key_raw_len,
stored_key, stored_key_dst_len);
if (stored_key_len != SCRAM_KEY_LEN)
if (stored_key_len != OD_SCRAM_MAX_KEY_LEN)
goto error;

memcpy(scram_state->stored_key, stored_key, SCRAM_KEY_LEN);
memcpy(scram_state->stored_key, stored_key, OD_SCRAM_MAX_KEY_LEN);

int server_key_raw_len = strlen(server_key_raw);
int server_key_dst_len = pg_b64_dec_len(server_key_raw_len);
Expand All @@ -94,10 +94,10 @@ int od_scram_parse_verifier(od_scram_state_t *scram_state, char *verifier)

int server_key_len = od_b64_decode(server_key_raw, server_key_raw_len,
server_key, server_key_dst_len);
if (server_key_len != SCRAM_KEY_LEN)
if (server_key_len != OD_SCRAM_MAX_KEY_LEN)
goto error;

memcpy(scram_state->server_key, server_key, SCRAM_KEY_LEN);
memcpy(scram_state->server_key, server_key, OD_SCRAM_MAX_KEY_LEN);

free(stored_key);
free(server_key);
Expand Down Expand Up @@ -133,7 +133,7 @@ int od_scram_init_from_plain_password(od_scram_state_t *scram_state,
char salt[SCRAM_DEFAULT_SALT_LEN];
RAND_bytes((uint8_t *)salt, sizeof(salt));

scram_state->iterations = SCRAM_DEFAULT_ITERATIONS;
scram_state->iterations = OD_SCRAM_SHA_256_DEFAULT_ITERATIONS;

int salt_dst_len = pg_b64_enc_len(sizeof(salt)) + 1;
scram_state->salt = malloc(salt_dst_len);
Expand All @@ -145,12 +145,12 @@ int od_scram_init_from_plain_password(od_scram_state_t *scram_state,
scram_state->salt[base64_salt_len] = '\0';
const char *errstr = NULL;

uint8_t salted_password[SCRAM_KEY_LEN];
uint8_t salted_password[OD_SCRAM_MAX_KEY_LEN];
od_scram_SaltedPassword(password, salt, sizeof(salt),
scram_state->iterations, salted_password,
&errstr);
od_scram_ClientKey(salted_password, scram_state->stored_key, &errstr);
od_scram_H(scram_state->stored_key, SCRAM_KEY_LEN,
od_scram_H(scram_state->stored_key, OD_SCRAM_MAX_KEY_LEN,
scram_state->stored_key, &errstr);
od_scram_ServerKey(salted_password, scram_state->server_key, &errstr);

Expand Down Expand Up @@ -298,7 +298,7 @@ static int calculate_client_proof(od_scram_state_t *scram_state,
if (prepared_password == NULL)
return -1;

scram_state->salted_password = malloc(SCRAM_KEY_LEN);
scram_state->salted_password = malloc(OD_SCRAM_MAX_KEY_LEN);
if (scram_state->salted_password == NULL)
goto error;

Expand All @@ -309,13 +309,13 @@ static int calculate_client_proof(od_scram_state_t *scram_state,
iterations, scram_state->salted_password,
&errstr);

uint8_t client_key[SCRAM_KEY_LEN];
uint8_t client_key[OD_SCRAM_MAX_KEY_LEN];
od_scram_ClientKey(scram_state->salted_password, client_key, &errstr);

uint8_t stored_key[SCRAM_KEY_LEN];
od_scram_H(client_key, SCRAM_KEY_LEN, stored_key, &errstr);
uint8_t stored_key[OD_SCRAM_MAX_KEY_LEN];
od_scram_H(client_key, OD_SCRAM_MAX_KEY_LEN, stored_key, &errstr);

od_scram_HMAC_init(ctx, stored_key, SCRAM_KEY_LEN);
od_scram_HMAC_init(ctx, stored_key, OD_SCRAM_MAX_KEY_LEN);

od_scram_HMAC_update(ctx, scram_state->client_first_message,
strlen(scram_state->client_first_message));
Expand All @@ -326,10 +326,10 @@ static int calculate_client_proof(od_scram_state_t *scram_state,
od_scram_HMAC_update(ctx, client_final_message,
strlen(client_final_message));

uint8_t client_signature[SCRAM_KEY_LEN];
uint8_t client_signature[OD_SCRAM_MAX_KEY_LEN];
od_scram_HMAC_final(client_signature, ctx);

for (int i = 0; i < SCRAM_KEY_LEN; i++)
for (int i = 0; i < OD_SCRAM_MAX_KEY_LEN; i++)
client_proof[i] = client_key[i] ^ client_signature[i];

od_scram_HMAC_free(ctx);
Expand All @@ -347,7 +347,7 @@ static char *calculate_server_signature(od_scram_state_t *scram_state)
{
od_scram_ctx_t *ctx = od_scram_HMAC_create();

od_scram_HMAC_init(ctx, scram_state->server_key, SCRAM_KEY_LEN);
od_scram_HMAC_init(ctx, scram_state->server_key, OD_SCRAM_MAX_KEY_LEN);
od_scram_HMAC_update(ctx, scram_state->client_first_message,
strlen(scram_state->client_first_message));
od_scram_HMAC_update(ctx, ",", 1);
Expand All @@ -357,17 +357,17 @@ static char *calculate_server_signature(od_scram_state_t *scram_state)
od_scram_HMAC_update(ctx, scram_state->client_final_message,
strlen(scram_state->client_final_message));

uint8_t server_signature[SCRAM_KEY_LEN];
uint8_t server_signature[OD_SCRAM_MAX_KEY_LEN];
od_scram_HMAC_final(server_signature, ctx);
od_scram_HMAC_free(ctx);

int base64_signature_dst_len = pg_b64_enc_len(SCRAM_KEY_LEN) + 1;
int base64_signature_dst_len = pg_b64_enc_len(OD_SCRAM_MAX_KEY_LEN) + 1;
char *base64_signature = malloc(base64_signature_dst_len);
if (base64_signature == NULL)
return NULL;

int base64_signature_len =
od_b64_encode((char *)server_signature, SCRAM_KEY_LEN,
od_b64_encode((char *)server_signature, OD_SCRAM_MAX_KEY_LEN,
base64_signature, base64_signature_dst_len);
base64_signature[base64_signature_len] = '\0';

Expand Down Expand Up @@ -404,7 +404,7 @@ od_scram_create_client_final_message(od_scram_state_t *scram_state,
if (scram_state->client_final_message == NULL)
return NULL;

uint8_t client_proof[SCRAM_KEY_LEN];
uint8_t client_proof[OD_SCRAM_MAX_KEY_LEN];
rc = calculate_client_proof(scram_state, password, salt, iterations,
result, client_proof);
if (rc == -1)
Expand All @@ -419,7 +419,7 @@ od_scram_create_client_final_message(od_scram_state_t *scram_state,
result[size++] = 'p';
result[size++] = '=';

size += od_b64_encode((char *)client_proof, SCRAM_KEY_LEN,
size += od_b64_encode((char *)client_proof, OD_SCRAM_MAX_KEY_LEN,
result + size, SCRAM_FINAL_MAX_SIZE - size);
#undef SCRAM_FINAL_MAX_SIZE
result[size] = '\0';
Expand Down Expand Up @@ -458,10 +458,10 @@ int read_server_final_message(char *auth_data, size_t auth_data_size,
decoded_signature_len =
od_b64_decode(signature, signature_size, decoded_signature,
decoded_signature_len);
if (decoded_signature_len != SCRAM_KEY_LEN)
if (decoded_signature_len != OD_SCRAM_MAX_KEY_LEN)
goto error;

memcpy(server_signature, decoded_signature, SCRAM_KEY_LEN);
memcpy(server_signature, decoded_signature, OD_SCRAM_MAX_KEY_LEN);

free(decoded_signature);
return 0;
Expand All @@ -486,9 +486,9 @@ od_retcode_t od_scram_verify_server_signature(od_scram_state_t *scram_state,
od_scram_ctx_t *ctx = od_scram_HMAC_create();

const char *errstr = NULL;
uint8_t server_key[SCRAM_KEY_LEN];
uint8_t server_key[OD_SCRAM_MAX_KEY_LEN];
od_scram_ServerKey(scram_state->salted_password, server_key, &errstr);
od_scram_HMAC_init(ctx, server_key, SCRAM_KEY_LEN);
od_scram_HMAC_init(ctx, server_key, OD_SCRAM_MAX_KEY_LEN);

od_scram_HMAC_update(ctx, scram_state->client_first_message,
strlen(scram_state->client_first_message));
Expand Down Expand Up @@ -823,14 +823,14 @@ od_retcode_t od_scram_verify_final_nonce(od_scram_state_t *scram_state,
od_retcode_t od_scram_verify_client_proof(od_scram_state_t *scram_state,
char *client_proof)
{
uint8_t client_signature[SCRAM_KEY_LEN];
uint8_t client_key[SCRAM_KEY_LEN];
uint8_t client_stored_key[SCRAM_KEY_LEN];
uint8_t client_signature[OD_SCRAM_MAX_KEY_LEN];
uint8_t client_key[OD_SCRAM_MAX_KEY_LEN];
uint8_t client_stored_key[OD_SCRAM_MAX_KEY_LEN];

od_scram_ctx_t *ctx = od_scram_HMAC_create();
const char *errstr = NULL;

od_scram_HMAC_init(ctx, scram_state->stored_key, SCRAM_KEY_LEN);
od_scram_HMAC_init(ctx, scram_state->stored_key, OD_SCRAM_MAX_KEY_LEN);
od_scram_HMAC_update(ctx, scram_state->client_first_message,
strlen(scram_state->client_first_message));
od_scram_HMAC_update(ctx, ",", 1);
Expand All @@ -841,13 +841,13 @@ od_retcode_t od_scram_verify_client_proof(od_scram_state_t *scram_state,
strlen(scram_state->client_final_message));
od_scram_HMAC_final(client_signature, ctx);

for (int i = 0; i < SCRAM_KEY_LEN; i++)
for (int i = 0; i < OD_SCRAM_MAX_KEY_LEN; i++)
client_key[i] = client_proof[i] ^ client_signature[i];

od_scram_H(client_key, SCRAM_KEY_LEN, client_stored_key, &errstr);
od_scram_H(client_key, OD_SCRAM_MAX_KEY_LEN, client_stored_key, &errstr);
od_scram_HMAC_free(ctx);

if (memcmp(client_stored_key, scram_state->stored_key, SCRAM_KEY_LEN) !=
if (memcmp(client_stored_key, scram_state->stored_key, OD_SCRAM_MAX_KEY_LEN) !=
0)
return NOT_OK_RESPONSE;

Expand Down
59 changes: 46 additions & 13 deletions sources/scram.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
* Scalable PostgreSQL connection pooler.
*/

#if PG_VERSION_NUM >= 160000
#define OD_SCRAM_MAX_KEY_LEN SCRAM_MAX_KEY_LEN
#define OD_SCRAM_SHA_256_DEFAULT_ITERATIONS SCRAM_SHA_256_DEFAULT_ITERATIONS
#else
#define OD_SCRAM_MAX_KEY_LEN SCRAM_KEY_LEN
#define OD_SCRAM_SHA_256_DEFAULT_ITERATIONS SCRAM_DEFAULT_ITERATIONS
#endif

#if PG_VERSION_NUM >= 130000
#define od_b64_encode(src, src_len, dst, dst_len) \
pg_b64_encode(src, src_len, dst, dst_len);
Expand Down Expand Up @@ -55,30 +63,55 @@ typedef struct pg_hmac_ctx od_scram_ctx_t;

#endif

#if PG_VERSION_NUM < 150000

#define od_scram_H(input, len, result, errstr) scram_H(input, len, result)
#if PG_VERSION_NUM >= 160000

#define od_scram_ServerKey(salted_password, result, errstr) \
scram_ServerKey(salted_password, result)
scram_ServerKey(salted_password, PG_SHA256, SCRAM_SHA_256_KEY_LEN, result, errstr)

#define od_scram_SaltedPassword(password, salt, saltlen, iterations, result, \
errstr) \
scram_SaltedPassword(password, salt, saltlen, iterations, result)
#define od_scram_ClientKey(salted_password, result, errstr) \
scram_ClientKey(salted_password, result)
scram_SaltedPassword(password, PG_SHA256, SCRAM_SHA_256_KEY_LEN, salt, saltlen, iterations, result, \
errstr)

#else
# define od_scram_H(input, len, result, errstr) \
scram_H(input, PG_SHA256, SCRAM_SHA_256_KEY_LEN, result, errstr)

#define od_scram_H(input, len, result, errstr) \
scram_H(input, len, result, errstr)
#define od_scram_ServerKey(salted_password, result, errstr) \
scram_ServerKey(salted_password, result, errstr)
#define od_scram_SaltedPassword(password, salt, saltlen, iterations, result, \
# define od_scram_ClientKey(salted_password, result, errstr) \
scram_ClientKey(salted_password, PG_SHA256, SCRAM_SHA_256_KEY_LEN, result, errstr)

#else

# if PG_VERSION_NUM >= 150000
# define od_scram_ServerKey(salted_password, result, errstr) \
scram_ServerKey(salted_password, result, errstr)

# define od_scram_SaltedPassword(password, salt, saltlen, iterations, result, \
errstr) \
scram_SaltedPassword(password, salt, saltlen, iterations, result, \
errstr)
#define od_scram_ClientKey(salted_password, result, errstr) \

# define od_scram_H(input, len, result, errstr) \
scram_H(input, len, result, errstr)

# define od_scram_ClientKey(salted_password, result, errstr) \
scram_ClientKey(salted_password, result, errstr)

# else

# define od_scram_ServerKey(salted_password, result, errstr) \
scram_ServerKey(salted_password, result)

# define od_scram_SaltedPassword(password, salt, saltlen, iterations, result, \
errstr) \
scram_SaltedPassword(password, salt, saltlen, iterations, result)

# define od_scram_H(input, len, result, errstr) scram_H(input, len, result)

# define od_scram_ClientKey(salted_password, result, errstr) \
scram_ClientKey(salted_password, result)

# endif
#endif

typedef struct od_scram_state od_scram_state_t;
Expand Down

0 comments on commit 6d02a9c

Please sign in to comment.