Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ibis): introduce Local file connector #1029

Merged
merged 13 commits into from
Jan 7, 2025
5 changes: 5 additions & 0 deletions ibis-server/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,8 @@ def custom_http_error_handler(request, exc: CustomHttpError):
with logger.contextualize(correlation_id=request.headers.get("X-Correlation-ID")):
logger.opt(exception=exc).error("Request failed")
return PlainTextResponse(str(exc), status_code=exc.status_code)


@app.exception_handler(NotImplementedError)
def not_implemented_error_handler(request, exc: NotImplementedError):
return PlainTextResponse(str(exc), status_code=501)
2 changes: 2 additions & 0 deletions ibis-server/app/mdl/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def _get_read_dialect(cls, experiment) -> str | None:
def _get_write_dialect(cls, data_source: DataSource) -> str:
if data_source == DataSource.canner:
return "trino"
elif data_source == DataSource.local_file:
return "duckdb"
return data_source.name


Expand Down
12 changes: 12 additions & 0 deletions ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ class QueryTrinoDTO(QueryDTO):
connection_info: ConnectionUrl | TrinoConnectionInfo = connection_info_field


class QueryLocalFileDTO(QueryDTO):
connection_info: LocalFileConnectionInfo = connection_info_field


class BigQueryConnectionInfo(BaseModel):
project_id: SecretStr
dataset_id: SecretStr
Expand Down Expand Up @@ -133,6 +137,13 @@ class TrinoConnectionInfo(BaseModel):
password: SecretStr | None = None


class LocalFileConnectionInfo(BaseModel):
url: SecretStr
format: str = Field(
description="File format", default="csv", examples=["csv", "parquet", "json"]
)


ConnectionInfo = (
BigQueryConnectionInfo
| CannerConnectionInfo
Expand All @@ -142,6 +153,7 @@ class TrinoConnectionInfo(BaseModel):
| PostgresConnectionInfo
| SnowflakeConnectionInfo
| TrinoConnectionInfo
| LocalFileConnectionInfo
)


Expand Down
15 changes: 15 additions & 0 deletions ibis-server/app/model/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __init__(self, data_source: DataSource, connection_info: ConnectionInfo):
self._connector = CannerConnector(connection_info)
elif data_source == DataSource.bigquery:
self._connector = BigQueryConnector(connection_info)
elif data_source == DataSource.local_file:
self._connector = DuckDBConnector(connection_info)
else:
self._connector = SimpleConnector(data_source, connection_info)

Expand Down Expand Up @@ -144,6 +146,19 @@ def query(self, sql: str, limit: int) -> pd.DataFrame:
raise e


class DuckDBConnector:
def __init__(self, _connection_info: ConnectionInfo):
import duckdb

self.connection = duckdb.connect()

def query(self, sql: str, limit: int) -> pd.DataFrame:
return self.connection.execute(sql).fetch_df().head(limit)

def dry_run(self, sql: str) -> None:
self.connection.execute(sql)


@cache
def _get_pg_type_names(connection: BaseBackend) -> dict[int, str]:
cur = connection.raw_sql("SELECT oid, typname FROM pg_type")
Expand Down
7 changes: 7 additions & 0 deletions ibis-server/app/model/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
QueryCannerDTO,
QueryClickHouseDTO,
QueryDTO,
QueryLocalFileDTO,
QueryMSSqlDTO,
QueryMySqlDTO,
QueryPostgresDTO,
Expand All @@ -39,6 +40,7 @@ class DataSource(StrEnum):
postgres = auto()
snowflake = auto()
trino = auto()
local_file = auto()

def get_connection(self, info: ConnectionInfo) -> BaseBackend:
try:
Expand All @@ -62,6 +64,7 @@ class DataSourceExtension(Enum):
postgres = QueryPostgresDTO
snowflake = QuerySnowflakeDTO
trino = QueryTrinoDTO
local_file = QueryLocalFileDTO
goldmedal marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, dto: QueryDTO):
self.dto = dto
Expand All @@ -70,6 +73,10 @@ def get_connection(self, info: ConnectionInfo) -> BaseBackend:
try:
if hasattr(info, "connection_url"):
return ibis.connect(info.connection_url.get_secret_value())
if self.name == "local_file":
raise NotImplementedError(
"Local file connection is not implemented to get ibis backend"
)
return getattr(self, f"get_{self.name}_connection")(info)
except KeyError:
raise NotImplementedError(f"Unsupported data source: {self}")
Expand Down
3 changes: 3 additions & 0 deletions ibis-server/app/model/metadata/dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class TableProperties(BaseModel):
schema_: str | None = Field(alias="schema", default=None)
catalog: str | None
table: str | None # only table name without schema or catalog
path: str | None = Field(
alias="path", default=None
) # the full path of the table for file-based table


class Table(BaseModel):
Expand Down
2 changes: 2 additions & 0 deletions ibis-server/app/model/metadata/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from app.model.metadata.metadata import Metadata
from app.model.metadata.mssql import MSSQLMetadata
from app.model.metadata.mysql import MySQLMetadata
from app.model.metadata.object_storage import LocalFileMetadata
from app.model.metadata.postgres import PostgresMetadata
from app.model.metadata.snowflake import SnowflakeMetadata
from app.model.metadata.trino import TrinoMetadata
Expand All @@ -18,6 +19,7 @@
DataSource.postgres: PostgresMetadata,
DataSource.trino: TrinoMetadata,
DataSource.snowflake: SnowflakeMetadata,
DataSource.local_file: LocalFileMetadata,
}


