Skip to content

Commit

Permalink
Merge pull request #39 from gizatechxyz/feature/upgrade-pydantic
Browse files Browse the repository at this point in the history
Feature/upgrade pydantic
  • Loading branch information
Gonmeso authored Mar 11, 2024
2 parents 3d06015 + d69db81 commit 3bc27d4
Show file tree
Hide file tree
Showing 24 changed files with 455 additions and 360 deletions.
8 changes: 4 additions & 4 deletions examples/basic.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ You can also create an API key for the current user. This API key will be stored

```console
> giza users create-api-key
[giza][2024-01-17 15:27:27.936] Creating API Key ✅
[giza][2024-01-17 15:27:27.936] Creating API Key ✅
[giza][2024-01-17 15:27:53.605] API Key written to: /Users/gizabrain/.giza/.api_key.json
[giza][2024-01-17 15:27:53.606] Successfully created API Key. It will be used for future requests ✅
[giza][2024-01-17 15:27:53.606] Successfully created API Key. It will be used for future requests ✅
```

**NOTE: The usage of API key is less secure than JWT, so use it with caution.**
Expand All @@ -81,10 +81,10 @@ But don't worry, `giza` makes this process a breeze with a simple command! Let's
```console
> giza transpile awesome_model.onnx --output-path cairo_model

[giza][2023-09-13 12:56:43.725] No model id provided, checking if model exists ✅
[giza][2023-09-13 12:56:43.725] No model id provided, checking if model exists ✅
[giza][2023-09-13 12:56:43.726] Model name is: awesome_model
[giza][2023-09-13 12:56:43.978] Model Created with id -> 25! ✅
[giza][2023-09-13 12:56:44.568] Sending model for transpilation ✅
[giza][2023-09-13 12:56:44.568] Sending model for transpilation ✅
[giza][2023-09-13 12:56:55.577] Transpilation recieved! ✅
[giza][2023-09-13 12:56:55.583] Transpilation saved at: cairo_model
```
Expand Down
2 changes: 1 addition & 1 deletion examples/reset_password.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Once you have received your reset token, you can use it to reset your password.
> giza reset-password --token your_reset_token

