Skip to content

Commit

Permalink
Merge pull request #25 from getmetal/task/sp/allow-for-filter-only-qu…
Browse files Browse the repository at this point in the history
…eries

Allow for filters only queries
  • Loading branch information
Czechh authored Jun 8, 2023
2 parents 2fe79fb + 896aa0a commit 8e669c4
Show file tree
Hide file tree
Showing 10 changed files with 255 additions and 88 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install httpx
pip install httpx respx
pip install typing_extensions
- name: Test with unittest
run: |
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "metal_sdk"
version = "1.0.7"
version = "1.0.8"
authors = [
{ name="Metal Technologies Inc", email="[email protected]" },
]
Expand Down
9 changes: 6 additions & 3 deletions src/metal_sdk/metal.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __getData(self, index, payload: dict = {}):

return data

def __validateIndexAndSearch(self, index=None, payload={}):
def __validateIndex(self, index=None, payload={}):
if index is None:
raise TypeError("index_id required")

Expand All @@ -64,7 +64,7 @@ def __validateIndexAndSearch(self, index=None, payload={}):

def index(self, payload: IndexPayload = {}, index_id=None):
index = self.index_id or index_id
self.__validateIndexAndSearch(index, payload)
self.__validateIndex(index, payload)
data = self.__getData(index, payload)
url = "/v1/index"

Expand All @@ -84,7 +84,10 @@ def search(
self, payload: SearchPayload = {}, index_id=None, ids_only=False, limit=1
):
index = index_id or self.index_id
self.__validateIndexAndSearch(index, payload)

if index is None:
raise TypeError("index_id required")

data = self.__getData(index, payload)

url = "/v1/search?limit=" + str(limit)
Expand Down
9 changes: 6 additions & 3 deletions src/metal_sdk/metal_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __getData(self, index, payload: dict = {}):

return data

def __validateIndexAndSearch(self, index=None, payload={}):
def __validateIndex(self, index=None, payload={}):
if index is None:
raise TypeError("index_id required")

Expand All @@ -64,7 +64,7 @@ def __validateIndexAndSearch(self, index=None, payload={}):

async def index(self, payload: IndexPayload = {}, index_id=None):
index = self.index_id or index_id
self.__validateIndexAndSearch(index, payload)
self.__validateIndex(index, payload)
data = self.__getData(index, payload)
url = "/v1/index"

Expand All @@ -84,7 +84,10 @@ async def search(
self, payload: SearchPayload = {}, index_id=None, ids_only=False, limit=1
):
index = index_id or self.index_id
self.__validateIndexAndSearch(index, payload)

if index is None:
raise TypeError("index_id required")

data = self.__getData(index, payload)

url = "/v1/search?limit=" + str(limit)
Expand Down
38 changes: 24 additions & 14 deletions src/metal_sdk/motorhead.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,53 @@
import httpx
from .typings import MotorheadPayload

API_URL = 'https://api.getmetal.io/v1/motorhead'
API_URL = 'https://api.getmetal.io/v1/motorhead/'


class Motorhead:
class Motorhead(httpx.Client):
def __init__(self, payload: MotorheadPayload = {}):
super().__init__()
self.api_key = payload.get("api_key")
self.client_id = payload.get("client_id")
self.base_url = payload.get("base_url") or API_URL

if self.base_url == API_URL and not (self.api_key and self.client_id):
has_api_key = self.api_key is not None
has_client_id = self.client_id is not None
has_key_and_id = has_api_key and has_client_id

if self.base_url == API_URL and not has_key_and_id:
raise ValueError('api_key and client_id required for managed motorhead')

