Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AsyncDefog #61

Merged
merged 17 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,9 @@ local_test.py
defog_metadata.csv
golden_queries.csv
golden_queries.json
glossary.txt
glossary.txt

# Ignore virtual environment directories
.virtual/
myenv/
venv/
103 changes: 99 additions & 4 deletions defog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@
import json
import os
from importlib.metadata import version
from defog import generate_schema, query_methods, admin_methods, health_methods
from defog import (
generate_schema,
async_generate_schema,
query_methods,
async_query_methods,
admin_methods,
async_admin_methods,
health_methods,
async_health_methods,
)

try:
__version__ = version("defog")
Expand All @@ -20,9 +29,9 @@
]


class Defog:
class BaseDefog:
"""
The main class for Defog
The base class for Defog and AsyncDefog
"""

def __init__(
Expand All @@ -37,7 +46,7 @@ def __init__(
verbose: bool = False,
):
"""
Initializes the Defog class.
Initializes the Base Defog class.
We have the possible scenarios detailed below:
1) no config file, no/incomplete params -> success if only db_creds missing, error otherwise
2) no config file, wrong params -> error
Expand Down Expand Up @@ -204,30 +213,116 @@ def from_base64_creds(self, base64_creds: str):
self.db_creds = creds["db_creds"]


class Defog(BaseDefog):
"""
The main class for Defog (Synchronous)
"""

def __init__(
self,
api_key: str = "",
db_type: str = "",
db_creds: dict = {},
base64creds: str = "",
save_json: bool = True,
base_url: str = "https://api.defog.ai",
generate_query_url: str = "https://api.defog.ai/generate_query_chat",
verbose: bool = False,
):
"""Initializes the synchronous version of the Defog class"""
super().__init__(
api_key=api_key,
db_type=db_type,
db_creds=db_creds,
base64creds=base64creds,
save_json=save_json,
base_url=base_url,
generate_query_url=generate_query_url,
verbose=verbose,
)


class AsyncDefog(BaseDefog):
"""
The main class for Defog (Asynchronous)
"""

def __init__(
self,
api_key: str = "",
db_type: str = "",
db_creds: dict = {},
base64creds: str = "",
save_json: bool = True,
base_url: str = "https://api.defog.ai",
generate_query_url: str = "https://api.defog.ai/generate_query_chat",
verbose: bool = False,
):
"""Initializes the asynchronous version of the Defog class"""
super().__init__(
api_key=api_key,
db_type=db_type,
db_creds=db_creds,
base64creds=base64creds,
save_json=save_json,
base_url=base_url,
generate_query_url=generate_query_url,
verbose=verbose,
)


# Add all methods from generate_schema to Defog
for name in dir(generate_schema):
attr = getattr(generate_schema, name)
if callable(attr):
# Add the method to Defog
setattr(Defog, name, attr)

# Add all methods from async_generate_schema to AsyncDefog
for name in dir(async_generate_schema):
attr = getattr(async_generate_schema, name)
if callable(attr):
# Add the method to AsyncDefog
setattr(AsyncDefog, name, attr)

# Add all methods from query_methods to Defog
for name in dir(query_methods):
attr = getattr(query_methods, name)
if callable(attr):
# Add the method to Defog
setattr(Defog, name, attr)

# Add all methods from async_query_methods to AsyncDefog
for name in dir(async_query_methods):
attr = getattr(async_query_methods, name)
if callable(attr):
# Add the method to AsyncDefog
setattr(AsyncDefog, name, attr)

# Add all methods from admin_methods to Defog
for name in dir(admin_methods):
attr = getattr(admin_methods, name)
if callable(attr):
# Add the method to Defog
setattr(Defog, name, attr)

# Add all methods from async_admin_methods to AsyncDefog
for name in dir(async_admin_methods):
attr = getattr(async_admin_methods, name)
if callable(attr):
# Add the method to AsyncDefog
setattr(AsyncDefog, name, attr)

# Add all methods from health_methods to Defog
for name in dir(health_methods):
attr = getattr(health_methods, name)
if callable(attr):
# Add the method to Defog
setattr(Defog, name, attr)

# Add all methods from async_health_methods to AsyncDefog
for name in dir(async_health_methods):
attr = getattr(async_health_methods, name)
if callable(attr):
# Add the method to AsyncDefog
setattr(AsyncDefog, name, attr)
12 changes: 10 additions & 2 deletions defog/admin_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ def get_feedback(self, n_rows: int = 50, start_from: int = 0):


def get_quota(self) -> Optional[Dict]:
"""
Get the quota for the API key.
"""
api_key = self.api_key
response = requests.post(
f"{self.base_url}/check_api_usage",
Expand Down Expand Up @@ -341,6 +344,7 @@ def create_empty_tables(self, dev: bool = False):
conn.commit()
conn.close()
return True

elif self.db_type == "mysql":
import mysql.connector

Expand All @@ -351,6 +355,7 @@ def create_empty_tables(self, dev: bool = False):
conn.commit()
conn.close()
return True

elif self.db_type == "databricks":
from databricks import sql

Expand All @@ -359,6 +364,7 @@ def create_empty_tables(self, dev: bool = False):
conn.commit()
conn.close()
return True

elif self.db_type == "snowflake":
import snowflake.connector

Expand All @@ -367,15 +373,16 @@ def create_empty_tables(self, dev: bool = False):
password=self.db_creds["password"],
account=self.db_creds["account"],
)
conn.cursor().execute(
cur = conn.cursor()
cur.execute(
f"USE WAREHOUSE {self.db_creds['warehouse']}"
) # set the warehouse
cur = conn.cursor()
for statement in ddl.split(";"):
cur.execute(statement)
conn.commit()
conn.close()
return True

elif self.db_type == "bigquery":
from google.cloud import bigquery

Expand All @@ -385,6 +392,7 @@ def create_empty_tables(self, dev: bool = False):
for statement in ddl.split(";"):
client.query(statement)
return True

elif self.db_type == "sqlserver":
import pyodbc

Expand Down
Loading
Loading