Skip to content

Commit

Permalink
Merge pull request #21 from getmetal/jo/normalize-motorhead
Browse files Browse the repository at this point in the history
standardize to the payload pattern
  • Loading branch information
softboyjimbo authored May 30, 2023
2 parents 5d6d0aa + bf5b057 commit a51ca27
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 10 deletions.
11 changes: 6 additions & 5 deletions src/metal_sdk/motorhead.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import httpx
from .typings import MotorheadPayload

API_URL = 'https://api.getmetal.io/v1/motorhead'


class Motorhead:
def __init__(self, api_key=None, client_id=None, base_url=API_URL):
self.api_key = api_key
self.client_id = client_id
self.base_url = base_url
def __init__(self, payload: MotorheadPayload = {}):
self.api_key = payload.get("api_key")
self.client_id = payload.get("client_id")
self.base_url = payload.get("base_url") or API_URL

if base_url == API_URL and not (api_key and client_id):
if self.base_url == API_URL and not (self.api_key and self.client_id):
raise ValueError('api_key and client_id required for managed motorhead')

self.client = httpx.Client(headers={
Expand Down
6 changes: 6 additions & 0 deletions src/metal_sdk/typings.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,9 @@ class TunePayload(TypedDict):
idA: str
idB: str
label: TuneLabel


class MotorheadPayload(TypedDict):
api_key: NotRequired[str]
client_id: NotRequired[str]
base_url: NotRequired[str]
2 changes: 1 addition & 1 deletion tests/test_motorhead.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class TestMotorhead(unittest.TestCase):

def setUp(self):
self.motorhead = Motorhead(api_key='test_key', client_id='test_client')
self.motorhead = Motorhead({"api_key": "test_key", "client_id": "test_client"})

def test_initialization(self):
self.assertEqual(self.motorhead.api_key, 'test_key')
Expand Down
8 changes: 4 additions & 4 deletions tests/test_motorhead_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,29 @@

class TestMotorheadAsync(TestCase):
async def test_initialization(self):
m = Motorhead(api_key=API_KEY, client_id=CLIENT_ID)
m = Motorhead({"api_key": API_KEY, "client_id": CLIENT_ID})
self.assertEqual(m.api_key, API_KEY)
self.assertEqual(m.client_id, CLIENT_ID)

with self.assertRaises(ValueError):
Motorhead()

async def test_add_memory(self):
m = Motorhead(api_key=API_KEY, client_id=CLIENT_ID)
m = Motorhead({"api_key": API_KEY, "client_id": CLIENT_ID})
m.client.post = mock.MagicMock(return_value=mock.Mock(status_code=201))

memory = await m.add_memory('test_session', {'key': 'value'})
self.assertEqual(memory, 'mock_memory')

async def test_get_memory(self):
m = Motorhead(api_key=API_KEY, client_id=CLIENT_ID)
m = Motorhead({"api_key": API_KEY, "client_id": CLIENT_ID})
m.client.get = mock.MagicMock(return_value=mock.Mock(status_code=200))

memory = await m.get_memory('test_session')
self.assertEqual(memory, 'mock_memory')

async def test_delete_memory(self):
m = Motorhead(api_key=API_KEY, client_id=CLIENT_ID)
m = Motorhead({"api_key": API_KEY, "client_id": CLIENT_ID})
m.client.delete = mock.MagicMock(return_value=mock.Mock(status_code=204))

memory = await m.delete_memory('test_session')
Expand Down

0 comments on commit a51ca27

Please sign in to comment.