self.client = httpx.Client(headers={
self.headers.update({
'Content-Type': 'application/json',
'x-metal-api-key': self.api_key,
'x-metal-client-id': self.client_id,
})

def request(self, method, url, *args, **kwargs):
return super().request(method, url, *args, **kwargs)

def add_memory(self, sessionId, payload):
response = self.client.post(f'{self.base_url}/sessions/{sessionId}/memory', json=payload)
response.raise_for_status()
url = f'/sessions/{sessionId}/memory'
res = self.request("post", url, json=payload)
res.raise_for_status()

data = response.json()
data = res.json()
memory = data.get('data', data)
return memory

def get_memory(self, sessionId):
response = self.client.get(f'{self.base_url}/sessions/{sessionId}/memory')
response.raise_for_status()

data = response.json()
url = f'/sessions/{sessionId}/memory'
res = self.request("get", url)
res.raise_for_status()
data = res.json()
memory = data.get('data', data)
return memory

def delete_memory(self, sessionId):
response = self.client.delete(f'{self.base_url}/sessions/{sessionId}/memory')
response.raise_for_status()
url = f'/sessions/{sessionId}/memory'
res = self.request("delete", url)
res.raise_for_status()

data = response.json()
data = res.json()
return data.get('data', data)
43 changes: 26 additions & 17 deletions src/metal_sdk/motorhead_async.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import httpx
from .typings import MotorheadPayload

API_URL = 'https://api.getmetal.io/v1/motorhead'
API_URL = 'https://api.getmetal.io/v1/motorhead/'


class Motorhead(httpx.AsyncClient):
def __init__(self, api_key=None, client_id=None, base_url=API_URL):
def __init__(self, payload: MotorheadPayload = {}):
super().__init__()
self.api_key = api_key
self.client_id = client_id
self.base_url = base_url
self.api_key = payload.get("api_key")
self.client_id = payload.get("client_id")
self.base_url = payload.get("base_url") or API_URL

if base_url == API_URL and not (api_key and client_id):
has_api_key = self.api_key is not None
has_client_id = self.client_id is not None
has_key_and_id = has_api_key and has_client_id

if self.base_url == API_URL and not has_key_and_id:
raise ValueError('api_key and client_id required for managed motorhead')

self.headers.update({
Expand All @@ -19,25 +24,29 @@ def __init__(self, api_key=None, client_id=None, base_url=API_URL):
'x-metal-client-id': self.client_id,
})

async def request(self, method, url, *args, **kwargs):
return await super().request(method, url, *args, **kwargs)

async def add_memory(self, sessionId, payload):
response = await self.post(f'{self.base_url}/sessions/{sessionId}/memory', json=payload)
response.raise_for_status()
url = f'/sessions/{sessionId}/memory'
res = await self.request("post", url, json=payload)
res.raise_for_status()

data = response.json()
data = res.json()
memory = data.get('data', data)
return memory

async def get_memory(self, sessionId):
response = await self.get(f'{self.base_url}/sessions/{sessionId}/memory')
response.raise_for_status()

data = response.json()
url = f'/sessions/{sessionId}/memory'
res = await self.request("get", url)
res.raise_for_status()
data = res.json()
memory = data.get('data', data)
return memory

async def delete_memory(self, sessionId):
response = await self.delete(f'{self.base_url}/sessions/{sessionId}/memory')
response.raise_for_status()
url = f'/sessions/{sessionId}/memory'
res = await self.request("delete", url)
res.raise_for_status()

data = response.json()
return data.get('data', data)
return
32 changes: 29 additions & 3 deletions tests/test_metal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import respx
from httpx import Response
from unittest import TestCase, mock
from src.metal_sdk.metal import Metal

Expand All @@ -15,6 +17,18 @@ def test_metal_instantiate(self):
self.assertEqual(metal.client_id, CLIENT_ID)
self.assertEqual(metal.index_id, index_id)

@respx.mock
def test_request(self):
url = 'https://api.getmetal.io/foo/bar'
method = 'GET'
respx.get(url).mock(return_value=Response(200))

index_id = "index-id"
metal = Metal(API_KEY, CLIENT_ID, index_id)

response = metal.request(method, "/foo/bar")
assert response.status_code == 200

def test_metal_index_without_index(self):
metal = Metal(API_KEY, CLIENT_ID)
with self.assertRaises(TypeError) as ctx:
Expand Down Expand Up @@ -77,10 +91,22 @@ def test_metal_search_without_payload(self):
my_index = "my-index"
metal = Metal(API_KEY, CLIENT_ID, my_index)

with self.assertRaises(TypeError) as ctx:
metal.search()
metal.request = mock.MagicMock(return_value=mock.Mock(status_code=200))
metal.search({"filters": [{"field": "foo", "value": "bar"}]}, limit=666)

self.assertEqual(metal.request.call_count, 1)
self.assertEqual(
str(ctx.exception), "imageBase64, imageUrl, text, or embedding required"
metal.request.call_args[0][0],
"post",
)
self.assertEqual(
metal.request.call_args[0][1],
"/v1/search?limit=666",
)
self.assertEqual(metal.request.call_args[1]["json"]["index"], my_index)
self.assertEqual(metal.request.call_args[1]["json"].get("text"), None)
self.assertEqual(
metal.request.call_args[1]["json"]["filters"], [{"field": "foo", "value": "bar"}]
)

def test_metal_search_with_text(self):
Expand Down
Loading

0 comments on commit 8e669c4

Please sign in to comment.