Please enter your new password 🔑: # Your new password goes here
Please confirm your new password 🔑:
Please confirm your new password 🔑:
[giza][2023-08-30 12:55:32.128] Password updated successfully
[giza][2023-08-30 12:55:32.132] Password reset was successful 🎉
```
Expand Down
26 changes: 13 additions & 13 deletions giza/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def create(
]
),
headers=headers,
params=endpoint_create.dict(),
params=endpoint_create.model_dump(),
data={"model_id": model_id, "version_id": version_id},
files={"sierra": f} if f is not None else None,
)
Expand Down Expand Up @@ -507,7 +507,7 @@ def list(self, params: Optional[Dict[str, str]] = None) -> EndpointsList:
response.raise_for_status()

return EndpointsList(
__root__=[Endpoint(**endpoint) for endpoint in response.json()]
root=[Endpoint(**endpoint) for endpoint in response.json()]
)

@auth
Expand Down Expand Up @@ -536,7 +536,7 @@ def list_jobs(self, endpoint_id: int) -> JobList:

response.raise_for_status()

return JobList(__root__=[Job(**job) for job in response.json()])
return JobList(root=[Job(**job) for job in response.json()])

@auth
def list_proofs(self, endpoint_id: int) -> ProofList:
Expand Down Expand Up @@ -564,7 +564,7 @@ def list_proofs(self, endpoint_id: int) -> ProofList:

response.raise_for_status()

return ProofList(__root__=[Proof(**proof) for proof in response.json()])
return ProofList(root=[Proof(**proof) for proof in response.json()])

@auth
def get_proof(self, endpoint_id: int, proof_id: int) -> Proof:
Expand Down Expand Up @@ -814,7 +814,7 @@ def list(self, **kwargs) -> ModelList:

response.raise_for_status()

return ModelList(__root__=[Model(**model) for model in response.json()])
return ModelList(root=[Model(**model) for model in response.json()])

def get_by_name(self, model_name: str, **kwargs) -> Union[Model, None]:
"""
Expand All @@ -832,7 +832,7 @@ def get_by_name(self, model_name: str, **kwargs) -> Union[Model, None]:
except HTTPError as e:
self._echo_debug(f"Could not retrieve model by name: {str(e)}")
return None
return model.__root__[0]
return model.root[0]

@auth
def create(self, model_create: ModelCreate) -> Model:
Expand All @@ -854,7 +854,7 @@ def create(self, model_create: ModelCreate) -> Model:
response = self.session.post(
f"{self.url}/{self.MODELS_ENDPOINT}",
headers=headers,
json=model_create.dict(),
json=model_create.model_dump(),
)
self._echo_debug(str(response))

Expand All @@ -880,7 +880,7 @@ def update(self, model_id: int, model_update: ModelUpdate) -> Model:
response = self.session.put(
f"{self.url}/{self.MODELS_ENDPOINT}/{model_id}",
headers=headers,
json=model_update.dict(),
json=model_update.model_dump(),
)
self._echo_debug(str(response))

Expand Down Expand Up @@ -950,7 +950,7 @@ def create(
response = self.session.post(
f"{self.url}/{self.JOBS_ENDPOINT}",
headers=headers,
params=job_create.dict(),
params=job_create.model_dump(),
files=files,
)
self._echo_debug(str(response))
Expand Down Expand Up @@ -1055,7 +1055,7 @@ def create(
]
),
headers=headers,
params=job_create.dict(),
params=job_create.model_dump(),
files={"file": f},
)
self._echo_debug(str(response))
Expand Down Expand Up @@ -1307,7 +1307,7 @@ def create(
response = self.session.post(
f"{self._get_version_url(model_id)}",
headers=headers,
json=version_create.dict(),
json=version_create.model_dump(),
params={"filename": filename} if filename else None,
)
self._echo_debug(str(response))
Expand Down Expand Up @@ -1447,7 +1447,7 @@ def list(self, model_id: int) -> VersionList:

response.raise_for_status()

return VersionList(__root__=[Version(**version) for version in response.json()])
return VersionList(root=[Version(**version) for version in response.json()])

@auth
def update(
Expand All @@ -1470,7 +1470,7 @@ def update(
response = self.session.put(
f"{self._get_version_url(model_id)}/{version_id}",
headers=headers,
json=version_update.dict(),
json=version_update.model_dump(),
)
self._echo_debug(str(response))

Expand Down
8 changes: 4 additions & 4 deletions giza/commands/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def list(
if debug:
raise e
sys.exit(1)
print_json(deployments.json())
print_json(deployments.model_dump_json())


# giza/commands/deployments.py
Expand Down Expand Up @@ -161,7 +161,7 @@ def get(
if debug:
raise e
sys.exit(1)
print_json(deployment.json())
print_json(deployment.model_dump_json())


@app.command(
Expand Down Expand Up @@ -235,7 +235,7 @@ def list_proofs(
if debug:
raise e
sys.exit(1)
print_json(proofs.json(exclude_unset=True))
print_json(proofs.model_dump_json(exclude_unset=True))


@app.command(
Expand Down Expand Up @@ -284,7 +284,7 @@ def get_proof(
if debug:
raise e
sys.exit(1)
print_json(proof.json(exclude_unset=True))
print_json(proof.model_dump_json(exclude_unset=True))


@app.command(
Expand Down
6 changes: 3 additions & 3 deletions giza/commands/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def get(
if debug:
raise e
sys.exit(1)
print_json(model.json())
print_json(model.model_dump_json())


@app.command(
Expand Down Expand Up @@ -116,7 +116,7 @@ def list(
if debug:
raise e
sys.exit(1)
print_json(models.json())
print_json(models.model_dump_json())


@app.command(
Expand Down Expand Up @@ -174,4 +174,4 @@ def create(
if debug:
raise e
sys.exit(1)
print_json(model.json())
print_json(model.model_dump_json())
13 changes: 7 additions & 6 deletions giza/commands/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import Optional

import typer
from pydantic import EmailError, EmailStr, SecretStr, ValidationError
from email_validator import EmailNotValidError, validate_email
from pydantic import SecretStr, ValidationError
from requests import HTTPError
from rich import print_json

Expand Down Expand Up @@ -59,11 +60,11 @@ def create(debug: Optional[bool] = DEBUG_OPTION) -> None:
echo("Creating user in Giza ✅ ")
try:
user_create = users.UserCreate(
username=user, password=SecretStr(password), email=EmailStr(email)
username=user, password=SecretStr(password), email=email
)
client = UsersClient(API_HOST)
client.create(user_create)
except ValidationError as e:
except (ValidationError, EmailNotValidError) as e:
echo.error("⛔️Could not create the user⛔️")
echo.error("Review the provided information")
if debug:
Expand Down Expand Up @@ -198,7 +199,7 @@ def me(debug: Optional[bool] = DEBUG_OPTION) -> None:
client = UsersClient(API_HOST, debug=debug)
user = client.me()

print_json(user.json())
print_json(user.model_dump_json())


@app.command(
Expand All @@ -223,8 +224,8 @@ def resend_email(debug: Optional[bool] = DEBUG_OPTION) -> None:
echo("Resending verification email ✅ ")
try:
client = UsersClient(API_HOST)
client.resend_email(EmailStr.validate(email))
except (ValidationError, EmailError) as e:
client.resend_email(validate_email(email).normalized)
except (ValidationError, EmailNotValidError) as e:
echo.error("⛔️Could not resend the email⛔️")
echo.error("Review the provided information")
if debug:
Expand Down
6 changes: 3 additions & 3 deletions giza/commands/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get(
with ExceptionHandler(debug=debug):
client = VersionsClient(API_HOST)
version: Version = client.get(model_id, version_id)
print_json(version.json())
print_json(version.model_dump_json())


def transpile(
Expand Down Expand Up @@ -180,7 +180,7 @@ def update(
zip_path = zip_folder(model_path, tmp_dir)
version = client.upload_cairo(model_id, version_id, zip_path)
echo("Version updated ✅ ")
print_json(version.json())
print_json(version.model_dump_json())


@app.command(
Expand All @@ -201,7 +201,7 @@ def list(
with ExceptionHandler(debug=debug):
client = VersionsClient(API_HOST)
versions: VersionList = client.list(model_id)
print_json(versions.json())
print_json(versions.model_dump_json())


@app.command(
Expand Down
2 changes: 1 addition & 1 deletion giza/commands/workspaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def get(
echo.error("⛔️Please delete the workspace and create a new one⛔️")
else:
echo.info(f"✅ Workspace URL: {workspace.url} ✅")
print_json(workspace.json())
print_json(workspace.model_dump_json())


@app.command(
Expand Down
2 changes: 1 addition & 1 deletion giza/frameworks/cairo.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def deploy(
endpoints_list: EndpointsList = client.list(
params={"model_id": model_id, "version_id": version_id, "is_active": True}
)
endpoints: dict = json.loads(endpoints_list.json())
endpoints: dict = json.loads(endpoints_list.model_dump_json())

if len(endpoints) > 0:
echo.info(
Expand Down
2 changes: 1 addition & 1 deletion giza/frameworks/ezkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def deploy(
endpoints_list: EndpointsList = client.list(
params={"model_id": model_id, "version_id": version_id, "is_active": True}
)
endpoints: dict = json.loads(endpoints_list.json())
endpoints: dict = json.loads(endpoints_list.model_dump_json())

if len(endpoints) > 0:
echo.info(
Expand Down
11 changes: 7 additions & 4 deletions giza/schemas/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

from giza.utils.enums import Framework, ServiceSize

Expand All @@ -15,6 +15,9 @@ class EndpointCreate(BaseModel):
service_name: Optional[str] = None
framework: Framework = Framework.CAIRO

model_config = ConfigDict(from_attributes=True)
model_config["protected_namespaces"] = ()


class Endpoint(BaseModel):
id: int
Expand All @@ -26,9 +29,9 @@ class Endpoint(BaseModel):
version_id: Optional[int] = None
is_active: bool

class Config:
orm_mode = True
model_config = ConfigDict(from_attributes=True)
model_config["protected_namespaces"] = ()


class EndpointsList(BaseModel):
__root__: list[Endpoint]
root: list[Endpoint]
7 changes: 5 additions & 2 deletions giza/schemas/jobs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
from typing import Optional

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

from giza.utils.enums import Framework, JobKind, JobSize, JobStatus

Expand All @@ -25,6 +25,9 @@ class JobCreate(BaseModel):
version_id: Optional[int] = None
proof_id: Optional[int] = None

model_config = ConfigDict(from_attributes=True)
model_config["protected_namespaces"] = ()


class JobList(BaseModel):
__root__: list[Job]
root: list[Job]
6 changes: 3 additions & 3 deletions giza/schemas/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from pydantic import BaseModel
from pydantic import BaseModel, RootModel


class Model(BaseModel):
Expand All @@ -18,5 +18,5 @@ class ModelUpdate(BaseModel):
description: str


class ModelList(BaseModel):
__root__: list[Model]
class ModelList(RootModel):
root: list[Model]
6 changes: 3 additions & 3 deletions giza/schemas/proofs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
from typing import Optional

from pydantic import BaseModel
from pydantic import BaseModel, RootModel


class Proof(BaseModel):
Expand All @@ -14,5 +14,5 @@ class Proof(BaseModel):
request_id: Optional[str] = None


class ProofList(BaseModel):
__root__: list[Proof]
class ProofList(RootModel):
root: list[Proof]
Loading

0 comments on commit 3bc27d4

Please sign in to comment.