Skip to content

Commit

Permalink
Merge pull request #4485 from sysown/v2.x-caching_sha2_compress
Browse files Browse the repository at this point in the history
Fix invalid free for `caching_sha2_password` and `CLIENT_COMPRESS`
  • Loading branch information
renecannao authored Mar 28, 2024
2 parents 3187ee6 + a2f3731 commit bec1970
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 70 deletions.
12 changes: 10 additions & 2 deletions lib/mysql_data_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1328,8 +1328,16 @@ int MySQL_Data_Stream::array2buffer() {
memcpy(&queueOUT.pkt,PSarrayOUT->index(idx), sizeof(PtrSize_t));
idx++;
//VALGRIND_ENABLE_ERROR_REPORTING;
// this is a special case, needed because compression is enabled *after* the first OK
if (DSS==STATE_CLIENT_AUTH_OK) {
// This is a special case, needed because compression is enabled *after* the first OK. In
// case of 'caching_sha2_password', not only the first packet needs to be processed, since
// there are other scenarios in which one extra byte is sent prior to the final OK packet
// flagging auth success. The generation of these extra packets should all be queued at
// the same time, since they represent the final client response. Right now this is
// handled during 'MySQL_Session::handler___status_CONNECTING_CLIENT___STATE_SERVER_HANDSHAKE'.
// Because of this, we can make the assumption that once we have sent all the packets
// currently in 'PSarrayOUT', it's safe to change the 'DSS' status, and enable compression
// if connections requires it.
if (DSS==STATE_CLIENT_AUTH_OK && idx == PSarrayOUT->len) {
DSS=STATE_SLEEP;
// enable compression
if (myconn->options.server_capabilities & CLIENT_COMPRESS) {
Expand Down
89 changes: 21 additions & 68 deletions test/tap/tests/test_auth_methods-t.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,69 +98,6 @@ bool match_pass(const char* p1, const char* p2) {
}
}

// TODO: Refactor
///////////////////////////////////////////////////////////////////////////////

int get_query_result(MYSQL* mysql, const string& query, uint64_t& out_val) {
int rc = mysql_query(mysql, query.c_str());
if (rc != EXIT_SUCCESS) {
fprintf(stderr, "File %s, line %d, Error: %s\n", __FILE__, __LINE__, mysql_error(mysql));
return EXIT_FAILURE;
}

MYSQL_RES* myres = mysql_store_result(mysql);
if (myres == nullptr) {
fprintf(stderr, "File %s, line %d, Error: %s\n", __FILE__, __LINE__, mysql_error(mysql));
return EXIT_FAILURE;
}

MYSQL_ROW row = mysql_fetch_row(myres);
if (row == nullptr || row[0] == nullptr) {
fprintf(stderr, "File %s, line %d, Error: %s\n", __FILE__, __LINE__, "Received empty row");
return EXIT_FAILURE;
}

out_val = std::stol(row[0]);

mysql_free_result(myres);

return EXIT_SUCCESS;
}

int wait_target_backend_conns(MYSQL* admin, uint32_t tg_conns, uint32_t timeout, int32_t hg=-1) {
const string query_select { "SELECT SUM(ConnFree + ConnUsed) FROM stats_mysql_connection_pool" };
const string query_where { hg == -1 ? "" : " WHERE hostgroup=" + std::to_string(hg) };
const string query { query_select + query_where };

uint32_t waited = 0;

while (waited < timeout) {
uint64_t conns_count = 0;
int q_res = get_query_result(admin, query.c_str(), conns_count);

if (q_res != EXIT_SUCCESS) {
diag("Failed getting conn stats query:`%s`,error:`%s`", query.c_str(), mysql_error(admin));
return -1;
}

if (conns_count == tg_conns) {
diag("Reached target conn count tg_conns:'%d',conns_count:'%ld'", tg_conns, conns_count);
break;
} else {
waited += 1;
diag(
"Conn count yet unmatched tg_conns:'%d',conns_count:'%ld',checks:'%u'",
tg_conns, conns_count, waited
);
sleep(1);
}
}

return waited < timeout ? 0 : -2;
}

///////////////////////////////////////////////////////////////////////////////

std::string unhex(const std::string& hex) {
if (hex.size() % 2) { return {}; };

Expand Down Expand Up @@ -345,6 +282,8 @@ struct test_conf_t {
bool hashed_pass;
/* @brief Wether to attempt auth under SSL conn or not. */
bool use_ssl;
/* @brief Wether to attempt auth with compression enabled or not. */
bool use_comp;
};

struct sess_info_t {
Expand Down Expand Up @@ -1052,15 +991,18 @@ vector<test_conf_t> get_conf_combs(
const vector<string>& def_auths,
const vector<string>& req_auths,
const vector<bool>& hash_pass,
const vector<bool>& use_ssl
const vector<bool>& use_ssl,
const vector<bool>& use_comp
) {
vector<test_conf_t> confs {};

for (const auto& def_auth : def_auths) {
for (const auto& req_auth : req_auths) {
for (const auto& hashed : hash_pass) {
for (const auto& ssl : use_ssl) {
confs.push_back({def_auth, req_auth, hashed, ssl});
for (const auto& comp : use_comp) {
confs.push_back({def_auth, req_auth, hashed, ssl, comp});
}
}
}
}
Expand Down Expand Up @@ -1156,13 +1098,17 @@ int config_mysql_conn(const CommandLine& cl, const test_conf_t& conf, MYSQL* pro

if (conf.use_ssl) {
mysql_ssl_set(proxy, NULL, NULL, NULL, NULL, NULL);
cflags = CLIENT_SSL;
cflags |= CLIENT_SSL;

if (getenv("SSLKEYLOGFILE") && F_SSLKEYLOGFILE) {
mysql_options(proxy, MARIADB_OPT_SSL_KEYLOG_CALLBACK, reinterpret_cast<void*>(ssl_keylog_callback));
}
}

if (conf.use_comp) {
cflags |= CLIENT_COMPRESS;
}

return cflags;
}

Expand Down Expand Up @@ -1450,7 +1396,11 @@ int backend_conns_cleanup(MYSQL* admin) {
MYSQL_QUERY(admin, "LOAD MYSQL SERVERS TO RUNTIME");

// Wait for backend connection cleanup
int w_res = wait_target_backend_conns(admin, 0, 10, TAP_MYSQL8_BACKEND_HG);
const string check_conn_cleanup {
"SELECT IIF((SELECT SUM(ConnUsed + ConnFree) FROM stats.stats_mysql_connection_pool"
" WHERE hostgroup=" + std::to_string(TAP_MYSQL8_BACKEND_HG) + ")=0, 'TRUE', 'FALSE')"
};
int w_res = wait_for_cond(admin, check_conn_cleanup, 10);
if (w_res != EXIT_SUCCESS) {
diag("Waiting for backend connections failed res:'%d'", w_res);
return EXIT_FAILURE;
Expand Down Expand Up @@ -1745,9 +1695,12 @@ int main(int argc, char** argv) {
};
const vector<bool> hash_pass { false, true };
const vector<bool> use_ssl { false, true };
const vector<bool> use_comp { false, true };

// Sequential access tests; exercising full logic
const vector<test_conf_t> all_conf_combs { get_conf_combs(def_auths, req_auhts, hash_pass, use_ssl) };
const vector<test_conf_t> all_conf_combs {
get_conf_combs(def_auths, req_auhts, hash_pass, use_ssl, use_comp)
};
const auto scs_stats { count_exp_scs(all_conf_combs, cbres.second, tests_creds) };

pair<uint64_t,uint64_t> rnd_scs_stats {};
Expand Down

0 comments on commit bec1970

Please sign in to comment.