diff --git a/.clang-tidy b/.clang-tidy index 219ac263ab3..d0e7b7d1371 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -27,6 +27,8 @@ Checks: [ '-bugprone-not-null-terminated-result', '-bugprone-reserved-identifier', # useful but too slow, TODO retry when https://reviews.llvm.org/rG1c282052624f9d0bd273bde0b47b30c96699c6c7 is merged '-bugprone-unchecked-optional-access', + '-bugprone-crtp-constructor-accessibility', + '-bugprone-suspicious-stringview-data-usage', '-cert-dcl16-c', '-cert-dcl37-c', @@ -36,6 +38,7 @@ Checks: [ '-cert-msc51-cpp', '-cert-oop54-cpp', '-cert-oop57-cpp', + '-cert-err33-c', # Misreports on clang-19: it warns about all functions containing 'remove' in the name, not only about the standard library. '-clang-analyzer-optin.performance.Padding', @@ -94,10 +97,12 @@ Checks: [ '-modernize-pass-by-value', '-modernize-return-braced-init-list', '-modernize-use-auto', + '-modernize-use-constraints', # This is a good check, but clang-tidy crashes, see https://github.com/llvm/llvm-project/issues/91872 '-modernize-use-default-member-init', '-modernize-use-emplace', '-modernize-use-nodiscard', '-modernize-use-trailing-return-type', + '-modernize-use-designated-initializers', '-performance-enum-size', '-performance-inefficient-string-concatenation', @@ -121,7 +126,8 @@ Checks: [ '-readability-magic-numbers', '-readability-named-parameter', '-readability-redundant-declaration', - '-readability-redundant-inline-specifier', + '-readability-redundant-inline-specifier', # useful but incompatible with __attribute((always_inline))__ (aka. ALWAYS_INLINE, base/base/defines.h). + # ALWAYS_INLINE only has an effect if combined with `inline`: https://godbolt.org/z/Eefd74qdM '-readability-redundant-member-init', # Useful but triggers another problem. Imagine a struct S with multiple String members. Structs are often instantiated via designated # initializer S s{.s1 = [...], .s2 = [...], [...]}. In this case, compiler warning `missing-field-initializers` requires to specify all members which are not in-struct # initialized (example: s1 in struct S { String s1; String s2{};}; is not in-struct initialized, therefore it must be specified at instantiation time). As explicitly @@ -132,12 +138,7 @@ Checks: [ '-readability-uppercase-literal-suffix', '-readability-use-anyofallof', - '-zircon-*', - - # This is a good check, but clang-tidy crashes, see https://github.com/llvm/llvm-project/issues/91872 - '-modernize-use-constraints', - # https://github.com/abseil/abseil-cpp/issues/1667 - '-clang-analyzer-optin.core.EnumCastOutOfRange' + '-zircon-*' ] WarningsAsErrors: '*' @@ -172,4 +173,4 @@ CheckOptions: performance-move-const-arg.CheckTriviallyCopyableMove: false # Workaround clang-tidy bug: https://github.com/llvm/llvm-project/issues/46097 readability-identifier-naming.TypeTemplateParameterIgnoredRegexp: expr-type - cppcoreguidelines-avoid-do-while.IgnoreMacros: true + cppcoreguidelines-avoid-do-while.IgnoreMacros: true \ No newline at end of file diff --git a/.cursorignore b/.cursorignore new file mode 100644 index 00000000000..ad66b56ef27 --- /dev/null +++ b/.cursorignore @@ -0,0 +1 @@ +contrib/ diff --git a/chdb/__init__.py b/chdb/__init__.py index 1cab17ec211..2cbd2682a7a 100644 --- a/chdb/__init__.py +++ b/chdb/__init__.py @@ -85,8 +85,10 @@ def query(sql, output_format="CSV", path="", udf_path=""): PyReader = _chdb.PyReader from . import dbapi, session, udf, utils # noqa: E402 +from .state import connect # noqa: E402 __all__ = [ + "_chdb", "PyReader", "ChdbError", "query", @@ -99,4 +101,5 @@ def query(sql, output_format="CSV", path="", udf_path=""): "session", "udf", "utils", + "connect", ] diff --git a/chdb/dbapi/connections.py b/chdb/dbapi/connections.py index 090aa5500f9..773f428cba5 100644 --- a/chdb/dbapi/connections.py +++ b/chdb/dbapi/connections.py @@ -1,7 +1,7 @@ -import json from . import err from .cursors import Cursor from . import converters +from ..state import sqlitelike as chdb_stateful DEBUG = False VERBOSE = False @@ -10,56 +10,29 @@ class Connection(object): """ Representation of a connection with chdb. - - The proper way to get an instance of this class is to call - connect(). - - Accepts several arguments: - - :param cursorclass: Custom cursor class to use. - :param path: Optional folder path to store database files on disk. - - See `Connection `_ in the - specification. """ - _closed = False - _session = None - - def __init__(self, cursorclass=Cursor, path=None): - - self._resp = None - - # 1. pre-process params in init - self.encoding = 'utf8' - - self.cursorclass = cursorclass - - self._result = None + def __init__(self, path=None): + self._closed = False + self.encoding = "utf8" self._affected_rows = 0 + self._resp = None - self.connect(path) + # Initialize sqlitelike connection + connection_string = ":memory:" if path is None else f"file:{path}" + self._conn = chdb_stateful.Connection(connection_string) - def connect(self, path=None): - from chdb import session as chs - self._session = chs.Session(path) - self._closed = False - self._execute_command("select 1;") - self._read_query_result() + # Test connection with a simple query + cursor = self._conn.cursor() + cursor.execute("SELECT 1") + cursor.close() def close(self): - """ - Send the quit message and close the socket. - - See `Connection.close() `_ - in the specification. - - :raise Error: If the connection is already closed. - """ + """Send the quit message and close the socket.""" if self._closed: raise err.Error("Already closed") self._closed = True - self._session = None + self._conn.close() @property def open(self): @@ -67,81 +40,41 @@ def open(self): return not self._closed def commit(self): - """ - Commit changes to stable storage. - - See `Connection.commit() `_ - in the specification. - """ - return + """Commit changes to stable storage.""" + # No-op for ClickHouse + pass def rollback(self): - """ - Roll back the current transaction. - - See `Connection.rollback() `_ - in the specification. - """ - return + """Roll back the current transaction.""" + # No-op for ClickHouse + pass def cursor(self, cursor=None): - """ - Create a new cursor to execute queries with. - - :param cursor: The type of cursor to create; current only :py:class:`Cursor` - None means use Cursor. - """ + """Create a new cursor to execute queries with.""" + if self._closed: + raise err.Error("Connection closed") if cursor: - return cursor(self) - return self.cursorclass(self) + return Cursor(self) + return Cursor(self) - # The following methods are INTERNAL USE ONLY (called from Cursor) - def query(self, sql): - if isinstance(sql, str): - sql = sql.encode(self.encoding, 'surrogateescape') - self._execute_command(sql) - self._affected_rows = self._read_query_result() - return self._affected_rows - - def _execute_command(self, sql): - """ - :raise InterfaceError: If the connection is closed. - :raise ValueError: If no username was specified. - """ + def query(self, sql, fmt="ArrowStream"): + """Execute a query and return the raw result.""" if self._closed: raise err.InterfaceError("Connection closed") if isinstance(sql, str): - sql = sql.encode(self.encoding) + sql = sql.encode(self.encoding, "surrogateescape") - if isinstance(sql, bytearray): - sql = bytes(sql) - - # drop last command return - if self._resp is not None: - self._resp = None - - if DEBUG: - print("DEBUG: query:", sql) try: - res = self._session.query(sql, fmt="JSON") - if res.has_error(): - raise err.DatabaseError(res.error_message()) - self._resp = res.data() + result = self._conn.query(sql.decode(), fmt) + self._resp = result + return result except Exception as error: - raise err.InterfaceError("query err: %s" % error) + raise err.InterfaceError(f"Query error: {error}") def escape(self, obj, mapping=None): - """Escape whatever value you pass to it. - - Non-standard, for internal use; do not use this in your applications. - """ - if isinstance(obj, str): - return "'" + self.escape_string(obj) + "'" - if isinstance(obj, (bytes, bytearray)): - ret = self._quote_bytes(obj) - return ret - return converters.escape_item(obj, mapping=mapping) + """Escape whatever value you pass to it.""" + return converters.escape_item(obj, mapping) def escape_string(self, s): return converters.escape_string(s) @@ -149,13 +82,6 @@ def escape_string(self, s): def _quote_bytes(self, s): return converters.escape_bytes(s) - def _read_query_result(self): - self._result = None - result = CHDBResult(self) - result.read() - self._result = result - return result.affected_rows - def __enter__(self): """Context manager that returns a Cursor""" return self.cursor() @@ -166,52 +92,9 @@ def __exit__(self, exc, value, traceback): self.rollback() else: self.commit() + self.close() @property def resp(self): + """Return the last query response""" return self._resp - - -class CHDBResult(object): - def __init__(self, connection): - """ - :type connection: Connection - """ - self.connection = connection - self.affected_rows = 0 - self.insert_id = None - self.warning_count = 0 - self.message = None - self.field_count = 0 - self.description = None - self.rows = None - self.has_next = None - - def read(self): - # Handle empty responses (for instance from CREATE TABLE) - if self.connection.resp is None: - return - - try: - data = json.loads(self.connection.resp) - except Exception as error: - raise err.InterfaceError("Unsupported response format:" % error) - - try: - self.field_count = len(data["meta"]) - description = [] - for meta in data["meta"]: - fields = [meta["name"], meta["type"]] - description.append(tuple(fields)) - self.description = tuple(description) - - rows = [] - for line in data["data"]: - row = [] - for i in range(self.field_count): - column_data = converters.convert_column_data(self.description[i][1], line[self.description[i][0]]) - row.append(column_data) - rows.append(tuple(row)) - self.rows = tuple(rows) - except Exception as error: - raise err.InterfaceError("Read return data err:" % error) diff --git a/chdb/dbapi/cursors.py b/chdb/dbapi/cursors.py index ee9e0fa5e8c..6b16ea29ad8 100644 --- a/chdb/dbapi/cursors.py +++ b/chdb/dbapi/cursors.py @@ -29,13 +29,11 @@ class Cursor(object): def __init__(self, connection): self.connection = connection + self._cursor = connection._conn.cursor() self.description = None self.rowcount = -1 - self.rownumber = 0 self.arraysize = 1 self.lastrowid = None - self._result = None - self._rows = None self._executed = None def __enter__(self): @@ -83,14 +81,7 @@ def close(self): """ Closing a cursor just exhausts all remaining data. """ - conn = self.connection - if conn is None: - return - try: - while self.nextset(): - pass - finally: - self.connection = None + self._cursor.close() def _get_db(self): if not self.connection: @@ -121,33 +112,6 @@ def mogrify(self, query, args=None): return query - def _clear_result(self): - self.rownumber = 0 - self._result = None - - self.rowcount = 0 - self.description = None - self.lastrowid = None - self._rows = None - - def _do_get_result(self): - conn = self._get_db() - - self._result = result = conn._result - - self.rowcount = result.affected_rows - self.description = result.description - self.lastrowid = result.insert_id - self._rows = result.rows - - def _query(self, q): - conn = self._get_db() - self._last_executed = q - self._clear_result() - conn.query(q) - self._do_get_result() - return self.rowcount - def execute(self, query, args=None): """Execute a query @@ -162,14 +126,24 @@ def execute(self, query, args=None): If args is a list or tuple, %s can be used as a placeholder in the query. If args is a dict, %(name)s can be used as a placeholder in the query. """ - while self.nextset(): - pass + if args is not None: + query = query % self._escape_args(args, self.connection) + + self._cursor.execute(query) - query = self.mogrify(query, args) + # Get description from Arrow schema + if self._cursor._current_table is not None: + self.description = [ + (field.name, field.type.to_pandas_dtype(), None, None, None, None, None) + for field in self._cursor._current_table.schema + ] + self.rowcount = self._cursor._current_table.num_rows + else: + self.description = None + self.rowcount = -1 - result = self._query(query) self._executed = query - return result + return self.rowcount def executemany(self, query, args): # type: (str, list) -> int @@ -233,34 +207,21 @@ def _check_executed(self): def fetchone(self): """Fetch the next row""" - self._check_executed() - if self._rows is None or self.rownumber >= len(self._rows): - return None - result = self._rows[self.rownumber] - self.rownumber += 1 - return result - - def fetchmany(self, size=None): + if not self._executed: + raise err.ProgrammingError("execute() first") + return self._cursor.fetchone() + + def fetchmany(self, size=1): """Fetch several rows""" - self._check_executed() - if self._rows is None: - return () - end = self.rownumber + (size or self.arraysize) - result = self._rows[self.rownumber:end] - self.rownumber = min(end, len(self._rows)) - return result + if not self._executed: + raise err.ProgrammingError("execute() first") + return self._cursor.fetchmany(size) def fetchall(self): """Fetch all the rows""" - self._check_executed() - if self._rows is None: - return () - if self.rownumber: - result = self._rows[self.rownumber:] - else: - result = self._rows - self.rownumber = len(self._rows) - return result + if not self._executed: + raise err.ProgrammingError("execute() first") + return self._cursor.fetchall() def nextset(self): """Get the next query set""" @@ -272,26 +233,3 @@ def setinputsizes(self, *args): def setoutputsizes(self, *args): """Does nothing, required by DB API.""" - - -class DictCursor(Cursor): - """A cursor which returns results as a dictionary""" - # You can override this to use OrderedDict or other dict-like types. - dict_type = dict - - def _do_get_result(self): - super()._do_get_result() - fields = [] - if self.description: - for f in self.description: - name = f[0] - fields.append(name) - self._fields = fields - - if fields and self._rows: - self._rows = [self._conv_row(r) for r in self._rows] - - def _conv_row(self, row): - if row is None: - return None - return self.dict_type(zip(self._fields, row)) diff --git a/chdb/state/__init__.py b/chdb/state/__init__.py new file mode 100644 index 00000000000..7c8f7d7ea01 --- /dev/null +++ b/chdb/state/__init__.py @@ -0,0 +1,3 @@ +from .sqlitelike import connect + +__all__ = ["connect"] diff --git a/chdb/state/sqlitelike.py b/chdb/state/sqlitelike.py new file mode 100644 index 00000000000..b99eb5e868d --- /dev/null +++ b/chdb/state/sqlitelike.py @@ -0,0 +1,132 @@ +import io +from typing import Optional, Any +from chdb import _chdb + +# try import pyarrow if failed, raise ImportError with suggestion +try: + import pyarrow as pa # noqa +except ImportError as e: + print(f"ImportError: {e}") + print('Please install pyarrow via "pip install pyarrow"') + raise ImportError("Failed to import pyarrow") from None + + +class Connection: + def __init__(self, connection_string: str): + # print("Connection", connection_string) + self._cursor: Optional[Cursor] = None + self._conn = _chdb.connect(connection_string) + + def cursor(self) -> "Cursor": + self._cursor = Cursor(self._conn) + return self._cursor + + def query(self, query: str, format: str = "ArrowStream") -> Any: + return self._conn.query(query, format) + + def close(self) -> None: + # print("close") + if self._cursor: + self._cursor.close() + self._conn.close() + + +class Cursor: + def __init__(self, connection): + self._conn = connection + self._cursor = self._conn.cursor() + self._current_table: Optional[pa.Table] = None + self._current_row: int = 0 + + def execute(self, query: str) -> None: + self._cursor.execute(query) + result_mv = self._cursor.get_memview() + # print("get_result", result_mv) + if self._cursor.has_error(): + raise Exception(self._cursor.error_message()) + if self._cursor.data_size() == 0: + self._current_table = None + self._current_row = 0 + return + arrow_data = result_mv.tobytes() + reader = pa.ipc.open_stream(io.BytesIO(arrow_data)) + self._current_table = reader.read_all() + self._current_row = 0 + + def commit(self) -> None: + self._cursor.commit() + + def fetchone(self) -> Optional[tuple]: + if not self._current_table or self._current_row >= len(self._current_table): + return None + + row_dict = { + col: self._current_table.column(col)[self._current_row].as_py() + for col in self._current_table.column_names + } + self._current_row += 1 + return tuple(row_dict.values()) + + def fetchmany(self, size: int = 1) -> tuple: + if not self._current_table: + return tuple() + + rows = [] + for _ in range(size): + if (row := self.fetchone()) is None: + break + rows.append(row) + return tuple(rows) + + def fetchall(self) -> tuple: + if not self._current_table: + return tuple() + + remaining_rows = [] + while (row := self.fetchone()) is not None: + remaining_rows.append(row) + return tuple(remaining_rows) + + def close(self) -> None: + self._cursor.close() + + def __iter__(self): + return self + + def __next__(self) -> tuple: + row = self.fetchone() + if row is None: + raise StopIteration + return row + + +def connect(connection_string: str = ":memory:") -> Connection: + """ + Create a connection to chDB backgroud server. + Only one open connection is allowed per process. Use `close` to close the connection. + If called with the same connection string, the same connection object will be returned. + You can use the connection object to create cursor object. `cursor` method will return a cursor object. + + Args: + connection_string (str, optional): Connection string. Defaults to ":memory:". + Aslo support file path like: + - ":memory:" (for in-memory database) + - "test.db" (for relative path) + - "file:test.db" (same as above) + - "/path/to/test.db" (for absolute path) + - "file:/path/to/test.db" (same as above) + - "file:test.db?param1=value1¶m2=value2" (for relative path with query params) + - "///path/to/test.db?param1=value1¶m2=value2" (for absolute path) + + Connection string args handling: + Connection string can contain query params like "file:test.db?param1=value1¶m2=value2" + "param1=value1" will be passed to ClickHouse engine as start up args. + + For more details, see `clickhouse local --help --verbose` + Some special args handling: + - "mode=ro" would be "--readonly=1" for clickhouse (read-only mode) + + Returns: + Connection: Connection object + """ + return Connection(connection_string) diff --git a/chdb/utils/__init__.py b/chdb/utils/__init__.py index b0905b0008e..a5cf4fcbbb6 100644 --- a/chdb/utils/__init__.py +++ b/chdb/utils/__init__.py @@ -5,4 +5,5 @@ "convert_to_columnar", "infer_data_type", "infer_data_types", + "trace", ] diff --git a/chdb/utils/trace.py b/chdb/utils/trace.py new file mode 100644 index 00000000000..62e61a325cf --- /dev/null +++ b/chdb/utils/trace.py @@ -0,0 +1,74 @@ +import functools +import inspect +import sys +import linecache +from datetime import datetime + +enable_print = False + + +def print_lines(func): + if not enable_print: + return func + + @functools.wraps(func) + def wrapper(*args, **kwargs): + # Get function name and determine if it's a method + is_method = inspect.ismethod(func) or ( + len(args) > 0 and hasattr(args[0].__class__, func.__name__) + ) + class_name = args[0].__class__.__name__ if is_method else None # type: ignore + + # Get the source code of the function + try: + source_lines, start_line = inspect.getsourcelines(func) + except OSError: + # Handle cases where source might not be available + print(f"Warning: Could not get source for {func.__name__}") + return func(*args, **kwargs) + + def trace(frame, event, arg): + if event == "line": + # Get the current line number and code + line_no = frame.f_lineno + line = linecache.getline(frame.f_code.co_filename, line_no).strip() + + # Don't print decorator lines or empty lines + if line and not line.startswith("@"): + # Get local variables + local_vars = frame.f_locals.copy() + if is_method: + # Remove 'self' from local variables for clarity + local_vars.pop("self", None) + + # Format timestamp + timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3] + + # Create context string (class.method or function) + context = ( + f"{class_name}.{func.__name__}" if class_name else func.__name__ + ) + + # Print execution information + print(f"[{timestamp}] {context} line {line_no}: {line}") + + # Print local variables if they exist and have changed + if local_vars: + vars_str = ", ".join( + f"{k}={repr(v)}" for k, v in local_vars.items() + ) + print(f" Variables: {vars_str}") + return trace + + # Set the trace function + sys.settrace(trace) + + # Call the original function + result = func(*args, **kwargs) + + # Disable tracing + sys.settrace(None) + + return result + + return wrapper diff --git a/examples/dbapi.py b/examples/dbapi.py index 82baa6f6f37..b09ec988b3c 100644 --- a/examples/dbapi.py +++ b/examples/dbapi.py @@ -12,23 +12,23 @@ cur1.close() conn1.close() -conn2 = dbapi.connect(cursorclass=DictCursor) -cur2 = conn2.cursor() -cur2.execute(''' -SELECT - town, - district, - count() AS c, - round(avg(price)) AS price -FROM url('https://datasets-documentation.s3.eu-west-3.amazonaws.com/house_parquet/house_0.parquet') -GROUP BY - town, - district -LIMIT 10 -''') -print("description", cur2.description) -for row in cur2: - print(row) +# conn2 = dbapi.connect(cursorclass=DictCursor) +# cur2 = conn2.cursor() +# cur2.execute(''' +# SELECT +# town, +# district, +# count() AS c, +# round(avg(price)) AS price +# FROM url('https://datasets-documentation.s3.eu-west-3.amazonaws.com/house_parquet/house_0.parquet') +# GROUP BY +# town, +# district +# LIMIT 10 +# ''') +# print("description", cur2.description) +# for row in cur2: +# print(row) -cur2.close() -conn2.close() +# cur2.close() +# conn2.close() diff --git a/programs/local/CMakeLists.txt b/programs/local/CMakeLists.txt index 46605dccd20..053b2362386 100644 --- a/programs/local/CMakeLists.txt +++ b/programs/local/CMakeLists.txt @@ -57,6 +57,7 @@ set (CLICKHOUSE_LOCAL_LINK clickhouse_parsers clickhouse_storages_system clickhouse_table_functions + ch_contrib::parquet ) clickhouse_program_add(local) diff --git a/programs/local/LocalChdb.cpp b/programs/local/LocalChdb.cpp index 7ea577e91ce..606ad070d76 100644 --- a/programs/local/LocalChdb.cpp +++ b/programs/local/LocalChdb.cpp @@ -1,12 +1,14 @@ #include "LocalChdb.h" +#include +#include "chdb.h" #if USE_PYTHON -#include -#include -#include +# include +namespace py = pybind11; + extern bool inside_main = true; @@ -53,6 +55,7 @@ local_result_v2 * queryToBuffer( // Convert std::string to char* std::vector argv_char; + argv_char.reserve(argv.size()); for (auto & arg : argv) argv_char.push_back(const_cast(arg.c_str())); @@ -77,13 +80,203 @@ memoryview_wrapper * query_result::get_memview() return new memoryview_wrapper(this->result_wrapper); } -#ifdef PY_TEST_MAIN -# include -# include -# include -# include -# include -# include + +// Parse SQLite-style connection string +std::pair> connection_wrapper::parse_connection_string(const std::string & conn_str) +{ + std::string path; + std::map params; + + if (conn_str.empty() || conn_str == ":memory:") + { + return {":memory:", params}; + } + + std::string working_str = conn_str; + + // Handle file: prefix + if (working_str.starts_with("file:")) + { + working_str = working_str.substr(5); + + // Handle triple slash for absolute paths + if (working_str.starts_with("///")) + { + working_str = working_str.substr(2); // Remove two slashes, keep one + } + } + + // Split path and parameters + auto query_pos = working_str.find('?'); + if (query_pos != std::string::npos) + { + path = working_str.substr(0, query_pos); + std::string query = working_str.substr(query_pos + 1); + + // Parse parameters + std::istringstream params_stream(query); + std::string param; + while (std::getline(params_stream, param, '&')) + { + auto eq_pos = param.find('='); + if (eq_pos != std::string::npos) + { + std::string key = param.substr(0, eq_pos); + std::string value = param.substr(eq_pos + 1); + params[key] = value; + } + else if (!param.empty()) + { + // Handle parameters without values + params[param] = ""; + } + } + } + else + { + path = working_str; + } + + // Convert relative paths to absolute + if (!path.empty() && path[0] != '/') + { + std::error_code ec; + path = std::filesystem::absolute(path, ec).string(); + if (ec) + { + throw std::runtime_error("Failed to resolve path: " + path); + } + } + + return {path, params}; +} + +std::vector +connection_wrapper::build_clickhouse_args(const std::string & path, const std::map & params) +{ + std::vector argv = {"clickhouse"}; + + if (path != ":memory:") + { + argv.push_back("--path=" + path); + } + + // Map SQLite parameters to ClickHouse arguments + for (const auto & [key, value] : params) + { + if (key == "mode") + { + if (value == "ro") + { + is_readonly = true; + argv.push_back("--readonly=1"); + } + } + else if (value.empty()) + { + // Handle parameters without values (like ?withoutarg) + argv.push_back("--" + key); + } + else + { + argv.push_back("--" + key + "=" + value); + } + } + + return argv; +} + +void connection_wrapper::initialize_database() +{ + if (is_readonly) + { + return; + } + if (is_memory_db) + { + // Setup memory engine + query_result * ret = query("CREATE DATABASE IF NOT EXISTS default ENGINE = Memory; USE default"); + if (ret->has_error()) + { + auto err_msg = fmt::format("Failed to create memory database: {}", std::string(ret->error_message())); + delete ret; + throw std::runtime_error(err_msg); + } + } + else + { + // Create directory if it doesn't exist + std::filesystem::create_directories(db_path); + // Setup Atomic database + query_result * ret = query("CREATE DATABASE IF NOT EXISTS default ENGINE = Atomic; USE default"); + if (ret->has_error()) + { + auto err_msg = fmt::format("Failed to create database: {}", std::string(ret->error_message())); + delete ret; + throw std::runtime_error(err_msg); + } + } +} + +connection_wrapper::connection_wrapper(const std::string & conn_str) +{ + auto [path, params] = parse_connection_string(conn_str); + + auto argv = build_clickhouse_args(path, params); + std::vector argv_char; + argv_char.reserve(argv.size()); + for (auto & arg : argv) + { + argv_char.push_back(const_cast(arg.c_str())); + } + + conn = connect_chdb(argv_char.size(), argv_char.data()); + db_path = path; + is_memory_db = (path == ":memory:"); + initialize_database(); +} + +connection_wrapper::~connection_wrapper() +{ + close_conn(conn); +} + +void connection_wrapper::close() +{ + close_conn(conn); +} + +cursor_wrapper * connection_wrapper::cursor() +{ + return new cursor_wrapper(this); +} + +void connection_wrapper::commit() +{ + // do nothing +} + +query_result * connection_wrapper::query(const std::string & query_str, const std::string & format) +{ + return new query_result(query_conn(*conn, query_str.c_str(), format.c_str()), true); +} + +void cursor_wrapper::execute(const std::string & query_str) +{ + release_result(); + + // Always use Arrow format internally + current_result = query_conn(conn->get_conn(), query_str.c_str(), "ArrowStream"); +} + + +# ifdef PY_TEST_MAIN +# include +# include +# include +# include +# include +# include std::shared_ptr queryToArrow(const std::string & queryStr) @@ -124,7 +317,7 @@ int main() return 0; } -#else +# else PYBIND11_MODULE(_chdb, m) { m.doc() = "chDB module for query function"; @@ -182,6 +375,32 @@ PYBIND11_MODULE(_chdb, m) "Returns:\n" " List[str, str]: List of column name and type pairs."); + py::class_(m, "cursor") + .def(py::init()) + .def("execute", &cursor_wrapper::execute) + .def("commit", &cursor_wrapper::commit) + .def("close", &cursor_wrapper::close) + .def("get_memview", &cursor_wrapper::get_memview) + .def("data_size", &cursor_wrapper::data_size) + .def("rows_read", &cursor_wrapper::rows_read) + .def("bytes_read", &cursor_wrapper::bytes_read) + .def("elapsed", &cursor_wrapper::elapsed) + .def("has_error", &cursor_wrapper::has_error) + .def("error_message", &cursor_wrapper::error_message); + + py::class_(m, "connect") + .def(py::init([](const std::string & path) { return new connection_wrapper(path); }), py::arg("path") = ":memory:") + .def("cursor", &connection_wrapper::cursor) + .def("execute", &connection_wrapper::query) + .def("commit", &connection_wrapper::commit) + .def("close", &connection_wrapper::close) + .def( + "query", + &connection_wrapper::query, + py::arg("query_str"), + py::arg("format") = "CSV", + "Execute a query and return a query_result object"); + m.def( "query", &query, @@ -193,5 +412,5 @@ PYBIND11_MODULE(_chdb, m) "Query chDB and return a query_result object"); } -#endif // PY_TEST_MAIN +# endif // PY_TEST_MAIN #endif // USE_PYTHON diff --git a/programs/local/LocalChdb.h b/programs/local/LocalChdb.h index 6401c04f03b..3193d4893e7 100644 --- a/programs/local/LocalChdb.h +++ b/programs/local/LocalChdb.h @@ -1,31 +1,79 @@ #pragma once +#include +#include #include "config.h" #if USE_PYTHON -#include "chdb.h" -#include -#include -#include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include "chdb.h" namespace py = pybind11; -class local_result_wrapper; +class __attribute__((visibility("default"))) local_result_wrapper; +class __attribute__((visibility("default"))) connection_wrapper; +class __attribute__((visibility("default"))) cursor_wrapper; class __attribute__((visibility("default"))) memoryview_wrapper; class __attribute__((visibility("default"))) query_result; +class connection_wrapper +{ +private: + chdb_conn ** conn; + std::string db_path; + bool is_memory_db; + bool is_readonly; + +public: + explicit connection_wrapper(const std::string & conn_str); + chdb_conn * get_conn() { return *conn; } + ~connection_wrapper(); + cursor_wrapper * cursor(); + void commit(); + void close(); + query_result * query(const std::string & query_str, const std::string & format = "CSV"); + + // Move the private methods declarations here + std::pair> parse_connection_string(const std::string & conn_str); + std::vector build_clickhouse_args(const std::string & path, const std::map & params); + void initialize_database(); +}; class local_result_wrapper { private: local_result_v2 * result; + bool keep_buf; // background server mode will handle buf in ClickHouse engine public: - local_result_wrapper(local_result_v2 * result) : result(result) { } + local_result_wrapper(local_result_v2 * result) : result(result), keep_buf(false) { } + local_result_wrapper(local_result_v2 * result, bool keep_buf) : result(result), keep_buf(keep_buf) { } ~local_result_wrapper() { - free_result_v2(result); + if (keep_buf) + { + if (!result) + return; + + result->_vec = nullptr; + delete[] result->error_message; + delete result; + } + else + { + free_result_v2(result); + } } char * data() { @@ -109,7 +157,8 @@ class query_result public: query_result(local_result_v2 * result) : result_wrapper(std::make_shared(result)) { } - ~query_result() { } + query_result(local_result_v2 * result, bool keep_buf) : result_wrapper(std::make_shared(result, keep_buf)) { } + ~query_result() = default; char * data() { return result_wrapper->data(); } py::bytes bytes() { return result_wrapper->bytes(); } py::str str() { return result_wrapper->str(); } @@ -128,11 +177,11 @@ class memoryview_wrapper std::shared_ptr result_wrapper; public: - memoryview_wrapper(std::shared_ptr result) : result_wrapper(result) + explicit memoryview_wrapper(std::shared_ptr result) : result_wrapper(result) { // std::cerr << "memoryview_wrapper::memoryview_wrapper" << this->result->bytes() << std::endl; } - ~memoryview_wrapper() { } + ~memoryview_wrapper() = default; size_t size() { @@ -159,4 +208,107 @@ class memoryview_wrapper } } }; + +class cursor_wrapper +{ +private: + connection_wrapper * conn; + local_result_v2 * current_result; + + void release_result() + { + if (current_result) + { + // The free_result_v2 vector is managed by the ClickHouse Engine + // As we don't want to copy the data, so just release the memory here. + // The memory will be released when the ClientBase.query_result_buf is reassigned. + if (current_result->_vec) + { + current_result->_vec = nullptr; + } + free_result_v2(current_result); + + current_result = nullptr; + } + } + +public: + explicit cursor_wrapper(connection_wrapper * connection) : conn(connection), current_result(nullptr) { } + + ~cursor_wrapper() { release_result(); } + + void execute(const std::string & query_str); + + void commit() + { + // do nothing + } + + void close() { release_result(); } + + py::memoryview get_memview() + { + if (current_result == nullptr) + { + return py::memoryview(py::memoryview::from_memory(nullptr, 0, true)); + } + return py::memoryview(py::memoryview::from_memory(current_result->buf, current_result->len, true)); + } + + size_t data_size() + { + if (current_result == nullptr) + { + return 0; + } + return current_result->len; + } + + size_t rows_read() + { + if (current_result == nullptr) + { + return 0; + } + return current_result->rows_read; + } + + size_t bytes_read() + { + if (current_result == nullptr) + { + return 0; + } + return current_result->bytes_read; + } + + double elapsed() + { + if (current_result == nullptr) + { + return 0; + } + return current_result->elapsed; + } + + bool has_error() + { + if (current_result == nullptr) + { + return false; + } + return current_result->error_message != nullptr; + } + + py::str error_message() + { + if (has_error()) + { + return py::str(current_result->error_message); + } + return py::str(); + } +}; + + #endif diff --git a/programs/local/LocalServer.cpp b/programs/local/LocalServer.cpp index 4239ef0aa59..ff31e18cc3e 100644 --- a/programs/local/LocalServer.cpp +++ b/programs/local/LocalServer.cpp @@ -65,7 +65,9 @@ namespace fs = std::filesystem; - +std::mutex global_connection_mutex; +chdb_conn * global_conn_ptr = nullptr; +std::string global_db_path; namespace CurrentMetrics { extern const Metric MemoryTracking; @@ -493,12 +495,12 @@ try } } - is_interactive = stdin_is_a_tty + is_interactive = !is_background && stdin_is_a_tty && (getClientConfiguration().hasOption("interactive") || (queries.empty() && !getClientConfiguration().has("table-structure") && queries_files.empty() && !getClientConfiguration().has("table-file"))); - if (!is_interactive) + if (!is_interactive && !is_background) { /// We will terminate process on error static KillingErrorHandler error_handler; @@ -523,7 +525,7 @@ try processConfig(); - SCOPE_EXIT({ cleanup(); }); + SCOPE_EXIT({ if (!is_background) cleanup(); }); initTTYBuffer(toProgressOption(getClientConfiguration().getString("progress", "default"))); ASTAlterCommand::setFormatAlterCommandsWithParentheses(true); @@ -560,7 +562,11 @@ try #if defined(FUZZING_MODE) runLibFuzzer(); #else - if (is_interactive && !delayed_interactive) + if (is_background) + { + runBackground(); + } + else if (is_interactive && !delayed_interactive) { runInteractive(); } @@ -1132,7 +1138,7 @@ std::unique_ptr pyEntryClickHouseLocal(int argc, char ** argv) app.getProcessedBytes(), app.getElapsedTime()); } else { - return std::make_unique(app.get_error_msg()); + return std::make_unique(app.getErrorMsg()); } } catch (const DB::Exception & e) @@ -1150,6 +1156,47 @@ std::unique_ptr pyEntryClickHouseLocal(int argc, char ** argv) } } +DB::LocalServer * bgClickHouseLocal(int argc, char ** argv) +{ + DB::LocalServer * app = nullptr; + try + { + app = new DB::LocalServer(); + app->setBackground(true); + app->init(argc, argv); + int ret = app->run(); + if (ret != 0) + { + auto err_msg = app->getErrorMsg(); + LOG_ERROR(&app->logger(), "Error running bgClickHouseLocal: {}", err_msg); + delete app; + app = nullptr; + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Error running bgClickHouseLocal: {}", err_msg); + } + return app; + } + catch (const DB::Exception & e) + { + delete app; + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "bgClickHouseLocal {}", DB::getExceptionMessage(e, false)); + } + catch (const Poco::Exception & e) + { + delete app; + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "bgClickHouseLocal {}", e.displayText()); + } + catch (const std::exception & e) + { + delete app; + throw std::domain_error(e.what()); + } + catch (...) + { + delete app; + throw std::domain_error(DB::getCurrentExceptionMessage(true)); + } +} + // todo fix the memory leak and unnecessary copy local_result * query_stable(int argc, char ** argv) { @@ -1207,7 +1254,7 @@ local_result_v2 * query_stable_v2(int argc, char ** argv) else { // Handle successful data retrieval scenario - res->_vec = new std::vector(*result->buf_); + res->_vec = result->buf_; res->len = result->buf_->size(); res->buf = result->buf_->data(); res->rows_read = result->rows_; @@ -1240,6 +1287,123 @@ void free_result_v2(local_result_v2 * result) delete result; } +chdb_conn ** connect_chdb(int argc, char ** argv) +{ + std::lock_guard lock(global_connection_mutex); + + // Check if we already have a connection with this path + std::string path = ":memory:"; // Default path + for (int i = 1; i < argc; i++) + { + if (strncmp(argv[i], "--path=", 7) == 0) + { + path = argv[i] + 7; + break; + } + } + + if (global_conn_ptr != nullptr) + { + if (path == global_db_path) + { + // Return existing connection + return &global_conn_ptr; + } + throw DB::Exception( + DB::ErrorCodes::BAD_ARGUMENTS, + "Another connection is already active with different path. Close the existing connection first."); + } + + // Create new connection + DB::LocalServer * server = bgClickHouseLocal(argc, argv); + auto * conn = new chdb_conn(); + conn->server = server; + conn->connected = true; + + // Store globally + global_conn_ptr = conn; + global_db_path = path; + + return &global_conn_ptr; +} + +void close_conn(chdb_conn ** conn) +{ + std::lock_guard lock(global_connection_mutex); + + if (!conn || !*conn) + return; + + if ((*conn)->connected) + { + DB::LocalServer * server = static_cast((*conn)->server); + server->cleanup(); + delete server; + + if (*conn == global_conn_ptr) + { + global_conn_ptr = nullptr; + global_db_path.clear(); + } + } + + delete *conn; + *conn = nullptr; +} + +struct local_result_v2 * query_conn(chdb_conn * conn, const char * query, const char * format) +{ + auto * result = new local_result_v2{ nullptr, 0, nullptr, 0, 0, 0, nullptr }; + + if (!conn || !conn->connected) + return result; + + std::lock_guard lock(global_connection_mutex); + + try + { + DB::LocalServer * server = static_cast(conn->server); + + // Execute query + if (!server->parseQueryTextWithOutputFormat(query, format)) + { + std::string error = server->getErrorMsg(); + result->error_message = new char[error.length() + 1]; + std::strcpy(result->error_message, error.c_str()); + return result; + } + + // Get query results without copying + auto output_span = server->getQueryOutputSpan(); + if (!output_span.empty()) + { + result->_vec = nullptr; + result->buf = output_span.data(); + result->len = output_span.size(); + } + + result->rows_read = server->getProcessedRows(); + result->bytes_read = server->getProcessedBytes(); + result->elapsed = server->getElapsedTime(); + + return result; + } + catch (const DB::Exception & e) + { + std::string error = DB::getExceptionMessage(e, false); + result->error_message = new char[error.length() + 1]; + std::strcpy(result->error_message, error.c_str()); + return result; + } + catch (...) + { + std::string error = DB::getCurrentExceptionMessage(true); + result->error_message = new char[error.length() + 1]; + std::strcpy(result->error_message, error.c_str()); + return result; + } +} + /** * The dummy_calls function is used to prevent certain functions from being optimized out by the compiler. * It includes calls to 'query_stable' and 'free_result' within a condition that is always false. diff --git a/programs/local/LocalServer.h b/programs/local/LocalServer.h index b18a7a90961..4488299ef97 100644 --- a/programs/local/LocalServer.h +++ b/programs/local/LocalServer.h @@ -30,11 +30,13 @@ class LocalServer : public ClientApplicationBase, public Loggers int main(const std::vector & /*args*/) override; -protected: - Poco::Util::LayeredConfiguration & getClientConfiguration() override; + void cleanup(); void connect() override; +protected: + Poco::Util::LayeredConfiguration & getClientConfiguration() override; + void processError(const String & query) const override; String getName() const override { return "local"; } @@ -60,7 +62,6 @@ class LocalServer : public ClientApplicationBase, public Loggers void tryInitPath(); void setupUsers(); - void cleanup(); void applyCmdOptions(ContextMutablePtr context); void applyCmdSettings(ContextMutablePtr context); diff --git a/programs/local/chdb.h b/programs/local/chdb.h index 1f4ae3835c4..a01b9f77367 100644 --- a/programs/local/chdb.h +++ b/programs/local/chdb.h @@ -5,6 +5,7 @@ # include extern "C" { #else +# include # include # include #endif @@ -50,6 +51,16 @@ CHDB_EXPORT void free_result(struct local_result * result); CHDB_EXPORT struct local_result_v2 * query_stable_v2(int argc, char ** argv); CHDB_EXPORT void free_result_v2(struct local_result_v2 * result); +struct chdb_conn +{ + void * server; // LocalServer * server; + bool connected; +}; + +CHDB_EXPORT struct chdb_conn ** connect_chdb(int argc, char ** argv); +CHDB_EXPORT void close_conn(struct chdb_conn ** conn); +CHDB_EXPORT struct local_result_v2 * query_conn(struct chdb_conn * conn, const char * query, const char * format); + #ifdef __cplusplus } #endif diff --git a/setup.py b/setup.py index 93aae87ceef..1bb12c5765b 100644 --- a/setup.py +++ b/setup.py @@ -268,6 +268,10 @@ def build_extensions(self): exclude_package_data={"": ["*.pyc", "src/**"]}, ext_modules=ext_modules, python_requires=">=3.8", + install_requires=[ + "pyarrow>=13.0.0", + "pandas>=2.0.0", + ], cmdclass={"build_ext": BuildExt}, test_suite="tests", zip_safe=False, diff --git a/src/Client/ClientBase.cpp b/src/Client/ClientBase.cpp index ae0b526854c..fc6fb1e8c39 100644 --- a/src/Client/ClientBase.cpp +++ b/src/Client/ClientBase.cpp @@ -707,6 +707,12 @@ bool ClientBase::isRegularFile(int fd) return fstat(fd, &file_stat) == 0 && S_ISREG(file_stat.st_mode); } +void ClientBase::setDefaultFormat(const String & format) +{ + default_output_format = format; + is_default_format = false; +} + void ClientBase::setDefaultFormatsAndCompressionFromConfiguration() { if (getClientConfiguration().has("output-format")) @@ -2419,6 +2425,23 @@ bool ClientBase::processQueryText(const String & text) return executeMultiQuery(text); } +bool ClientBase::parseQueryTextWithOutputFormat(const String & query, const String & format) +{ + // Set output format if specified + if (!format.empty()) + { + client_context->setDefaultFormat(format); + setDefaultFormat(format); + } + + // Check connection and reconnect if needed + if (!connection->checkConnected(connection_parameters.timeouts)) + connect(); + + // Execute query + return processQueryText(query); +} + String ClientBase::prompt() const { @@ -2677,6 +2700,21 @@ void ClientBase::runInteractive() } +void ClientBase::runBackground() +{ + initQueryIdFormats(); + + // Initialize DateLUT here to avoid counting time spent here as query execution time. + (void)DateLUT::instance().getTimeZone(); + + if (home_path.empty()) + { + const char * home_path_cstr = getenv("HOME"); // NOLINT(concurrency-mt-unsafe) + if (home_path_cstr) + home_path = home_path_cstr; + } +} + bool ClientBase::processMultiQueryFromFile(const String & file_name) { String queries_from_file; diff --git a/src/Client/ClientBase.h b/src/Client/ClientBase.h index 53a0d142a13..adbfd7f60b5 100644 --- a/src/Client/ClientBase.h +++ b/src/Client/ClientBase.h @@ -104,10 +104,22 @@ class ClientBase // std::vector vec(buf.begin(), buf.end()); return query_result_memory; } + + std::span getQueryOutputSpan() + { + if (!query_result_memory || !query_result_buf) + return {}; + auto size = query_result_buf->count(); + return std::span(query_result_memory->begin(), size); + } + size_t getProcessedRows() const { return processed_rows; } size_t getProcessedBytes() const { return processed_bytes; } double getElapsedTime() const { return progress_indication.elapsedSeconds(); } - std::string get_error_msg() const { return error_message_oss.str(); } + std::string getErrorMsg() const { return error_message_oss.str(); } + void setDefaultFormat(const String & format); + void setBackground(bool is_background_) { is_background = is_background_; } + bool parseQueryTextWithOutputFormat(const String & query, const String & format); ASTPtr parseQuery(const char *& pos, const char * end, const Settings & settings, bool allow_multi_statements); @@ -115,6 +127,9 @@ class ClientBase void runInteractive(); void runNonInteractive(); + bool is_background = false; + void runBackground(); + char * argv0 = nullptr; void runLibFuzzer(); diff --git a/src/Interpreters/InterpreterUseQuery.cpp b/src/Interpreters/InterpreterUseQuery.cpp index 3999b5a2fef..8068a87d122 100644 --- a/src/Interpreters/InterpreterUseQuery.cpp +++ b/src/Interpreters/InterpreterUseQuery.cpp @@ -33,10 +33,11 @@ BlockIO InterpreterUseQuery::execute() tmp_path_fs << new_database; tmp_path_fs.close(); } - else - { - throw Exception(ErrorCodes::CANNOT_OPEN_FILE, "Cannot open file {} for writing", default_database_path.string()); - } + //chdb todo: fix the following code on bgClickHouseLocal mode + // else + // { + // throw Exception(ErrorCodes::CANNOT_OPEN_FILE, "Cannot open file {} for writing", default_database_path.string()); + // } return {}; } diff --git a/tests/test_conn_cursor.py b/tests/test_conn_cursor.py new file mode 100644 index 00000000000..adf40108568 --- /dev/null +++ b/tests/test_conn_cursor.py @@ -0,0 +1,292 @@ +import unittest +import os +import shutil +from typing import List, Any, Dict + +from chdb import connect + +db_path = "test_db_3fdds" + + +class TestCHDB(unittest.TestCase): + def setUp(self): + if os.path.exists(db_path): + shutil.rmtree(db_path) + + def test_conn_query_without_receiving_result(self): + conn = connect() + conn.query("SELECT 1", "CSV") + conn.query("SELECT 1", "Null") + conn.query("SELECT 1", "Null") + conn.close() + + def test_basic_operations(self): + conn = connect(":memory:") + cursor = conn.cursor() + # Create a table + cursor.execute( + """ + CREATE TABLE users ( + id Int32, + name String, + scores Array(UInt8) + ) ENGINE = Memory + """ + ) + + # Insert test data + cursor.execute( + """ + INSERT INTO users VALUES + (1, 'Alice', [95, 87, 92]), + (2, 'Bob', [88, 85, 90]), + (3, 'Charlie', [91, 89, 94]) + """ + ) + + # Test fetchone + cursor.execute("SELECT * FROM users WHERE id = 1") + row = cursor.fetchone() + print(row) + self.assertEqual(row[0], 1) + self.assertEqual(row[1], "Alice") + self.assertEqual(row[2], [95, 87, 92]) + + cursor.execute("SELECT * FROM users WHERE id = 2") + row = cursor.fetchone() + print(row) + self.assertEqual(row[0], 2) + self.assertEqual(row[1], "Bob") + self.assertEqual(row[2], [88, 85, 90]) + + row = cursor.fetchone() + self.assertIsNone(row) + + # Test fetchall + print("fetchall") + cursor.execute("SELECT * FROM users ORDER BY id") + rows = cursor.fetchall() + self.assertEqual(len(rows), 3) + self.assertEqual(rows[1][1], "Bob") + + # Test iteration + cursor.execute("SELECT * FROM users ORDER BY id") + rows = [row for row in cursor] + self.assertEqual(len(rows), 3) + self.assertEqual(rows[2][1], "Charlie") + cursor.close() + conn.close() + + def test_connection_management(self): + # Test file-based connection + file_conn = connect(f"file:{db_path}") + self.assertIsNotNone(file_conn) + file_conn.close() + + # Test connection with parameters + readonly_conn = connect(f"file:{db_path}?mode=ro") + self.assertIsNotNone(readonly_conn) + with self.assertRaises(Exception): + cur = readonly_conn.cursor() + cur.execute("CREATE TABLE test (id Int32) ENGINE = Memory") + readonly_conn.close() + + # Test create dir fails + with self.assertRaises(Exception): + # try to create a directory with this test file name + # which will fail surely + connect("test_conn_cursor.py") + + def test_cursor_error_handling(self): + conn = connect(":memory:") + cursor = conn.cursor() + try: + # Test syntax error + with self.assertRaises(Exception): + cursor.execute("INVALID SQL QUERY") + + # Test table not found error + with self.assertRaises(Exception): + cursor.execute("SELECT * FROM nonexistent_table") + finally: + cursor.close() + conn.close() + + def test_transaction_behavior(self): + # Create test table + conn = connect(":memory:") + cursor = conn.cursor() + cursor.execute( + """ + CREATE TABLE test_transactions ( + id Int32, + value String + ) ENGINE = Memory + """ + ) + + # Test basic insert + cursor.execute("INSERT INTO test_transactions VALUES (1, 'test')") + cursor.commit() # Should work even though Memory engine doesn't support transactions + + # Verify data + cursor.execute("SELECT * FROM test_transactions") + row = cursor.fetchone() + self.assertEqual(row, (1, "test")) + + def test_cursor_data_types(self): + conn = connect(":memory:") + cursor = conn.cursor() + # Test various data types + cursor.execute( + """ + CREATE TABLE type_test ( + int_val Int32, + float_val Float64, + string_val String, + array_val Array(Int32), + nullable_val Nullable(String), + date_val Date, + datetime_val DateTime + ) ENGINE = Memory + """ + ) + + cursor.execute( + """ + INSERT INTO type_test VALUES + (42, 3.14, 'hello', [1,2,3], NULL, '2023-01-01', '2023-01-01 12:00:00') + """ + ) + + cursor.execute("SELECT * FROM type_test") + row = cursor.fetchone() + self.assertEqual(row[0], 42) + self.assertAlmostEqual(row[1], 3.14) + self.assertEqual(row[2], "hello") + self.assertEqual(row[3], [1, 2, 3]) + self.assertIsNone(row[4]) + + def test_cursor_multiple_results(self): + conn = connect(":memory:") + cursor = conn.cursor() + # Create test data + cursor.execute( + """ + CREATE TABLE multi_test (id Int32, value String) ENGINE = Memory; + INSERT INTO multi_test VALUES (1, 'one'), (2, 'two'), (3, 'three'); + """ + ) + + # Test partial fetching + cursor.execute("SELECT * FROM multi_test ORDER BY id") + first_row = cursor.fetchone() + self.assertEqual(first_row, (1, "one")) + + remaining_rows = cursor.fetchall() + self.assertEqual(len(remaining_rows), 2) + self.assertEqual(remaining_rows[0], (2, "two")) + + def test_query_formats(self): + conn = connect(":memory:") + cursor = conn.cursor() + # Create test data + cursor.execute( + """ + CREATE TABLE format_test (id Int32, value String) ENGINE = Memory; + INSERT INTO format_test VALUES (1, 'test'); + """ + ) + + # Test different output formats + csv_result = conn.query("SELECT * FROM format_test", format="CSV") + self.assertIsNotNone(csv_result) + + arrow_result = conn.query("SELECT * FROM format_test", format="ArrowStream") + self.assertIsNotNone(arrow_result) + + def test_cursor_statistics(self): + conn = connect(":memory:") + cursor = conn.cursor() + # Create and populate test table + cursor.execute( + """ + CREATE TABLE stats_test (id Int32, value String) ENGINE = Memory; + INSERT INTO stats_test SELECT number, toString(number) + FROM numbers(1000); + """ + ) + + # Execute query and check statistics + cursor.execute("SELECT * FROM stats_test") + self.assertGreater(cursor._cursor.rows_read(), 0) + self.assertGreater(cursor._cursor.bytes_read(), 0) + self.assertGreater(cursor._cursor.elapsed(), 0) + + def test_memory_management(self): + conn = connect(":memory:") + cursor = conn.cursor() + # Test multiple executions + for i in range(10): + cursor = conn.cursor() + cursor.execute("SELECT 1") + self.assertIsNotNone(cursor.fetchone()) + + # Test large result sets + cursor.execute( + """ + SELECT number, toString(number) as str_num + FROM numbers(1000000) + """ + ) + rows = cursor.fetchall() + self.assertEqual(len(rows), 1000000) + + def test_multiple_connections(self): + conn1 = connect(":memory:") + conn2 = connect(":memory:") + cursor1 = conn1.cursor() + cursor2 = conn2.cursor() + + with self.assertRaises(Exception): + connect("file:test.db") + + # Create table in first connection + cursor1.execute( + """ + CREATE TABLE test_table (id Int32, value String) ENGINE = Memory + """ + ) + + # Insert data in second connection + cursor2.execute("INSERT INTO test_table VALUES (1, 'test')") + cursor2.commit() + + # Query data from first connection + cursor1.execute("SELECT * FROM test_table") + row = cursor1.fetchone() + self.assertEqual(row, (1, "test")) + + conn1.close() + conn2.close() + + def test_connection_properties(self): + # conn = connect("{db_path}?log_queries=1&verbose=1&log-level=test") + with self.assertRaises(Exception): + conn = connect(f"{db_path}?not_exist_flag=1") + with self.assertRaises(Exception): + conn = connect(f"{db_path}?verbose=1") + + conn = connect(f"{db_path}?verbose&log-level=test") + ret = conn.query("SELECT 123", "CSV") + print(ret) + print(len(ret)) + self.assertEqual(str(ret), "123\n") + ret = conn.query("show tables in system", "CSV") + self.assertGreater(len(ret), 10) + + conn.close() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_dbapi.py b/tests/test_dbapi.py index 0ed283e4b9b..eb16f6c7cec 100644 --- a/tests/test_dbapi.py +++ b/tests/test_dbapi.py @@ -1,80 +1,95 @@ #!/usr/bin/env python3 +import os +import shutil +import tempfile import unittest from chdb import dbapi # version should be string split by '.' # eg. '0.12.0' or '0.12.0rc1' or '0.12.0beta1' or '0.12.0alpha1' or '0.12.0a1' -expected_version_pattern = r'^\d+\.\d+\.\d+(.*)?$' -expected_clickhouse_version_pattern = r'^\d+\.\d+\.\d+.\d+$' +expected_version_pattern = r"^\d+\.\d+\.\d+(.*)?$" +expected_clickhouse_version_pattern = r"^\d+\.\d+\.\d+.\d+$" class TestDBAPI(unittest.TestCase): + def test_select_version(self): conn = dbapi.connect() cur = conn.cursor() - cur.execute('select version()') # ClickHouse version - description = cur.description + cur.execute("select version()") # ClickHouse version + # description = cur.description data = cur.fetchone() cur.close() conn.close() # Add your assertions here to validate the description and data - print(description) + # print(description) print(data) self.assertRegex(data[0], expected_clickhouse_version_pattern) def test_insert_and_read_data(self): - conn = dbapi.connect() - cur = conn.cursor() - cur.execute("CREATE DATABASE IF NOT EXISTS test_db ENGINE = Atomic") - cur.execute("USE test_db") - cur.execute(""" - CREATE TABLE rate ( - day Date, - value Int64 - ) ENGINE = ReplacingMergeTree ORDER BY day""") + # make a tmp dir context + with tempfile.TemporaryDirectory() as tmpdirname: + conn = dbapi.connect(tmpdirname) + print(conn) + cur = conn.cursor() + # cur.execute("CREATE DATABASE IF NOT EXISTS test_db ENGINE = Atomic") + # cur.execute("USE test_db") + cur.execute( + """ + CREATE TABLE rate ( + day Date, + value Int64 + ) ENGINE = ReplacingMergeTree ORDER BY day""" + ) - # Insert single value - cur.execute("INSERT INTO rate VALUES (%s, %s)", ("2021-01-01", 24)) - # Insert multiple values - cur.executemany("INSERT INTO rate VALUES (%s, %s)", [("2021-01-02", 128), ("2021-01-03", 256)]) - # Test executemany outside optimized INSERT/REPLACE path - cur.executemany("ALTER TABLE rate UPDATE value = %s WHERE day = %s", [(72, "2021-01-02"), (96, "2021-01-03")]) + # Insert single value + cur.execute("INSERT INTO rate VALUES (%s, %s)", ("2021-01-01", 24)) + # Insert multiple values + cur.executemany( + "INSERT INTO rate VALUES (%s, %s)", + [("2021-01-02", 128), ("2021-01-03", 256)], + ) + # Test executemany outside optimized INSERT/REPLACE path + cur.executemany( + "ALTER TABLE rate UPDATE value = %s WHERE day = %s", + [(72, "2021-01-02"), (96, "2021-01-03")], + ) - # Test fetchone - cur.execute("SELECT value FROM rate ORDER BY day DESC LIMIT 2") - row1 = cur.fetchone() - self.assertEqual(row1, (96,)) - row2 = cur.fetchone() - self.assertEqual(row2, (72,)) - row3 = cur.fetchone() - self.assertIsNone(row3) + # Test fetchone + cur.execute("SELECT value FROM rate ORDER BY day DESC LIMIT 2") + row1 = cur.fetchone() + self.assertEqual(row1, (96,)) + row2 = cur.fetchone() + self.assertEqual(row2, (72,)) + row3 = cur.fetchone() + self.assertIsNone(row3) - # Test fetchmany - cur.execute("SELECT value FROM rate ORDER BY day DESC") - result_set1 = cur.fetchmany(2) - self.assertEqual(result_set1, ((96,), (72,))) - result_set2 = cur.fetchmany(1) - self.assertEqual(result_set2, ((24,),)) + # Test fetchmany + cur.execute("SELECT value FROM rate ORDER BY day DESC") + result_set1 = cur.fetchmany(2) + self.assertEqual(result_set1, ((96,), (72,))) + result_set2 = cur.fetchmany(1) + self.assertEqual(result_set2, ((24,),)) - # Test fetchall - cur.execute("SELECT value FROM rate ORDER BY day DESC") - rows = cur.fetchall() - self.assertEqual(rows, ((96,), (72,), (24,))) + # Test fetchall + cur.execute("SELECT value FROM rate ORDER BY day DESC") + rows = cur.fetchall() + self.assertEqual(rows, ((96,), (72,), (24,))) - # Clean up - cur.close() - conn.close() + # Clean up + cur.close() + conn.close() def test_select_chdb_version(self): ver = dbapi.get_client_info() # chDB version liek '0.12.0' ver_tuple = dbapi.chdb_version # chDB version tuple like ('0', '12', '0') print(ver) print(ver_tuple) - self.assertEqual(ver, '.'.join(ver_tuple)) + self.assertEqual(ver, ".".join(ver_tuple)) self.assertRegex(ver, expected_version_pattern) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()