From 5b23af2090b011a50240cc745e40927e821f0447 Mon Sep 17 00:00:00 2001 From: Muhammad18557 Date: Tue, 10 Sep 2024 13:43:47 +0800 Subject: [PATCH 01/17] ignore venvs --- .gitignore | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index f55732f..fbaa9fb 100644 --- a/.gitignore +++ b/.gitignore @@ -67,4 +67,9 @@ local_test.py defog_metadata.csv golden_queries.csv golden_queries.json -glossary.txt \ No newline at end of file +glossary.txt + +# Ignore virtual environment directories +.virtual/ +myenv/ +venv/ \ No newline at end of file From e96d0412bed3906cc1d7def55c9ac899e36d1711 Mon Sep 17 00:00:00 2001 From: Muhammad18557 Date: Tue, 10 Sep 2024 13:45:48 +0800 Subject: [PATCH 02/17] init for defog base, sync and async classes --- defog/__init__.py | 103 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 99 insertions(+), 4 deletions(-) diff --git a/defog/__init__.py b/defog/__init__.py index 8f166ca..e064d82 100644 --- a/defog/__init__.py +++ b/defog/__init__.py @@ -2,7 +2,16 @@ import json import os from importlib.metadata import version -from defog import generate_schema, query_methods, admin_methods, health_methods +from defog import ( + generate_schema, + async_generate_schema, + query_methods, + async_query_methods, + admin_methods, + async_admin_methods, + health_methods, + async_health_methods, +) try: __version__ = version("defog") @@ -20,9 +29,9 @@ ] -class Defog: +class BaseDefog: """ - The main class for Defog + The base class for Defog and AsyncDefog """ def __init__( @@ -37,7 +46,7 @@ def __init__( verbose: bool = False, ): """ - Initializes the Defog class. + Initializes the Base Defog class. We have the possible scenarios detailed below: 1) no config file, no/incomplete params -> success if only db_creds missing, error otherwise 2) no config file, wrong params -> error @@ -204,6 +213,64 @@ def from_base64_creds(self, base64_creds: str): self.db_creds = creds["db_creds"] +class Defog(BaseDefog): + """ + The main class for Defog (Synchronous) + """ + + def __init__( + self, + api_key: str = "", + db_type: str = "", + db_creds: dict = {}, + base64creds: str = "", + save_json: bool = True, + base_url: str = "https://api.defog.ai", + generate_query_url: str = "https://api.defog.ai/generate_query_chat", + verbose: bool = False, + ): + """Initializes the synchronous version of the Defog class""" + super().__init__( + api_key=api_key, + db_type=db_type, + db_creds=db_creds, + base64creds=base64creds, + save_json=save_json, + base_url=base_url, + generate_query_url=generate_query_url, + verbose=verbose, + ) + + +class AsyncDefog(BaseDefog): + """ + The main class for Defog (Asynchronous) + """ + + def __init__( + self, + api_key: str = "", + db_type: str = "", + db_creds: dict = {}, + base64creds: str = "", + save_json: bool = True, + base_url: str = "https://api.defog.ai", + generate_query_url: str = "https://api.defog.ai/generate_query_chat", + verbose: bool = False, + ): + """Initializes the asynchronous version of the Defog class""" + super().__init__( + api_key=api_key, + db_type=db_type, + db_creds=db_creds, + base64creds=base64creds, + save_json=save_json, + base_url=base_url, + generate_query_url=generate_query_url, + verbose=verbose, + ) + + # Add all methods from generate_schema to Defog for name in dir(generate_schema): attr = getattr(generate_schema, name) @@ -211,6 +278,13 @@ def from_base64_creds(self, base64_creds: str): # Add the method to Defog setattr(Defog, name, attr) +# Add all methods from async_generate_schema to AsyncDefog +for name in dir(async_generate_schema): + attr = getattr(async_generate_schema, name) + if callable(attr): + # Add the method to AsyncDefog + setattr(AsyncDefog, name, attr) + # Add all methods from query_methods to Defog for name in dir(query_methods): attr = getattr(query_methods, name) @@ -218,6 +292,13 @@ def from_base64_creds(self, base64_creds: str): # Add the method to Defog setattr(Defog, name, attr) +# Add all methods from async_query_methods to AsyncDefog +for name in dir(async_query_methods): + attr = getattr(async_query_methods, name) + if callable(attr): + # Add the method to AsyncDefog + setattr(AsyncDefog, name, attr) + # Add all methods from admin_methods to Defog for name in dir(admin_methods): attr = getattr(admin_methods, name) @@ -225,9 +306,23 @@ def from_base64_creds(self, base64_creds: str): # Add the method to Defog setattr(Defog, name, attr) +# Add all methods from async_admin_methods to AsyncDefog +for name in dir(async_admin_methods): + attr = getattr(async_admin_methods, name) + if callable(attr): + # Add the method to AsyncDefog + setattr(AsyncDefog, name, attr) + # Add all methods from health_methods to Defog for name in dir(health_methods): attr = getattr(health_methods, name) if callable(attr): # Add the method to Defog setattr(Defog, name, attr) + +# Add all methods from async_health_methods to AsyncDefog +for name in dir(async_health_methods): + attr = getattr(async_health_methods, name) + if callable(attr): + # Add the method to AsyncDefog + setattr(AsyncDefog, name, attr) From 3f0c3643bf9037dcf8cb56e964d290b5eb6d02f4 Mon Sep 17 00:00:00 2001 From: Muhammad18557 Date: Tue, 10 Sep 2024 13:48:18 +0800 Subject: [PATCH 03/17] async query functions --- defog/query.py | 270 ++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 256 insertions(+), 14 deletions(-) diff --git a/defog/query.py b/defog/query.py index 7ec2cbe..36c6533 100644 --- a/defog/query.py +++ b/defog/query.py @@ -1,10 +1,10 @@ import json import re import requests -from defog.util import write_logs +from defog.util import write_logs, async_write_logs, make_async_post_request +import asyncio import os - # execute query for given db_type and return column names and data def execute_query_once(db_type: str, db_creds, query: str): """ @@ -19,10 +19,11 @@ def execute_query_once(db_type: str, db_creds, query: str): cur = conn.cursor() cur.execute(query) colnames = [desc[0] for desc in cur.description] - results = cur.fetchall() + rows = cur.fetchall() cur.close() conn.close() - return colnames, results + return colnames, rows + elif db_type == "redshift": try: import psycopg2 @@ -51,10 +52,11 @@ def execute_query_once(db_type: str, db_creds, query: str): for i, col in enumerate(colnames) ] - results = cur.fetchall() + rows = cur.fetchall() cur.close() conn.close() - return colnames, results + return colnames, rows + elif db_type == "mysql": try: import mysql.connector @@ -64,10 +66,11 @@ def execute_query_once(db_type: str, db_creds, query: str): cur = conn.cursor() cur.execute(query) colnames = [desc[0] for desc in cur.description] - results = cur.fetchall() + rows = cur.fetchall() cur.close() conn.close() - return colnames, results + return colnames, rows + elif db_type == "bigquery": try: from google.cloud import bigquery @@ -83,6 +86,7 @@ def execute_query_once(db_type: str, db_creds, query: str): for row in results: rows.append([row[i] for i in range(len(row))]) return colnames, rows + elif db_type == "snowflake": try: import snowflake.connector @@ -99,10 +103,11 @@ def execute_query_once(db_type: str, db_creds, query: str): cur.execute(f"USE DATABASE {db_creds['database']}") cur.execute(query) colnames = [desc[0] for desc in cur.description] - results = cur.fetchall() + rows = cur.fetchall() cur.close() conn.close() - return colnames, results + return colnames, rows + elif db_type == "databricks": try: from databricks import sql @@ -112,8 +117,9 @@ def execute_query_once(db_type: str, db_creds, query: str): with conn.cursor() as cursor: cursor.execute(query) colnames = [desc[0] for desc in cursor.description] - results = cursor.fetchall() - return colnames, results + rows = cursor.fetchall() + return colnames, rows + elif db_type == "sqlserver": try: import pyodbc @@ -129,10 +135,155 @@ def execute_query_once(db_type: str, db_creds, query: str): cur.execute(query) colnames = [desc[0] for desc in cur.description] results = cur.fetchall() - results = [list(row) for row in results] + rows = [list(row) for row in results] cur.close() conn.close() - return colnames, results + return colnames, rows + + else: + raise Exception(f"Database type {db_type} not yet supported.") + + +async def async_execute_query_once(db_type: str, db_creds, query: str): + """ + Asynchrnously executes the query once and returns the column names and results. + """ + if db_type == "postgres": + try: + import asyncpg + except: + raise Exception("asyncpg not installed.") + + conn = await asyncpg.connect(**db_creds) + results = await conn.fetch(query) + + colnames = list(results[0].keys()) + if colnames is None: + colnames = [] + + await conn.close() + # get the results in a list of lists format + rows = [list(row.values()) for row in results] + return colnames, rows + + elif db_type == "redshift": + try: + import asyncpg + except: + raise Exception("asyncpg not installed.") + + if "schema" not in db_creds: + schema = "public" + conn = await asyncpg.connect(**db_creds) + else: + schema = db_creds["schema"] + del db_creds["schema"] + conn = await asyncpg.connect(**db_creds) + + if schema is not None and schema != "public": + await conn.execute(f"SET search_path TO {schema}") + + results = await conn.fetch(query) + colnames = list(results[0].keys()) + + # deduplicate the column names + colnames = [ + f"{col}_{i}" if colnames.count(col) > 1 else col + for i, col in enumerate(colnames) + ] + rows = [list(row.values()) for row in results] + + await conn.close() + return colnames, rows + + elif db_type == "mysql": + try: + import aiomysql + except: + raise Exception("aiomysql not installed.") + conn = await aiomysql.connect(**db_creds) + cur = await conn.cursor() + await cur.execute(query) + colnames = [desc[0] for desc in cur.description] + rows = await cur.fetchall() + await cur.close() + await conn.ensure_closed() + return colnames, rows + + elif db_type == "bigquery": + try: + from google.cloud import bigquery + except: + raise Exception("google.cloud.bigquery not installed.") + # using asynico.to_thread since google-cloud-bigquery is synchronous + json_key = db_creds["json_key_path"] + client = await asyncio.to_thread( + bigquery.Client.from_service_account_json, json_key + ) + query_job = await asyncio.to_thread(client.query, query) + results = await asyncio.to_thread(query_job.result) + colnames = [i.name for i in results.schema] + rows = [] + for row in results: + rows.append([row[i] for i in range(len(row))]) + return colnames, rows + + elif db_type == "snowflake": + try: + import snowflake.connector + except: + raise Exception("snowflake.connector not installed.") + conn = await asyncio.to_thread( + snowflake.connector.connect, + user=db_creds["user"], + password=db_creds["password"], + account=db_creds["account"], + ) + cur = await asyncio.to_thread(conn.cursor) + await asyncio.to_thread( + cur.execute, f"USE WAREHOUSE {db_creds['warehouse']}" + ) # set the warehouse + if "database" in db_creds: + await asyncio.to_thread(cur.execute, f"USE DATABASE {db_creds['database']}") + await asyncio.to_thread(cur.execute, query) + colnames = [desc[0] for desc in cur.description] + rows = await asyncio.to_thread(cur.fetchall) + await asyncio.to_thread(cur.close) + await asyncio.to_thread(conn.close) + return colnames, rows + + elif db_type == "databricks": + try: + from databricks import sql + except: + raise Exception("databricks-sql-connector not installed.") + conn = await asyncio.to_thread(sql.connect, **db_creds) + cursor = await asyncio.to_thread(conn.cursor) + + await asyncio.to_thread(cursor.execute, query) + colnames = [desc[0] for desc in cursor.description] + rows = await asyncio.to_thread(cursor.fetchall) + return colnames, rows + + elif db_type == "sqlserver": + try: + import pyodbc + except: + raise Exception("pyodbc not installed.") + + if db_creds["database"] != "": + connection_string = f"DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={db_creds['server']};DATABASE={db_creds['database']};UID={db_creds['user']};PWD={db_creds['password']};TrustServerCertificate=yes;Connection Timeout=120;" + else: + connection_string = f"DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={db_creds['server']};UID={db_creds['user']};PWD={db_creds['password']};TrustServerCertificate=yes;Connection Timeout=120;" + conn = await asyncio.to_thread(pyodbc.connect, connection_string) + cur = await asyncio.to_thread(conn.cursor) + await asyncio.to_thread(cursor.execute, query) + colnames = [desc[0] for desc in cursor.description] + results = await asyncio.to_thread(cursor.fetchall) + rows = [list(row) for row in results] + await asyncio.to_thread(cursor.close) + await asyncio.to_thread(conn.close) + return colnames, rows else: raise Exception(f"Database type {db_type} not yet supported.") @@ -223,6 +374,97 @@ def execute_query( raise Exception(err_msg) +async def async_execute_query( + query: str, + api_key: str, + db_type: str, + db_creds, + question: str = "", + hard_filters: str = "", + retries: int = 3, + schema: dict = None, + dev: bool = False, + temp: bool = False, + base_url: str = None, +): + """ + Execute the query asynchronously and retry with adaptive learning if there is an error. + Raises an Exception if there are no retries left, or if the error is a connection error. + """ + err_msg = None + # if base_url is not explicitly defined, check if DEFOG_BASE_URL is set in the environment + # if not, then use "https://api.defog.ai" as the default + if base_url is None: + base_url = os.environ.get("DEFOG_BASE_URL", "https://api.defog.ai") + + try: + return await async_execute_query_once( + db_type=db_type, db_creds=db_creds, query=query + ) + (query,) + except Exception as e: + err_msg = str(e) + if is_connection_error(err_msg): + raise Exception( + f"There was a connection issue to your database:\n{err_msg}\n\nPlease check your database credentials and try again." + ) + # log this error to our feedback system first (this is a 1-way side-effect) + try: + await make_async_post_request( + url=f"{base_url}/feedback", + payload={ + "api_key": api_key, + "feedback": "bad", + "text": err_msg, + "db_type": db_type, + "question": question, + "query": query, + "dev": dev, + "temp": temp, + }, + timeout=1, + ) + except: + pass + # log locally + await async_write_logs(str(e)) + # retry with adaptive learning + while retries > 0: + await async_write_logs(f"Retries left: {retries}") + try: + retry = { + "api_key": api_key, + "previous_query": query, + "error": err_msg, + "db_type": db_type, + "hard_filters": hard_filters, + "question": question, + "dev": dev, + "temp": temp, + } + if schema is not None: + retry["schema"] = schema + + await async_write_logs(json.dumps(retry)) + + response = await make_async_post_request( + url=f"{base_url}/retry_query_after_error", + payload=retry, + ) + new_query = response["new_query"] + await async_write_logs(f"New query: \n{new_query}") + return await async_execute_query_once(db_type, db_creds, new_query) + ( + new_query, + ) + except Exception as e: + err_msg = str(e) + print( + "There was an error when running the previous query. Retrying with adaptive learning..." + ) + write_logs(str(e)) + retries -= 1 + raise Exception(err_msg) + + def is_connection_error(err_msg: str) -> bool: return ( isinstance(err_msg, str) From 44c4fa224ca3fc6ff5c61ff4bf30a65035cc1f5c Mon Sep 17 00:00:00 2001 From: Muhammad18557 Date: Tue, 10 Sep 2024 13:56:59 +0800 Subject: [PATCH 04/17] async version of sets of methods and minor fixes in sync versions --- defog/admin_methods.py | 12 +- defog/async_admin_methods.py | 409 ++++++++++++++++++++ defog/async_generate_schema.py | 684 +++++++++++++++++++++++++++++++++ defog/async_health_methods.py | 46 +++ defog/async_query_methods.py | 172 +++++++++ defog/generate_schema.py | 4 + 6 files changed, 1325 insertions(+), 2 deletions(-) create mode 100644 defog/async_admin_methods.py create mode 100644 defog/async_generate_schema.py create mode 100644 defog/async_health_methods.py create mode 100644 defog/async_query_methods.py diff --git a/defog/admin_methods.py b/defog/admin_methods.py index c69fb52..ce125fe 100644 --- a/defog/admin_methods.py +++ b/defog/admin_methods.py @@ -144,6 +144,9 @@ def get_feedback(self, n_rows: int = 50, start_from: int = 0): def get_quota(self) -> Optional[Dict]: + """ + Get the quota for the API key. + """ api_key = self.api_key response = requests.post( f"{self.base_url}/check_api_usage", @@ -341,6 +344,7 @@ def create_empty_tables(self, dev: bool = False): conn.commit() conn.close() return True + elif self.db_type == "mysql": import mysql.connector @@ -351,6 +355,7 @@ def create_empty_tables(self, dev: bool = False): conn.commit() conn.close() return True + elif self.db_type == "databricks": from databricks import sql @@ -359,6 +364,7 @@ def create_empty_tables(self, dev: bool = False): conn.commit() conn.close() return True + elif self.db_type == "snowflake": import snowflake.connector @@ -367,15 +373,16 @@ def create_empty_tables(self, dev: bool = False): password=self.db_creds["password"], account=self.db_creds["account"], ) - conn.cursor().execute( + cur = conn.cursor() + cur.execute( f"USE WAREHOUSE {self.db_creds['warehouse']}" ) # set the warehouse - cur = conn.cursor() for statement in ddl.split(";"): cur.execute(statement) conn.commit() conn.close() return True + elif self.db_type == "bigquery": from google.cloud import bigquery @@ -385,6 +392,7 @@ def create_empty_tables(self, dev: bool = False): for statement in ddl.split(";"): client.query(statement) return True + elif self.db_type == "sqlserver": import pyodbc diff --git a/defog/async_admin_methods.py b/defog/async_admin_methods.py new file mode 100644 index 0000000..7a53aea --- /dev/null +++ b/defog/async_admin_methods.py @@ -0,0 +1,409 @@ +import json +from typing import Dict, List, Optional +from defog.util import make_async_post_request +import pandas as pd +import asyncio +import aiofiles + + +async def update_db_schema(self, path_to_csv, dev=False, temp=False): + """ + Update the DB schema via a CSV + """ + schema_df = pd.read_csv(path_to_csv).fillna("") + # check columns + if not all( + col in schema_df.columns + for col in ["table_name", "column_name", "data_type", "column_description"] + ): + raise ValueError( + "The CSV must contain the following columns: table_name, column_name, data_type, column_description" + ) + schema = {} + for table_name in schema_df["table_name"].unique(): + schema[table_name] = schema_df[schema_df["table_name"] == table_name][ + ["column_name", "data_type", "column_description"] + ].to_dict(orient="records") + + payload = { + "api_key": self.api_key, + "table_metadata": schema, + "db_type": self.db_type, + "dev": dev, + "temp": temp, + } + + resp = await make_async_post_request( + url=f"{self.base_url}/update_metadata", payload=payload + ) + return resp + + +async def update_glossary( + self, + glossary: str = "", + customized_glossary: dict = None, + glossary_compulsory: str = "", + glossary_prunable_units: List[str] = [], + dev: bool = False, +): + """ + Updates the glossary on the defog servers. + :param glossary: The glossary to be used. + """ + data = { + "api_key": self.api_key, + "glossary": glossary, + "dev": dev, + "glossary_compulsory": glossary_compulsory, + "glossary_prunable_units": glossary_prunable_units, + } + if customized_glossary: + data["customized_glossary"] = customized_glossary + resp = await make_async_post_request( + url=f"{self.base_url}/update_glossary", payload=data + ) + return resp + + +async def delete_glossary(self, user_type=None, dev=False): + """ + Deletes the glossary on the defog servers. + """ + data = { + "api_key": self.api_key, + "dev": dev, + } + if user_type: + data["key"] = user_type + r = await make_async_post_request( + url=f"{self.base_url}/delete_glossary", + payload=data, + return_response_object=True, + ) + if r.status_code == 200: + print("Glossary deleted successfully.") + else: + error_message = r.json().get("message", "") + print(f"Glossary deletion failed.\nError message: {error_message}") + + +async def get_glossary(self, mode="general", dev=False): + """ + Gets the glossary on the defog servers. + """ + resp = await make_async_post_request( + url=f"{self.base_url}/get_glossary", + payload={"api_key": self.api_key, "dev": dev}, + ) + if mode == "general": + return resp["glossary"] + elif mode == "customized": + return resp["customized_glossary"] + + +async def get_metadata(self, format="markdown", export_path=None, dev=False): + """ + Gets the metadata on the defog servers. + """ + resp = await make_async_post_request( + url=f"{self.base_url}/get_metadata", + payload={"api_key": self.api_key, "dev": dev}, + ) + items = [] + for table in resp["table_metadata"]: + for item in resp["table_metadata"][table]: + item["table_name"] = table + items.append(item) + if format == "markdown": + return pd.DataFrame(items)[ + ["table_name", "column_name", "data_type", "column_description"] + ].to_markdown(index=False) + elif format == "csv": + if export_path is None: + export_path = "metadata.csv" + pd.DataFrame(items)[ + ["table_name", "column_name", "data_type", "column_description"] + ].to_csv(export_path, index=False) + print(f"Metadata exported to {export_path}") + return True + elif format == "json": + return resp["table_metadata"] + + +async def get_feedback(self, n_rows: int = 50, start_from: int = 0): + """ + Gets the feedback on the defog servers. + """ + resp = await make_async_post_request( + url=f"{self.base_url}/get_feedback", payload={"api_key": self.api_key} + ) + df = pd.DataFrame(resp["data"], columns=resp["columns"]) + df["created_at"] = df["created_at"].apply(lambda x: x[:10]) + for col in ["query_generated", "feedback_text"]: + df[col] = df[col].fillna("") + df[col] = df[col].apply(lambda x: x.replace("\n", "\\n")) + return df.iloc[start_from:].head(n_rows).to_markdown(index=False) + + +async def get_quota(self) -> Optional[Dict]: + """ + Get the quota usage for the API key. + """ + api_key = self.api_key + r = await make_async_post_request( + url=f"{self.base_url}/check_api_usage", + payload={"api_key": api_key}, + return_response_object=True, + ) + # get status code and return None if not 200 + if r.status_code != 200: + return None + return r.json() + + +async def update_golden_queries( + self, + golden_queries: List[Dict] = None, + golden_queries_path: str = None, + scrub: bool = True, + dev: bool = False, +): + """ + Updates the golden queries on the defog servers. + :param golden_queries: The golden queries to be used. + :param golden_queries_path: The path to the golden queries CSV. + :param scrub: Whether to scrub the golden queries. + """ + if golden_queries is None and golden_queries_path is None: + raise ValueError("Please provide either golden_queries or golden_queries_path.") + + if golden_queries is None: + golden_queries = ( + pd.read_csv(golden_queries_path).fillna("").to_dict(orient="records") + ) + + resp = await make_async_post_request( + url=f"{self.base_url}/update_golden_queries", + payload={ + "api_key": self.api_key, + "golden_queries": golden_queries, + "scrub": scrub, + "dev": dev, + }, + ) + print( + "Golden queries have been received by the system, and will be processed shortly..." + ) + print( + "Once that is done, you should be able to see improved results for your questions." + ) + return resp + + +async def delete_golden_queries( + self, + golden_queries: dict = None, + golden_queries_path: str = None, + all: bool = False, + dev: bool = False, +): + """ + Updates the golden queries on the defog servers. + :param golden_queries: The golden queries to be used. + :param golden_queries_path: The path to the golden queries CSV. + :param scrub: Whether to scrub the golden queries. + """ + if golden_queries is None and golden_queries_path is None and not all: + raise ValueError( + "Please provide either golden_queries or golden_queries_path, or set all=True." + ) + + if all: + resp = await make_async_post_request( + url=f"{self.base_url}/delete_golden_queries", + payload={"api_key": self.api_key, "all": True, "dev": dev}, + ) + print("All golden queries have now been deleted.") + else: + if golden_queries is None: + golden_queries = ( + pd.read_csv(golden_queries_path).fillna("").to_dict(orient="records") + ) + resp = await make_async_post_request( + url=f"{self.base_url}/update_golden_queries", + payload={"api_key": self.api_key, "golden_queries": golden_queries}, + ) + return resp + + +async def get_golden_queries( + self, format: str = "csv", export_path: str = None, dev: bool = False +): + """ + Gets the golden queries on the defog servers. + """ + resp = await make_async_post_request( + url=f"{self.base_url}/get_golden_queries", + payload={"api_key": self.api_key, "dev": dev}, + ) + golden_queries = resp["golden_queries"] + if format == "csv": + if export_path is None: + export_path = "golden_queries.csv" + pd.DataFrame(golden_queries).to_csv(export_path, index=False) + print(f"{len(golden_queries)} golden queries exported to {export_path}") + return golden_queries + elif format == "json": + if export_path is None: + export_path = "golden_queries.json" + # Writing JSON asynchronously + async with aiofiles.open(export_path, "w") as f: + await f.write(json.dumps(resp, indent=4)) + print(f"{len(golden_queries)} golden queries exported to {export_path}") + return golden_queries + else: + raise ValueError("format must be either 'csv' or 'json'.") + + +def create_table_ddl( + table_name: str, columns: List[Dict[str, str]], add_exists=True +) -> str: + """ + Return a DDL statement for creating a table from a list of columns + `columns` is a list of dictionaries with the following keys: + - column_name: str + - data_type: str + - column_description: str + """ + md_create = "" + if add_exists: + md_create += f"CREATE TABLE IF NOT EXISTS {table_name} (\n" + else: + md_create += f"CREATE TABLE {table_name} (\n" + for i, column in enumerate(columns): + col_name = column["column_name"] + # if column name has spaces and hasn't been wrapped in double quotes, wrap it in double quotes + if " " in col_name and not col_name.startswith('"'): + col_name = f'"{col_name}"' + dtype = column["data_type"] + if i < len(columns) - 1: + md_create += f" {col_name} {dtype},\n" + else: + # avoid the trailing comma for the last line + md_create += f" {col_name} {dtype}\n" + md_create += ");\n" + return md_create + + +def create_ddl_from_metadata( + metadata: Dict[str, List[Dict[str, str]]], add_exists=True +) -> str: + """ + Return a DDL statement for creating tables from metadata + `metadata` is a dictionary with table names as keys and lists of dictionaries as values. + Each dictionary in the list has the following keys: + - column_name: str + - data_type: str + - column_description: str + """ + md_create = "" + for table_name, columns in metadata.items(): + if "." in table_name: + table_name = table_name.split(".", 1)[1] + schema_name = table_name.split(".")[0] + + md_create += f"CREATE SCHEMA IF NOT EXISTS {schema_name};\n" + md_create += create_table_ddl(table_name, columns, add_exists=add_exists) + return md_create + + +async def create_empty_tables(self, dev: bool = False): + """ + Create empty tables based on metadata + """ + metadata = self.get_metadata(format="json", dev=dev) + if self.db_type == "sqlserver": + ddl = create_ddl_from_metadata(metadata, add_exists=False) + else: + ddl = create_ddl_from_metadata(metadata) + + try: + if self.db_type == "postgres" or self.db_type == "redshift": + import asyncpg + + conn = await asyncpg.connect(**self.db_creds) + await conn.execute(ddl) + await conn.close() + return True + + elif self.db_type == "mysql": + import aiomysql + + conn = await aiomysql.connect(**self.db_creds) + async with conn.cursor() as cur: + for statement in ddl.split(";"): + await cur.execute(statement) + await conn.commit() + await conn.ensure_closed() + return True + + elif self.db_type == "databricks": + from databricks import sql + + conn = await asyncio.to_thread(sql.connect, **self.db_creds) + await asyncio.to_thread(conn.execute, ddl) + await asyncio.to_thread(conn.commit) + await asyncio.to_thread(conn.close) + return True + + elif self.db_type == "snowflake": + import snowflake.connector + + conn = await asyncio.to_thread( + snowflake.connector.connect, + user=self.db_creds["user"], + password=self.db_creds["password"], + account=self.db_creds["account"], + ) + cur = await asyncio.to_thread(conn.cursor) + await asyncio.to_thread( + cur.execute, f"USE WAREHOUSE {self.db_creds['warehouse']}" + ) + for statement in ddl.split(";"): + await asyncio.to_thread(cur.execute, statement) + await asyncio.to_thread(conn.commit) + await asyncio.to_thread(conn.close) + return True + + elif self.db_type == "bigquery": + from google.cloud import bigquery + + client = await asyncio.to_thread( + bigquery.Client.from_service_account_json, + self.db_creds["json_key_path"], + ) + for statement in ddl.split(";"): + await asyncio.to_thread(client.query, statement) + return True + + elif self.db_type == "sqlserver": + import pyodbc + + if self.db_creds["database"] != "": + connection_string = f"DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={self.db_creds['server']};DATABASE={self.db_creds['database']};UID={self.db_creds['user']};PWD={self.db_creds['password']};TrustServerCertificate=yes;Connection Timeout=120;" + else: + connection_string = f"DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={self.db_creds['server']};UID={self.db_creds['user']};PWD={self.db_creds['password']};TrustServerCertificate=yes;Connection Timeout=120;" + conn = await asyncio.to_thread(pyodbc.connect, connection_string) + cur = await asyncio.to_thread(conn.cursor) + for statement in ddl.split(";"): + await asyncio.to_thread(cur.execute, statement) + await asyncio.to_thread(conn.commit) + await asyncio.to_thread(conn.close) + return True + + else: + raise ValueError(f"Unsupported DB type: {self.db_type}") + except Exception as e: + print(f"Error: {e}") + return False diff --git a/defog/async_generate_schema.py b/defog/async_generate_schema.py new file mode 100644 index 0000000..584a2f7 --- /dev/null +++ b/defog/async_generate_schema.py @@ -0,0 +1,684 @@ +from defog.util import async_identify_categorical_columns, make_async_post_request +import asyncio +from io import StringIO +import pandas as pd +import json +from typing import List + + +async def generate_postgres_schema( + self, + tables: list, + upload: bool = True, + return_format: str = "csv", + scan: bool = True, + return_tables_only: bool = False, + schemas: List[str] = ["public"], +) -> str: + # when upload is True, we send the schema to the defog servers and generate a CSV + # when its false, we return the schema as a dict + try: + import asyncpg + except ImportError: + raise ImportError( + "asyncpg not installed. Please install it with `pip install psycopg2-binary`." + ) + + conn = await asyncpg.connect(**self.db_creds) + schemas = tuple(schemas) + + if len(tables) == 0: + # get all tables + for schema in schemas: + query = """ + SELECT table_name + FROM information_schema.tables + WHERE table_schema = $1; + """ + rows = await conn.fetch(query, schema) + if schema == "public": + tables += [row[0] for row in rows] + else: + tables += [schema + "." + row[0] for row in rows] + + if return_tables_only: + await conn.close() + return tables + + print("Getting schema for each table that you selected...") + + table_columns = {} + + # get the columns for each table + for schema in schemas: + for table_name in tables: + if "." in table_name: + _, table_name = table_name.split(".", 1) + query = """ + SELECT CAST(column_name AS TEXT), CAST(data_type AS TEXT) + FROM information_schema.columns + WHERE table_name = $1 AND table_schema = $2; + """ + rows = await conn.fetch(query, table_name, schema) + rows = [row for row in rows] + rows = [{"column_name": row[0], "data_type": row[1]} for row in rows] + if len(rows) > 0: + if scan: + rows = await async_identify_categorical_columns( + conn=conn, cur=None, table_name=table_name, rows=rows + ) + if schema == "public": + table_columns[table_name] = rows + else: + table_columns[schema + "." + table_name] = rows + await conn.close() + + print( + "Sending the schema to the defog servers and generating column descriptions. This might take up to 2 minutes..." + ) + if upload: + # send the schemas dict to the defog servers + resp = await make_async_post_request( + url=f"{self.base_url}/get_schema_csv", + payload={ + "api_key": self.api_key, + "schemas": table_columns, + }, + ) + if "csv" in resp: + csv = resp["csv"] + if return_format == "csv": + pd.read_csv(StringIO(csv)).to_csv("defog_metadata.csv", index=False) + return "defog_metadata.csv" + else: + return csv + else: + print(f"We got an error!") + if "message" in resp: + print(f"Error message: {resp['message']}") + print( + f"Please feel free to open a github issue at https://github.com/defog-ai/defog-python if this a generic library issue, or email support@defog.ai." + ) + else: + return table_columns + + +async def generate_redshift_schema( + self, + tables: list, + upload: bool = True, + return_format: str = "csv", + scan: bool = True, + return_tables_only: bool = False, +) -> str: + # when upload is True, we send the schema to the defog servers and generate a CSV + # when its false, we return the schema as a dict + try: + import asyncpg + except ImportError: + raise ImportError( + "asyncpg not installed. Please install it with `pip install psycopg2-binary`." + ) + + if "schema" not in self.db_creds: + schema = "public" + conn = await asyncpg.connect(**self.db_creds) + else: + schema = self.db_creds["schema"] + del self.db_creds["schema"] + conn = await asyncpg.connect(**self.db_creds) + + schemas = {} + + if len(tables) == 0: + table_names_query = ( + "SELECT table_name FROM information_schema.tables WHERE table_schema = $1;" + ) + results = await conn.fetch(table_names_query, schema) + tables = [row[0] for row in results] + + if return_tables_only: + await conn.close() + return tables + + print("Getting schema for each table that you selected...") + # get the schema for each table + for table_name in tables: + table_schema_query = "SELECT CAST(column_name AS TEXT), CAST(data_type AS TEXT) FROM information_schema.columns WHERE table_name::text = $1 AND table_schema= $2;" + rows = await conn.fetch(table_schema_query, table_name, schema) + rows = [row for row in rows] + rows = [{"column_name": i[0], "data_type": i[1]} for i in rows] + if len(rows) > 0: + if scan: + await conn.execute(f"SET search_path TO {schema}") + rows = await async_identify_categorical_columns( + conn=conn, cur=None, table_name=table_name, rows=rows + ) + + schemas[table_name] = rows + + await conn.close() + + if upload: + print( + "Sending the schema to the defog servers and generating column descriptions. This might take up to 2 minutes..." + ) + # send the schemas dict to the defog servers + resp = await make_async_post_request( + url=f"{self.base_url}/get_schema_csv", + payload={ + "api_key": self.api_key, + "schemas": schemas, + "foreign_keys": [], + "indexes": [], + }, + ) + if "csv" in resp: + csv = resp["csv"] + if return_format == "csv": + pd.read_csv(StringIO(csv)).to_csv("defog_metadata.csv", index=False) + return "defog_metadata.csv" + else: + return csv + else: + print(f"We got an error!") + if "message" in resp: + print(f"Error message: {resp['message']}") + print( + f"Please feel free to open a github issue at https://github.com/defog-ai/defog-python if this a generic library issue, or email support@defog.ai." + ) + else: + return schemas + + +async def generate_mysql_schema( + self, + tables: list, + upload: bool = True, + return_format: str = "csv", + scan: bool = True, + return_tables_only: bool = False, +) -> str: + try: + import aiomysql + except: + raise Exception("aiomysql not installed.") + + conn = await aiomysql.connect(**self.db_creds) + cur = await conn.cursor() + schemas = {} + + if len(tables) == 0: + # get all tables + db_name = self.db_creds.get("database", "") + await cur.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{db_name}';" + ) + tables = [row[0] for row in await cur.fetchall()] + + if return_tables_only: + await conn.close() + return tables + + print("Getting schema for the relevant table in your database...") + # get the schema for each table + for table_name in tables: + await cur.execute( + "SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %s;", + (table_name,), + ) + rows = await cur.fetchall() + rows = [row for row in rows] + rows = [{"column_name": i[0], "data_type": i[1]} for i in rows] + if scan: + rows = await async_identify_categorical_columns( + conn=None, + cur=cur, + table_name=table_name, + rows=rows, + is_cursor_async=True, + ) + if len(rows) > 0: + schemas[table_name] = rows + + await conn.ensure_closed() + + if upload: + resp = await make_async_post_request( + url=f"{self.base_url}/get_schema_csv", + payload={ + "api_key": self.api_key, + "schemas": schemas, + "foreign_keys": [], + "indexes": [], + }, + ) + if "csv" in resp: + csv = resp["csv"] + if return_format == "csv": + pd.read_csv(StringIO(csv)).to_csv("defog_metadata.csv", index=False) + return "defog_metadata.csv" + else: + return csv + else: + print(f"We got an error!") + if "message" in resp: + print(f"Error message: {resp['message']}") + print( + f"Please feel free to open a github issue at https://github.com/defog-ai/defog-python if this a generic library issue, or email support@defog.ai." + ) + else: + return schemas + + +async def generate_databricks_schema( + self, + tables: list, + upload: bool = True, + return_format: str = "csv", + scan: bool = True, + return_tables_only: bool = False, +) -> str: + try: + from databricks import sql + except: + raise Exception("databricks-sql-connector not installed.") + + conn = await asyncio.to_thread(sql.connect, **self.db_creds) + schemas = {} + async with await asyncio.to_thread(conn.cursor) as cur: + print("Getting schema for each table that you selected...") + # get the schema for each table + + if len(tables) == 0: + # get all tables from databricks + await asyncio.to_thread( + cur.tables, schema_name=self.db_creds.get("schema", "default") + ) + tables = [row.TABLE_NAME for row in await asyncio.to_thread(cur.fetchall())] + + if return_tables_only: + await asyncio.to_thread(conn.close) + return tables + + for table_name in tables: + await asyncio.to_thread( + cur.columns, + schema_name=self.db_creds.get("schema", "default"), + table_name=table_name, + ) + rows = await asyncio.to_thread(cur.fetchall) + rows = [row for row in rows] + rows = [ + {"column_name": i.COLUMN_NAME, "data_type": i.TYPE_NAME} for i in rows + ] + if scan: + rows = await async_identify_categorical_columns( + conn=None, + cur=cur, + table_name=table_name, + rows=rows, + is_cursor_async=False, + ) + if len(rows) > 0: + schemas[table_name] = rows + + await asyncio.to_thread(conn.close) + + if upload: + resp = await make_async_post_request( + url=f"{self.base_url}/get_schema_csv", + payload={ + "api_key": self.api_key, + "schemas": schemas, + "foreign_keys": [], + "indexes": [], + }, + ) + if "csv" in resp: + csv = resp["csv"] + if return_format == "csv": + pd.read_csv(StringIO(csv)).to_csv("defog_metadata.csv", index=False) + return "defog_metadata.csv" + else: + return csv + else: + print(f"We got an error!") + if "message" in resp: + print(f"Error message: {resp['message']}") + print( + f"Please feel free to open a github issue at https://github.com/defog-ai/defog-python if this a generic library issue, or email support@defog.ai." + ) + else: + return schemas + + +async def generate_snowflake_schema( + self, + tables: list, + upload: bool = True, + return_format: str = "csv", + scan: bool = True, + return_tables_only: bool = False, +) -> str: + try: + import snowflake.connector + except: + raise Exception("snowflake-connector not installed.") + + conn = await asyncio.to_thread( + snowflake.connector.connect, + user=self.db_creds["user"], + password=self.db_creds["password"], + account=self.db_creds["account"], + ) + + await asyncio.to_thread( + conn.cursor().execute, f"USE WAREHOUSE {self.db_creds['warehouse']}" + ) # set the warehouse + + schemas = {} + alt_types = {"DATE": "TIMESTAMP", "TEXT": "VARCHAR", "FIXED": "NUMERIC"} + print("Getting schema for each table that you selected...") + # get the schema for each table + if len(tables) == 0: + # get all tables from Snowflake database + cur = await asyncio.to_thread(conn.cursor().execute, "SHOW TERSE TABLES;") + res = await asyncio.to_thread(cur.fetchall) + tables = [f"{row[3]}.{row[4]}.{row[1]}" for row in res] + + if return_tables_only: + await asyncio.to_thread(conn.close) + return tables + + for table_name in tables: + rows = [] + cur = await asyncio.to_thread(conn.cursor) + fetched_rows = await asyncio.to_thread( + cur.execute, f"SHOW COLUMNS IN {table_name};" + ) + for row in fetched_rows: + rows.append(row) + rows = [ + { + "column_name": i[2], + "data_type": json.loads(i[3])["type"], + "column_description": i[8], + } + for i in rows + ] + for idx, row in enumerate(rows): + if row["data_type"] in alt_types: + row["data_type"] = alt_types[row["data_type"]] + rows[idx] = row + + cur = await asyncio.to_thread(conn.cursor) + if scan: + rows = await async_identify_categorical_columns( + conn=None, + cur=cur, + table_name=table_name, + rows=rows, + is_cursor_async=False, + ) + await asyncio.to_thread(cur.close) + if len(rows) > 0: + schemas[table_name] = rows + + await asyncio.to_thread(conn.close) + + if upload: + print( + "Sending the schema to the defog servers and generating column descriptions. This might take up to 2 minutes..." + ) + resp = await make_async_post_request( + url=f"{self.base_url}/get_schema_csv", + payload={ + "api_key": self.api_key, + "schemas": schemas, + "foreign_keys": [], + "indexes": [], + }, + ) + if "csv" in resp: + csv = resp["csv"] + if return_format == "csv": + pd.read_csv(StringIO(csv)).to_csv("defog_metadata.csv", index=False) + return "defog_metadata.csv" + else: + return csv + else: + print(f"We got an error!") + if "message" in resp: + print(f"Error message: {resp['message']}") + print( + f"Please feel free to open a github issue at https://github.com/defog-ai/defog-python if this a generic library issue, or email support@defog.ai." + ) + else: + return schemas + + +async def generate_bigquery_schema( + self, + tables: list, + upload: bool = True, + return_format: str = "csv", + scan: bool = True, + return_tables_only: bool = False, +) -> str: + try: + from google.cloud import bigquery + except: + raise Exception("google-cloud-bigquery not installed.") + + client = await asyncio.to_thread( + bigquery.Client.from_service_account_json, self.db_creds["json_key_path"] + ) + project_id = [p.project_id for p in await asyncio.to_thread(client.list_projects)][ + 0 + ] + datasets = [ + dataset.dataset_id for dataset in await asyncio.to_thread(client.list_datasets) + ] + schemas = {} + + if len(tables) == 0: + # get all tables + tables = [] + for dataset in datasets: + table_list = await asyncio.to_thread(client.list_tables, dataset) + tables += [ + f"{project_id}.{dataset}.{table.table_id}" for table in table_list + ] + + print("Getting the schema for each table that you selected...") + # get the schema for each table + for table_name in tables: + table = await asyncio.to_thread(client.get_table, table_name) + rows = table.schema + rows = [{"column_name": i.name, "data_type": i.field_type} for i in rows] + if len(rows) > 0: + schemas[table_name] = rows + + await asyncio.to_thread(client.close) + + if upload: + print( + "Sending the schema to Defog servers and generating column descriptions. This might take up to 2 minutes..." + ) + resp = await make_async_post_request( + url=f"{self.base_url}/get_schema_csv", + payload={ + "api_key": self.api_key, + "schemas": schemas, + "foreign_keys": [], + "indexes": [], + }, + ) + if "csv" in resp: + csv = resp["csv"] + if return_format == "csv": + pd.read_csv(StringIO(csv)).to_csv("defog_metadata.csv", index=False) + return "defog_metadata.csv" + else: + return csv + else: + print(f"We got an error!") + if "message" in resp: + print(f"Error message: {resp['message']}") + print( + f"Please feel free to open a github issue at https://github.com/defog-ai/defog-python if this a generic library issue, or email support@defog.ai." + ) + else: + return schemas + + +async def generate_sqlserver_schema( + self, + tables: list, + upload: bool = True, + return_format: str = "csv", + return_tables_only: bool = False, +) -> str: + try: + import pyodbc + except: + raise Exception("pyodbc not installed.") + + if self.db_creds["database"] != "": + connection_string = f"DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={self.db_creds['server']};DATABASE={self.db_creds['database']};UID={self.db_creds['user']};PWD={self.db_creds['password']};TrustServerCertificate=yes;Connection Timeout=120;" + else: + connection_string = f"DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={self.db_creds['server']};UID={self.db_creds['user']};PWD={self.db_creds['password']};TrustServerCertificate=yes;Connection Timeout=120;" + conn = await asyncio.to_thread(pyodbc.connect, connection_string) + cur = await asyncio.to_thread(conn.cursor) + schemas = {} + schema = self.db_creds.get("schema", "dbo") + + if len(tables) == 0: + table_names_query = ( + "SELECT table_name FROM information_schema.tables WHERE table_schema = %s;" + ) + await asyncio.to_thread(cur.execute, table_names_query, (schema,)) + if schema == "dbo": + tables += [row[0] for row in await asyncio.to_thread(cur.fetchall)] + else: + tables += [ + schema + "." + row[0] for row in await asyncio.to_thread(cur.fetchall) + ] + + if return_tables_only: + await asyncio.to_thread(conn.close) + return tables + + print("Getting schema for each table in your database...") + # get the schema for each table + for table_name in tables: + await asyncio.to_thread( + cur.execute, + f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';", + ) + rows = await asyncio.to_thread(cur.fetchall) + rows = [row for row in rows] + rows = [{"column_name": i[0], "data_type": i[1]} for i in rows] + if len(rows) > 0: + schemas[table_name] = rows + + await asyncio.to_thread(conn.close) + if upload: + print( + "Sending the schema to Defog servers and generating column descriptions. This might take up to 2 minutes..." + ) + resp = await make_async_post_request( + url=f"{self.base_url}/get_schema_csv", + payload={ + "api_key": self.api_key, + "schemas": schemas, + "foreign_keys": [], + "indexes": [], + }, + ) + if "csv" in resp: + csv = resp["csv"] + if return_format == "csv": + pd.read_csv(StringIO(csv)).to_csv("defog_metadata.csv", index=False) + return "defog_metadata.csv" + else: + return csv + else: + print(f"We got an error!") + if "message" in resp: + print(f"Error message: {resp['message']}") + print( + f"Please feel free to open a github issue at https://github.com/defog-ai/defog-python if this a generic library issue, or email support@defog.ai." + ) + else: + return schemas + + +async def generate_db_schema( + self, + tables: list, + scan: bool = True, + upload: bool = True, + return_tables_only: bool = False, + return_format: str = "csv", +) -> str: + if self.db_type == "postgres": + return await self.generate_postgres_schema( + tables, + return_format=return_format, + scan=scan, + upload=upload, + return_tables_only=return_tables_only, + ) + elif self.db_type == "mysql": + return await self.generate_mysql_schema( + tables, + return_format=return_format, + scan=scan, + upload=upload, + return_tables_only=return_tables_only, + ) + elif self.db_type == "bigquery": + return await self.generate_bigquery_schema( + tables, + return_format=return_format, + scan=scan, + upload=upload, + return_tables_only=return_tables_only, + ) + elif self.db_type == "redshift": + return await self.generate_redshift_schema( + tables, + return_format=return_format, + scan=scan, + upload=upload, + return_tables_only=return_tables_only, + ) + elif self.db_type == "snowflake": + return await self.generate_snowflake_schema( + tables, + return_format=return_format, + scan=scan, + upload=upload, + return_tables_only=return_tables_only, + ) + elif self.db_type == "databricks": + return await self.generate_databricks_schema( + tables, + return_format=return_format, + scan=scan, + upload=upload, + return_tables_only=return_tables_only, + ) + elif self.db_type == "sqlserver": + return await self.generate_sqlserver_schema( + tables, + return_format=return_format, + upload=upload, + return_tables_only=return_tables_only, + ) + else: + raise ValueError( + f"Creation of a DB schema for {self.db_type} is not yet supported via the library. If you are a premium user, please contact us at founder@defog.ai so we can manually add it." + ) diff --git a/defog/async_health_methods.py b/defog/async_health_methods.py new file mode 100644 index 0000000..d7e471f --- /dev/null +++ b/defog/async_health_methods.py @@ -0,0 +1,46 @@ +from defog.util import make_async_post_request + + +async def check_golden_queries_coverage(self, dev: bool = False): + """ + Check the number of tables and columns inside the metadata schema that are covered by the golden queries. + """ + url = f"{self.base_url}/get_golden_queries_coverage" + payload = {"api_key": self.api_key, "dev": dev} + return await make_async_post_request(url, payload) + + +async def check_md_valid(self, dev: bool = False): + """ + Check if the metadata schema is valid. + """ + url = f"{self.base_url}/check_md_valid" + payload = {"api_key": self.api_key, "db_type": self.db_type, "dev": dev} + return await make_async_post_request(url, payload) + + +async def check_gold_queries_valid(self, dev: bool = False): + """ + Check if the golden queries are valid and can be executed on a given database without errors. + """ + url = f"{self.base_url}/check_gold_queries_valid" + payload = {"api_key": self.api_key, "db_type": self.db_type, "dev": dev} + return await make_async_post_request(url, payload) + + +async def check_glossary_valid(self, dev: bool = False): + """ + Check if the glossary is valid by verifying if all schema, table, and column names referenced are present in the metadata. + """ + url = f"{self.base_url}/check_glossary_valid" + payload = {"api_key": self.api_key, "dev": dev} + return await make_async_post_request(url, payload) + + +async def check_glossary_consistency(self, dev: bool = False): + """ + Check if all logic in the glossary is consistent and coherent. + """ + url = f"{self.base_url}/check_glossary_consistency" + payload = {"api_key": self.api_key, "dev": dev} + return await make_async_post_request(url, payload) diff --git a/defog/async_query_methods.py b/defog/async_query_methods.py new file mode 100644 index 0000000..895bae0 --- /dev/null +++ b/defog/async_query_methods.py @@ -0,0 +1,172 @@ +from defog.util import make_async_post_request +from defog.query import execute_query +from datetime import datetime + + +async def get_query( + self, + question: str, + hard_filters: str = "", + previous_context: list = [], + glossary: str = "", + debug: bool = False, + dev: bool = False, + temp: bool = False, + profile: bool = False, + ignore_cache: bool = False, + model: str = "", + use_golden_queries: bool = True, + subtable_pruning: bool = False, + glossary_pruning: bool = False, + prune_max_tokens: int = 2000, + prune_bm25_num_columns: int = 10, + prune_glossary_max_tokens: int = 1000, + prune_glossary_num_cos_sim_units: int = 10, + prune_glossary_bm25_units: int = 10, +): + """ + Asynchronously sends the query to the defog servers, and return the response. + :param question: The question to be asked. + :return: The response from the defog server. + """ + try: + data = { + "question": question, + "api_key": self.api_key, + "previous_context": previous_context, + "db_type": self.db_type if self.db_type != "databricks" else "postgres", + "glossary": glossary, + "hard_filters": hard_filters, + "dev": dev, + "temp": temp, + "ignore_cache": ignore_cache, + "model": model, + "use_golden_queries": use_golden_queries, + "subtable_pruning": subtable_pruning, + "glossary_pruning": glossary_pruning, + "prune_max_tokens": prune_max_tokens, + "prune_bm25_num_columns": prune_bm25_num_columns, + "prune_glossary_max_tokens": prune_glossary_max_tokens, + "prune_glossary_num_cos_sim_units": prune_glossary_num_cos_sim_units, + "prune_glossary_bm25_units": prune_glossary_bm25_units, + } + + t_start = datetime.now() + + resp = await make_async_post_request( + url=self.generate_query_url, payload=data, timeout=300 + ) + + t_end = datetime.now() + time_taken = (t_end - t_start).total_seconds() + query_generated = resp.get("sql", resp.get("query_generated")) + ran_successfully = resp.get("ran_successfully") + error_message = resp.get("error_message") + query_db = self.db_type + resp = { + "query_generated": query_generated, + "ran_successfully": ran_successfully, + "error_message": error_message, + "query_db": query_db, + "previous_context": resp.get("previous_context"), + "reason_for_query": resp.get("reason_for_query"), + } + if profile: + resp["time_taken"] = time_taken + + return resp + except Exception as e: + if debug: + print(e) + return { + "ran_successfully": False, + "error_message": "Sorry :( Our server is at capacity right now and we are unable to process your query. Please try again in a few minutes?", + } + + +async def run_query( + self, + question: str, + hard_filters: str = "", + previous_context: list = [], + glossary: str = "", + query: dict = None, + retries: int = 3, + dev: bool = False, + temp: bool = False, + profile: bool = False, + ignore_cache: bool = False, + model: str = "", + use_golden_queries: bool = True, + subtable_pruning: bool = False, + glossary_pruning: bool = False, + prune_max_tokens: int = 2000, + prune_bm25_num_columns: int = 10, + prune_glossary_max_tokens: int = 1000, + prune_glossary_num_cos_sim_units: int = 10, + prune_glossary_bm25_units: int = 10, +): + """ + Asynchronously sends the question to the defog servers, executes the generated SQL, + and returns the response. + :param question: The question to be asked. + :return: The response from the defog server. + """ + if query is None: + print(f"Generating the query for your question: {question}...") + query = await self.get_query( + question, + hard_filters, + previous_context, + glossary=glossary, + dev=dev, + temp=temp, + profile=profile, + model=model, + ignore_cache=ignore_cache, + use_golden_queries=use_golden_queries, + subtable_pruning=subtable_pruning, + glossary_pruning=glossary_pruning, + prune_max_tokens=prune_max_tokens, + prune_bm25_num_columns=prune_bm25_num_columns, + prune_glossary_max_tokens=prune_glossary_max_tokens, + prune_glossary_num_cos_sim_units=prune_glossary_num_cos_sim_units, + prune_glossary_bm25_units=prune_glossary_bm25_units, + ) + if query["ran_successfully"]: + try: + print("Query generated, now running it on your database...") + tstart = datetime.now() + colnames, result, executed_query = await execute_query( + query=query["query_generated"], + api_key=self.api_key, + db_type=self.db_type, + db_creds=self.db_creds, + question=question, + hard_filters=hard_filters, + retries=retries, + dev=dev, + temp=temp, + ) + tend = datetime.now() + time_taken = (tend - tstart).total_seconds() + resp = { + "columns": colnames, + "data": result, + "query_generated": executed_query, + "ran_successfully": True, + "reason_for_query": query.get("reason_for_query"), + "previous_context": query.get("previous_context"), + } + if profile: + resp["execution_time_taken"] = time_taken + resp["generation_time_taken"] = query.get("time_taken") + return resp + except Exception as e: + return { + "ran_successfully": False, + "error_message": str(e), + "query_generated": query["query_generated"], + } + else: + return {"ran_successfully": False, "error_message": query["error_message"]} diff --git a/defog/generate_schema.py b/defog/generate_schema.py index 36196b4..2c0052a 100644 --- a/defog/generate_schema.py +++ b/defog/generate_schema.py @@ -41,6 +41,7 @@ def generate_postgres_schema( tables += [schema + "." + row[0] for row in cur.fetchall()] if return_tables_only: + conn.close() return tables print("Getting schema for each table that you selected...") @@ -139,6 +140,7 @@ def generate_redshift_schema( tables = [row[0] for row in cur.fetchall()] if return_tables_only: + conn.close() return tables print("Getting schema for each table that you selected...") @@ -220,6 +222,7 @@ def generate_mysql_schema( tables = [row[0] for row in cur.fetchall()] if return_tables_only: + conn.close() return tables print("Getting schema for the relevant table in your database...") @@ -293,6 +296,7 @@ def generate_databricks_schema( tables = [row.TABLE_NAME for row in cur.fetchall()] if return_tables_only: + conn.close() return tables for table_name in tables: From 58dd2088f199a2f39d1ce0416cd586b353abd81d Mon Sep 17 00:00:00 2001 From: Muhammad18557 Date: Tue, 10 Sep 2024 13:57:50 +0800 Subject: [PATCH 05/17] async util functions --- defog/util.py | 124 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) diff --git a/defog/util.py b/defog/util.py index 8be5df9..e2dc778 100644 --- a/defog/util.py +++ b/defog/util.py @@ -4,6 +4,10 @@ from prompt_toolkit import prompt import requests +import aiohttp +import aiofiles +import asyncio + def parse_update( args_list: List[str], attributes_list: List[str], config_dict: dict @@ -53,6 +57,25 @@ def write_logs(msg: str) -> None: pass +async def async_write_logs(msg: str) -> None: + """ + Asynchronously write out log messages to ~/.defog/logs to avoid bloating cli output, + while still preserving more verbose error messages when debugging. + + Args: + msg (str): The message to write. + """ + log_file_path = os.path.expanduser("~/.defog/logs") + + try: + if not os.path.exists(log_file_path): + os.makedirs(os.path.dirname(log_file_path), exist_ok=True) + async with aiofiles.open(log_file_path, "a") as file: + await file.write(msg + "\n") + except Exception as e: + pass + + def is_str_type(data_type: str) -> bool: """ Check if the given data_type is a string type. @@ -132,6 +155,82 @@ def identify_categorical_columns( return rows +async def async_identify_categorical_columns( + conn=None, + cur=None, # a cursor object for any database + table_name: str = "", + rows: list = [], + is_cursor_async: bool = False, + distinct_threshold: int = 10, + character_length_threshold: int = 50, +): + """ + Identify categorical columns in the table and return the top distinct values for each column. + + Args: + conn (connection): A connection object for any database. + cur (cursor): A cursor object for any database. This cursor should support the following methods: + - execute(sql, params) + - fetchone() + - fetchall() + table_name (str): The name of the table. + rows (list): A list of dictionaries containing the column names and data types.a + distinct_threshold (int): The threshold for the number of distinct values in a column to be considered categorical. + character_length_threshold (int): The threshold for the maximum length of a string column to be considered categorical. + This is a heuristic for pruning columns that might contain arbitrarily long strings like json / configs. + + Returns: + rows (list): The updated list of dictionaries containing the column names, data types and top distinct values. + The list is modified in-place. + """ + # loop through each column, look at whether it is a string column, and then determine if it might be a categorical variable + # if it is a categorical variable, then we want to get the distinct values and their counts + # we will then send this to the defog servers so that we can generate a column description + # for each categorical variable + print( + f"Identifying categorical columns in {table_name}. This might take a while if you have many rows in your table." + ) + + async def run_query(query, params=None): + if conn: + # If using an async connection like asyncpg + return await conn.fetch(query, *params if params else ()) + elif cur: + if is_cursor_async: + # If using an async cursor (like aiomysql or others) + await cur.execute(query, params) + return await cur.fetchall() + else: + if params: + await asyncio.to_thread(cur.execute, query, params) + else: + await asyncio.to_thread(cur.execute, query) + return await asyncio.to_thread(cur.fetchall) + + for idx, row in enumerate(rows): + if is_str_type(row["data_type"]): + # get the total number of rows and number of distinct values in the table for this column + column_name = row["column_name"] + + query = f"SELECT COUNT(*) FROM (SELECT DISTINCT {column_name} FROM {table_name} LIMIT 10000) AS temp;" + + result = await run_query(query) + try: + num_distinct_values = result[0][0] + except Exception as e: + num_distinct_values = 0 + if num_distinct_values <= distinct_threshold and num_distinct_values > 0: + # get the top distinct_threshold distinct values + query = f"""SELECT {column_name}, COUNT({column_name}) AS col_count FROM {table_name} GROUP BY {column_name} ORDER BY col_count DESC LIMIT %s;""" + top_values = await run_query(query, (distinct_threshold,)) + top_values = [i[0] for i in top_values if i[0] is not None] + rows[idx]["top_values"] = ",".join(sorted(top_values)) + print( + f"Identified {column_name} as a likely categorical column. The unique values are: {top_values}" + ) + return rows + + def get_feedback( api_key: str, db_type: str, user_question: str, sql_generated: str, base_url: str ): @@ -333,3 +432,28 @@ def get_feedback( except Exception as e: write_logs(f"Error in get_feedback:\n{e}") pass + + +async def make_async_post_request( + url: str, payload: dict, timeout=None, return_response_object=False +): + """ + Helper function to make async POST requests and defaults to return the JSON response. Optionally allows returning the response object itself. + + Args: + url (str): The URL to make the POST request to. + payload (dict): The payload to send with the POST request. + timeout (int): The timeout for the request. + return_response_object (bool): Whether to return the response object itself. + + Returns: + dict: The JSON response from the POST request or the response object itself if return_response_object is True. + """ + try: + async with aiohttp.ClientSession() as session: + async with session.post(url, json=payload, timeout=timeout) as response: + if return_response_object: + return response + return await response.json() + except Exception as e: + return {"error": str(e)} From 5169ef86d859f4eb6e30156deb7a39717a4f72cb Mon Sep 17 00:00:00 2001 From: Muhammad18557 Date: Tue, 10 Sep 2024 13:58:33 +0800 Subject: [PATCH 06/17] requirements updated --- requirements.txt | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 34c4fa1..a532898 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,13 @@ pandas prompt_toolkit psycopg2-binary>=2.9.5 +asyncpg +aiomysql pwinput requests>=2.28.2 +aiohttp +aiofiles tabulate uvicorn -tqdm \ No newline at end of file +tqdm +setuptools \ No newline at end of file From 529e4ac7ec9693af1241eac9085a2070ac20f018 Mon Sep 17 00:00:00 2001 From: Muhammad18557 Date: Tue, 10 Sep 2024 13:59:30 +0800 Subject: [PATCH 07/17] linted --- defog/query.py | 1 + 1 file changed, 1 insertion(+) diff --git a/defog/query.py b/defog/query.py index 36c6533..bb3e0f2 100644 --- a/defog/query.py +++ b/defog/query.py @@ -5,6 +5,7 @@ import asyncio import os + # execute query for given db_type and return column names and data def execute_query_once(db_type: str, db_creds, query: str): """ From 6a52fdd980c6c4f7ec0f8cbb5a34691b9c2125c3 Mon Sep 17 00:00:00 2001 From: Muhammad18557 Date: Tue, 10 Sep 2024 14:03:35 +0800 Subject: [PATCH 08/17] made all functions async for consistency --- defog/async_admin_methods.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/defog/async_admin_methods.py b/defog/async_admin_methods.py index 7a53aea..732b428 100644 --- a/defog/async_admin_methods.py +++ b/defog/async_admin_methods.py @@ -266,7 +266,7 @@ async def get_golden_queries( raise ValueError("format must be either 'csv' or 'json'.") -def create_table_ddl( +async def create_table_ddl( table_name: str, columns: List[Dict[str, str]], add_exists=True ) -> str: """ @@ -296,7 +296,7 @@ def create_table_ddl( return md_create -def create_ddl_from_metadata( +async def create_ddl_from_metadata( metadata: Dict[str, List[Dict[str, str]]], add_exists=True ) -> str: """ From 8b3d7af6700cd0bf45e89f8862be9bbdf640c7b9 Mon Sep 17 00:00:00 2001 From: Muhammad18557 Date: Tue, 10 Sep 2024 14:12:08 +0800 Subject: [PATCH 09/17] fixed doctsring for async identify categorical columns function --- defog/util.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/defog/util.py b/defog/util.py index e2dc778..ce86cc1 100644 --- a/defog/util.py +++ b/defog/util.py @@ -156,7 +156,7 @@ def identify_categorical_columns( async def async_identify_categorical_columns( - conn=None, + conn=None, # a connection object for any database cur=None, # a cursor object for any database table_name: str = "", rows: list = [], @@ -168,25 +168,22 @@ async def async_identify_categorical_columns( Identify categorical columns in the table and return the top distinct values for each column. Args: - conn (connection): A connection object for any database. - cur (cursor): A cursor object for any database. This cursor should support the following methods: - - execute(sql, params) - - fetchone() - - fetchall() - table_name (str): The name of the table. - rows (list): A list of dictionaries containing the column names and data types.a - distinct_threshold (int): The threshold for the number of distinct values in a column to be considered categorical. - character_length_threshold (int): The threshold for the maximum length of a string column to be considered categorical. + conn (optional): Async connection for databases (e.g., asyncpg). + cur (optional): Sync/async cursor object for database queries. + table_name (str): The name of the table to analyze. + rows (list): List of column info dictionaries (with keys like "column_name", "data_type"). + is_cursor_async (bool): Set True if using an async cursor. + distinct_threshold (int): Max distinct values to classify a column as categorical. + character_length_threshold (int): Max length of a string column to be considered categorical. This is a heuristic for pruning columns that might contain arbitrarily long strings like json / configs. + Note: + - The function requires one of conn or cur to be provided. + Returns: rows (list): The updated list of dictionaries containing the column names, data types and top distinct values. The list is modified in-place. """ - # loop through each column, look at whether it is a string column, and then determine if it might be a categorical variable - # if it is a categorical variable, then we want to get the distinct values and their counts - # we will then send this to the defog servers so that we can generate a column description - # for each categorical variable print( f"Identifying categorical columns in {table_name}. This might take a while if you have many rows in your table." ) From dfce92020123e09898980457d52ca167fdfc47d0 Mon Sep 17 00:00:00 2001 From: Muhammad18557 Date: Tue, 10 Sep 2024 17:22:46 +0800 Subject: [PATCH 10/17] async queries tested --- tests/test_query.py | 202 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 201 insertions(+), 1 deletion(-) diff --git a/tests/test_query.py b/tests/test_query.py index 3e51542..48021c3 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -4,7 +4,13 @@ import unittest from unittest import mock -from defog.query import is_connection_error, execute_query_once, execute_query +from defog.query import ( + is_connection_error, + execute_query_once, + execute_query, + async_execute_query_once, + async_execute_query, +) class ExecuteQueryOnceTestCase(unittest.TestCase): @@ -176,6 +182,200 @@ def side_effect(db_type, db_creds, query): self.assertIn(json.dumps(json_req), lines[2]) +class ExecuteAsyncQueryOnceTestCase(unittest.IsolatedAsyncioTestCase): + @classmethod + def setUpClass(self): + # if connection.json exists, copy it to /tmp since we'll be overwriting it + home_dir = os.path.expanduser("~") + self.logs_path = os.path.join(home_dir, ".defog", "logs") + self.tmp_dir = os.path.join("/tmp") + self.moved = False + if os.path.exists(self.logs_path): + print("Moving logs to /tmp") + if os.path.exists(os.path.join(self.tmp_dir, "logs")): + os.remove(os.path.join(self.tmp_dir, "logs")) + shutil.move(self.logs_path, self.tmp_dir) + self.moved = True + + @classmethod + def tearDownClass(self): + # copy back the original after all tests have completed + if self.moved: + print("Moving logs back to ~/.defog") + shutil.move(os.path.join(self.tmp_dir, "logs"), self.logs_path) + + @mock.patch("asyncpg.connect") + async def test_async_execute_query_once_success(self, mock_connect): + # Mock the asyncpg.connect function + mock_cursor = mock_connect.return_value.fetch + mock_cursor.return_value = [ + {"col1": "data1", "col2": "data2"}, + {"col1": "data3", "col2": "data4"}, + ] + + db_type = "postgres" + db_creds = { + "host": "localhost", + "port": 5432, + "database": "test_db", + "user": "test_user", + "password": "test_password", + } + query = "SELECT * FROM table_name;" + + colnames, results = await async_execute_query_once(db_type, db_creds, query) + + # Add your assertions here to validate the results + self.assertEqual(colnames, ["col1", "col2"]) + self.assertEqual(results, [["data1", "data2"], ["data3", "data4"]]) + print("Postgres async query execution test passed!") + + @mock.patch("aiohttp.ClientSession.post") + @mock.patch("defog.query.async_execute_query_once") + async def test_async_execute_query_success( + self, mock_execute_query_once, mock_aiohttp_post + ): + # Mock the execute_query_once function + db_type = "postgres" + db_creds = { + "host": "localhost", + "port": 5432, + "database": "test_db", + "user": "test_user", + "password": "test_password", + } + query1 = "SELECT * FROM table_name;" + query2 = "SELECT * FROM new_table_name;" + api_key = "your_api_key" + question = "your_question" + hard_filters = "your_hard_filters" + retries = 3 + + # Set up the mock responses + mock_execute_query_once.return_value = ( + ["col1", "col2"], + [["data1", "data2"], ["data3", "data4"]], + ) + + # Mock the async aiohttp response + mock_response = mock.Mock() + mock_response.json = mock.AsyncMock(return_value={"new_query": query2}) + mock_aiohttp_post.return_value.__aenter__.return_value = mock_response + + # Call the function being tested + colnames, results, rcv_query = await async_execute_query( + query1, api_key, db_type, db_creds, question, hard_filters, retries + ) + + # Assert the expected behavior + mock_execute_query_once.assert_called_once_with( + db_type="postgres", # Use keyword arguments for db_type + db_creds={ + "host": "localhost", + "port": 5432, + "database": "test_db", + "user": "test_user", + "password": "test_password", + }, # Use keyword arguments for db_creds + query="SELECT * FROM table_name;", # Use keyword arguments for query + ) + + # Since there should be no retry, aiohttp post should not be called + mock_aiohttp_post.assert_not_called() + + self.assertEqual(colnames, ["col1", "col2"]) + self.assertEqual(results, [["data1", "data2"], ["data3", "data4"]]) + self.assertEqual(rcv_query, query1) # should return the original query + + @mock.patch("aiohttp.ClientSession.post") + @mock.patch("defog.query.async_execute_query_once") + async def test_execute_query_success_with_retry( + self, mock_execute_query_once, mock_aiohttp_post + ): + db_type = "postgres" + db_creds = { + "host": "localhost", + "port": 5432, + "database": "test_db", + "user": "test_user", + "password": "test_password", + } + query1 = "SELECT * FROM table_name;" + query2 = "SELECT * FROM table_name WHERE colour='blue';" + api_key = "your_api_key" + question = "your_question" + hard_filters = "your_hard_filters" + retries = 3 + dev = False + temp = False + colnames = (["col1", "col2"],) + results = [("data1", "data2"), ("data3", "data4")] + + # Mock the execute_query_once function to raise an exception the first + # time it is called and return the results the second time it is called + err_msg = "Test exception" + + def side_effect(db_type, db_creds, query): + if query == query1: + raise Exception(err_msg) + else: + return colnames, results + + mock_execute_query_once.side_effect = side_effect + + # Mock the async aiohttp response for the retry query + mock_response = mock.Mock() + mock_response.json = mock.AsyncMock(return_value={"new_query": query2}) + mock_aiohttp_post.return_value.__aenter__.return_value = mock_response + + # remove logs if they exist + if os.path.exists(os.path.join(self.logs_path)): + os.remove(os.path.join(self.logs_path)) + + # Call the function being tested + ret = await async_execute_query( + query=query1, + api_key=api_key, + db_type=db_type, + db_creds=db_creds, + question=question, + hard_filters=hard_filters, + retries=retries, + ) + + # should return new query2 instead of query1 + self.assertEqual(ret, (colnames, results, query2)) + + # Assert the mock function calls + mock_execute_query_once.assert_called_with(db_type, db_creds, query2) + + json_req = { + "api_key": api_key, + "previous_query": query1, + "error": err_msg, + "db_type": db_type, + "hard_filters": hard_filters, + "question": question, + "dev": dev, + "temp": temp, + } + + # Assert aiohttp post was called with the correct arguments + mock_aiohttp_post.assert_called_with( + "https://api.defog.ai/retry_query_after_error", + json=json_req, + timeout=None, + ) + + # check that error logs are populated + with open(self.logs_path, "r") as f: + lines = f.readlines() + self.assertEqual(len(lines), 5) + self.assertIn(err_msg, lines[0]) + self.assertIn(f"Retries left: {retries}", lines[1]) + self.assertIn(json.dumps(json_req), lines[2]) + + class TestConnectionError(unittest.TestCase): def test_connection_failed(self): self.assertTrue( From a533f21ca23d0c5cde9fc4a68d57ca4e25d6e67a Mon Sep 17 00:00:00 2001 From: Muhammad18557 Date: Tue, 10 Sep 2024 17:29:06 +0800 Subject: [PATCH 11/17] exactaly the same basic tests as for Defog --- tests/test_async_defog.py | 256 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 256 insertions(+) create mode 100644 tests/test_async_defog.py diff --git a/tests/test_async_defog.py b/tests/test_async_defog.py new file mode 100644 index 0000000..81fa5b1 --- /dev/null +++ b/tests/test_async_defog.py @@ -0,0 +1,256 @@ +import shutil +import unittest +import asyncio +from defog import AsyncDefog +from defog.util import parse_update +import os +from unittest.mock import patch + + +class TestAsyncDefog(unittest.TestCase): + @classmethod + def setUpClass(self): + # if connection.json exists, copy it to /tmp since we'll be overwriting it + home_dir = os.path.expanduser("~") + self.filepath = os.path.join(home_dir, ".defog", "connection.json") + self.tmp_dir = os.path.join("/tmp") + self.moved = False + if os.path.exists(self.filepath): + print("Moving connection.json to /tmp") + if os.path.exists(os.path.join(self.tmp_dir, "connection.json")): + os.remove(os.path.join(self.tmp_dir, "connection.json")) + shutil.move(self.filepath, self.tmp_dir) + self.moved = True + + @classmethod + def tearDownClass(self): + # copy back the original after all tests have completed + if self.moved: + print("Moving connection.json back to ~/.defog") + shutil.move(os.path.join(self.tmp_dir, "connection.json"), self.filepath) + + def tearDown(self): + # clean up connection.json created/saved after each test case + if os.path.exists(self.filepath): + print("Removing connection.json used for testing") + os.remove(self.filepath) + + ### Case 1: + def test_async_defog_bad_init_no_params(self): + with self.assertRaises(ValueError): + print("Testing AsyncDefog with no params") + AsyncDefog() + + # test initialization with partial params + def test_async_defog_good_init_no_db_creds(self): + df = AsyncDefog("test_api_key", "redis") + self.assertEqual(df.api_key, "test_api_key") + self.assertEqual(df.db_type, "redis") + self.assertEqual(df.db_creds, {}) + + ### Case 2: + # no connection file, no params + def test_async_defog_bad_init_no_connection_file(self): + with self.assertRaises(ValueError): + print("Testing AsyncDefog with no connection file, no params") + AsyncDefog() + + # no connection file, incomplete db_creds + def test_async_defog_bad_init_incomplete_creds(self): + with self.assertRaises(KeyError): + AsyncDefog("test_api_key", "postgres", {"host": "some_host"}) + + ### Case 3: + def test_async_defog_good_init(self): + print("testing AsyncDefog with good params") + db_creds = { + "host": "some_host", + "port": "some_port", + "database": "some_database", + "user": "some_user", + "password": "some_password", + } + df = AsyncDefog("test_api_key", "postgres", db_creds) + self.assertEqual(df.api_key, "test_api_key") + self.assertEqual(df.db_type, "postgres") + self.assertEqual(df.db_creds, db_creds) + + ### Case 4: + def test_async_defog_no_overwrite(self): + db_creds = { + "host": "host", + "port": "port", + "database": "database", + "user": "user", + "password": "password", + } + df1 = AsyncDefog("old_api_key", "postgres", db_creds) + self.assertEqual(df1.api_key, "old_api_key") + self.assertEqual(df1.db_type, "postgres") + self.assertEqual(df1.db_creds, db_creds) + self.assertTrue(os.path.exists(self.filepath)) + del df1 + df2 = AsyncDefog() # should read connection.json + self.assertEqual(df2.api_key, "old_api_key") + self.assertEqual(df2.db_type, "postgres") + self.assertEqual(df2.db_creds, db_creds) + + ### Case 5: + @patch("builtins.input", lambda *args: "y") + def test_async_defog_overwrite(self): + db_creds = { + "host": "host", + "port": "port", + "database": "database", + "user": "user", + "password": "password", + } + df = AsyncDefog("old_api_key", "postgres", db_creds) + self.assertEqual(df.api_key, "old_api_key") + self.assertEqual(df.db_type, "postgres") + self.assertEqual(df.db_creds, db_creds) + self.assertTrue(os.path.exists(self.filepath)) + df = AsyncDefog("new_api_key", "redshift") + self.assertEqual(df.api_key, "new_api_key") + self.assertEqual(df.db_type, "redshift") + self.assertEqual(df.db_creds, {}) + + @patch("builtins.input", lambda *args: "n") + def test_async_defog_no_overwrite(self): + db_creds = { + "host": "host", + "port": "port", + "database": "database", + "user": "user", + "password": "password", + } + df = AsyncDefog("old_api_key", "postgres", db_creds) + self.assertEqual(df.api_key, "old_api_key") + self.assertEqual(df.db_type, "postgres") + self.assertEqual(df.db_creds, db_creds) + self.assertTrue(os.path.exists(self.filepath)) + df = AsyncDefog("new_api_key", "redshift") + self.assertEqual(df.api_key, "new_api_key") + self.assertEqual(df.db_type, "redshift") + self.assertEqual(df.db_creds, {}) + + # test check_db_creds with all the different supported db types + def test_check_db_creds_postgres(self): + db_creds = { + "host": "some_host", + "port": "some_port", + "database": "some_database", + "user": "some_user", + "password": "some_password", + } + AsyncDefog.check_db_creds("postgres", db_creds) + AsyncDefog.check_db_creds("postgres", {}) + with self.assertRaises(KeyError): + AsyncDefog.check_db_creds("postgres", {"host": "some_host"}) + + def test_check_db_creds_redshift(self): + db_creds = { + "host": "some_host", + "port": "some_port", + "database": "some_database", + "user": "some_user", + "password": "some_password", + } + AsyncDefog.check_db_creds("redshift", db_creds) + AsyncDefog.check_db_creds("redshift", {}) + with self.assertRaises(KeyError): + AsyncDefog.check_db_creds("redshift", {"host": "some_host"}) + + def test_check_db_creds_mysql(self): + db_creds = { + "host": "some_host", + "database": "some_database", + "user": "some_user", + "password": "some_password", + } + AsyncDefog.check_db_creds("mysql", db_creds) + AsyncDefog.check_db_creds("mysql", {}) + with self.assertRaises(KeyError): + AsyncDefog.check_db_creds("mysql", {"host": "some_host"}) + + async def test_check_db_creds_snowflake(self): + db_creds = { + "account": "some_account", + "warehouse": "some_warehouse", + "user": "some_user", + "password": "some_password", + } + AsyncDefog.check_db_creds("snowflake", db_creds) + AsyncDefog.check_db_creds("snowflake", {}) + with self.assertRaises(KeyError): + AsyncDefog.check_db_creds("snowflake", {"account": "some_account"}) + + def test_check_db_creds_mongo(self): + db_creds = {"connection_string": "some_connection_string"} + AsyncDefog.check_db_creds("mongo", db_creds) + AsyncDefog.check_db_creds("mongo", {}) + with self.assertRaises(KeyError): + AsyncDefog.check_db_creds("mongo", {"account": "some_account"}) + + def test_check_db_creds_sqlserver(self): + db_creds = { + "server": "some_server", + "database": "some_database", + "user": "some_user", + "password": "some_password", + } + AsyncDefog.check_db_creds("sqlserver", db_creds) + AsyncDefog.check_db_creds("sqlserver", {}) + with self.assertRaises(KeyError): + AsyncDefog.check_db_creds("sqlserver", {"account": "some_account"}) + + def test_check_db_creds_bigquery(self): + db_creds = {"json_key_path": "some_json_key_path"} + AsyncDefog.check_db_creds("bigquery", db_creds) + AsyncDefog.check_db_creds("bigquery", {}) + with self.assertRaises(KeyError): + AsyncDefog.check_db_creds("bigquery", {"account": "some_account"}) + + def test_base64_encode_decode(self): + db_creds = { + "host": "some_host", + "port": "some_port", + "database": "some_database", + "user": "some_user", + "password": "some_password", + } + df1 = AsyncDefog("test_api_key", "postgres", db_creds) + df1_base64creds = df1.to_base64_creds() + df2 = AsyncDefog(base64creds=df1_base64creds) + self.assertEqual(df1.api_key, df2.api_key) + self.assertEqual(df1.api_key, "test_api_key") + self.assertEqual(df1.db_type, df2.db_type) + self.assertEqual(df1.db_type, "postgres") + self.assertEqual(df1.db_creds, df2.db_creds) + self.assertEqual(df1.db_creds, db_creds) + + def test_save_json(self): + db_creds = { + "host": "some_host", + "port": "some_port", + "database": "some_database", + "user": "some_user", + "password": "some_password", + } + _ = AsyncDefog("test_api_key", "postgres", db_creds, save_json=True) + self.assertTrue(os.path.exists(self.filepath)) + + def test_no_save_json(self): + db_creds = { + "host": "some_host", + "port": "some_port", + "database": "some_database", + "user": "some_user", + "password": "some_password", + } + df_save = AsyncDefog("test_api_key", "postgres", db_creds, save_json=False) + self.assertTrue(not os.path.exists(self.filepath)) + + +if __name__ == "__main__": + unittest.main() From 60873e76bd8b6dba9545f9b328ac5a49b2b5d43e Mon Sep 17 00:00:00 2001 From: Muhammad18557 Date: Wed, 11 Sep 2024 12:30:49 +0800 Subject: [PATCH 12/17] close connection was missing --- defog/generate_schema.py | 1 + 1 file changed, 1 insertion(+) diff --git a/defog/generate_schema.py b/defog/generate_schema.py index 2c0052a..2ec9e10 100644 --- a/defog/generate_schema.py +++ b/defog/generate_schema.py @@ -378,6 +378,7 @@ def generate_snowflake_schema( tables = [f"{row[3]}.{row[4]}.{row[1]}" for row in res] if return_tables_only: + conn.close() return tables for table_name in tables: From 13b3b96c5f248066dbde43da353616edc7f52fe3 Mon Sep 17 00:00:00 2001 From: Muhammad18557 Date: Wed, 11 Sep 2024 12:32:18 +0800 Subject: [PATCH 13/17] snowflake adjusted in generate schema --- defog/async_generate_schema.py | 51 +++++++++++++++++----------------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/defog/async_generate_schema.py b/defog/async_generate_schema.py index 584a2f7..0e6c54c 100644 --- a/defog/async_generate_schema.py +++ b/defog/async_generate_schema.py @@ -353,6 +353,7 @@ async def generate_databricks_schema( return schemas + async def generate_snowflake_schema( self, tables: list, @@ -366,39 +367,46 @@ async def generate_snowflake_schema( except: raise Exception("snowflake-connector not installed.") - conn = await asyncio.to_thread( - snowflake.connector.connect, + conn = snowflake.connector.connect( user=self.db_creds["user"], password=self.db_creds["password"], account=self.db_creds["account"], ) - await asyncio.to_thread( - conn.cursor().execute, f"USE WAREHOUSE {self.db_creds['warehouse']}" - ) # set the warehouse + cur = conn.cursor() + cur.execute_async(f"USE WAREHOUSE {self.db_creds['warehouse']}") # set the warehouse + query_id = cur.sfqid # Get the query ID after execution + + # Check the status of the query + while conn.is_still_running(conn.get_query_status(query_id)): + await asyncio.sleep(1) # Sleep while the query is still running schemas = {} alt_types = {"DATE": "TIMESTAMP", "TEXT": "VARCHAR", "FIXED": "NUMERIC"} print("Getting schema for each table that you selected...") # get the schema for each table if len(tables) == 0: + cur = conn.cursor() # get all tables from Snowflake database - cur = await asyncio.to_thread(conn.cursor().execute, "SHOW TERSE TABLES;") - res = await asyncio.to_thread(cur.fetchall) + cur.execute_async("SHOW TERSE TABLES;") + query_id = cur.sfqid + while conn.is_still_running(conn.get_query_status(query_id)): + await asyncio.sleep(1) + res = cur.fetchall() tables = [f"{row[3]}.{row[4]}.{row[1]}" for row in res] if return_tables_only: - await asyncio.to_thread(conn.close) + conn.close() return tables for table_name in tables: rows = [] - cur = await asyncio.to_thread(conn.cursor) - fetched_rows = await asyncio.to_thread( - cur.execute, f"SHOW COLUMNS IN {table_name};" - ) - for row in fetched_rows: - rows.append(row) + cur = conn.cursor() + cur.execute_async(f"SHOW COLUMNS IN {table_name};") + query_id = cur.sfqid + while conn.is_still_running(conn.get_query_status(query_id)): + await asyncio.sleep(1) + rows = cur.fetchall() rows = [ { "column_name": i[2], @@ -411,21 +419,14 @@ async def generate_snowflake_schema( if row["data_type"] in alt_types: row["data_type"] = alt_types[row["data_type"]] rows[idx] = row - - cur = await asyncio.to_thread(conn.cursor) + cur = conn.cursor() if scan: - rows = await async_identify_categorical_columns( - conn=None, - cur=cur, - table_name=table_name, - rows=rows, - is_cursor_async=False, - ) - await asyncio.to_thread(cur.close) + rows = async_identify_categorical_columns(conn=conn, cur=cur, table_name=table_name, rows=rows, is_cursor_async=False, db_type="snowflake") + cur.close() if len(rows) > 0: schemas[table_name] = rows - await asyncio.to_thread(conn.close) + conn.close() if upload: print( From bda240ebafd79055cd9f562faf966bc6776ccafb Mon Sep 17 00:00:00 2001 From: Muhammad18557 Date: Wed, 11 Sep 2024 13:56:26 +0800 Subject: [PATCH 14/17] native aioobdc for sqlserver --- defog/async_generate_schema.py | 68 +++++++++++++++++----------------- defog/generate_schema.py | 1 + defog/query.py | 44 ++++++++++++---------- 3 files changed, 60 insertions(+), 53 deletions(-) diff --git a/defog/async_generate_schema.py b/defog/async_generate_schema.py index 0e6c54c..7e38e4c 100644 --- a/defog/async_generate_schema.py +++ b/defog/async_generate_schema.py @@ -373,13 +373,9 @@ async def generate_snowflake_schema( account=self.db_creds["account"], ) - cur = conn.cursor() - cur.execute_async(f"USE WAREHOUSE {self.db_creds['warehouse']}") # set the warehouse - query_id = cur.sfqid # Get the query ID after execution - - # Check the status of the query - while conn.is_still_running(conn.get_query_status(query_id)): - await asyncio.sleep(1) # Sleep while the query is still running + conn.cursor().execute( + f"USE WAREHOUSE {self.db_creds['warehouse']}" + ) # set the warehouse schemas = {} alt_types = {"DATE": "TIMESTAMP", "TEXT": "VARCHAR", "FIXED": "NUMERIC"} @@ -388,8 +384,8 @@ async def generate_snowflake_schema( if len(tables) == 0: cur = conn.cursor() # get all tables from Snowflake database - cur.execute_async("SHOW TERSE TABLES;") - query_id = cur.sfqid + cur.execute_async("SHOW TERSE TABLES;") # execute asynchrnously + query_id = cur.sfqid # get the query id to check the status while conn.is_still_running(conn.get_query_status(query_id)): await asyncio.sleep(1) res = cur.fetchall() @@ -421,7 +417,14 @@ async def generate_snowflake_schema( rows[idx] = row cur = conn.cursor() if scan: - rows = async_identify_categorical_columns(conn=conn, cur=cur, table_name=table_name, rows=rows, is_cursor_async=False, db_type="snowflake") + rows = await async_identify_categorical_columns( + conn=conn, + cur=cur, + table_name=table_name, + rows=rows, + is_cursor_async=False, + db_type="snowflake", + ) cur.close() if len(rows) > 0: schemas[table_name] = rows @@ -533,7 +536,6 @@ async def generate_bigquery_schema( else: return schemas - async def generate_sqlserver_schema( self, tables: list, @@ -542,49 +544,49 @@ async def generate_sqlserver_schema( return_tables_only: bool = False, ) -> str: try: - import pyodbc - except: - raise Exception("pyodbc not installed.") - + import aioodbc + except Exception as e: + raise Exception( + "aioodbc not installed. Please install it with `pip install aioodbc`." + ) if self.db_creds["database"] != "": connection_string = f"DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={self.db_creds['server']};DATABASE={self.db_creds['database']};UID={self.db_creds['user']};PWD={self.db_creds['password']};TrustServerCertificate=yes;Connection Timeout=120;" else: connection_string = f"DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={self.db_creds['server']};UID={self.db_creds['user']};PWD={self.db_creds['password']};TrustServerCertificate=yes;Connection Timeout=120;" - conn = await asyncio.to_thread(pyodbc.connect, connection_string) - cur = await asyncio.to_thread(conn.cursor) + + conn = await aioodbc.connect(dsn=connection_string) + cur = await conn.cursor() schemas = {} schema = self.db_creds.get("schema", "dbo") if len(tables) == 0: - table_names_query = ( - "SELECT table_name FROM information_schema.tables WHERE table_schema = %s;" + # get all tables + await cur.execute( + "SELECT table_name FROM information_schema.tables WHERE table_schema = %s;", + (schema,), ) - await asyncio.to_thread(cur.execute, table_names_query, (schema,)) if schema == "dbo": - tables += [row[0] for row in await asyncio.to_thread(cur.fetchall)] + tables = [row[0] for row in await cur.fetchall()] else: - tables += [ - schema + "." + row[0] for row in await asyncio.to_thread(cur.fetchall) - ] - + tables = [schema + "." + row[0] for row in await cur.fetchall()] + if return_tables_only: - await asyncio.to_thread(conn.close) + await conn.close() return tables - print("Getting schema for each table in your database...") + print("Getting schema for the relevant table in your database...") # get the schema for each table for table_name in tables: - await asyncio.to_thread( - cur.execute, - f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';", + await cur.execute( + "SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %s;", + (table_name,), ) - rows = await asyncio.to_thread(cur.fetchall) + rows = await cur.fetchall() rows = [row for row in rows] rows = [{"column_name": i[0], "data_type": i[1]} for i in rows] if len(rows) > 0: schemas[table_name] = rows - - await asyncio.to_thread(conn.close) + await conn.close() if upload: print( "Sending the schema to Defog servers and generating column descriptions. This might take up to 2 minutes..." diff --git a/defog/generate_schema.py b/defog/generate_schema.py index 2ec9e10..c0e854d 100644 --- a/defog/generate_schema.py +++ b/defog/generate_schema.py @@ -544,6 +544,7 @@ def generate_sqlserver_schema( tables += [schema + "." + row[0] for row in cur.fetchall()] if return_tables_only: + conn.close() return tables print("Getting schema for each table in your database...") diff --git a/defog/query.py b/defog/query.py index bb3e0f2..e87fc4d 100644 --- a/defog/query.py +++ b/defog/query.py @@ -234,23 +234,26 @@ async def async_execute_query_once(db_type: str, db_creds, query: str): import snowflake.connector except: raise Exception("snowflake.connector not installed.") - conn = await asyncio.to_thread( - snowflake.connector.connect, + conn = snowflake.connector.connect( user=db_creds["user"], password=db_creds["password"], account=db_creds["account"], ) - cur = await asyncio.to_thread(conn.cursor) - await asyncio.to_thread( - cur.execute, f"USE WAREHOUSE {db_creds['warehouse']}" - ) # set the warehouse + cur = conn.cursor() + cur.execute(f"USE WAREHOUSE {db_creds['warehouse']}") # set the warehouse + if "database" in db_creds: - await asyncio.to_thread(cur.execute, f"USE DATABASE {db_creds['database']}") - await asyncio.to_thread(cur.execute, query) + cur.execute(f"USE DATABASE {db_creds['database']}") # set the database + + cur.execute_async(query) + query_id = cur.sfqid + while conn.is_still_running(conn.get_query_status(query_id)): + await asyncio.sleep(1) + colnames = [desc[0] for desc in cur.description] - rows = await asyncio.to_thread(cur.fetchall) - await asyncio.to_thread(cur.close) - await asyncio.to_thread(conn.close) + rows = cur.fetchall() + cur.close() + conn.close() return colnames, rows elif db_type == "databricks": @@ -268,22 +271,23 @@ async def async_execute_query_once(db_type: str, db_creds, query: str): elif db_type == "sqlserver": try: - import pyodbc + import aioodbc except: - raise Exception("pyodbc not installed.") + raise Exception("aioodbc not installed.") if db_creds["database"] != "": connection_string = f"DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={db_creds['server']};DATABASE={db_creds['database']};UID={db_creds['user']};PWD={db_creds['password']};TrustServerCertificate=yes;Connection Timeout=120;" else: connection_string = f"DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={db_creds['server']};UID={db_creds['user']};PWD={db_creds['password']};TrustServerCertificate=yes;Connection Timeout=120;" - conn = await asyncio.to_thread(pyodbc.connect, connection_string) - cur = await asyncio.to_thread(conn.cursor) - await asyncio.to_thread(cursor.execute, query) - colnames = [desc[0] for desc in cursor.description] - results = await asyncio.to_thread(cursor.fetchall) + conn = await aioodbc.connect(dsn=connection_string) + cur = await conn.cursor() + + await cur.execute(query) + colnames = [desc[0] for desc in cur.description] + results = await cur.fetchall() rows = [list(row) for row in results] - await asyncio.to_thread(cursor.close) - await asyncio.to_thread(conn.close) + await cur.close() + await conn.close() return colnames, rows else: raise Exception(f"Database type {db_type} not yet supported.") From aef12620f90a49f52c5e5973e1807c9471c8f3a5 Mon Sep 17 00:00:00 2001 From: Muhammad18557 Date: Wed, 11 Sep 2024 13:56:52 +0800 Subject: [PATCH 15/17] util function for identifying categorical variables adjusted --- defog/util.py | 11 +++++++++++ requirements.txt | 1 + 2 files changed, 12 insertions(+) diff --git a/defog/util.py b/defog/util.py index ce86cc1..b8fa08d 100644 --- a/defog/util.py +++ b/defog/util.py @@ -161,6 +161,7 @@ async def async_identify_categorical_columns( table_name: str = "", rows: list = [], is_cursor_async: bool = False, + db_type="", distinct_threshold: int = 10, character_length_threshold: int = 50, ): @@ -189,6 +190,16 @@ async def async_identify_categorical_columns( ) async def run_query(query, params=None): + if db_type == "snowflake": + if params is not None: + cur.execute_async(query, params) + else: + cur.execute_async(query) + query_id = cur.sfqid + while conn.is_still_running(conn.get_query_status(query_id)): + await asyncio.sleep(1) + return cur.fetchall() + if conn: # If using an async connection like asyncpg return await conn.fetch(query, *params if params else ()) diff --git a/requirements.txt b/requirements.txt index a532898..c4fbb45 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ prompt_toolkit psycopg2-binary>=2.9.5 asyncpg aiomysql +aioodbc pwinput requests>=2.28.2 aiohttp From a4ba3edefd202e43d8f28043e5634d92eae28a2e Mon Sep 17 00:00:00 2001 From: Muhammad18557 Date: Wed, 11 Sep 2024 13:57:16 +0800 Subject: [PATCH 16/17] linted --- defog/async_generate_schema.py | 10 +++++----- defog/query.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/defog/async_generate_schema.py b/defog/async_generate_schema.py index 7e38e4c..b9178c7 100644 --- a/defog/async_generate_schema.py +++ b/defog/async_generate_schema.py @@ -353,7 +353,6 @@ async def generate_databricks_schema( return schemas - async def generate_snowflake_schema( self, tables: list, @@ -384,8 +383,8 @@ async def generate_snowflake_schema( if len(tables) == 0: cur = conn.cursor() # get all tables from Snowflake database - cur.execute_async("SHOW TERSE TABLES;") # execute asynchrnously - query_id = cur.sfqid # get the query id to check the status + cur.execute_async("SHOW TERSE TABLES;") # execute asynchrnously + query_id = cur.sfqid # get the query id to check the status while conn.is_still_running(conn.get_query_status(query_id)): await asyncio.sleep(1) res = cur.fetchall() @@ -536,6 +535,7 @@ async def generate_bigquery_schema( else: return schemas + async def generate_sqlserver_schema( self, tables: list, @@ -553,7 +553,7 @@ async def generate_sqlserver_schema( connection_string = f"DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={self.db_creds['server']};DATABASE={self.db_creds['database']};UID={self.db_creds['user']};PWD={self.db_creds['password']};TrustServerCertificate=yes;Connection Timeout=120;" else: connection_string = f"DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={self.db_creds['server']};UID={self.db_creds['user']};PWD={self.db_creds['password']};TrustServerCertificate=yes;Connection Timeout=120;" - + conn = await aioodbc.connect(dsn=connection_string) cur = await conn.cursor() schemas = {} @@ -569,7 +569,7 @@ async def generate_sqlserver_schema( tables = [row[0] for row in await cur.fetchall()] else: tables = [schema + "." + row[0] for row in await cur.fetchall()] - + if return_tables_only: await conn.close() return tables diff --git a/defog/query.py b/defog/query.py index e87fc4d..ef0af81 100644 --- a/defog/query.py +++ b/defog/query.py @@ -249,7 +249,7 @@ async def async_execute_query_once(db_type: str, db_creds, query: str): query_id = cur.sfqid while conn.is_still_running(conn.get_query_status(query_id)): await asyncio.sleep(1) - + colnames = [desc[0] for desc in cur.description] rows = cur.fetchall() cur.close() From e0eb4013def7807d142ded3f215c0ff5eb071d24 Mon Sep 17 00:00:00 2001 From: Muhammad18557 Date: Wed, 11 Sep 2024 14:20:46 +0800 Subject: [PATCH 17/17] addednative snowflake and sqlserver in async admin methods --- defog/async_admin_methods.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/defog/async_admin_methods.py b/defog/async_admin_methods.py index 732b428..2a456d1 100644 --- a/defog/async_admin_methods.py +++ b/defog/async_admin_methods.py @@ -360,20 +360,22 @@ async def create_empty_tables(self, dev: bool = False): elif self.db_type == "snowflake": import snowflake.connector - conn = await asyncio.to_thread( - snowflake.connector.connect, + conn = snowflake.connector.connect( user=self.db_creds["user"], password=self.db_creds["password"], account=self.db_creds["account"], ) - cur = await asyncio.to_thread(conn.cursor) - await asyncio.to_thread( - cur.execute, f"USE WAREHOUSE {self.db_creds['warehouse']}" - ) + cur = conn.cursor() + cur.execute( + f"USE WAREHOUSE {self.db_creds['warehouse']}" + ) # set the warehouse for statement in ddl.split(";"): - await asyncio.to_thread(cur.execute, statement) - await asyncio.to_thread(conn.commit) - await asyncio.to_thread(conn.close) + cur.execute_async(statement) + query_id = cur.sfqid + while conn.is_still_running(conn.get_query_status(query_id)): + await asyncio.sleep(1) + cur.close() + conn.close() return True elif self.db_type == "bigquery": @@ -388,18 +390,19 @@ async def create_empty_tables(self, dev: bool = False): return True elif self.db_type == "sqlserver": - import pyodbc + import aioodbc if self.db_creds["database"] != "": connection_string = f"DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={self.db_creds['server']};DATABASE={self.db_creds['database']};UID={self.db_creds['user']};PWD={self.db_creds['password']};TrustServerCertificate=yes;Connection Timeout=120;" else: connection_string = f"DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={self.db_creds['server']};UID={self.db_creds['user']};PWD={self.db_creds['password']};TrustServerCertificate=yes;Connection Timeout=120;" - conn = await asyncio.to_thread(pyodbc.connect, connection_string) - cur = await asyncio.to_thread(conn.cursor) + conn = await aioodbc.connect(dsn=connection_string) + cur = await conn.cursor() for statement in ddl.split(";"): - await asyncio.to_thread(cur.execute, statement) - await asyncio.to_thread(conn.commit) - await asyncio.to_thread(conn.close) + await cur.execute(statement) + await conn.commit() + await cur.close() + await conn.close() return True else: