Skip to content

Commit

Permalink
Fix windows credential cache
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jszczerbinski committed Nov 5, 2024
1 parent 3e422e3 commit cb96bf8
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 34 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ if (WIN32)
set(VSDIR "vs15" CACHE STRING "Used to specify visual studio version of libsnowflakeclient dependecies")
add_definitions(-D_CRT_SECURE_NO_DEPRECATE)
find_library(OOB_LIB libtelemetry_a.lib PATHS deps-build/${PLATFORM}/${VSDIR}/${CMAKE_BUILD_TYPE}/oob/lib/ REQUIRED NO_DEFAULT_PATH)
# find_library(CURL_LIB libcurl_a.lib PATHS deps-build/${PLATFORM}/${VSDIR}/${CMAKE_BUILD_TYPE}/curl/lib/ REQUIRED NO_DEFAULT_PATH)
find_library(CURL_LIB libcurl_a.lib PATHS deps-build/${PLATFORM}/${VSDIR}/${CMAKE_BUILD_TYPE}/curl/lib/ REQUIRED NO_DEFAULT_PATH)
find_library(SSL_LIB libssl_a.lib PATHS deps-build/${PLATFORM}/${VSDIR}/${CMAKE_BUILD_TYPE}/openssl/lib/ REQUIRED NO_DEFAULT_PATH)
find_library(CRYPTO_LIB libcrypto_a.lib PATHS deps-build/${PLATFORM}/${VSDIR}/${CMAKE_BUILD_TYPE}/openssl/lib/ REQUIRED NO_DEFAULT_PATH)
find_library(ZLIB_LIB zlib_a.lib PATHS deps-build/${PLATFORM}/${VSDIR}/${CMAKE_BUILD_TYPE}/zlib/lib/ REQUIRED NO_DEFAULT_PATH)
Expand Down
10 changes: 5 additions & 5 deletions cpp/lib/CredentialCache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ namespace sf {
};

CredentialCache *CredentialCache::make() {
#if defined(__WIN32) || defined(__APPLE__)
#if defined(_WIN32) || defined(__APPLE__)
return new SecureStorageCredentialCache();
#else
return new FileCredentialCache(std::string("file.txt"));
Expand All @@ -144,14 +144,14 @@ cred_cache_ptr cred_cache_init() {

char* cred_cache_get_credential(cred_cache_ptr tc, const char* account, const char* host, const char* user, CredentialType type)
{
sf::CredentialKey key = { .account = account, .host = host, .user = user, .type = type };
sf::CredentialKey key = { account, host, user, type };
auto tokenOpt = reinterpret_cast<sf::CredentialCache *>(tc)->get(key);
if (!tokenOpt) {
return nullptr;
}
size_t result_size = tokenOpt->size() + 1;
char* result = new char[result_size];
strncpy(result, tokenOpt->c_str(), result_size + 1);
strncpy(result, tokenOpt->c_str(), result_size);
return result;
}

Expand All @@ -161,13 +161,13 @@ void cred_cache_free_credential(char* cred) {

void cred_cache_save_credential(cred_cache_ptr tc, const char* account, const char* host, const char* user, CredentialType type, const char *cred)
{
sf::CredentialKey key = { .account = account, .host = host, .user = user, .type = type };
sf::CredentialKey key = { account, host, user, type };
reinterpret_cast<sf::CredentialCache *>(tc)->save(key, std::string(cred));
}

void cred_cache_remove_credential(cred_cache_ptr tc, const char* account, const char* host, const char* user, CredentialType type)
{
sf::CredentialKey key = { .account = account, .host = host, .user = user, .type = type };
sf::CredentialKey key = { account, host, user, type };
reinterpret_cast<sf::CredentialCache *>(tc)->remove(key);
}

Expand Down
18 changes: 10 additions & 8 deletions cpp/platform/SecureStorage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
namespace sf
{

using Snowflake::Client::SFLogger;

bool SecureStorage::storeToken(const std::string& host,
const std::string& username,
const std::string& credType,
Expand All @@ -19,11 +21,11 @@ namespace sf
SecureStorageImpl secStorage;
if (secStorage.storeToken(host, username, credType, token) != SecureStorageStatus::Success)
{
log_error("Failed to store secure token%s", "");
CXX_LOG_ERROR("Failed to store secure token");
return false;
}

log_debug("Successfully stored secure token%s", "");
CXX_LOG_DEBUG("Successfully stored secure token");
return true;
}

Expand All @@ -35,11 +37,11 @@ namespace sf
std::string result;
if (secStorage.retrieveToken(host, username, credType, result) != SecureStorageStatus::Success)
{
log_error("Failed to retrieve secure token%s", "");
CXX_LOG_ERROR("Failed to retrieve secure token");
return {};
}

log_debug("Successfully retrieved secure tokeni%s", "");
CXX_LOG_DEBUG("Successfully retrieved secure token");
return result;
}

Expand All @@ -51,11 +53,11 @@ namespace sf
SecureStorageImpl secStorage;
if ( secStorage.updateToken(host, username, credType, token) != SecureStorageStatus::Success)
{
log_error("Failed to update secure token%s", "");
CXX_LOG_ERROR("Failed to update secure token");
return false;
}

log_debug("Successfully updated secure token%s", "");
CXX_LOG_DEBUG("Successfully updated secure token");
return true;
}

Expand All @@ -66,10 +68,10 @@ namespace sf
SecureStorageImpl secStorage;
if ( secStorage.removeToken(host, username, credType) != SecureStorageStatus::Success)
{
log_error("Failed to remove secure token%s", "");
CXX_LOG_ERROR("Failed to remove secure token");
return false;
}
log_debug("Successfully removed secure token%s", "");
CXX_LOG_DEBUG("Successfully removed secure token");
return true;
}

Expand Down
44 changes: 26 additions & 18 deletions cpp/platform/SecureStorageWin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ namespace sf
const std::string& token)
{
std::string target = convertTarget(host, username, credType);
CREDENTIALW creds = { 0 };
std::wstring wide_target = std::wstring(target.begin(), target.end());

creds.TargetName = target_wcs;
creds.CredentialBlobSize = strlen(token);
creds.CredentialBlob = (LPBYTE)token;
CREDENTIALW creds = { 0 };
creds.TargetName = wide_target.data();
creds.CredentialBlobSize = token.size();
creds.CredentialBlob = (LPBYTE)token.c_str();
creds.Persist = CRED_PERSIST_LOCAL_MACHINE;
creds.Type = CRED_TYPE_GENERIC;

Expand All @@ -62,27 +63,33 @@ namespace sf
const std::string& credType,
std::string& token)
{
std::string target = convertTarget(host, username, credType, target_wcs, max_len);
std::string target = convertTarget(host, username, credType);
std::wstring wide_target = std::wstring(target.begin(), target.end());
PCREDENTIALW retcreds = nullptr;

if (!CredReadW(target_wcs, CRED_TYPE_GENERIC, 0, &retcreds))
if (!CredReadW(wide_target.data(), CRED_TYPE_GENERIC, 0, &retcreds))
{
CXX_LOG_ERROR("Failed to read target or could not find it");
return SecureStorageStatus::Error;
}
else

CXX_LOG_DEBUG("Read the token now copying it");

DWORD blobSize = retcreds->CredentialBlobSize;
if (!blobSize)
{
CXX_LOG_DEBUG("Read the token now copying it");

blobSize = retcreds->CredentialBlobSize;
if (!blobSize)
{
return SecureStorageStatus::Error;
}
strncpy_s(token, MAX_TOKEN_LEN, (char *)retcreds->CredentialBlob, size_t(blobSize)+1);
CXX_LOG_DEBUG("Read the token, copied it will return now.");
return SecureStorageStatus::Error;
}

*token_len = size_t(blobSize);
token = "";
std::copy(
retcreds->CredentialBlob,
retcreds->CredentialBlob + blobSize,
std::back_inserter(token)
);

CXX_LOG_DEBUG("Copied token");

CredFree(retcreds);
return SecureStorageStatus::Success;
}
Expand All @@ -100,8 +107,9 @@ namespace sf
const std::string& credType)
{
std::string target = convertTarget(host, username, credType);
std::wstring wide_target = std::wstring(target.begin(), target.end());

if (!CredDeleteW(target_wcs, CRED_TYPE_GENERIC, 0))
if (!CredDeleteW(wide_target.data(), CRED_TYPE_GENERIC, 0))
{
return SecureStorageStatus::Error;
}
Expand Down
4 changes: 3 additions & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ SET(TESTS_CXX
test_unit_snowflake_types_to_string
test_unit_azure_client
test_unit_query_context_cache
test_unit_sfurl)
test_unit_sfurl
test_unit_credential_cache
)

