From d95e8650bcc71c1b20de16a498ebd02dea170a5e Mon Sep 17 00:00:00 2001 From: reshke Date: Mon, 16 Oct 2023 14:30:34 +0500 Subject: [PATCH] Fix sccram compile with pg >= 16 (#535) --- sources/scram.c | 62 ++++++++++++++++++++++++------------------------- sources/scram.h | 59 +++++++++++++++++++++++++++++++++++----------- 2 files changed, 77 insertions(+), 44 deletions(-) diff --git a/sources/scram.c b/sources/scram.c index b011208ec..b88417860 100644 --- a/sources/scram.c +++ b/sources/scram.c @@ -81,10 +81,10 @@ int od_scram_parse_verifier(od_scram_state_t *scram_state, char *verifier) int stored_key_len = od_b64_decode(stored_key_raw, stored_key_raw_len, stored_key, stored_key_dst_len); - if (stored_key_len != SCRAM_KEY_LEN) + if (stored_key_len != OD_SCRAM_MAX_KEY_LEN) goto error; - memcpy(scram_state->stored_key, stored_key, SCRAM_KEY_LEN); + memcpy(scram_state->stored_key, stored_key, OD_SCRAM_MAX_KEY_LEN); int server_key_raw_len = strlen(server_key_raw); int server_key_dst_len = pg_b64_dec_len(server_key_raw_len); @@ -94,10 +94,10 @@ int od_scram_parse_verifier(od_scram_state_t *scram_state, char *verifier) int server_key_len = od_b64_decode(server_key_raw, server_key_raw_len, server_key, server_key_dst_len); - if (server_key_len != SCRAM_KEY_LEN) + if (server_key_len != OD_SCRAM_MAX_KEY_LEN) goto error; - memcpy(scram_state->server_key, server_key, SCRAM_KEY_LEN); + memcpy(scram_state->server_key, server_key, OD_SCRAM_MAX_KEY_LEN); free(stored_key); free(server_key); @@ -133,7 +133,7 @@ int od_scram_init_from_plain_password(od_scram_state_t *scram_state, char salt[SCRAM_DEFAULT_SALT_LEN]; RAND_bytes((uint8_t *)salt, sizeof(salt)); - scram_state->iterations = SCRAM_DEFAULT_ITERATIONS; + scram_state->iterations = OD_SCRAM_SHA_256_DEFAULT_ITERATIONS; int salt_dst_len = pg_b64_enc_len(sizeof(salt)) + 1; scram_state->salt = malloc(salt_dst_len); @@ -145,12 +145,12 @@ int od_scram_init_from_plain_password(od_scram_state_t *scram_state, scram_state->salt[base64_salt_len] = '\0'; const char *errstr = NULL; - uint8_t salted_password[SCRAM_KEY_LEN]; + uint8_t salted_password[OD_SCRAM_MAX_KEY_LEN]; od_scram_SaltedPassword(password, salt, sizeof(salt), scram_state->iterations, salted_password, &errstr); od_scram_ClientKey(salted_password, scram_state->stored_key, &errstr); - od_scram_H(scram_state->stored_key, SCRAM_KEY_LEN, + od_scram_H(scram_state->stored_key, OD_SCRAM_MAX_KEY_LEN, scram_state->stored_key, &errstr); od_scram_ServerKey(salted_password, scram_state->server_key, &errstr); @@ -298,7 +298,7 @@ static int calculate_client_proof(od_scram_state_t *scram_state, if (prepared_password == NULL) return -1; - scram_state->salted_password = malloc(SCRAM_KEY_LEN); + scram_state->salted_password = malloc(OD_SCRAM_MAX_KEY_LEN); if (scram_state->salted_password == NULL) goto error; @@ -309,13 +309,13 @@ static int calculate_client_proof(od_scram_state_t *scram_state, iterations, scram_state->salted_password, &errstr); - uint8_t client_key[SCRAM_KEY_LEN]; + uint8_t client_key[OD_SCRAM_MAX_KEY_LEN]; od_scram_ClientKey(scram_state->salted_password, client_key, &errstr); - uint8_t stored_key[SCRAM_KEY_LEN]; - od_scram_H(client_key, SCRAM_KEY_LEN, stored_key, &errstr); + uint8_t stored_key[OD_SCRAM_MAX_KEY_LEN]; + od_scram_H(client_key, OD_SCRAM_MAX_KEY_LEN, stored_key, &errstr); - od_scram_HMAC_init(ctx, stored_key, SCRAM_KEY_LEN); + od_scram_HMAC_init(ctx, stored_key, OD_SCRAM_MAX_KEY_LEN); od_scram_HMAC_update(ctx, scram_state->client_first_message, strlen(scram_state->client_first_message)); @@ -326,10 +326,10 @@ static int calculate_client_proof(od_scram_state_t *scram_state, od_scram_HMAC_update(ctx, client_final_message, strlen(client_final_message)); - uint8_t client_signature[SCRAM_KEY_LEN]; + uint8_t client_signature[OD_SCRAM_MAX_KEY_LEN]; od_scram_HMAC_final(client_signature, ctx); - for (int i = 0; i < SCRAM_KEY_LEN; i++) + for (int i = 0; i < OD_SCRAM_MAX_KEY_LEN; i++) client_proof[i] = client_key[i] ^ client_signature[i]; od_scram_HMAC_free(ctx); @@ -347,7 +347,7 @@ static char *calculate_server_signature(od_scram_state_t *scram_state) { od_scram_ctx_t *ctx = od_scram_HMAC_create(); - od_scram_HMAC_init(ctx, scram_state->server_key, SCRAM_KEY_LEN); + od_scram_HMAC_init(ctx, scram_state->server_key, OD_SCRAM_MAX_KEY_LEN); od_scram_HMAC_update(ctx, scram_state->client_first_message, strlen(scram_state->client_first_message)); od_scram_HMAC_update(ctx, ",", 1); @@ -357,17 +357,17 @@ static char *calculate_server_signature(od_scram_state_t *scram_state) od_scram_HMAC_update(ctx, scram_state->client_final_message, strlen(scram_state->client_final_message)); - uint8_t server_signature[SCRAM_KEY_LEN]; + uint8_t server_signature[OD_SCRAM_MAX_KEY_LEN]; od_scram_HMAC_final(server_signature, ctx); od_scram_HMAC_free(ctx); - int base64_signature_dst_len = pg_b64_enc_len(SCRAM_KEY_LEN) + 1; + int base64_signature_dst_len = pg_b64_enc_len(OD_SCRAM_MAX_KEY_LEN) + 1; char *base64_signature = malloc(base64_signature_dst_len); if (base64_signature == NULL) return NULL; int base64_signature_len = - od_b64_encode((char *)server_signature, SCRAM_KEY_LEN, + od_b64_encode((char *)server_signature, OD_SCRAM_MAX_KEY_LEN, base64_signature, base64_signature_dst_len); base64_signature[base64_signature_len] = '\0'; @@ -404,7 +404,7 @@ od_scram_create_client_final_message(od_scram_state_t *scram_state, if (scram_state->client_final_message == NULL) return NULL; - uint8_t client_proof[SCRAM_KEY_LEN]; + uint8_t client_proof[OD_SCRAM_MAX_KEY_LEN]; rc = calculate_client_proof(scram_state, password, salt, iterations, result, client_proof); if (rc == -1) @@ -419,7 +419,7 @@ od_scram_create_client_final_message(od_scram_state_t *scram_state, result[size++] = 'p'; result[size++] = '='; - size += od_b64_encode((char *)client_proof, SCRAM_KEY_LEN, + size += od_b64_encode((char *)client_proof, OD_SCRAM_MAX_KEY_LEN, result + size, SCRAM_FINAL_MAX_SIZE - size); #undef SCRAM_FINAL_MAX_SIZE result[size] = '\0'; @@ -458,10 +458,10 @@ int read_server_final_message(char *auth_data, size_t auth_data_size, decoded_signature_len = od_b64_decode(signature, signature_size, decoded_signature, decoded_signature_len); - if (decoded_signature_len != SCRAM_KEY_LEN) + if (decoded_signature_len != OD_SCRAM_MAX_KEY_LEN) goto error; - memcpy(server_signature, decoded_signature, SCRAM_KEY_LEN); + memcpy(server_signature, decoded_signature, OD_SCRAM_MAX_KEY_LEN); free(decoded_signature); return 0; @@ -486,9 +486,9 @@ od_retcode_t od_scram_verify_server_signature(od_scram_state_t *scram_state, od_scram_ctx_t *ctx = od_scram_HMAC_create(); const char *errstr = NULL; - uint8_t server_key[SCRAM_KEY_LEN]; + uint8_t server_key[OD_SCRAM_MAX_KEY_LEN]; od_scram_ServerKey(scram_state->salted_password, server_key, &errstr); - od_scram_HMAC_init(ctx, server_key, SCRAM_KEY_LEN); + od_scram_HMAC_init(ctx, server_key, OD_SCRAM_MAX_KEY_LEN); od_scram_HMAC_update(ctx, scram_state->client_first_message, strlen(scram_state->client_first_message)); @@ -823,14 +823,14 @@ od_retcode_t od_scram_verify_final_nonce(od_scram_state_t *scram_state, od_retcode_t od_scram_verify_client_proof(od_scram_state_t *scram_state, char *client_proof) { - uint8_t client_signature[SCRAM_KEY_LEN]; - uint8_t client_key[SCRAM_KEY_LEN]; - uint8_t client_stored_key[SCRAM_KEY_LEN]; + uint8_t client_signature[OD_SCRAM_MAX_KEY_LEN]; + uint8_t client_key[OD_SCRAM_MAX_KEY_LEN]; + uint8_t client_stored_key[OD_SCRAM_MAX_KEY_LEN]; od_scram_ctx_t *ctx = od_scram_HMAC_create(); const char *errstr = NULL; - od_scram_HMAC_init(ctx, scram_state->stored_key, SCRAM_KEY_LEN); + od_scram_HMAC_init(ctx, scram_state->stored_key, OD_SCRAM_MAX_KEY_LEN); od_scram_HMAC_update(ctx, scram_state->client_first_message, strlen(scram_state->client_first_message)); od_scram_HMAC_update(ctx, ",", 1); @@ -841,13 +841,13 @@ od_retcode_t od_scram_verify_client_proof(od_scram_state_t *scram_state, strlen(scram_state->client_final_message)); od_scram_HMAC_final(client_signature, ctx); - for (int i = 0; i < SCRAM_KEY_LEN; i++) + for (int i = 0; i < OD_SCRAM_MAX_KEY_LEN; i++) client_key[i] = client_proof[i] ^ client_signature[i]; - od_scram_H(client_key, SCRAM_KEY_LEN, client_stored_key, &errstr); + od_scram_H(client_key, OD_SCRAM_MAX_KEY_LEN, client_stored_key, &errstr); od_scram_HMAC_free(ctx); - if (memcmp(client_stored_key, scram_state->stored_key, SCRAM_KEY_LEN) != + if (memcmp(client_stored_key, scram_state->stored_key, OD_SCRAM_MAX_KEY_LEN) != 0) return NOT_OK_RESPONSE; diff --git a/sources/scram.h b/sources/scram.h index a05d4924d..2096f5bc1 100644 --- a/sources/scram.h +++ b/sources/scram.h @@ -7,6 +7,14 @@ * Scalable PostgreSQL connection pooler. */ +#if PG_VERSION_NUM >= 160000 +#define OD_SCRAM_MAX_KEY_LEN SCRAM_MAX_KEY_LEN +#define OD_SCRAM_SHA_256_DEFAULT_ITERATIONS SCRAM_SHA_256_DEFAULT_ITERATIONS +#else +#define OD_SCRAM_MAX_KEY_LEN SCRAM_KEY_LEN +#define OD_SCRAM_SHA_256_DEFAULT_ITERATIONS SCRAM_DEFAULT_ITERATIONS +#endif + #if PG_VERSION_NUM >= 130000 #define od_b64_encode(src, src_len, dst, dst_len) \ pg_b64_encode(src, src_len, dst, dst_len); @@ -55,30 +63,55 @@ typedef struct pg_hmac_ctx od_scram_ctx_t; #endif -#if PG_VERSION_NUM < 150000 -#define od_scram_H(input, len, result, errstr) scram_H(input, len, result) +#if PG_VERSION_NUM >= 160000 + #define od_scram_ServerKey(salted_password, result, errstr) \ - scram_ServerKey(salted_password, result) + scram_ServerKey(salted_password, PG_SHA256, SCRAM_SHA_256_KEY_LEN, result, errstr) + #define od_scram_SaltedPassword(password, salt, saltlen, iterations, result, \ errstr) \ - scram_SaltedPassword(password, salt, saltlen, iterations, result) -#define od_scram_ClientKey(salted_password, result, errstr) \ - scram_ClientKey(salted_password, result) + scram_SaltedPassword(password, PG_SHA256, SCRAM_SHA_256_KEY_LEN, salt, saltlen, iterations, result, \ + errstr) -#else +# define od_scram_H(input, len, result, errstr) \ + scram_H(input, PG_SHA256, SCRAM_SHA_256_KEY_LEN, result, errstr) -#define od_scram_H(input, len, result, errstr) \ - scram_H(input, len, result, errstr) -#define od_scram_ServerKey(salted_password, result, errstr) \ - scram_ServerKey(salted_password, result, errstr) -#define od_scram_SaltedPassword(password, salt, saltlen, iterations, result, \ +# define od_scram_ClientKey(salted_password, result, errstr) \ + scram_ClientKey(salted_password, PG_SHA256, SCRAM_SHA_256_KEY_LEN, result, errstr) + +#else + +# if PG_VERSION_NUM >= 150000 +# define od_scram_ServerKey(salted_password, result, errstr) \ + scram_ServerKey(salted_password, result, errstr) + +# define od_scram_SaltedPassword(password, salt, saltlen, iterations, result, \ errstr) \ scram_SaltedPassword(password, salt, saltlen, iterations, result, \ errstr) -#define od_scram_ClientKey(salted_password, result, errstr) \ + +# define od_scram_H(input, len, result, errstr) \ + scram_H(input, len, result, errstr) + +# define od_scram_ClientKey(salted_password, result, errstr) \ scram_ClientKey(salted_password, result, errstr) +# else + +# define od_scram_ServerKey(salted_password, result, errstr) \ + scram_ServerKey(salted_password, result) + +# define od_scram_SaltedPassword(password, salt, saltlen, iterations, result, \ + errstr) \ + scram_SaltedPassword(password, salt, saltlen, iterations, result) + +# define od_scram_H(input, len, result, errstr) scram_H(input, len, result) + +# define od_scram_ClientKey(salted_password, result, errstr) \ + scram_ClientKey(salted_password, result) + +# endif #endif typedef struct od_scram_state od_scram_state_t;