Skip to content

Commit

Permalink
Add vector_data_type option; add deployment and connection to portal …
Browse files Browse the repository at this point in the history
…object
  • Loading branch information
kesmit13 committed Jul 12, 2024
1 parent 23f9817 commit e5bb10a
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 0 deletions.
11 changes: 11 additions & 0 deletions singlestoredb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,17 @@
environ='SINGLESTOREDB_ENABLE_EXTENDED_DATA_TYPES',
)

register_option(
'vector_data_format', 'string',
functools.partial(
check_str,
valid_values=['json', 'binary'],
),
'binary',
'Format for vector data values',
environ='SINGLESTOREDB_VECTOR_DATA_FORMAT',
)

register_option(
'fusion.enabled', 'bool', check_bool, False,
'Should Fusion SQL queries be enabled?',
Expand Down
3 changes: 3 additions & 0 deletions singlestoredb/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,7 @@ def connect(
encoding_errors: Optional[str] = None,
track_env: Optional[bool] = None,
enable_extended_data_types: Optional[bool] = None,
vector_data_format: Optional[str] = None,
) -> Connection:
"""
Return a SingleStoreDB connection.
Expand Down Expand Up @@ -1387,6 +1388,8 @@ def connect(
Should the connection track the SINGLESTOREDB_URL environment variable?
enable_extended_data_types : bool, optional
Should extended data types (BSON, vector) be enabled?
vector_data_format : str, optional
Format for vector types: json or binary
Examples
--------
Expand Down
2 changes: 2 additions & 0 deletions singlestoredb/http/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,5 +1243,7 @@ def connect(
inf_as_null: Optional[bool] = None,
encoding_errors: Optional[str] = None,
track_env: Optional[bool] = None,
enable_extended_data_types: Optional[bool] = None,
vector_data_format: Optional[str] = None,
) -> Connection:
return Connection(**dict(locals()))
19 changes: 19 additions & 0 deletions singlestoredb/mysql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ class Connection(BaseConnection):
Should the connection track the SINGLESTOREDB_URL environment variable?
enable_extended_data_types : bool, optional
Should extended data types (BSON, vector) be enabled?
vector_data_format : str, optional
Specify the data type of vector values: json or binary
See `Connection <https://www.python.org/dev/peps/pep-0249/#connection-objects>`_
in the specification.
Expand Down Expand Up @@ -350,6 +352,7 @@ def __init__( # noqa: C901
encoding_errors='strict',
track_env=False,
enable_extended_data_types=True,
vector_data_format='binary',
):
BaseConnection.__init__(**dict(locals()))

Expand Down Expand Up @@ -634,6 +637,13 @@ def _config(key, arg):
self._in_sync = False
self._track_env = bool(track_env) or self.host == 'singlestore.com'
self._enable_extended_data_types = enable_extended_data_types
if vector_data_format.lower() in ['json', 'binary']:
self._vector_data_format = vector_data_format
else:
raise ValueError(
'unknown value for vector_data_format, '
f'expecting "json" or "binary": {vector_data_format}',
)
self._connection_info = {}
events.subscribe(self._handle_event)

Expand Down Expand Up @@ -1117,6 +1127,15 @@ def connect(self, sock=None):
pass
c.close()

if self._vector_data_format:
c = self.cursor()
try:
val = self._vector_data_format
c.execute(f'SET @@SESSION.vector_type_project_format={val}')
except self.OperationalError:
pass
c.close()

if self.init_command is not None:
c = self.cursor()
c.execute(self.init_command)
Expand Down
32 changes: 32 additions & 0 deletions singlestoredb/notebook/_portal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple

from . import _objects as obj
from ..management import workspace as mgr
Expand Down Expand Up @@ -187,6 +188,37 @@ def workspace(self, name_or_id: str) -> None:
timeout_message='timeout waiting for workspace update',
)

deployment = workspace

@property
def connection(self) -> Tuple[obj.Workspace, Optional[str]]:
"""Workspace and default database name."""
return self.workspace, self.default_database

@connection.setter
def connection(self, workspace_and_default_database: Tuple[str, str]) -> None:
"""Set workspace and default database name."""
name_or_id, default_database = workspace_and_default_database
if re.match(
r'[a-z0-9]{8}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{12}',
name_or_id, flags=re.I,
):
w = mgr.get_workspace(name_or_id)
else:
w = mgr.get_workspace_group(self.workspace_group_id).workspaces[name_or_id]

if w.state and w.state.lower() not in ['active', 'resumed']:
raise RuntimeError('workspace is not active')

id = w.id

self._call_javascript(
'changeConnection', [id, default_database],
wait_on_condition=lambda: self.workspace_id == id and
self.default_database == default_database, # type: ignore
timeout_message='timeout waiting for workspace update',
)

@property
def cluster_id(self) -> Optional[str]:
"""Cluster ID."""
Expand Down

0 comments on commit e5bb10a

Please sign in to comment.