SET(TESTS_PUTGET
test_include_aws
Expand Down
7 changes: 6 additions & 1 deletion tests/test_manual_connect.c
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ void test_mfa_connect_with_duo_passcodeInPassword(void** unused)

void test_mfa_connect_with_mfa_cache(void** unused)
{
/*
* Should trigger mfa push notification at most once.
* Make sure ALLOW_CLIENT_MFA_CACHING is set to true
* For more details refer to: https://docs.snowflake.com/en/user-guide/security-mfa#using-mfa-token-caching-to-minimize-the-number-of-prompts-during-authentication-optional
*/
for (int i = 0; i < 2; i++) {
SF_CONNECT *sf = snowflake_init();
snowflake_set_attribute(sf, SF_CON_APPLICATION_NAME, "ODBC");
Expand Down Expand Up @@ -205,7 +210,7 @@ void test_none(void** unused) {}

int main(void)
{
initialize_test(SF_BOOLEAN_FALSE);
initialize_test(SF_BOOLEAN_TRUE);
struct CMUnitTest tests[1] = {
cmocka_unit_test(test_none)
};
Expand Down
29 changes: 29 additions & 0 deletions tests/test_unit_credential_cache.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
//
// Created by Jakub Szczerbinski on 05.11.24.
//
#include "lib/CredentialCache.hpp"
#include "utils/test_setup.h"

void test_credential_cache(void **unused)
{
std::unique_ptr<sf::CredentialCache> cache{sf::CredentialCache::make()};
sf::CredentialKey key { "account", "host", "user", CredentialType::MFA_TOKEN };

std::string token = "example_token";
assert_true(cache->save(key, token));
assert_true(cache->get(key).value() == token);

assert_true(cache->remove(key));
assert_false(cache->get(key).has_value());
}

int main(void) {
/* Testing only file based credential cache */
#ifndef __linux__
return 0;
#endif
const struct CMUnitTest tests[] = {
cmocka_unit_test(test_credential_cache),
};
return cmocka_run_group_tests(tests, NULL, NULL);
}

0 comments on commit cb96bf8

Please sign in to comment.