Skip to content

Commit

Permalink
Merge pull request #4479 from sysown/v2.x-sqlite3_pass_exts-2
Browse files Browse the repository at this point in the history
Add new SQLite3 functions for password hash generation
  • Loading branch information
renecannao authored Mar 28, 2024
2 parents b5b49c2 + b656987 commit 2df30b6
Show file tree
Hide file tree
Showing 12 changed files with 976 additions and 76 deletions.
1 change: 1 addition & 0 deletions deps/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ sqlite3/sqlite3/sqlite3.o:
cd sqlite3 && tar -zxf sqlite-amalgamation-*.tar.gz
cd sqlite3/sqlite3 && patch -p1 < ../from_unixtime.patch
cd sqlite3/sqlite3 && patch sqlite3.c < ../sqlite3.c-multiplication-overflow.patch
cd sqlite3/sqlite3 && patch -p0 < ../sqlite3_pass_exts.patch
cd sqlite3/sqlite3 && ${CC} ${MYCFLAGS} -fPIC -c -o sqlite3.o sqlite3.c -DSQLITE_ENABLE_MEMORY_MANAGEMENT -DSQLITE_ENABLE_JSON1 -DSQLITE_DLL=1
cd sqlite3/sqlite3 && ${CC} -shared -o libsqlite3.so sqlite3.o

Expand Down
196 changes: 196 additions & 0 deletions deps/sqlite3/sqlite3_pass_exts.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
--- sqlite3.c 2024-03-22 19:22:47.046093173 +0100
+++ sqlite3-pass-exts.c 2024-03-22 19:24:09.557303716 +0100
@@ -25168,6 +25168,183 @@
sqlite3ResultStrAccum(context, &sRes);
}

