diff --git a/src/metal_sdk/motorhead.py b/src/metal_sdk/motorhead.py index d3b92dc..13154cc 100644 --- a/src/metal_sdk/motorhead.py +++ b/src/metal_sdk/motorhead.py @@ -1,15 +1,16 @@ import httpx +from .typings import MotorheadPayload API_URL = 'https://api.getmetal.io/v1/motorhead' class Motorhead: - def __init__(self, api_key=None, client_id=None, base_url=API_URL): - self.api_key = api_key - self.client_id = client_id - self.base_url = base_url + def __init__(self, payload: MotorheadPayload = {}): + self.api_key = payload.get("api_key") + self.client_id = payload.get("client_id") + self.base_url = payload.get("base_url") or API_URL - if base_url == API_URL and not (api_key and client_id): + if self.base_url == API_URL and not (self.api_key and self.client_id): raise ValueError('api_key and client_id required for managed motorhead') self.client = httpx.Client(headers={ diff --git a/src/metal_sdk/typings.py b/src/metal_sdk/typings.py index 51766af..190a4aa 100644 --- a/src/metal_sdk/typings.py +++ b/src/metal_sdk/typings.py @@ -50,3 +50,9 @@ class TunePayload(TypedDict): idA: str idB: str label: TuneLabel + + +class MotorheadPayload(TypedDict): + api_key: NotRequired[str] + client_id: NotRequired[str] + base_url: NotRequired[str] diff --git a/tests/test_motorhead.py b/tests/test_motorhead.py index 56b9170..b9bfbff 100644 --- a/tests/test_motorhead.py +++ b/tests/test_motorhead.py @@ -7,7 +7,7 @@ class TestMotorhead(unittest.TestCase): def setUp(self): - self.motorhead = Motorhead(api_key='test_key', client_id='test_client') + self.motorhead = Motorhead({"api_key": "test_key", "client_id": "test_client"}) def test_initialization(self): self.assertEqual(self.motorhead.api_key, 'test_key') diff --git a/tests/test_motorhead_async.py b/tests/test_motorhead_async.py index 96f2cb6..d3ac321 100644 --- a/tests/test_motorhead_async.py +++ b/tests/test_motorhead_async.py @@ -8,7 +8,7 @@ class TestMotorheadAsync(TestCase): async def test_initialization(self): - m = Motorhead(api_key=API_KEY, client_id=CLIENT_ID) + m = Motorhead({"api_key": API_KEY, "client_id": CLIENT_ID}) self.assertEqual(m.api_key, API_KEY) self.assertEqual(m.client_id, CLIENT_ID) @@ -16,21 +16,21 @@ async def test_initialization(self): Motorhead() async def test_add_memory(self): - m = Motorhead(api_key=API_KEY, client_id=CLIENT_ID) + m = Motorhead({"api_key": API_KEY, "client_id": CLIENT_ID}) m.client.post = mock.MagicMock(return_value=mock.Mock(status_code=201)) memory = await m.add_memory('test_session', {'key': 'value'}) self.assertEqual(memory, 'mock_memory') async def test_get_memory(self): - m = Motorhead(api_key=API_KEY, client_id=CLIENT_ID) + m = Motorhead({"api_key": API_KEY, "client_id": CLIENT_ID}) m.client.get = mock.MagicMock(return_value=mock.Mock(status_code=200)) memory = await m.get_memory('test_session') self.assertEqual(memory, 'mock_memory') async def test_delete_memory(self): - m = Motorhead(api_key=API_KEY, client_id=CLIENT_ID) + m = Motorhead({"api_key": API_KEY, "client_id": CLIENT_ID}) m.client.delete = mock.MagicMock(return_value=mock.Mock(status_code=204)) memory = await m.delete_memory('test_session')