Expand Down
156 changes: 156 additions & 0 deletions ibis-server/app/model/metadata/object_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import os

import duckdb
import opendal
from loguru import logger

from app.model import LocalFileConnectionInfo
from app.model.metadata.dto import (
Column,
RustWrenEngineColumnType,
Table,
TableProperties,
)
from app.model.metadata.metadata import Metadata


class ObjectStorageMetadata(Metadata):
def __init__(self, connection_info):
super().__init__(connection_info)

def get_table_list(self) -> list[Table]:
op = opendal.Operator("fs", root=self.connection_info.url.get_secret_value())
conn = self._get_connection()
unique_tables = {}
for file in op.list("/"):
if file.path != "/":
stat = op.stat(file.path)
if stat.mode.is_dir():
# if the file is a directory, use the directory name as the table name
table_name = os.path.basename(os.path.normpath(file.path))
full_path = f"{self.connection_info.url.get_secret_value()}/{table_name}/*.{self.connection_info.format}"
else:
# if the file is a file, use the file name as the table name
table_name = os.path.splitext(os.path.basename(file.path))[0]
full_path = (
f"{self.connection_info.url.get_secret_value()}/{file.path}"
)

# read the file with the target format if unreadable, skip the file
df = self._read_df(conn, full_path)
if df is None:
continue
columns = []
try:
for col in df.columns:
duckdb_type = df[col].dtypes[0]
columns.append(
Column(
name=col,
type=self._to_column_type(duckdb_type.__str__()),
notNull=False,
)
)
except Exception as e:
logger.debug(f"Failed to read column types: {e}")
continue

unique_tables[table_name] = Table(
name=table_name,
description=None,
columns=[],
properties=TableProperties(
table=table_name,
schema=None,
catalog=None,
path=full_path,
),
primaryKey=None,
)
unique_tables[table_name].columns = columns

return list(unique_tables.values())

def get_constraints(self):
return []

def get_version(self):
raise NotImplementedError("Subclasses must implement `get_version` method")

def _read_df(self, conn, path):
if self.connection_info.format == "parquet":
try:
return conn.read_parquet(path)
except Exception as e:
logger.debug(f"Failed to read parquet file: {e}")
return None
elif self.connection_info.format == "csv":
try:
logger.debug(f"Reading csv file: {path}")
return conn.read_csv(path)
except Exception as e:
logger.debug(f"Failed to read csv file: {e}")
return None
elif self.connection_info.format == "json":
try:
return conn.read_json(path)
except Exception as e:
logger.debug(f"Failed to read json file: {e}")
return None
else:
raise NotImplementedError(
f"Unsupported format: {self.connection_info.format}"
)

def _to_column_type(self, col_type: str) -> RustWrenEngineColumnType:
if col_type.startswith("DECIMAL"):
return RustWrenEngineColumnType.DECIMAL

# TODO: support struct
if col_type.startswith("STRUCT"):
return RustWrenEngineColumnType.UNKNOWN

# TODO: support array
if col_type.endswith("[]"):
return RustWrenEngineColumnType.UNKNOWN

# refer to https://duckdb.org/docs/sql/data_types/overview#general-purpose-data-types
switcher = {
"BIGINT": RustWrenEngineColumnType.INT64,
"BIT": RustWrenEngineColumnType.INT2,
"BLOB": RustWrenEngineColumnType.BYTES,
"BOOLEAN": RustWrenEngineColumnType.BOOL,
"DATE": RustWrenEngineColumnType.DATE,
"DOUBLE": RustWrenEngineColumnType.DOUBLE,
"FLOAT": RustWrenEngineColumnType.FLOAT,
"INTEGER": RustWrenEngineColumnType.INT,
# TODO: Wren engine does not support HUGEINT. Map to INT64 for now.
"HUGEINT": RustWrenEngineColumnType.INT64,
"INTERVAL": RustWrenEngineColumnType.INTERVAL,
"JSON": RustWrenEngineColumnType.JSON,
"SMALLINT": RustWrenEngineColumnType.INT2,
"TIME": RustWrenEngineColumnType.TIME,
"TIMESTAMP": RustWrenEngineColumnType.TIMESTAMP,
"TIMESTAMP WITH TIME ZONE": RustWrenEngineColumnType.TIMESTAMPTZ,
"TINYINT": RustWrenEngineColumnType.INT2,
"UBIGINT": RustWrenEngineColumnType.INT64,
# TODO: Wren engine does not support UHUGEINT. Map to INT64 for now.
"UHUGEINT": RustWrenEngineColumnType.INT64,
"UINTEGER": RustWrenEngineColumnType.INT,
"USMALLINT": RustWrenEngineColumnType.INT2,
"UTINYINT": RustWrenEngineColumnType.INT2,
"UUID": RustWrenEngineColumnType.UUID,
"VARCHAR": RustWrenEngineColumnType.STRING,
}
return switcher.get(col_type, RustWrenEngineColumnType.UNKNOWN)

def _get_connection(self):
return duckdb.connect()


class LocalFileMetadata(ObjectStorageMetadata):
def __init__(self, connection_info: LocalFileConnectionInfo):
super().__init__(connection_info)

def get_version(self):
return "Local File System"
Loading
Loading