+#define DEF_SALT_SIZE 20
+#define SHA_DIGEST_LENGTH 20
+
+/// Forward declarations
+////////////////////////////////////////////////////////////////////////////////
+
+// ctype.h
+extern int toupper (int __c) __THROW;
+
+// SHA256_crypt
+char * sha256_crypt_r (const char *key, const char *salt, char *buffer, int buflen);
+
+// OpenSSL
+unsigned char *SHA1(const unsigned char *d, size_t n, unsigned char *md);
+int RAND_bytes(unsigned char *buf, int num);
+unsigned long ERR_get_error(void);
+char *ERR_error_string(unsigned long e, char *buf);
+
+////////////////////////////////////////////////////////////////////////////////
+
+int check_args_types(int argc, sqlite3_value** argv) {
+ int inv_type = sqlite3_value_type(argv[0]) != SQLITE_TEXT;
+
+ if (inv_type == 0 && argc == 2) {
+ return
+ sqlite3_value_type(argv[1]) != SQLITE_TEXT &&
+ sqlite3_value_type(argv[1]) != SQLITE_BLOB;
+ } else {
+ return inv_type;
+ }
+}
+
+int check_args_lengths(int argc, sqlite3_value** argv) {
+ int inv_size = 1;
+
+ int pass_size = sqlite3_value_bytes(argv[0]);
+ if (pass_size > 0) {
+ inv_size = 0;
+ }
+
+ if (inv_size == 0 && argc == 2) {
+ int salt_size = sqlite3_value_bytes(argv[1]);
+
+ return salt_size <= 0 || salt_size > DEF_SALT_SIZE;
+ } else {
+ return inv_size;
+ }
+}
+
+/**
+ * @brief SQLite3 extension function for hash generation.
+ * @details Computes a hash equivalent to the one generated by MySQL for 'mysql_native_password'.
+ * @param context SQLite3 context used for returning computation result.
+ * @param argc Number of arguments; expected to be 1.
+ * @param argv Argument list; expected to hold one argument with len > 0 of type 'SQLITE_TEXT'.
+ */
+static void mysql_native_passwordFunc(sqlite3_context* context, int argc, sqlite3_value** argv) {
+ if (argc != 1) {
+ sqlite3_result_text(context, "Invalid number of arguments", -1, SQLITE_TRANSIENT);
+ return;
+ } else {
+ if (check_args_types(argc, argv)) {
+ sqlite3_result_text(context, "Invalid argument type", -1, SQLITE_TRANSIENT);
+ return;
+ }
+ if (check_args_lengths(argc, argv)) {
+ sqlite3_result_text(context, "Invalid argument size", -1, SQLITE_TRANSIENT);
+ return;
+ }
+ }
+
+ const unsigned char* input = sqlite3_value_text(argv[0]);
+ int input_len = strlen((const char*)input);
+
+ unsigned char hash1[SHA_DIGEST_LENGTH] = { 0 };
+ unsigned char hash2[SHA_DIGEST_LENGTH] = { 0 };
+
+ SHA1(input, input_len, hash1);
+ SHA1(hash1, SHA_DIGEST_LENGTH, hash2);
+
+ char hex_hash[2 * SHA_DIGEST_LENGTH + 2];
+ unsigned int i = 0;
+
+ for (i = 0; i < SHA_DIGEST_LENGTH; i++) {
+ sprintf(hex_hash + 2 * i + 1, "%02x", hash2[i]);
+
+ hex_hash[2 * i + 1] = toupper(hex_hash[2 * i + 1]);
+ hex_hash[2 * i + 1 + 1] = toupper(hex_hash[2 * i + 1 + 1]);
+ }
+
+ hex_hash[0] = '*';
+ hex_hash[2 * SHA_DIGEST_LENGTH + 1] = '\0';
+
+ sqlite3_result_text(context, hex_hash, -1, SQLITE_TRANSIENT);
+}
+
+/**
+ * @brief SQLite3 extension function for hash generation.
+ * @details Computes a hash equivalent to the one generated by MySQL for 'caching_sha2_password'.
+ * @param context SQLite3 context used for returning computation result.
+ * @param argc Number of arguments; either 1 or 2. One for random salt, two providing salt.
+ * @param argv Argument list; expected to hold either 1 or 2 arguments:
+ * 1. Password to be hashed; with len > 0 and of type 'SQLITE_TEXT'.
+ * 1. Optional salt; with (len > 0 && len <= 20) and of type ('SQLITE_TEXT' || 'SQLITE_BLOB'). If no salt is
+ * provided a randomly generated salt with length 20 will be used.
+ */
+static void caching_sha2_passwordFunc(sqlite3_context* context, int argc, sqlite3_value** argv) {
+ if (argc < 1 || argc > 2) {
+ sqlite3_result_text(context, "Invalid number of arguments", -1, SQLITE_TRANSIENT);
+ return;
+ } else {
+ if (check_args_types(argc, argv)) {
+ sqlite3_result_text(context, "Invalid argument type", -1, SQLITE_TRANSIENT);
+ return;
+ }
+ if (check_args_lengths(argc, argv)) {
+ sqlite3_result_text(context, "Invalid argument size", -1, SQLITE_TRANSIENT);
+ return;
+ }
+ }
+
+ unsigned int salt_size = DEF_SALT_SIZE;
+ const char* cpass = (const char*)sqlite3_value_text(argv[0]);
+ unsigned char salt[DEF_SALT_SIZE + 1] = { 0 };
+
+ if (argc == 2) {
+ salt_size = sqlite3_value_bytes(argv[1]);
+ const void* b_salt = sqlite3_value_blob(argv[1]);
+
+ memcpy(salt, b_salt, salt_size);
+ } else {
+ unsigned char salt_buf[DEF_SALT_SIZE + 1] = { 0 };
+
+ if (RAND_bytes(salt_buf, DEF_SALT_SIZE) != 1) {
+ const char t_msg[] = { "SALT creation failed (%lu:'%s')" };
+ char err_buf[256] = { 0 };
+ char err_msg[sizeof(err_buf)/sizeof(char) + sizeof(t_msg)/sizeof(char) + 20] = { 0 };
+
+ const unsigned long err = ERR_get_error();
+ ERR_error_string(err, err_buf);
+
+ sprintf(err_msg, t_msg, err, err_buf);
+ sqlite3_result_text(context, err_msg, -1, SQLITE_TRANSIENT);
+ return;
+ } else {
+ unsigned int i = 0;
+
+ for (i = 0; i < sizeof(salt_buf)/sizeof(unsigned char); i++) {
+ salt_buf[i] = salt_buf[i] & 0x7f;
+
+ if (salt_buf[i] == '\0' || salt_buf[i] == '$') {
+ salt_buf[i] = salt_buf[i] + 1;
+ }
+ }
+
+ memcpy(salt, salt_buf, salt_size);
+ }
+ }
+
+ #define BASE_SHA2_SALT "$5$rounds=5000$"
+ #define BASE_SHA2_HASH "$A$005$"
+
+ char sha2_buf[100] = { 0 };
+ char sha2_salt[100] = { BASE_SHA2_SALT };
+
+ strcat(sha2_salt, (const char*)salt);
+ sha256_crypt_r(cpass, sha2_salt, sha2_buf, sizeof(sha2_buf));
+
+ char sha2_hash[100] = { BASE_SHA2_HASH };
+ const char* sha256 = sha2_buf + salt_size + strlen(BASE_SHA2_SALT) + 1;
+
+ strcat(sha2_hash, (const char*)salt);
+ strcat(sha2_hash, sha256);
+
+ sqlite3_result_text(context, sha2_hash, -1, SQLITE_TRANSIENT);
+}
+
/*
** current_time()
**
@@ -129263,6 +129440,9 @@
FUNCTION(substr, 3, 0, 0, substrFunc ),
FUNCTION(substring, 2, 0, 0, substrFunc ),
FUNCTION(substring, 3, 0, 0, substrFunc ),
+ FUNCTION(mysql_native_password, 1, 0, 0, mysql_native_passwordFunc ),
+ FUNCTION(caching_sha2_password, 1, 0, 0, caching_sha2_passwordFunc ),
+ FUNCTION(caching_sha2_password, 2, 0, 0, caching_sha2_passwordFunc ),
WAGGREGATE(sum, 1,0,0, sumStep, sumFinalize, sumFinalize, sumInverse, 0),
WAGGREGATE(total, 1,0,0, sumStep,totalFinalize,totalFinalize,sumInverse, 0),
WAGGREGATE(avg, 1,0,0, sumStep, avgFinalize, avgFinalize, sumInverse, 0),
2 changes: 1 addition & 1 deletion lib/MySQL_Protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ extern ClickHouse_Authentication *GloClickHouseAuth;
#include "proxysql_find_charset.h"


char * sha256_crypt_r (const char *key, const char *salt, char *buffer, int buflen);
extern "C" char * sha256_crypt_r (const char *key, const char *salt, char *buffer, int buflen);

static const char *plugins[3] = {
"mysql_native_password",
Expand Down
2 changes: 1 addition & 1 deletion lib/sha256crypt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ static const char b64t[65] =
"./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";


char * sha256_crypt_r (const char *key, const char *salt, char *buffer, int buflen)
extern "C" char * sha256_crypt_r (const char *key, const char *salt, char *buffer, int buflen)
{
unsigned char alt_result[32]
__attribute__ ((__aligned__ (__alignof__ (uint32_t))));
Expand Down
13 changes: 11 additions & 2 deletions test/tap/tap/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ endif

OPT := $(STDCPP) -O2 -ggdb -Wl,--no-as-needed $(WASAN)

# NOTE-LWGCOV (LinkWithGCOV):
# Linking against GCOV is required when ProxySQL is build with support for it. This is because
# 'sha256crypt.oo' is being used for 'libtap.a'. This requisite is imposed due to 'sha256_crypt_r'
# being used inside ProxySQL linked 'SQLite3', which is also used by `libtap.so`.
LWGCOV :=
ifeq ($(WITHGCOV),1)
LWGCOV := -lgcov
endif


### main targets

Expand All @@ -85,10 +94,10 @@ tap.o: tap.cpp cpp-dotenv/static/cpp-dotenv/libcpp_dotenv.a libcurl.so libssl.so
$(CXX) -fPIC -c tap.cpp $(IDIRS) $(OPT)

libtap.a: tap.o command_line.o utils.o cpp-dotenv/static/cpp-dotenv/libcpp_dotenv.a
ar rcs libtap.a tap.o command_line.o utils.o $(SQLITE3_LDIR)/sqlite3.o
ar rcs libtap.a tap.o command_line.o utils.o $(SQLITE3_LDIR)/sqlite3.o $(PROXYSQL_LDIR)/obj/sha256crypt.oo

libtap.so: libtap.a cpp-dotenv/dynamic/cpp-dotenv/libcpp_dotenv.so
$(CXX) -shared -o libtap.so -Wl,--whole-archive libtap.a -Wl,--no-whole-archive
$(CXX) -shared -o libtap.so -Wl,--whole-archive libtap.a -Wl,--no-whole-archive $(LWGCOV)


### tap deps targets
Expand Down
99 changes: 62 additions & 37 deletions test/tap/tap/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,63 @@ std::vector<mysql_res_row> extract_mysql_rows(MYSQL_RES* my_res) {
return result;
};

pair<uint32_t,vector<mysql_res_row>> mysql_query_ext_rows(MYSQL* mysql, const string& query) {
int rc = mysql_query(mysql, query.c_str());
if (rc != EXIT_SUCCESS) {
return { mysql_errno(mysql), {} };
}

MYSQL_RES* myres = mysql_store_result(mysql);
if (myres == nullptr) {
return { mysql_errno(mysql), {} };
}

const vector<mysql_res_row> rows { extract_mysql_rows(myres) };
mysql_free_result(myres);

return { EXIT_SUCCESS, rows };
}

ext_val_t<string> ext_single_row_val(const mysql_res_row& row, const string& def_val) {
if (row.empty() || row.front().empty()) {
return { -1, def_val, {} };
} else {
return { EXIT_SUCCESS, string { row[0] }, string { row[0] } };
}
}

ext_val_t<int64_t> ext_single_row_val(const mysql_res_row& row, const int64_t& def_val) {
if (row.empty() || row.front().empty()) {
return { -1, def_val, {} };
} else {
errno = 0;
char* p_end {};
const int64_t val = std::strtoll(row.front().c_str(), &p_end, 10);

if (row[0] == p_end || errno == ERANGE) {
return { -2, def_val, string { row[0] } };
} else {
return { EXIT_SUCCESS, val, string { row[0] } };
}
}
}

ext_val_t<uint64_t> ext_single_row_val(const mysql_res_row& row, const uint64_t& def_val) {
if (row.empty() || row.front().empty()) {
return { -1, def_val, {} };
} else {
errno = 0;
char* p_end {};
const uint64_t val = std::strtoll(row.front().c_str(), &p_end, 10);

if (row[0] == p_end || errno == ERANGE) {
return { -2, def_val, string { row[0] } };
} else {
return { EXIT_SUCCESS, val, string { row[0] } };
}
}
}

struct memory {
char* data;
size_t size;
Expand Down Expand Up @@ -1078,53 +1135,21 @@ int execute_eof_test(const CommandLine& cl, MYSQL* mysql, const string& test, co
}

int get_cur_backend_conns(MYSQL* proxy_admin, const string& conn_type, uint32_t& found_conn_num) {
MYSQL_QUERY(proxy_admin, string {"SELECT " + conn_type + " FROM stats_mysql_connection_pool"}.c_str());
MYSQL_QUERY(proxy_admin, string {"SELECT SUM(" + conn_type + ") FROM stats_mysql_connection_pool"}.c_str());

MYSQL_ROW row = nullptr;
MYSQL_RES* my_res = mysql_store_result(proxy_admin);
uint32_t field_num = mysql_num_fields(my_res);
vector<uint32_t> connfree_vals {};
MYSQL_RES* myres = mysql_store_result(proxy_admin);
uint32_t field_num = mysql_num_fields(myres);

if (field_num != 1) {
diag("Invalid number of columns in resulset from 'stats_mysql_connection_pool': %d", field_num);
} else {
found_conn_num = std::strtol(row[0], NULL, 10);
}

if (my_res != nullptr) {
while ((row = mysql_fetch_row(my_res))) {
connfree_vals.push_back(std::strtol(row[0], NULL, 10));
}
}
mysql_free_result(my_res);

found_conn_num = std::accumulate(connfree_vals.begin(), connfree_vals.end(), 0, std::plus<uint32_t>());
return EXIT_SUCCESS;
}

int wait_for_backend_conns(
MYSQL* proxy_admin, const string& conn_type, uint32_t exp_conn_num, uint32_t timeout
) {
uint32_t total_conn_num = 0;
uint32_t waited = 0;

while (waited < timeout) {
int get_err = get_cur_backend_conns(proxy_admin, conn_type, total_conn_num);
if (get_err != EXIT_SUCCESS) { return EXIT_FAILURE; }

if (total_conn_num == exp_conn_num) {
break;
} else {
sleep(1);
waited += 1;
}
}

if (waited >= timeout) {
return EXIT_FAILURE;
} else {
return EXIT_SUCCESS;
}
}

string join_path(const string& p1, const string& p2) {
if (p1.back() == '/') {
return p1 + p2;
Expand Down
Loading

0 comments on commit 2df30b6

Please sign in to comment.