Skip to content

Commit

Permalink
Added Mock or keyring to run UTs on GitHub action
Browse files Browse the repository at this point in the history
  • Loading branch information
thijs-nijhuis committed Mar 1, 2024
1 parent a5edcda commit f3fe79c
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 6 deletions.
6 changes: 3 additions & 3 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def set_sharded_password(self, service_name: str, username: str, password: str)
keyring.set_password(service_name, username, json.dumps(shard_info))
# then store all shards with the shard number as postfix
for i, s in enumerate(password_shards):
keyring.set_password(service_name, f"{username}|{i}", s)
keyring.set_password(service_name, f"{username}__{i}", s)

def get_sharded_password(self, service_name: str, username: str) -> Optional[str]:
password = keyring.get_password(service_name, username)
Expand All @@ -459,7 +459,7 @@ def get_sharded_password(self, service_name: str, username: str) -> Optional[str

password = ""
for i in range(shard_count):
password += str(keyring.get_password(service_name, f"{username}|{i}"))
password += str(keyring.get_password(service_name, f"{username}__{i}"))
except ValueError:
pass

Expand All @@ -474,7 +474,7 @@ def delete_sharded_password(self, service_name: str, username: str) -> None:
if password_as_dict.get("sharded_password"):
shard_count = int(password_as_dict.get("shard_count"))
for i in range(shard_count):
keyring.delete_password(service_name, f"{username}|{i}")
keyring.delete_password(service_name, f"{username}__{i}")
except ValueError:
pass

Expand Down
69 changes: 66 additions & 3 deletions tests/unit/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
from dbt.adapters.databricks.connections import DatabricksCredentials
import keyring.backend
import pytest


Expand Down Expand Up @@ -74,7 +75,10 @@ def test_token(self):


class TestShardedPassword(unittest.TestCase):
def test_store_short_password(self):
def test_store_and_delete_short_password(self):
# set the keyring to mock class
keyring.set_keyring(MockKeyring())

service = "dbt-databricks"
host = "my.cloud.databricks.com"
long_password = "x" * 10
Expand All @@ -87,10 +91,18 @@ def test_store_short_password(self):
retrieved_password = creds.get_sharded_password(service, host)
self.assertEqual(long_password, retrieved_password)

def test_store_long_password(self):
# delete password
creds.delete_sharded_password(service, host)
retrieved_password = creds.get_sharded_password(service, host)
self.assertIsNone(retrieved_password)

def test_store_and_delete_long_password(self):
# set the keyring to mock class
keyring.set_keyring(MockKeyring())

service = "dbt-databricks"
host = "my.cloud.databricks.com"
long_password = "x" * 2000
long_password = "x" * 3000

creds = DatabricksCredentials(
host=host, token="foo", database="andre", http_path="http://foo", schema="dbt"
Expand All @@ -99,3 +111,54 @@ def test_store_long_password(self):

retrieved_password = creds.get_sharded_password(service, host)
self.assertEqual(long_password, retrieved_password)

# delete password
creds.delete_sharded_password(service, host)
retrieved_password = creds.get_sharded_password(service, host)
self.assertIsNone(retrieved_password)


class MockKeyring(keyring.backend.KeyringBackend):
def __init__(self):
self.file_location = self._generate_test_root_dir()

def priority(self):
return 1

def _generate_test_root_dir(self):
import tempfile
return tempfile.mkdtemp(prefix="dbt-unit-test-")

def file_path(self, servicename, username):
from os.path import join

file_location = self.file_location
file_name = f"{servicename}_{username}.txt"
return join(file_location, file_name)

def set_password(self, servicename, username, password):
file_path = self.file_path(servicename, username)

with open(file_path, "w") as file:
file.write(password)

def get_password(self, servicename, username):
import os

file_path = self.file_path(servicename, username)
if not os.path.exists(file_path):
return None

with open(file_path, "r") as file:
password = file.read()

return password

def delete_password(self, servicename, username):
import os

file_path = self.file_path(servicename, username)
if not os.path.exists(file_path):
return None

os.remove(file_path)

0 comments on commit f3fe79c

Please sign in to comment.