From 454504fbc73ad1860776bbd01ab3f3d856c7699c Mon Sep 17 00:00:00 2001 From: Eugene Clark Date: Mon, 29 Apr 2024 16:03:45 -0400 Subject: [PATCH] Integrate SQLAlchemy for db conn management and introduce new SqlStorage abstraction (#93) * Update expected VRS IDs for VCF tests * Update VRS IDs for variation tests * Added new SqlStorage implementation as an abstract base class for all RDBMS storage implementations The base class utilizes SqlAlchemy for connection management and SQL statement execution because it is the only connection pooling library that works with the Snowflake connector. The base class includes the background db write capabilities from the Snowflake implementation and actual SQL statement execution where standard SQL is used. Abstract methods are defined for queries where the SQL or database APIs are not standard. * Switch to snowflake-sqlalchemy package * Update the Postgres storage implementation to be a subclass of the new SqlStorage base class Primarily removed code that was included in the base class and reorganized remaining code into the base class API shape. Because the Snowflake connector only supports SqlAlchemy 1.4 which in turn only supports psycopg2, had to modify the batch insert logic to use a different API. * Update Snowflake storage implementation to be a subclass of new SqlStorage base class Removed code that is now included in base class and reorganized remaining code into base class API shape. * Updated unit tests to cover the use of background writes in Postgres storage implementation Refactored mocks for SqlAlchemy based testing into separate module * Rename test file and replace unused var names with underscore * Add storage option to always fully flush on batch context exit * When storage construction does not complete, the batch_thread and conn_pool are sometimes not created leading to spurious errors on close(). Check for these attributes before attempting to clean them up. * Depending on the underlying database, the returned column value can be a string or a dict * Add batch add mode settings to control what type of SQL statement to use when adding new VRS objects to the database * Update variation test data to match VRS 2.0 changes * Comment out response model to return full VRS objects instead of serialized version * Make get location/variation behave consistently even when the object store does not throw a KeyError on missing key * Remove code added to make debugging easier * Uupdate queries to use specified table name * Fix bug in detecting column value type on fetch * Batch add mode only makes sense for Snowflake because in Postgres the vrs_objects table has a primary key and uses "ON CONFLICT" on inserts * Switch to using question mark bind variables for Snowflake because named parameters were not working Pick up table name from environment in unit tests * Update to batch insert to play nicely with Snowflake quirks * Update example URL to be SQLAlchemy friendly * Use super() to invoke __init__() * Add support for Snowflake private key auth * Add monkey patch workaround for bug in Snowflake SQLAlchemy * Update collation in temp loading table * Storage implementations should be consistent with MutableMapping API and throw KeyError when an item is not found * Remove VRS model classes from response objects because the serialization used internally is not correct for API responses * Corrected path used for missing allele id test * Get location and get variation should be consistent in behavior when id is not found * Revert unecessary change * Throw KeyError when id is not found * Add missing argument to _get_connect_args * Code formatting * Suppress SQL injection warning as elsewhere * Code formatting * Adding missing SQL injection warning suppressions * Update README to reflect changes * Address "Incomplete URL substring sanitization" warning --- README.md | 83 ++- setup.cfg | 2 +- src/anyvar/anyvar.py | 3 +- src/anyvar/storage/postgres.py | 410 +++-------- src/anyvar/storage/snowflake.py | 674 ++++++------------ src/anyvar/storage/sql_storage.py | 499 +++++++++++++ tests/conftest.py | 2 +- tests/storage/sqlalchemy_mocks.py | 153 ++++ tests/storage/test_postgres.py | 210 ++++++ tests/storage/test_snowflake.py | 406 ++++++----- ...mapping.py => test_sql_storage_mapping.py} | 25 +- 11 files changed, 1486 insertions(+), 981 deletions(-) create mode 100644 src/anyvar/storage/sql_storage.py create mode 100644 tests/storage/sqlalchemy_mocks.py create mode 100644 tests/storage/test_postgres.py rename tests/storage/{test_storage_mapping.py => test_sql_storage_mapping.py} (61%) diff --git a/README.md b/README.md index aa985b7..1808454 100644 --- a/README.md +++ b/README.md @@ -57,40 +57,68 @@ In another terminal: curl http://localhost:8000/info -### Setting up Postgres - -A Postgres-backed *AnyVar* installation may use any Postgres instance, local -or remote. The following instructions are for using a docker-based -Postgres instance. - -First, run the commands in [README-pg.md](src/anyvar/storage/README-pg.md). This will create and start a local Postgres docker instance. - -Next, run the commands in [postgres_init.sql](src/anyvar/storage/postgres_init.sql). This will create the `anyvar` user with the appropriate permissions and create the `anyvar` database. - -### Setting up Snowflake -A Snowflake-backed *AnyVar* installation may use any Snowflake database schema. +### SQL Database Setup +A Postgres or Snowflake database may be used with *AnyVar*. The Postgres database +may be either local or remote. Use the `ANYVAR_STORAGE_URI` environment variable +to define the database connection URL. *AnyVar* uses [SQLAlchemy 1.4](https://docs.sqlalchemy.org/en/14/index.html) +to provide database connection management. The default database connection URL +is `"postgresql://postgres@localhost:5432/anyvar"`. + +The database integrations can be modified using the following parameters: +* `ANYVAR_SQL_STORE_BATCH_LIMIT` - in batch mode, limit VRS object upsert batches +to this number; defaults to `100,000` +* `ANYVAR_SQL_STORE_TABLE_NAME` - the name of the table that stores VRS objects; +defaults to `vrs_objects` +* `ANYVAR_SQL_STORE_MAX_PENDING_BATCHES` - the maximum number of pending batches +to allow before blocking; defaults to `50` +* `ANYVAR_SQL_STORE_FLUSH_ON_BATCHCTX_EXIT` - whether or not flush all pending +database writes when the batch manager exists; defaults to `True` + +The Postgres and Snowflake database connectors utilize a background thread +to write VRS objects to the database when operating in batch mode (e.g. annotating +a VCF file). Queries and statistics query only against the already committed database +state. Therefore, queries issued immediately after a batch operation may not reflect +all pending changes if the `ANYVAR_SQL_STORE_FLUSH_ON_BATCHCTX_EXIT` parameter is sett +to `False`. + +#### Setting up Postgres +The following instructions are for using a docker-based Postgres instance. + +First, run the commands in [README-pg.md](src/anyvar/storage/README-pg.md). +This will create and start a local Postgres docker instance. + +Next, run the commands in [postgres_init.sql](src/anyvar/storage/postgres_init.sql). +This will create the `anyvar` user with the appropriate permissions and create the +`anyvar` database. + +#### Setting up Snowflake The Snowflake database and schema must exist prior to starting *AnyVar*. To point *AnyVar* at Snowflake, specify a Snowflake URI in the `ANYVAR_STORAGE_URI` environment variable. For example: - - snowflake://my-sf-acct/?database=sf_db_name&schema=sd_schema_name&user=sf_username&password=sf_password - +``` +snowflake://sf_username:@sf_account_identifier/sf_db_name/sf_schema_name?password=sf_password +``` [Snowflake connection parameter reference](https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-api) -When running interactively and connecting to a Snowflake account that utilizes federated authentication or SSO, add -the parameter `authenticator=externalbrowser`. Non-interactive execution in a federated authentication or SSO environment -requires a service account to connect. Connections using an encrypted or unencrypted private key are also supported by -specifying the parameter `private_key=path/to/file.p8`. The key material may be URL-encoded and inlined in the connection URI, +When running interactively and connecting to a Snowflake account that utilizes +federated authentication or SSO, add the parameter `authenticator=externalbrowser`. +Non-interactive execution in a federated authentication or SSO environment +requires a service account to connect. Connections using an encrypted or unencrypted +private key are also supported by specifying the parameter `private_key=path/to/file.p8`. +The key material may be URL-encoded and inlined in the connection URI, for example: `private_key=-----BEGIN+PRIVATE+KEY-----%0AMIIEvAIBA...` - Environment variables that can be used to modify Snowflake database integration: -* `ANYVAR_SNOWFLAKE_STORE_BATCH_LIMIT` - in batch mode, limit VRS object upsert batches to this number; defaults to `100,000` -* `ANYVAR_SNOWFLAKE_STORE_TABLE_NAME` - the name of the table that stores VRS objects; defaults to `vrs_objects` -* `ANYVAR_SNOWFLAKE_STORE_MAX_PENDING_BATCHES` - the maximum number of pending batches to allow before blocking; defaults to `50` * `ANYVAR_SNOWFLAKE_STORE_PRIVATE_KEY_PASSPHRASE` - the passphrase for an encrypted private key - -NOTE: If you choose to create the VRS objects table in advance, the minimal table specification is as follows: +* `ANYVAR_SNOWFLAKE_BATCH_ADD_MODE` - the SQL statement type to use when adding new VRS objects, one of: + * `merge` (default) - use a MERGE statement. This guarantees that duplicate VRS IDs will + not be added, but also locks the VRS object table, limiting throughput. + * `insert_notin` - use INSERT INTO vrs_objects SELECT FROM tmp WHERE vrs_id NOT IN (...). + This narrows the chance of duplicates and does not require a table lock. + * `insert` - use INSERT INTO. This maximizes throughput at the cost of not checking for + duplicates at all. + +If you choose to create the VRS objects table in advance, the minimal table specification is as follows: ```sql CREATE TABLE ... ( vrs_id VARCHAR(500) COLLATE 'utf8', @@ -98,11 +126,6 @@ CREATE TABLE ... ( ) ``` -NOTE: The Snowflake database connector utilizes a background thread to write VRS objects to the database when operating in batch -mode (e.g. annotating a VCF file). Queries and statistics query only against the already committed database state. Therefore, -queries issued immediately after a batch operation may not reflect all pending changes. - - ## Deployment NOTE: The authoritative and sole source for version tags is the diff --git a/setup.cfg b/setup.cfg index c938cb6..1f7c81d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,7 +9,7 @@ install_requires = uvicorn ga4gh.vrs[extras]~=2.0.0a5 psycopg[binary] - snowflake-connector-python~=3.4.1 + snowflake-sqlalchemy~=1.5.1 [options.package_data] * = diff --git a/src/anyvar/anyvar.py b/src/anyvar/anyvar.py index f7fbf02..e5257ee 100644 --- a/src/anyvar/anyvar.py +++ b/src/anyvar/anyvar.py @@ -28,8 +28,7 @@ def create_storage(uri: Optional[str] = None) -> _Storage: * PostgreSQL `postgresql://[username]:[password]@[domain]/[database]` * Snowflake - `snowflake://[account_identifier].snowflakecomputing.com/?[param=value]&[param=value]...` - `snowflake://[account_identifier]/?[param=value]&[param=value]...` + `snowflake://[user]:@[account]/[database]/[schema]?[param=value]&[param=value]...` """ uri = uri or os.environ.get("ANYVAR_STORAGE_URI", DEFAULT_STORAGE_URI) diff --git a/src/anyvar/storage/postgres.py b/src/anyvar/storage/postgres.py index c28df48..7a4135a 100644 --- a/src/anyvar/storage/postgres.py +++ b/src/anyvar/storage/postgres.py @@ -1,265 +1,56 @@ +from io import StringIO import json -import logging -from typing import Any, Optional +from typing import Any, Optional, List -import psycopg -from ga4gh.core import is_pydantic_instance -from ga4gh.vrs import models +from sqlalchemy import text as sql_text +from sqlalchemy.engine import Connection -from anyvar.restapi.schema import VariationStatisticType - -from . import _BatchManager, _Storage +from .sql_storage import SqlStorage silos = "locations alleles haplotypes genotypes variationsets relations texts".split() -_logger = logging.getLogger(__name__) - - -class PostgresObjectStore(_Storage): +class PostgresObjectStore(SqlStorage): """PostgreSQL storage backend. Currently, this is our recommended storage approach. """ - def __init__(self, db_url: str, batch_limit: int = 65536): - """Initialize PostgreSQL DB handler. - - :param db_url: libpq connection info URL - :param batch_limit: max size of batch insert queue - """ - self.conn = psycopg.connect(db_url, autocommit=True) - self.ensure_schema_exists() - - self.batch_manager = PostgresBatchManager - self.batch_mode = False - self.batch_insert_values = [] - self.batch_limit = batch_limit - - def _create_schema(self): - """Add DB schema.""" - create_statement = """ - CREATE TABLE vrs_objects ( - vrs_id TEXT PRIMARY KEY, - vrs_object JSONB - ); - """ - with self.conn.cursor() as cur: - cur.execute(create_statement) - - def ensure_schema_exists(self): - """Check that DB schema is in place.""" - with self.conn.cursor() as cur: - cur.execute( - "SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = 'vrs_objects')" - ) # noqa: E501 - result = cur.fetchone() - if result and result[0]: - return - self._create_schema() - - def __repr__(self): - return str(self.conn) - - def __setitem__(self, name: str, value: Any): - """Add item to database. If batch mode is on, add item to queue and write only - if queue size is exceeded. - - :param name: value for `vrs_id` field - :param value: value for `vrs_object` field - """ - assert is_pydantic_instance(value), "ga4gh.vrs object value required" - name = str(name) # in case str-like - value_json = json.dumps(value.model_dump(exclude_none=True)) - if self.batch_mode: - self.batch_insert_values.append((name, value_json)) - if len(self.batch_insert_values) > self.batch_limit: - self.copy_insert() - else: - insert_query = "INSERT INTO vrs_objects (vrs_id, vrs_object) VALUES (%s, %s) ON CONFLICT DO NOTHING;" # noqa: E501 - with self.conn.cursor() as cur: - cur.execute(insert_query, [name, value_json]) - - def __getitem__(self, name: str) -> Optional[Any]: - """Fetch item from DB given key. - - Future issues: - * Remove reliance on VRS-Python models (requires rewriting the enderef module) - - :param name: key to retrieve VRS object for - :return: VRS object if available - :raise NotImplementedError: if unsupported VRS object type (this is WIP) - """ - with self.conn.cursor() as cur: - cur.execute("SELECT vrs_object FROM vrs_objects WHERE vrs_id = %s;", [name]) - result = cur.fetchone() - if result: - result = result[0] - object_type = result["type"] - if object_type == "Allele": - return models.Allele(**result) - elif object_type == "CopyNumberCount": - return models.CopyNumberCount(**result) - elif object_type == "CopyNumberChange": - return models.CopyNumberChange(**result) - elif object_type == "SequenceLocation": - return models.SequenceLocation(**result) - else: - raise NotImplementedError - else: - raise KeyError(name) - - def __contains__(self, name: str) -> bool: - """Check whether VRS objects table contains ID. - - :param name: VRS ID to look up - :return: True if ID is contained in vrs objects table - """ - with self.conn.cursor() as cur: - cur.execute("SELECT EXISTS (SELECT 1 FROM vrs_objects WHERE vrs_id = %s);", [name]) - result = cur.fetchone() - return result[0] if result else False - - def __delitem__(self, name: str) -> None: - """Delete item (not cascading -- doesn't delete referenced items) - - :param name: key to delete object for - """ - name = str(name) # in case str-like - with self.conn.cursor() as cur: - cur.execute("DELETE FROM vrs_objects WHERE vrs_id = %s;", [name]) - self.conn.commit() - - def wait_for_writes(self): - """Returns once any currently pending database modifications have been completed. - The PostgresObjectStore does not implement async writes, therefore this method is a no-op - and present only to maintain compatibility with the `_Storage` base class""" - - def close(self): - """Terminate connection if necessary.""" - if self.conn is not None: - self.conn.close() - - def __del__(self): - """Tear down DB instance.""" - self.close() - - def __len__(self): - with self.conn.cursor() as cur: - cur.execute( - """ - SELECT COUNT(*) AS c FROM vrs_objects - WHERE vrs_object ->> 'type' = 'Allele'; - """ - ) - result = cur.fetchone() - if result: - return result[0] - else: - return 0 - - def get_variation_count(self, variation_type: VariationStatisticType) -> int: - """Get total # of registered variations of requested type. - - :param variation_type: variation type to check - :return: total count - """ - if variation_type == VariationStatisticType.SUBSTITUTION: - return self._substitution_count() - elif variation_type == VariationStatisticType.INSERTION: - return self._insertion_count() - elif variation_type == VariationStatisticType.DELETION: - return self._deletion_count() - else: - return self._substitution_count() + self._deletion_count() + self._insertion_count() - - def _deletion_count(self) -> int: - with self.conn.cursor() as cur: - cur.execute( - """ - select count(*) as c from vrs_objects - where length(vrs_object -> 'state' ->> 'sequence') = 0; - """ - ) - result = cur.fetchone() - if result: - return result[0] - else: - return 0 - - def _substitution_count(self) -> int: - with self.conn.cursor() as cur: - cur.execute( - """ - select count(*) as c from vrs_objects - where length(vrs_object -> 'state' ->> 'sequence') = 1; - """ - ) - result = cur.fetchone() - if result: - return result[0] - else: - return 0 + def __init__( + self, + db_url: str, + batch_limit: Optional[int] = None, + table_name: Optional[str] = None, + max_pending_batches: Optional[int] = None, + flush_on_batchctx_exit: Optional[bool] = None, + ): + super().__init__( + db_url, + batch_limit, + table_name, + max_pending_batches, + flush_on_batchctx_exit, + ) - def _insertion_count(self): - with self.conn.cursor() as cur: - cur.execute( - """ - select count(*) as c from vrs_objects - where length(vrs_object -> 'state' ->> 'sequence') > 1 - """ + def create_schema(self, db_conn: Connection): + check_statement = f""" + SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = '{self.table_name}') + """ # nosec B608 + create_statement = f""" + CREATE TABLE {self.table_name} ( + vrs_id TEXT PRIMARY KEY, + vrs_object JSONB ) - result = cur.fetchone() - if result: - return result[0] - else: - return 0 - - def __iter__(self): - with self.conn.cursor() as cur: - cur.execute("SELECT * FROM vrs_objects;") - while True: - _next = cur.fetchone() - if _next is None: - break - yield _next - - def keys(self): - with self.conn.cursor() as cur: - cur.execute("SELECT vrs_id FROM vrs_objects;") - result = [row[0] for row in cur.fetchall()] - return result - - def search_variations(self, refget_accession: str, start: int, stop: int): - """Find all alleles that were registered that are in 1 genomic region + """ # nosec B608 + result = db_conn.execute(sql_text(check_statement)) + if not result or not result.scalar(): + db_conn.execute(sql_text(create_statement)) - Args: - refget_accession (str): refget accession (SQ. identifier) - start (int): Start genomic region to query - stop (iint): Stop genomic region to query - - Returns: - A list of VRS Alleles that have locations referenced as identifiers - """ - query_str = """ - SELECT vrs_object FROM vrs_objects - WHERE vrs_object->>'location' IN ( - SELECT vrs_id FROM vrs_objects - WHERE CAST (vrs_object->>'start' as INTEGER) >= %s - AND CAST (vrs_object->>'end' as INTEGER) <= %s - AND vrs_object->'sequenceReference'->>'refgetAccession' = %s - ); - """ - with self.conn.cursor() as cur: - cur.execute(query_str, [start, stop, refget_accession]) - results = cur.fetchall() - return [vrs_object[0] for vrs_object in results if vrs_object] - - def wipe_db(self): - """Remove all stored records from vrs_objects table.""" - with self.conn.cursor() as cur: - cur.execute("DELETE FROM vrs_objects;") + def add_one_item(self, db_conn: Connection, name: str, value: Any): + insert_query = f"INSERT INTO {self.table_name} (vrs_id, vrs_object) VALUES (:vrs_id, :vrs_object) ON CONFLICT DO NOTHING" # nosec B608 + value_json = json.dumps(value.model_dump(exclude_none=True)) + db_conn.execute(sql_text(insert_query), {"vrs_id": name, "vrs_object": value_json}) - def copy_insert(self): + def add_many_items(self, db_conn: Connection, items: list): """Perform copy-based insert, enabling much faster writes for large, repeated insert statements, using insert parameters stored in `self.batch_insert_values`. @@ -269,63 +60,72 @@ def copy_insert(self): conflicts when moving data over from that table to vrs_objects. """ tmp_statement = ( - "CREATE TEMP TABLE tmp_table (LIKE vrs_objects INCLUDING DEFAULTS);" # noqa: E501 + f"CREATE TEMP TABLE tmp_table (LIKE {self.table_name} INCLUDING DEFAULTS)" # noqa: E501 ) - copy_statement = "COPY tmp_table (vrs_id, vrs_object) FROM STDIN;" - insert_statement = ( - "INSERT INTO vrs_objects SELECT * FROM tmp_table ON CONFLICT DO NOTHING;" # noqa: E501 + insert_statement = f"INSERT INTO {self.table_name} SELECT * FROM tmp_table ON CONFLICT DO NOTHING" # nosec B608 + drop_statement = "DROP TABLE tmp_table" + db_conn.execute(sql_text(tmp_statement)) + with db_conn.connection.cursor() as cur: + row_data = [ + f"{name}\t{json.dumps(value.model_dump(exclude_none=True))}" + for name, value in items + ] + fl = StringIO("\n".join(row_data)) + cur.copy_from(fl, "tmp_table", columns=["vrs_id", "vrs_object"]) + fl.close() + db_conn.execute(sql_text(insert_statement)) + db_conn.execute(sql_text(drop_statement)) + + def deletion_count(self, db_conn: Connection) -> int: + result = db_conn.execute( + sql_text( + f""" + SELECT COUNT(*) AS c + FROM {self.table_name} + WHERE LENGTH(vrs_object -> 'state' ->> 'sequence') = 0 + """ # nosec B608 + ) ) - drop_statement = "DROP TABLE tmp_table;" - with self.conn.cursor() as cur: - cur.execute(tmp_statement) - with cur.copy(copy_statement) as copy: - for row in self.batch_insert_values: - copy.write_row(row) - cur.execute(insert_statement) - cur.execute(drop_statement) - self.conn.commit() - self.batch_insert_values = [] - - -class PostgresBatchManager(_BatchManager): - """Context manager enabling batch insertion statements via Postgres COPY command. - - Use in cases like VCF ingest when intaking large amounts of data at once. - """ - - def __init__(self, storage: PostgresObjectStore): - """Initialize context manager. - - :param storage: Postgres instance to manage. Should be taken from the active - AnyVar instance -- otherwise it won't be able to delay insertions. - :raise ValueError: if `storage` param is not a `PostgresObjectStore` instance - """ - if not isinstance(storage, PostgresObjectStore): - raise ValueError("PostgresBatchManager requires a PostgresObjectStore instance") - self._storage = storage - - def __enter__(self): - """Enter managed context.""" - self._storage.batch_insert_values = [] - self._storage.batch_mode = True - - def __exit__( - self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[Any] - ) -> bool: - """Handle exit from context management. This method is responsible for - committing or rolling back any staged inserts. - - :param exc_type: type of exception encountered, if any - :param exc_value: exception value - :param traceback: traceback for context of exception - :return: True if no exceptions encountered, False otherwise - """ - if exc_type is not None: - self._storage.conn.rollback() - self._storage.batch_insert_values = [] - self._storage.batch_mode = False - _logger.error(f"Postgres batch manager encountered exception {exc_type}: {exc_value}") - return False - self._storage.copy_insert() - self._storage.batch_mode = False - return True + return result.scalar() + + def substitution_count(self, db_conn: Connection) -> int: + result = db_conn.execute( + sql_text( + f""" + SELECT COUNT(*) AS c + FROM {self.table_name} + WHERE LENGTH(vrs_object -> 'state' ->> 'sequence') = 1 + """ # nosec B608 + ) + ) + return result.scalar() + + def insertion_count(self, db_conn: Connection): + result = db_conn.execute( + sql_text( + f""" + SELECT COUNT(*) AS c + FROM {self.table_name} + WHERE LENGTH(vrs_object -> 'state' ->> 'sequence') > 1 + """ # nosec B608 + ) + ) + return result.scalar() + + def search_vrs_objects( + self, db_conn: Connection, type: str, refget_accession: str, start: int, stop: int + ) -> List: + query_str = f""" + SELECT vrs_object + FROM {self.table_name} + WHERE vrs_object->>'type' = %s + AND vrs_object->>'location' IN ( + SELECT vrs_id FROM {self.table_name} + WHERE CAST (vrs_object->>'start' AS INTEGER) >= %s + AND CAST (vrs_object->>'end' AS INTEGER) <= %s + AND vrs_object->'sequenceReference'->>'refgetAccession' = %s) + """ # nosec B608 + with db_conn.connection.cursor() as cur: + cur.execute(query_str, [type, start, stop, refget_accession]) + results = cur.fetchall() + return [vrs_object[0] for vrs_object in results if vrs_object] diff --git a/src/anyvar/storage/snowflake.py b/src/anyvar/storage/snowflake.py index 54857c9..4085d20 100644 --- a/src/anyvar/storage/snowflake.py +++ b/src/anyvar/storage/snowflake.py @@ -1,25 +1,66 @@ +from enum import auto, StrEnum import json import logging import os -from threading import Condition, Thread -from typing import Any, List, Optional, Tuple -from urllib.parse import urlparse, parse_qs +import snowflake.connector +from snowflake.sqlalchemy.snowdialect import SnowflakeDialect +from typing import Any, List, Optional +from sqlalchemy import text as sql_text +from sqlalchemy.engine import Connection, URL +from urllib.parse import urlparse, parse_qs, urlunparse, urlencode from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization -import ga4gh.core -from ga4gh.vrs import models -import snowflake.connector -from snowflake.connector import SnowflakeConnection -from anyvar.restapi.schema import VariationStatisticType - -from . import _BatchManager, _Storage +from .sql_storage import SqlStorage _logger = logging.getLogger(__name__) +snowflake.connector.paramstyle = "qmark" + +# +# Monkey patch to workaround a bug in the Snowflake SQLAlchemy dialect +# https://github.com/snowflakedb/snowflake-sqlalchemy/issues/489 + +# Create a new pointer to the existing create_connect_args method +SnowflakeDialect._orig_create_connect_args = SnowflakeDialect.create_connect_args + + +# Define a new create_connect_args method that calls the original method +# and then fixes the result so that the account name is not mangled +# when using privatelink +def sf_create_connect_args_override(self, url: URL): + + # retval is tuple of empty array and dict ([], {}) + retval = self._orig_create_connect_args(url) + + # the dict has the options including the mangled account name + opts = retval[1] + if ( + "host" in opts + and "account" in opts + and opts["host"].endswith(".privatelink.snowflakecomputing.com") + ): + opts["account"] = opts["host"].split(".")[0] + + return retval + + +# Replace the create_connect_args method with the override +SnowflakeDialect.create_connect_args = sf_create_connect_args_override + +# +# End monkey patch +# -class SnowflakeObjectStore(_Storage): + +class SnowflakeBatchAddMode(StrEnum): + merge = auto() + insert_notin = auto() + insert = auto() + + +class SnowflakeObjectStore(SqlStorage): """Snowflake storage backend. Requires existing Snowflake database.""" def __init__( @@ -28,496 +69,185 @@ def __init__( batch_limit: Optional[int] = None, table_name: Optional[str] = None, max_pending_batches: Optional[int] = None, + flush_on_batchctx_exit: Optional[bool] = None, + batch_add_mode: Optional[SnowflakeBatchAddMode] = None, ): - """Initialize Snowflake DB handler. - - :param db_url: snowflake connection info URL, snowflake://[account_identifier]/?[param=value]&[param=value]... - :param batch_limit: max size of batch insert queue, defaults to 100000; can be set with - ANYVAR_SNOWFLAKE_STORE_BATCH_LIMIT environment variable - :param table_name: table name for storing VRS objects, defaults to `vrs_objects`; can be set with - ANYVAR_SNOWFLAKE_STORE_TABLE_NAME environment variable - :param max_pending_batches: maximum number of pending batches allowed before batch queueing blocks; can - be set with ANYVAR_SNOWFLAKE_STORE_MAX_PENDING_BATCHES environment variable - - See https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-api for full list - of database connection url parameters """ - # specify that bind variables in queries should be indicated with a question mark - snowflake.connector.paramstyle = "qmark" - - # get table name override from environment - self.table_name = table_name or os.environ.get( - "ANYVAR_SNOWFLAKE_STORE_TABLE_NAME", "vrs_objects" + :param batch_add_mode: what type of SQL statement to use when adding many items at one; one of `merge` + (no duplicates), `insert_notin` (try to avoid duplicates) or `insert` (don't worry about duplicates); + defaults to `merge`; can be set with the ANYVAR_SNOWFLAKE_BATCH_ADD_MODE + """ + prepared_db_url = self._preprocess_db_url(db_url) + super().__init__( + prepared_db_url, + batch_limit, + table_name, + max_pending_batches, + flush_on_batchctx_exit, + ) + self.batch_add_mode = batch_add_mode or os.environ.get( + "ANYVAR_SNOWFLAKE_BATCH_ADD_MODE", SnowflakeBatchAddMode.merge ) + if self.batch_add_mode not in SnowflakeBatchAddMode.__members__: + raise Exception("batch_add_mode must be one of 'merge', 'insert_notin', or 'insert'") - # parse the db url and extract the account name and conn params + def _preprocess_db_url(self, db_url: str) -> str: + db_url = db_url.replace(".snowflakecomputing.com", "") parsed_uri = urlparse(db_url) - account_name = parsed_uri.hostname.replace(".snowflakecomputing.com", "") conn_params = { key: value[0] if value else None for key, value in parse_qs(parsed_uri.query).items() } + if "private_key" in conn_params: + self.private_key_param = conn_params["private_key"] + del conn_params["private_key"] + parsed_uri = parsed_uri._replace(query=urlencode(conn_params)) + else: + self.private_key_param = None + + return urlunparse(parsed_uri) + def _get_connect_args(self, db_url: str) -> dict: # if there is a private_key param that is a file, read the contents of file - if "private_key" in conn_params: - pk_value = conn_params["private_key"] + if self.private_key_param: p_key = None pk_passphrase = None if "ANYVAR_SNOWFLAKE_STORE_PRIVATE_KEY_PASSPHRASE" in os.environ: pk_passphrase = os.environ["ANYVAR_SNOWFLAKE_STORE_PRIVATE_KEY_PASSPHRASE"].encode() - if os.path.isfile(pk_value): - with open(pk_value, "rb") as key: + if os.path.isfile(self.private_key_param): + with open(self.private_key_param, "rb") as key: p_key = serialization.load_pem_private_key( key.read(), password=pk_passphrase, backend=default_backend() ) else: p_key = serialization.load_pem_private_key( - pk_value.encode(), password=pk_passphrase, backend=default_backend() + self.private_key_param.encode(), + password=pk_passphrase, + backend=default_backend(), ) - conn_params["private_key"] = p_key.private_bytes( - encoding=serialization.Encoding.DER, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ) - - # log sanitized connection parameters - if _logger.isEnabledFor(logging.DEBUG): - sanitized_conn_params = conn_params.copy() - for secret_param in ["password", "private_key"]: - if secret_param in sanitized_conn_params: - sanitized_conn_params[secret_param] = "****sanitized****" - - _logger.debug( - "Connecting to Snowflake account %s with params %s", - account_name, - sanitized_conn_params, - ) - # log connection attempt + return { + "private_key": p_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + } else: - _logger.info( - "Connecting to Snowflake account %s", - account_name, - ) + return {} - # create the database connection and ensure it is setup - self.conn = snowflake.connector.connect(account=account_name, **conn_params) - self.ensure_schema_exists() - - # setup batch handling - self.batch_manager = SnowflakeBatchManager - self.batch_mode = False - self.batch_insert_values = [] - self.batch_limit = batch_limit or int( - os.environ.get("ANYVAR_SNOWFLAKE_STORE_BATCH_LIMIT", "100000") - ) - max_pending_batches = max_pending_batches or int( - os.environ.get("ANYVAR_SNOWFLAKE_STORE_MAX_PENDING_BATCHES", "50") - ) - self.batch_thread = SnowflakeBatchThread(self.conn, self.table_name, max_pending_batches) - self.batch_thread.start() - - def _create_schema(self): - """Add the VRS object table if it does not exist""" - # self.table_name is only modifiable via environment variable or direct instantiation of the SnowflakeObjectStore - create_statement = f""" - CREATE TABLE {self.table_name} ( - vrs_id VARCHAR(500) PRIMARY KEY COLLATE 'utf8', - vrs_object VARIANT - ); + def create_schema(self, db_conn: Connection): + check_statement = f""" + SELECT COUNT(*) FROM information_schema.tables + WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() + AND UPPER(table_name) = UPPER('{self.table_name}') """ # nosec B608 - _logger.info("Creating VRS object table %s", self.table_name) - with self.conn.cursor() as cur: - cur.execute(create_statement) - - def ensure_schema_exists(self): - """Check that VRS object table exists and create it if it does not""" - with self.conn.cursor() as cur: - cur.execute( - f""" - SELECT COUNT(*) FROM information_schema.tables - WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() - AND UPPER(table_name) = UPPER('{self.table_name}'); - """ # nosec B608 + create_statement = f""" + CREATE TABLE {self.table_name} ( + vrs_id VARCHAR(500) PRIMARY KEY COLLATE 'utf8', + vrs_object VARIANT ) - result = cur.fetchone() - - if result is None or result[0] <= 0: - self._create_schema() - - def __repr__(self): - return str(self.conn) - - def __setitem__(self, name: str, value: Any): - """Add item to database. If batch mode is on, add item to batch and submit batch - for write only if batch size is exceeded. - - :param name: value for `vrs_id` field - :param value: value for `vrs_object` field - """ - assert ga4gh.core.is_pydantic_instance(value), "ga4gh.vrs object value required" - name = str(name) # in case str-like - if self.batch_mode: - self.batch_insert_values.append((name, value)) - _logger.debug("Appended item %s to batch queue", name) - if len(self.batch_insert_values) >= self.batch_limit: - self.batch_thread.queue_batch(self.batch_insert_values) - _logger.info( - "Queued batch of %s VRS objects for write", len(self.batch_insert_values) - ) - self.batch_insert_values = [] - else: - value_json = json.dumps(value.model_dump(exclude_none=True)) - insert_query = f""" - MERGE INTO {self.table_name} t USING (SELECT ? AS vrs_id, ? AS vrs_object) s ON t.vrs_id = s.vrs_id - WHEN NOT MATCHED THEN INSERT (vrs_id, vrs_object) VALUES (s.vrs_id, PARSE_JSON(s.vrs_object)); - """ # nosec B608 - with self.conn.cursor() as cur: - cur.execute(insert_query, [name, value_json]) - _logger.debug("Inserted item %s to %s", name, self.table_name) - - def __getitem__(self, name: str) -> Optional[Any]: - """Fetch item from DB given key. - - Future issues: - * Remove reliance on VRS-Python models (requires rewriting the enderef module) + """ # nosec B608 + result = db_conn.execute(sql_text(check_statement)) + if result.scalar() < 1: + db_conn.execute(sql_text(create_statement)) + + def add_one_item(self, db_conn: Connection, name: str, value: Any): + insert_query = f""" + MERGE INTO {self.table_name} t USING (SELECT ? AS vrs_id, ? AS vrs_object) s ON t.vrs_id = s.vrs_id + WHEN NOT MATCHED THEN INSERT (vrs_id, vrs_object) VALUES (s.vrs_id, PARSE_JSON(s.vrs_object)) + """ # nosec B608 + value_json = json.dumps(value.model_dump(exclude_none=True)) + db_conn.execute(insert_query, (name, value_json)) + _logger.debug("Inserted item %s to %s", name, self.table_name) - :param name: key to retrieve VRS object for - :return: VRS object if available - :raise NotImplementedError: if unsupported VRS object type (this is WIP) - """ - with self.conn.cursor() as cur: - cur.execute( - f"SELECT vrs_object FROM {self.table_name} WHERE vrs_id = ?;", [name] # nosec B608 - ) - result = cur.fetchone() - if result: - result = json.loads(result[0]) - object_type = result["type"] - if object_type == "Allele": - return models.Allele(**result) - elif object_type == "CopyNumberCount": - return models.CopyNumberCount(**result) - elif object_type == "CopyNumberChange": - return models.CopyNumberChange(**result) - elif object_type == "SequenceLocation": - return models.SequenceLocation(**result) - else: - raise NotImplementedError + def add_many_items(self, db_conn: Connection, items: list): + """Bulk inserts the batch values into a TEMP table, then merges into the main {self.table_name} table""" + tmp_statement = "CREATE TEMP TABLE IF NOT EXISTS tmp_vrs_objects (vrs_id VARCHAR(500) COLLATE 'utf8', vrs_object VARCHAR)" + insert_statement = "INSERT INTO tmp_vrs_objects (vrs_id, vrs_object) VALUES (?, ?)" + if self.batch_add_mode == SnowflakeBatchAddMode.insert: + merge_statement = f""" + INSERT INTO {self.table_name} (vrs_id, vrs_object) + SELECT vrs_id, PARSE_JSON(vrs_object) FROM tmp_vrs_objects + """ # nosec B608 + elif self.batch_add_mode == SnowflakeBatchAddMode.insert_notin: + merge_statement = f""" + INSERT INTO {self.table_name} (vrs_id, vrs_object) + SELECT t.vrs_id, PARSE_JSON(t.vrs_object) + FROM tmp_vrs_objects t + LEFT OUTER JOIN {self.table_name} v ON v.vrs_id = t.vrs_id + WHERE v.vrs_id IS NULL + """ # nosec B608 else: - raise KeyError(name) - - def __contains__(self, name: str) -> bool: - """Check whether VRS objects table contains ID. - - :param name: VRS ID to look up - :return: True if ID is contained in vrs objects table - """ - with self.conn.cursor() as cur: - cur.execute( - f"SELECT COUNT(*) FROM {self.table_name} WHERE vrs_id = ?;", [name] # nosec B608 - ) - result = cur.fetchone() - return result[0] > 0 if result else False - - def __delitem__(self, name: str) -> None: - """Delete item (not cascading -- doesn't delete referenced items) - - :param name: key to delete object for - """ - name = str(name) # in case str-like - with self.conn.cursor() as cur: - cur.execute(f"DELETE FROM {self.table_name} WHERE vrs_id = ?;", [name]) # nosec B608 - self.conn.commit() - - def wait_for_writes(self): - """Returns once any currently pending database modifications have been completed.""" - if self.batch_thread is not None: - # short circuit if the queue is empty - with self.batch_thread.cond: - if not self.batch_thread.pending_batch_list: - return - - # queue an empty batch - batch = [] - self.batch_thread.queue_batch(batch) - # wait for the batch to be removed from the pending queue - while True: - with self.batch_thread.cond: - if list(filter(lambda x: x is batch, self.batch_thread.pending_batch_list)): - self.batch_thread.cond.wait() - else: - break - - def close(self): - """Stop the batch thread and wait for it to complete""" - if self.batch_thread is not None: - self.batch_thread.stop() - self.batch_thread.join() - self.batch_thread = None - # Terminate connection if necessary. - if self.conn is not None: - self.conn.close() - self.conn = None - - def __del__(self): - """Tear down DB instance.""" - self.close() - - def __len__(self): - with self.conn.cursor() as cur: - cur.execute( - f""" - SELECT COUNT(*) AS c FROM {self.table_name} - WHERE vrs_object:type = 'Allele'; + merge_statement = f""" + MERGE INTO {self.table_name} v USING tmp_vrs_objects s ON v.vrs_id = s.vrs_id + WHEN NOT MATCHED THEN INSERT (vrs_id, vrs_object) VALUES (s.vrs_id, PARSE_JSON(s.vrs_object)) """ # nosec B608 - ) - result = cur.fetchone() - if result: - return result[0] - else: - return 0 - - def get_variation_count(self, variation_type: VariationStatisticType) -> int: - """Get total # of registered variations of requested type. + drop_statement = "DROP TABLE tmp_vrs_objects" + + # create row data removing duplicates + # because if there are duplicates in the source of the merge + # Snowflake inserts duplicate rows + row_data = [] + row_keys = set() + for name, value in items: + if name not in row_keys: + row_keys.add(name) + row_data.append((name, json.dumps(value.model_dump(exclude_none=True)))) + _logger.info("Created row data for insert, first item is %s", row_data[0]) + + db_conn.execute(sql_text(tmp_statement)) + # NB - enclosing the insert statement in sql_text() + # causes a "Bind variable ? not set" error from Snowflake + # It is unclear why this is that case + db_conn.execute(insert_statement, row_data) + db_conn.execute(sql_text(merge_statement)) + db_conn.execute(sql_text(drop_statement)) + + def deletion_count(self, db_conn: Connection) -> int: + result = db_conn.execute( + f""" + SELECT COUNT(*) + FROM {self.table_name} + WHERE LENGTH(vrs_object:state:sequence) = 0 + """ # nosec B608 + ) + return result.scalar() + + def substitution_count(self, db_conn: Connection) -> int: + result = db_conn.execute( + f""" + SELECT COUNT(*) + FROM {self.table_name} + WHERE LENGTH(vrs_object:state:sequence) = 1 + """ # nosec B608 + ) + return result.scalar() + + def insertion_count(self, db_conn: Connection) -> int: + result = db_conn.execute( + f""" + SELECT COUNT(*) + FROM {self.table_name} + WHERE LENGTH(vrs_object:state:sequence) > 1 + """ # nosec B608 + ) + return result.scalar() - :param variation_type: variation type to check - :return: total count - """ - if variation_type == VariationStatisticType.SUBSTITUTION: - return self._substitution_count() - elif variation_type == VariationStatisticType.INSERTION: - return self._insertion_count() - elif variation_type == VariationStatisticType.DELETION: - return self._deletion_count() - else: - return self._substitution_count() + self._deletion_count() + self._insertion_count() - - def _deletion_count(self) -> int: - with self.conn.cursor() as cur: - cur.execute( - f""" - SELECT COUNT(*) FROM {self.table_name} - WHERE LENGTH(vrs_object:state:sequence) = 0; - """ # nosec B608 - ) - result = cur.fetchone() - if result: - return result[0] - else: - return 0 - - def _substitution_count(self) -> int: - with self.conn.cursor() as cur: - cur.execute( - f""" - SELECT COUNT(*) FROM {self.table_name} - WHERE LENGTH(vrs_object:state:sequence) = 1; - """ # nosec B608 - ) - result = cur.fetchone() - if result: - return result[0] - else: - return 0 - - def _insertion_count(self): - with self.conn.cursor() as cur: - cur.execute( - f""" - SELECT COUNT(*) FROM {self.table_name} - WHERE LENGTH(vrs_object:state:sequence) > 1 - """ # nosec B608 - ) - result = cur.fetchone() - if result: - return result[0] - else: - return 0 - - def __iter__(self): - with self.conn.cursor() as cur: - cur.execute(f"SELECT * FROM {self.table_name};") # nosec B608 - while True: - _next = cur.fetchone() - if _next is None: - break - yield _next - - def keys(self): - with self.conn.cursor() as cur: - cur.execute(f"SELECT vrs_id FROM {self.table_name};") # nosec B608 - result = [row[0] for row in cur.fetchall()] - return result - - def search_variations(self, refget_accession: str, start: int, stop: int): - """Find all alleles that were registered that are in 1 genomic region - - Args: - refget_accession (str): refget accession (SQ. identifier) - start (int): Start genomic region to query - stop (iint): Stop genomic region to query - - Returns: - A list of VRS Alleles that have locations referenced as identifiers - """ + def search_vrs_objects( + self, db_conn: Connection, type: str, refget_accession: str, start: int, stop: int + ) -> List[Any]: query_str = f""" - SELECT vrs_object FROM {self.table_name} - WHERE vrs_object:location IN ( + SELECT vrs_object + FROM {self.table_name} + WHERE vrs_object:type = ? + AND vrs_object:location IN ( SELECT vrs_id FROM {self.table_name} - WHERE vrs_object:start::INTEGER >= ? - AND vrs_object:end::INTEGER <= ? - AND vrs_object:sequenceReference:refgetAccession = ? - ); + WHERE vrs_object:start::INTEGER >= ? + AND vrs_object:end::INTEGER <= ? + AND vrs_object:sequenceReference:refgetAccession = ?) """ # nosec B608 - with self.conn.cursor() as cur: - cur.execute(query_str, [start, stop, refget_accession]) - results = cur.fetchall() - return [json.loads(vrs_object[0]) for vrs_object in results if vrs_object] - - def wipe_db(self): - """Remove all stored records from {self.table_name} table.""" - with self.conn.cursor() as cur: - cur.execute(f"DELETE FROM {self.table_name};") # nosec B608 - - def num_pending_batches(self): - if self.batch_thread: - return len(self.batch_thread.pending_batch_list) - else: - return 0 - - -class SnowflakeBatchManager(_BatchManager): - """Context manager enabling bulk insertion statements - - Use in cases like VCF ingest when intaking large amounts of data at once. - Insertion batches are processed by a background thread. - """ - - def __init__(self, storage: SnowflakeObjectStore): - """Initialize context manager. - - :param storage: Snowflake instance to manage. Should be taken from the active - AnyVar instance -- otherwise it won't be able to delay insertions. - :raise ValueError: if `storage` param is not a `SnowflakeObjectStore` instance - """ - if not isinstance(storage, SnowflakeObjectStore): - raise ValueError("SnowflakeBatchManager requires a SnowflakeObjectStore instance") - self._storage = storage - - def __enter__(self): - """Enter managed context.""" - self._storage.batch_insert_values = [] - self._storage.batch_mode = True - - def __exit__( - self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[Any] - ) -> bool: - """Handle exit from context management. Hands off final batch to background bulk insert processor. - - :param exc_type: type of exception encountered, if any - :param exc_value: exception value - :param traceback: traceback for context of exception - :return: True if no exceptions encountered, False otherwise - """ - if exc_type is not None: - self._storage.batch_insert_values = None - self._storage.batch_mode = False - _logger.error(f"Snowflake batch manager encountered exception {exc_type}: {exc_value}") - _logger.exception(exc_value) - return False - self._storage.batch_thread.queue_batch(self._storage.batch_insert_values) - self._storage.batch_mode = False - self._storage.batch_insert_values = None - return True - - -class SnowflakeBatchThread(Thread): - """Background thread that merges VRS objects into the database""" - - def __init__(self, conn: SnowflakeConnection, table_name: str, max_pending_batches: int): - """Constructs a new background thread - - :param conn: Snowflake connection - """ - super().__init__(daemon=True) - self.conn = conn - self.cond = Condition() - self.run_flag = True - self.pending_batch_list = [] - self.table_name = table_name - self.max_pending_batches = max_pending_batches - - def run(self): - """As long as run_flag is true, waits then processes pending batches""" - while self.run_flag: - with self.cond: - self.cond.wait() - self.process_pending_batches() - - def stop(self): - """Sets the run_flag to false and notifies""" - self.run_flag = False - with self.cond: - self.cond.notify() - - def queue_batch(self, batch_insert_values: List[Tuple]): - """Adds a batch to the pending list. If the pending batch list is already at its max size, waits until there is room - - :param batch_insert_values: list of tuples where each tuple consists of (vrs_id, vrs_object) - """ - with self.cond: - if batch_insert_values is not None: - _logger.info("Queueing batch of %s items", len(batch_insert_values)) - while len(self.pending_batch_list) >= self.max_pending_batches: - _logger.debug("Pending batch queue is full, waiting for space...") - self.cond.wait() - self.pending_batch_list.append(batch_insert_values) - _logger.info("Queued batch of %s items", len(batch_insert_values)) - self.cond.notify_all() - - def process_pending_batches(self): - """As long as batches are available for processing, merges them into the database""" - _logger.info("Processing %s queued batches", len(self.pending_batch_list)) - while True: - batch_insert_values = None - with self.cond: - if len(self.pending_batch_list) > 0: - batch_insert_values = self.pending_batch_list[0] - del self.pending_batch_list[0] - self.cond.notify_all() - else: - self.cond.notify_all() - break - - if batch_insert_values: - self._run_copy_insert(batch_insert_values) - _logger.info("Processed queued batch of %s items", len(batch_insert_values)) - - def _run_copy_insert(self, batch_insert_values): - """Bulk inserts the batch values into a TEMP table, then merges into the main {self.table_name} table""" - - try: - tmp_statement = "CREATE TEMP TABLE IF NOT EXISTS tmp_vrs_objects (vrs_id VARCHAR(500) COLLATE 'utf8', vrs_object VARCHAR);" - insert_statement = "INSERT INTO tmp_vrs_objects (vrs_id, vrs_object) VALUES (?, ?);" - merge_statement = f""" - MERGE INTO {self.table_name} v USING tmp_vrs_objects s ON v.vrs_id = s.vrs_id - WHEN NOT MATCHED THEN INSERT (vrs_id, vrs_object) VALUES (s.vrs_id, PARSE_JSON(s.vrs_object)); - """ # nosec B608 - drop_statement = "DROP TABLE tmp_vrs_objects;" - - row_data = [ - (name, json.dumps(value.model_dump(exclude_none=True))) - for name, value in batch_insert_values - ] - _logger.info("Created row data for insert, first item is %s", row_data[0]) - - with self.conn.cursor() as cur: - cur.execute(tmp_statement) - cur.executemany(insert_statement, row_data) - cur.execute(merge_statement) - cur.execute(drop_statement) - self.conn.commit() - except Exception: - _logger.exception("Failed to merge VRS object batch into database") - finally: - self.conn.rollback() + results = db_conn.execute( + query_str, + (type, start, stop, refget_accession), + ) + return [json.loads(row[0]) for row in results if row] diff --git a/src/anyvar/storage/sql_storage.py b/src/anyvar/storage/sql_storage.py new file mode 100644 index 0000000..fbeb78e --- /dev/null +++ b/src/anyvar/storage/sql_storage.py @@ -0,0 +1,499 @@ +from abc import abstractmethod +import json +import logging +import os +from threading import Condition, Thread +from typing import Any, Generator, List, Optional, Tuple + +import ga4gh.core +from ga4gh.vrs import models +from sqlalchemy import text as sql_text, create_engine +from sqlalchemy.engine import Connection + +from anyvar.restapi.schema import VariationStatisticType + +from . import _BatchManager, _Storage + +_logger = logging.getLogger(__name__) + + +class SqlStorage(_Storage): + """Relational database storage backend. Uses SQLAlchemy as a DB abstraction layer and pool. + Methods that utilize straightforward SQL are implemented in this class. Methods that require + specialized SQL statements must be implemented in a database specific subclass. + """ + + def __init__( + self, + db_url: str, + batch_limit: Optional[int] = None, + table_name: Optional[str] = None, + max_pending_batches: Optional[int] = None, + flush_on_batchctx_exit: Optional[bool] = None, + ): + """Initialize DB handler. + + :param db_url: db connection info URL + :param batch_limit: max size of batch insert queue, defaults to 100000; can be set with + ANYVAR_SQL_STORE_BATCH_LIMIT environment variable + :param table_name: table name for storing VRS objects, defaults to `vrs_objects`; can be set with + ANYVAR_SQL_STORE_TABLE_NAME environment variable + :param max_pending_batches: maximum number of pending batches allowed before batch queueing blocks; can + be set with ANYVAR_SQL_STORE_MAX_PENDING_BATCHES environment variable + :param flush_on_batchctx_exit: whether to call `wait_for_writes()` when exiting the batch manager context; + defaults to True; can be set with the ANYVAR_SQL_STORE_FLUSH_ON_BATCHCTX_EXIT environment variable + + See https://docs.sqlalchemy.org/en/20/core/connections.html for connection URL info + """ + + # get table name override from environment + self.table_name = table_name or os.environ.get("ANYVAR_SQL_STORE_TABLE_NAME", "vrs_objects") + + # create the database connection engine + self.conn_pool = create_engine( + db_url, + pool_size=1, + max_overflow=1, + pool_recycle=3600, + connect_args=self._get_connect_args(db_url), + ) + + # create the schema objects if necessary + with self._get_connection() as conn: + self.create_schema(conn) + + # setup batch handling + self.batch_manager = SqlStorageBatchManager + self.batch_mode = False + self.batch_insert_values = [] + self.batch_limit = batch_limit or int( + os.environ.get("ANYVAR_SQL_STORE_BATCH_LIMIT", "100000") + ) + self.flush_on_batchctx_exit = ( + bool(os.environ.get("ANYVAR_SQL_STORE_FLUSH_ON_BATCHCTX_EXIT", "True")) + if flush_on_batchctx_exit is None + else flush_on_batchctx_exit + ) + max_pending_batches = max_pending_batches or int( + os.environ.get("ANYVAR_SQL_STORE_MAX_PENDING_BATCHES", "50") + ) + self.batch_thread = SqlStorageBatchThread(self, max_pending_batches) + self.batch_thread.start() + + def _get_connection(self) -> Connection: + """Returns a database connection""" + return self.conn_pool.connect() + + def _get_connect_args(self, db_url: str) -> dict: + """Returns connect_args for the SQLAlchemy create_engine() call + The default implementation returns None""" + return {} + + @abstractmethod + def create_schema(self, db_conn: Connection): + """Add the VRS object table if it does not exist + + :param db_conn: a database connection + """ + + def __repr__(self): + return str(self.conn_pool) + + def __setitem__(self, name: str, value: Any): + """Add item to database. If batch mode is on, add item to batch and submit batch + for write only if batch size is exceeded. + + :param name: value for `vrs_id` field + :param value: value for `vrs_object` field + """ + assert ga4gh.core.is_pydantic_instance(value), "ga4gh.vrs object value required" + name = str(name) # in case str-like + if self.batch_mode: + self.batch_insert_values.append((name, value)) + _logger.debug("Appended item %s to batch queue", name) + if len(self.batch_insert_values) >= self.batch_limit: + self.batch_thread.queue_batch(self.batch_insert_values) + _logger.info( + "Queued batch of %s VRS objects for write", len(self.batch_insert_values) + ) + self.batch_insert_values = [] + else: + with self._get_connection() as db_conn: + with db_conn.begin(): + self.add_one_item(db_conn, name, value) + _logger.debug("Inserted item %s to %s", name, self.table_name) + + @abstractmethod + def add_one_item(self, db_conn: Connection, name: str, value: Any): + """Adds/merges a single item to the database + + :param db_conn: a database connection + :param name: value for `vrs_id` field + :param value: value for `vrs_object` field + """ + + @abstractmethod + def add_many_items(self, db_conn: Connection, items: list): + """Adds/merges many items to the database + + :param db_conn: a database connection + :param items: a list of (vrs_id, vrs_object) tuples + """ + + def __getitem__(self, name: str) -> Optional[Any]: + """Fetch item from DB given key. + + Future issues: + * Remove reliance on VRS-Python models (requires rewriting the enderef module) + + :param name: key to retrieve VRS object for + :return: VRS object if available + :raise NotImplementedError: if unsupported VRS object type (this is WIP) + """ + with self._get_connection() as conn: + result = self.fetch_vrs_object(conn, name) + if result: + object_type = result["type"] + if object_type == "Allele": + return models.Allele(**result) + elif object_type == "CopyNumberCount": + return models.CopyNumberCount(**result) + elif object_type == "CopyNumberChange": + return models.CopyNumberChange(**result) + elif object_type == "SequenceLocation": + return models.SequenceLocation(**result) + else: + raise NotImplementedError + else: + raise KeyError(name) + + def fetch_vrs_object(self, db_conn: Connection, vrs_id: str) -> Optional[Any]: + """Fetches a single VRS object from the database, return the value as a JSON object + + :param db_conn: a database connection + :param vrs_id: the VRS ID + :return: VRS object if available + """ + result = db_conn.execute( + sql_text( + f"SELECT vrs_object FROM {self.table_name} WHERE vrs_id = :vrs_id" # nosec B608 + ), + {"vrs_id": vrs_id}, + ) + if result: + value = result.scalar() + return json.loads(value) if value and isinstance(value, str) else value + else: + return None + + def __contains__(self, name: str) -> bool: + """Check whether VRS objects table contains ID. + + :param name: VRS ID to look up + :return: True if ID is contained in vrs objects table + """ + with self._get_connection() as conn: + return self.fetch_vrs_object(conn, name) is not None + + def __delitem__(self, name: str) -> None: + """Delete item (not cascading -- doesn't delete referenced items) + + :param name: key to delete object for + """ + name = str(name) # in case str-like + with self._get_connection() as conn: + with conn.begin(): + self.delete_vrs_object(conn, name) + + def delete_vrs_object(self, db_conn: Connection, vrs_id: str): + """Delete a single VRS object + + :param db_conn: a database connection + :param vrs_id: the VRS ID + """ + db_conn.execute( + sql_text(f"DELETE FROM {self.table_name} WHERE vrs_id = :vrs_id"), # nosec B608 + {"vrs_id": vrs_id}, + ) + + def wait_for_writes(self): + """Returns once any currently pending database modifications have been completed.""" + if hasattr(self, "batch_thread") and self.batch_thread is not None: + # short circuit if the queue is empty + with self.batch_thread.cond: + if not self.batch_thread.pending_batch_list: + return + + # queue an empty batch + batch = [] + self.batch_thread.queue_batch(batch) + # wait for the batch to be removed from the pending queue + while True: + with self.batch_thread.cond: + if list(filter(lambda x: x is batch, self.batch_thread.pending_batch_list)): + self.batch_thread.cond.wait() + else: + break + + def close(self): + """Stop the batch thread and wait for it to complete""" + if hasattr(self, "batch_thread") and self.batch_thread is not None: + self.batch_thread.stop() + self.batch_thread.join() + self.batch_thread = None + # Terminate connection if necessary. + if hasattr(self, "conn_pool") and self.conn_pool is not None: + self.conn_pool.dispose() + self.conn_pool = None + + def __del__(self): + """Tear down DB instance.""" + self.close() + + def __len__(self) -> int: + """Returns the total number of VRS objects""" + with self._get_connection() as conn: + return self.get_vrs_object_count(conn) + + def get_vrs_object_count(self, db_conn: Connection) -> int: + """Returns the total number of objects + + :param db_conn: a database connection + """ + result = db_conn.execute(sql_text(f"SELECT COUNT(*) FROM {self.table_name}")) # nosec B608 + return result.scalar() + + def get_variation_count(self, variation_type: VariationStatisticType) -> int: + """Get total # of registered variations of requested type. + + :param variation_type: variation type to check + :return: total count + """ + with self._get_connection() as conn: + if variation_type == VariationStatisticType.SUBSTITUTION: + return self.substitution_count(conn) + elif variation_type == VariationStatisticType.INSERTION: + return self.insertion_count(conn) + elif variation_type == VariationStatisticType.DELETION: + return self.deletion_count(conn) + else: + return ( + self.substitution_count(conn) + + self.deletion_count(conn) + + self.insertion_count(conn) + ) + + @abstractmethod + def deletion_count(self, db_conn: Connection) -> int: + """Returns the total number of deletions + + :param db_conn: a database connection + """ + + @abstractmethod + def substitution_count(self, db_conn: Connection) -> int: + """Returns the total number of substitutions + + :param db_conn: a database connection + """ + + @abstractmethod + def insertion_count(self, db_conn: Connection): + """Returns the total number of insertions + + :param db_conn: a database connection + """ + + def __iter__(self): + """Iterates over all VRS objects in the database""" + with self._get_connection() as conn: + iter = self.fetch_all_vrs_objects(conn) + for obj in iter: + yield obj + + def fetch_all_vrs_objects(self, db_conn: Connection) -> Generator[Any, Any, Any]: + """Returns a generator that iterates over all VRS objects in the database + in no specific order + + :param db_conn: a database connection + """ + result = db_conn.execute( + sql_text(f"SELECT vrs_object FROM {self.table_name}") # nosec B608 + ) + for row in result: + if row: + value = row["vrs_object"] + yield json.loads(value) if value and isinstance(value, str) else value + else: + yield None + + def keys(self): + """Returns a list of all VRS IDs in the database""" + with self._get_connection() as conn: + return self.fetch_all_vrs_ids(conn) + + def fetch_all_vrs_ids(self, db_conn: Connection) -> List: + """Returns a list of all VRS IDs in the database + + :param db_conn: a database connection + """ + result = db_conn.execute(sql_text(f"SELECT vrs_id FROM {self.table_name}")) # nosec B608 + return [row[0] for row in result] + + def search_variations(self, refget_accession: str, start: int, stop: int): + """Find all alleles that were registered that are in 1 genomic region + + :param refget_accession: refget accession (SQ. identifier) + :param start: Start genomic region to query + :param stop: Stop genomic region to query + + :return: a list of VRS Alleles that have locations referenced as identifiers + """ + with self._get_connection() as conn: + return self.search_vrs_objects(conn, "Allele", refget_accession, start, stop) + + @abstractmethod + def search_vrs_objects( + self, db_conn: Connection, type: str, refget_accession: str, start: int, stop: int + ) -> List[Any]: + """Find all VRS objects of the particular type and region + + :param type: the type of VRS object to search for + :param refget_accession: refget accession (SQ. identifier) + :param start: Start genomic region to query + :param stop: Stop genomic region to query + + :return: a list of VRS objects + """ + + def wipe_db(self): + """Remove all stored records from the database""" + with self._get_connection() as conn: + with conn.begin(): + conn.execute(sql_text(f"DELETE FROM {self.table_name}")) # nosec B608 + + def num_pending_batches(self) -> int: + """Returns the number of pending insert batches""" + if self.batch_thread: + return len(self.batch_thread.pending_batch_list) + else: + return 0 + + +class SqlStorageBatchManager(_BatchManager): + """Context manager enabling bulk insertion statements + + Use in cases like VCF ingest when intaking large amounts of data at once. + Insertion batches are processed by a background thread. + """ + + def __init__(self, storage: SqlStorage): + """Initialize context manager. + + :param storage: SqlStorage instance to manage. Should be taken from the active + AnyVar instance -- otherwise it won't be able to delay insertions. + :raise ValueError: if `storage` param is not a `SqlStorage` instance + """ + if not isinstance(storage, SqlStorage): + raise ValueError("SqlStorageBatchManager requires a SqlStorage instance") + self._storage = storage + + def __enter__(self): + """Enter managed context.""" + self._storage.batch_insert_values = [] + self._storage.batch_mode = True + + def __exit__( + self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[Any] + ) -> bool: + """Handle exit from context management. Hands off final batch to background bulk insert processor. + + :param exc_type: type of exception encountered, if any + :param exc_value: exception value + :param traceback: traceback for context of exception + :return: True if no exceptions encountered, False otherwise + """ + if exc_type is not None: + self._storage.batch_insert_values = None + self._storage.batch_mode = False + _logger.error( + f"Sql storage batch manager encountered exception {exc_type}: {exc_value}" + ) + _logger.exception(exc_value) + return False + self._storage.batch_thread.queue_batch(self._storage.batch_insert_values) + self._storage.batch_mode = False + self._storage.batch_insert_values = None + if self._storage.flush_on_batchctx_exit: + self._storage.wait_for_writes() + return True + + +class SqlStorageBatchThread(Thread): + """Background thread that merges VRS objects into the database""" + + def __init__(self, sql_store: SqlStorage, max_pending_batches: int): + """Constructs a new background thread + + :param conn_pool: SQLAlchemy connection pool + """ + super().__init__(daemon=True) + self.sql_store = sql_store + self.cond = Condition() + self.run_flag = True + self.pending_batch_list = [] + self.max_pending_batches = max_pending_batches + + def run(self): + """As long as run_flag is true, waits then processes pending batches""" + while self.run_flag: + with self.cond: + self.cond.wait() + self.process_pending_batches() + + def stop(self): + """Sets the run_flag to false and notifies""" + self.run_flag = False + with self.cond: + self.cond.notify() + + def queue_batch(self, batch_insert_values: List[Tuple]): + """Adds a batch to the pending list. If the pending batch list is already at its max size, waits until there is room + + :param batch_insert_values: list of tuples where each tuple consists of (vrs_id, vrs_object) + """ + with self.cond: + if batch_insert_values is not None: + _logger.info("Queueing batch of %s items", len(batch_insert_values)) + while len(self.pending_batch_list) >= self.max_pending_batches: + _logger.debug("Pending batch queue is full, waiting for space...") + self.cond.wait() + self.pending_batch_list.append(batch_insert_values) + _logger.info("Queued batch of %s items", len(batch_insert_values)) + self.cond.notify_all() + + def process_pending_batches(self): + """As long as batches are available for processing, merges them into the database""" + _logger.info("Processing %s queued batches", len(self.pending_batch_list)) + while True: + batch_insert_values = None + with self.cond: + if len(self.pending_batch_list) > 0: + batch_insert_values = self.pending_batch_list[0] + del self.pending_batch_list[0] + self.cond.notify_all() + else: + self.cond.notify_all() + break + + if batch_insert_values: + self._run_copy_insert(batch_insert_values) + _logger.info("Processed queued batch of %s items", len(batch_insert_values)) + + def _run_copy_insert(self, batch_insert_values): + try: + with self.sql_store._get_connection() as conn: + with conn.begin(): + self.sql_store.add_many_items(conn, batch_insert_values) + except Exception: + _logger.exception("Failed to merge VRS object batch into database") diff --git a/tests/conftest.py b/tests/conftest.py index 367eb13..7eace7d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,7 @@ def pytest_collection_modifyitems(items): """Modify test items in place to ensure test modules run in a given order.""" - MODULE_ORDER = ["test_lifespan", "test_variation", "test_general", "test_location", "test_search", "test_vcf", "test_storage_mapping", "test_snowflake"] + MODULE_ORDER = ["test_lifespan", "test_variation", "test_general", "test_location", "test_search", "test_vcf", "test_sql_storage_mapping", "test_postgres", "test_snowflake"] # remember to add new test modules to the order constant: assert len(MODULE_ORDER) == len(list(Path(__file__).parent.rglob("test_*.py"))) items.sort(key=lambda i: MODULE_ORDER.index(i.module.__name__)) diff --git a/tests/storage/sqlalchemy_mocks.py b/tests/storage/sqlalchemy_mocks.py new file mode 100644 index 0000000..e2b5201 --- /dev/null +++ b/tests/storage/sqlalchemy_mocks.py @@ -0,0 +1,153 @@ +import json +import re +import time + +class MockResult(list): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def fetchone(self): + if len(self) > 0: + retval = self[0] + del self[0] + return retval + + def fetchall(self): + new_list = list(self) + while len(self) > 0: + del self[0] + return new_list + + def scalar(self): + if len(self) > 0: + value = self[0][0] + while len(self) > 0: + del self[0] + try: + return int(value) + except: + return str(value) if value else None + else: + return None + +class MockStmt: + def __init__(self, sql: str, params, result, wait_for_secs: int = 0): + self.sql = re.sub(r'\s+', ' ', sql).strip() + self.params = params + self.result = result if not result or isinstance(result, Exception) else MockResult(result) + self.wait_for_secs = wait_for_secs + + def matches(self, sql: str, params): + norm_sql = re.sub(r'\s+', ' ', sql).strip() + if norm_sql == self.sql: + if self.params == True: + return True + elif (self.params is None or len(self.params) == 0) and (params is None or len(params) == 0): + return True + elif self.params == params: + return True + return False + +class MockStmtSequence(list): + def __init__(self): + self.execd = [] + + def add_stmt(self, sql: str, params, result, wait_for_secs: int = 0): + self.append(MockStmt(sql, params, result, wait_for_secs)) + return self + + def add_copy_from(self, table_name, data): + self.append(MockStmt(f"COPY FROM fd INTO {table_name}", data, [(1,)])) + return self + + def pop_if_matches(self, sql: str, params) -> list: + if len(self) > 0 and self[0].matches(sql, params): + self.execd.append(self[0]) + wait_for_secs = self[0].wait_for_secs + result = self[0].result + del self[0] + if wait_for_secs > 0: + time.sleep(wait_for_secs) + if isinstance(result, Exception): + raise result + else: + return result + return None + + def were_all_execd(self): + return len(self) <= 0 + +class MockConnection: + def __init__(self): + self.mock_stmt_sequences = [] + self.connection = self + self.last_result = None + + def close(self): + pass + + def cursor(self): + return self + + def fetchall(self): + return self.last_result.fetchall() if self.last_result else None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback) -> bool: + if exc_value: + raise exc_value + return True + + def begin(self): + return self + + def execute(self, cmd, params = None): + for seq in self.mock_stmt_sequences: + result = seq.pop_if_matches(str(cmd), params) + if result: + self.last_result = result + return result + + norm_sql = re.sub(r'\s+', ' ', str(cmd)).strip() + raise Exception(f"no mock statement found for {norm_sql} with params {params}") + + def copy_from(self, fd, table_name, columns=None): + data_as_str = str(fd.read()) + for seq in self.mock_stmt_sequences: + result = seq.pop_if_matches(f"COPY FROM fd INTO {table_name}", data_as_str) + if result: + self.last_result = result + return result + + raise Exception(f"no mock statement found for COPY FROM fd INTO {table_name} with data {data_as_str[:10]}...") + +class MockEngine: + def __init__(self): + self.conn = MockConnection() + + def connect(self): + return self.conn + + def dispose(self): + pass + + def add_mock_stmt_sequence(self, stmt_seq: MockStmtSequence): + self.conn.mock_stmt_sequences.append(stmt_seq) + + def were_all_execd(self): + for seq in self.conn.mock_stmt_sequences: + if not seq.were_all_execd(): + return False + return True + +class MockVRSObject: + def __init__(self, id: str): + self.id = id + + def model_dump(self, exclude_none: bool): + return { "id": self.id } + + def to_json(self): + return json.dumps(self.model_dump(exclude_none=True)) diff --git a/tests/storage/test_postgres.py b/tests/storage/test_postgres.py new file mode 100644 index 0000000..d5909e9 --- /dev/null +++ b/tests/storage/test_postgres.py @@ -0,0 +1,210 @@ +""" +Test Postgres specific storage integration methods +and the async batch insertion + +Uses mocks for database integration +""" +import os +from sqlalchemy_mocks import MockEngine, MockStmtSequence, MockVRSObject + +from anyvar.restapi.schema import VariationStatisticType +from anyvar.storage.postgres import PostgresObjectStore + +vrs_object_table_name = os.environ.get("ANYVAR_SQL_STORE_TABLE_NAME", "vrs_objects") + +def test_create_schema(mocker): + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt(f"SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = '{vrs_object_table_name}')", None, [(False,)]) + .add_stmt(f"CREATE TABLE {vrs_object_table_name} ( vrs_id TEXT PRIMARY KEY, vrs_object JSONB )", None, [("Table created",)]) + ) + sf = PostgresObjectStore("postgres://account/?param=value") + sf.close() + assert mock_eng.were_all_execd() + +def test_create_schema_exists(mocker): + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt(f"SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = '{vrs_object_table_name}')", None, [(True,)]) + ) + sf = PostgresObjectStore("postgres://account/?param=value") + sf.close() + assert mock_eng.were_all_execd() + +def test_add_one_item(mocker): + mocker.patch("ga4gh.core.is_pydantic_instance", return_value=True) + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt(f"SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = '{vrs_object_table_name}')", None, [(True,)]) + .add_stmt( + f""" + INSERT INTO {vrs_object_table_name} (vrs_id, vrs_object) VALUES (:vrs_id, :vrs_object) ON CONFLICT DO NOTHING + """, + {"vrs_id": "ga4gh:VA.01", "vrs_object": MockVRSObject('01').to_json()}, [(1,)]) + ) + sf = PostgresObjectStore("postgres://account/?param=value") + sf["ga4gh:VA.01"] = MockVRSObject('01') + sf.close() + assert mock_eng.were_all_execd() + +def test_add_many_items(mocker): + tmp_statement = "CREATE TEMP TABLE tmp_table (LIKE vrs_objects2 INCLUDING DEFAULTS)" + insert_statement = "INSERT INTO vrs_objects2 SELECT * FROM tmp_table ON CONFLICT DO NOTHING" + drop_statement = "DROP TABLE tmp_table" + + vrs_id_object_pairs = [ + ("ga4gh:VA.01", MockVRSObject('01')), + ("ga4gh:VA.02", MockVRSObject('02')), + ("ga4gh:VA.03", MockVRSObject('03')), + ("ga4gh:VA.04", MockVRSObject('04')), + ("ga4gh:VA.05", MockVRSObject('05')), + ("ga4gh:VA.06", MockVRSObject('06')), + ("ga4gh:VA.07", MockVRSObject('07')), + ("ga4gh:VA.08", MockVRSObject('08')), + ("ga4gh:VA.09", MockVRSObject('09')), + ("ga4gh:VA.10", MockVRSObject('10')), + ("ga4gh:VA.11", MockVRSObject('11')), + ] + + mocker.patch("ga4gh.core.is_pydantic_instance", return_value=True) + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt("SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = 'vrs_objects2')", None, [(True,)]) + # Batch 1 + .add_stmt(tmp_statement, None, [("Table created",)]) + .add_copy_from("tmp_table", "\n".join([f"{pair[0]}\t{pair[1].to_json()}" for pair in vrs_id_object_pairs[0:2]])) + .add_stmt(insert_statement, None, [(2,)], 5) + .add_stmt(drop_statement, None, [("Table dropped",)]) + # Batch 2 + .add_stmt(tmp_statement, None, [("Table created",)]) + .add_copy_from("tmp_table", "\n".join([f"{pair[0]}\t{pair[1].to_json()}" for pair in vrs_id_object_pairs[2:4]])) + .add_stmt(insert_statement, None, [(2,)], 4) + .add_stmt(drop_statement, None, [("Table dropped",)]) + # Batch 3 + .add_stmt(tmp_statement, None, [("Table created",)]) + .add_copy_from("tmp_table", "\n".join([f"{pair[0]}\t{pair[1].to_json()}" for pair in vrs_id_object_pairs[4:6]])) + .add_stmt(insert_statement, None, [(2,)], 3) + .add_stmt(drop_statement, None, [("Table dropped",)]) + # Batch 4 + .add_stmt(tmp_statement, None, [("Table created",)]) + .add_copy_from("tmp_table", "\n".join([f"{pair[0]}\t{pair[1].to_json()}" for pair in vrs_id_object_pairs[6:8]])) + .add_stmt(insert_statement, None, [(2,)], 5) + .add_stmt(drop_statement, None, [("Table dropped",)]) + # Batch 5 + .add_stmt(tmp_statement, None, [("Table created",)]) + .add_copy_from("tmp_table", "\n".join([f"{pair[0]}\t{pair[1].to_json()}" for pair in vrs_id_object_pairs[8:10]])) + .add_stmt(insert_statement, None, Exception("query timeout")) + # Batch 6 + .add_stmt(tmp_statement, None, [("Table created",)]) + .add_copy_from("tmp_table", "\n".join([f"{pair[0]}\t{pair[1].to_json()}" for pair in vrs_id_object_pairs[10:11]])) + .add_stmt(insert_statement, None, [(2,)], 2) + .add_stmt(drop_statement, None, [("Table dropped",)]) + ) + + sf = PostgresObjectStore("postgres://account/?param=value", 2, "vrs_objects2", 4, False) + with sf.batch_manager(sf): + sf.wait_for_writes() + assert sf.num_pending_batches() == 0 + sf[vrs_id_object_pairs[0][0]] = vrs_id_object_pairs[0][1] + sf[vrs_id_object_pairs[1][0]] = vrs_id_object_pairs[1][1] + assert sf.num_pending_batches() > 0 + sf.wait_for_writes() + assert sf.num_pending_batches() == 0 + sf[vrs_id_object_pairs[2][0]] = vrs_id_object_pairs[2][1] + sf[vrs_id_object_pairs[3][0]] = vrs_id_object_pairs[3][1] + sf[vrs_id_object_pairs[4][0]] = vrs_id_object_pairs[4][1] + sf[vrs_id_object_pairs[5][0]] = vrs_id_object_pairs[5][1] + sf[vrs_id_object_pairs[6][0]] = vrs_id_object_pairs[6][1] + sf[vrs_id_object_pairs[7][0]] = vrs_id_object_pairs[7][1] + sf[vrs_id_object_pairs[8][0]] = vrs_id_object_pairs[8][1] + sf[vrs_id_object_pairs[9][0]] = vrs_id_object_pairs[9][1] + sf[vrs_id_object_pairs[10][0]] = vrs_id_object_pairs[10][1] + + assert sf.num_pending_batches() > 0 + sf.close() + assert sf.num_pending_batches() == 0 + assert mock_eng.return_value.were_all_execd() + +def test_insertion_count(mocker): + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt(f"SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = '{vrs_object_table_name}')", None, [(True,)]) + .add_stmt( + f""" + SELECT COUNT(*) AS c + FROM {vrs_object_table_name} + WHERE LENGTH(vrs_object -> 'state' ->> 'sequence') > 1 + """, + None, [(12,)]) + ) + sf = PostgresObjectStore("postgres://account/?param=value") + assert sf.get_variation_count(VariationStatisticType.INSERTION) == 12 + sf.close() + assert mock_eng.were_all_execd() + +def test_substitution_count(mocker): + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt(f"SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = '{vrs_object_table_name}')", None, [(True,)]) + .add_stmt( + f""" + SELECT COUNT(*) AS c + FROM {vrs_object_table_name} + WHERE LENGTH(vrs_object -> 'state' ->> 'sequence') = 1 + """, + None, [(13,)]) + ) + sf = PostgresObjectStore("postgres://account/?param=value") + assert sf.get_variation_count(VariationStatisticType.SUBSTITUTION) == 13 + sf.close() + assert mock_eng.were_all_execd() + +def test_deletion_count(mocker): + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt(f"SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = '{vrs_object_table_name}')", None, [(True,)]) + .add_stmt( + f""" + SELECT COUNT(*) AS c + FROM {vrs_object_table_name} + WHERE LENGTH(vrs_object -> 'state' ->> 'sequence') = 0 + """, + None, [(14,)]) + ) + sf = PostgresObjectStore("postgres://account/?param=value") + assert sf.get_variation_count(VariationStatisticType.DELETION) == 14 + sf.close() + assert mock_eng.were_all_execd() + +def test_search_vrs_objects(mocker): + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt(f"SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = '{vrs_object_table_name}')", None, [(True,)]) + .add_stmt( + f""" + SELECT vrs_object + FROM {vrs_object_table_name} + WHERE vrs_object->>'type' = %s + AND vrs_object->>'location' IN ( + SELECT vrs_id FROM {vrs_object_table_name} + WHERE CAST (vrs_object->>'start' AS INTEGER) >= %s + AND CAST (vrs_object->>'end' AS INTEGER) <= %s + AND vrs_object->'sequenceReference'->>'refgetAccession' = %s) + """, + [ "Allele", 123456, 123457, "MySQAccId" ], [({"id": 1},), ({"id": 2},)]) + ) + sf = PostgresObjectStore("postgres://account/?param=value") + vars = sf.search_variations("MySQAccId", 123456, 123457) + sf.close() + assert len(vars) == 2 + assert "id" in vars[0] and vars[0]["id"] == 1 + assert "id" in vars[1] and vars[1]["id"] == 2 + assert mock_eng.were_all_execd() \ No newline at end of file diff --git a/tests/storage/test_snowflake.py b/tests/storage/test_snowflake.py index d6123e9..49d5c23 100644 --- a/tests/storage/test_snowflake.py +++ b/tests/storage/test_snowflake.py @@ -1,160 +1,82 @@ """ -Test Snowflake storage async batch management and write feature +Test Snowflake specific storage integration methods +and the async batch insertion -To run an integration test with a real Snowflake connection, run all - the tests with the ANYVAR_TEST_STORAGE_URI env variable set to - a Snowflake URI +Uses mocks for database integration """ -import json -import logging -import re -import time - -from anyvar.storage.snowflake import SnowflakeObjectStore - -class MockStmt: - def __init__(self, sql: str, params: list, result: list, wait_for_secs: int = 0): - self.sql = re.sub(r'\s+', ' ', sql).strip() - self.params = params - self.result = result - self.wait_for_secs = wait_for_secs - - def matches(self, sql: str, params: list): - norm_sql = re.sub(r'\s+', ' ', sql).strip() - if norm_sql == self.sql: - if self.params == True: - return True - elif (self.params is None or len(self.params) == 0) and (params is None or len(params) == 0): - return True - elif self.params == params: - return True - return False - -class MockStmtSequence(list): - def __init__(self): - self.execd = [] - - def add_stmt(self, sql: str, params: list, result: list, wait_for_secs: int = 0): - self.append(MockStmt(sql, params, result, wait_for_secs)) - return self - - def pop_if_matches(self, sql: str, params: list) -> list: - if len(self) > 0 and self[0].matches(sql, params): - self.execd.append(self[0]) - wait_for_secs = self[0].wait_for_secs - result = self[0].result - del self[0] - if wait_for_secs > 0: - time.sleep(wait_for_secs) - if isinstance(result, Exception): - raise result - else: - return result - return None - - def were_all_execd(self): - return len(self) <= 0 - -class MockConnectionCursor: - def __init__(self): - self.mock_stmt_sequences = [] - self.current_result = None - self.closed = False - - def close(self): - self.closed = True - - def cursor(self): - if self.closed: - raise Exception("connection is closed") - return self - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback) -> bool: - if exc_value: - raise exc_value - return True - - def execute(self, cmd: str, params: list = None): - if self.closed: - raise Exception("connection is closed") - self.current_result = None - for seq in self.mock_stmt_sequences: - result = seq.pop_if_matches(cmd, params) - if result: - self.current_result = result - if not self.current_result: - raise Exception(f"no mock statement found for {cmd} with params {params}") - - def executemany(self, cmd: str, params: list = None): - self.execute(cmd, params) - - def fetchone(self): - if self.closed: - raise Exception("connection is closed") - return self.current_result[0] if self.current_result and len(self.current_result) > 0 else None - - def commit(self): - self.execute("COMMIT;", None) - - def rollback(self): - self.execute("ROLLBACK;", None) - - def add_mock_stmt_sequence(self, stmt_seq: MockStmtSequence): - self.mock_stmt_sequences.append(stmt_seq) - - def were_all_execd(self): - for seq in self.mock_stmt_sequences: - if not seq.were_all_execd(): - return False - return True - -class MockVRSObject: - def __init__(self, id: str): - self.id = id - - def model_dump(self, exclude_none: bool): - return { "id": self.id } - - def to_json(self): - return json.dumps(self.model_dump(exclude_none=True)) - -def test_create_schema(caplog, mocker): - caplog.set_level(logging.DEBUG) - sf_conn = mocker.patch('snowflake.connector.connect') - sf_conn.return_value = MockConnectionCursor() - sf_conn.return_value.add_mock_stmt_sequence(MockStmtSequence() - .add_stmt("SELECT COUNT(*) FROM information_schema.tables WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() AND UPPER(table_name) = UPPER('vrs_objects');", None, [(0,)]) - .add_stmt("CREATE TABLE vrs_objects ( vrs_id VARCHAR(500) PRIMARY KEY COLLATE 'utf8', vrs_object VARIANT );", None, [("Table created",)]) - ) +import os +from sqlalchemy_mocks import MockEngine, MockStmtSequence, MockVRSObject - sf = SnowflakeObjectStore("snowflake://account/?param=value") - sf.close() - assert sf_conn.return_value.were_all_execd() +from anyvar.restapi.schema import VariationStatisticType +from anyvar.storage.snowflake import SnowflakeObjectStore, SnowflakeBatchAddMode +vrs_object_table_name = os.environ.get("ANYVAR_SQL_STORE_TABLE_NAME", "vrs_objects") -def test_schema_exists(mocker): - sf_conn = mocker.patch('snowflake.connector.connect') - sf_conn.return_value = MockConnectionCursor() - sf_conn.return_value.add_mock_stmt_sequence(MockStmtSequence() - .add_stmt("SELECT COUNT(*) FROM information_schema.tables WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() AND UPPER(table_name) = UPPER('vrs_objects');", None, [(1,)]) +def test_create_schema(mocker): + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt( + f""" + SELECT COUNT(*) FROM information_schema.tables + WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() + AND UPPER(table_name) = UPPER('{vrs_object_table_name}') + """, + None, [(0,)]) + .add_stmt(f"CREATE TABLE {vrs_object_table_name} ( vrs_id VARCHAR(500) PRIMARY KEY COLLATE 'utf8', vrs_object VARIANT )", None, [("Table created",)]) ) + sf = SnowflakeObjectStore("snowflake://account/?param=value") + sf.close() + assert mock_eng.were_all_execd() +def test_create_schema_exists(mocker): + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt( + f""" + SELECT COUNT(*) FROM information_schema.tables + WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() + AND UPPER(table_name) = UPPER('{vrs_object_table_name}') + """, + None, [(1,)]) + ) sf = SnowflakeObjectStore("snowflake://account/?param=value") sf.close() - assert sf_conn.return_value.were_all_execd() + assert mock_eng.were_all_execd() +def test_add_one_item(mocker): + mocker.patch("ga4gh.core.is_pydantic_instance", return_value=True) + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt( + f""" + SELECT COUNT(*) FROM information_schema.tables + WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() + AND UPPER(table_name) = UPPER('{vrs_object_table_name}') + """, + None, [(1,)]) + .add_stmt( + f""" + MERGE INTO {vrs_object_table_name} t USING (SELECT ? AS vrs_id, ? AS vrs_object) s ON t.vrs_id = s.vrs_id + WHEN NOT MATCHED THEN INSERT (vrs_id, vrs_object) VALUES (s.vrs_id, PARSE_JSON(s.vrs_object)) + """, + ("ga4gh:VA.01", MockVRSObject('01').to_json()), [(1,)]) + ) + sf = SnowflakeObjectStore("snowflake://account/?param=value") + sf["ga4gh:VA.01"] = MockVRSObject('01') + sf.close() + assert mock_eng.were_all_execd() -def test_batch_mgmt_and_async_write_single_thread(mocker): - tmp_statement = "CREATE TEMP TABLE IF NOT EXISTS tmp_vrs_objects (vrs_id VARCHAR(500) COLLATE 'utf8', vrs_object VARCHAR);" - insert_statement = "INSERT INTO tmp_vrs_objects (vrs_id, vrs_object) VALUES (?, ?);" +def test_add_many_items(mocker): + tmp_statement = "CREATE TEMP TABLE IF NOT EXISTS tmp_vrs_objects (vrs_id VARCHAR(500) COLLATE 'utf8', vrs_object VARCHAR)" + insert_statement = "INSERT INTO tmp_vrs_objects (vrs_id, vrs_object) VALUES (?, ?)" merge_statement = f""" MERGE INTO vrs_objects2 v USING tmp_vrs_objects s ON v.vrs_id = s.vrs_id - WHEN NOT MATCHED THEN INSERT (vrs_id, vrs_object) VALUES (s.vrs_id, PARSE_JSON(s.vrs_object)); + WHEN NOT MATCHED THEN INSERT (vrs_id, vrs_object) VALUES (s.vrs_id, PARSE_JSON(s.vrs_object)) """ - drop_statement = "DROP TABLE tmp_vrs_objects;" + drop_statement = "DROP TABLE tmp_vrs_objects" vrs_id_object_pairs = [ ("ga4gh:VA.01", MockVRSObject('01')), @@ -170,54 +92,43 @@ def test_batch_mgmt_and_async_write_single_thread(mocker): ("ga4gh:VA.11", MockVRSObject('11')), ] - mocker.patch('ga4gh.core.is_pydantic_instance', return_value=True) - sf_conn = mocker.patch('snowflake.connector.connect') - sf_conn.return_value = MockConnectionCursor() - sf_conn.return_value.add_mock_stmt_sequence(MockStmtSequence() - .add_stmt("SELECT COUNT(*) FROM information_schema.tables WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() AND UPPER(table_name) = UPPER('vrs_objects2');", None, [(1,)]) + mocker.patch("ga4gh.core.is_pydantic_instance", return_value=True) + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt("SELECT COUNT(*) FROM information_schema.tables WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() AND UPPER(table_name) = UPPER('vrs_objects2')", None, [(1,)]) # Batch 1 .add_stmt(tmp_statement, None, [("Table created",)]) .add_stmt(insert_statement, list((pair[0], pair[1].to_json()) for pair in vrs_id_object_pairs[0:2]), [(2,)], 5) .add_stmt(merge_statement, None, [(2,)]) .add_stmt(drop_statement, None, [("Table dropped",)]) - .add_stmt("COMMIT;", None, [("Committed", )]) - .add_stmt("ROLLBACK;", None, [("Rolled back", )]) # Batch 2 .add_stmt(tmp_statement, None, [("Table created",)]) .add_stmt(insert_statement, list((pair[0], pair[1].to_json()) for pair in vrs_id_object_pairs[2:4]), [(2,)], 4) .add_stmt(merge_statement, None, [(2,)]) .add_stmt(drop_statement, None, [("Table dropped",)]) - .add_stmt("COMMIT;", None, [("Committed", )]) - .add_stmt("ROLLBACK;", None, [("Rolled back", )]) # Batch 3 .add_stmt(tmp_statement, None, [("Table created",)]) .add_stmt(insert_statement, list((pair[0], pair[1].to_json()) for pair in vrs_id_object_pairs[4:6]), [(2,)], 3) .add_stmt(merge_statement, None, [(2,)]) .add_stmt(drop_statement, None, [("Table dropped",)]) - .add_stmt("COMMIT;", None, [("Committed", )]) - .add_stmt("ROLLBACK;", None, [("Rolled back", )]) # Batch 4 .add_stmt(tmp_statement, None, [("Table created",)]) .add_stmt(insert_statement, list((pair[0], pair[1].to_json()) for pair in vrs_id_object_pairs[6:8]), [(2,)], 5) .add_stmt(merge_statement, None, [(2,)]) .add_stmt(drop_statement, None, [("Table dropped",)]) - .add_stmt("COMMIT;", None, [("Committed", )]) - .add_stmt("ROLLBACK;", None, [("Rolled back", )]) # Batch 5 .add_stmt(tmp_statement, None, [("Table created",)]) .add_stmt(insert_statement, list((pair[0], pair[1].to_json()) for pair in vrs_id_object_pairs[8:10]), [(2,)], 3) .add_stmt(merge_statement, None, Exception("query timeout")) - .add_stmt("ROLLBACK;", None, [("Rolled back", )]) # Batch 6 .add_stmt(tmp_statement, None, [("Table created",)]) .add_stmt(insert_statement, list((pair[0], pair[1].to_json()) for pair in vrs_id_object_pairs[10:11]), [(2,)], 2) .add_stmt(merge_statement, None, [(2,)]) .add_stmt(drop_statement, None, [("Table dropped",)]) - .add_stmt("COMMIT;", None, [("Committed", )]) - .add_stmt("ROLLBACK;", None, [("Rolled back", )]) ) - sf = SnowflakeObjectStore("snowflake://account/?param=value", 2, "vrs_objects2", 4) + sf = SnowflakeObjectStore("snowflake://account/?param=value", 2, "vrs_objects2", 4, False, SnowflakeBatchAddMode.merge) with sf.batch_manager(sf): sf.wait_for_writes() assert sf.num_pending_batches() == 0 @@ -239,4 +150,179 @@ def test_batch_mgmt_and_async_write_single_thread(mocker): assert sf.num_pending_batches() > 0 sf.close() assert sf.num_pending_batches() == 0 - assert sf_conn.return_value.were_all_execd() \ No newline at end of file + assert mock_eng.return_value.were_all_execd() + +def test_batch_add_mode_insert_notin(mocker): + tmp_statement = "CREATE TEMP TABLE IF NOT EXISTS tmp_vrs_objects (vrs_id VARCHAR(500) COLLATE 'utf8', vrs_object VARCHAR)" + insert_statement = "INSERT INTO tmp_vrs_objects (vrs_id, vrs_object) VALUES (?, ?)" + merge_statement = f""" + INSERT INTO vrs_objects2 (vrs_id, vrs_object) + SELECT t.vrs_id, PARSE_JSON(t.vrs_object) + FROM tmp_vrs_objects t + LEFT OUTER JOIN vrs_objects2 v ON v.vrs_id = t.vrs_id + WHERE v.vrs_id IS NULL + """ + drop_statement = "DROP TABLE tmp_vrs_objects" + + vrs_id_object_pairs = [ + ("ga4gh:VA.01", MockVRSObject('01')), + ("ga4gh:VA.02", MockVRSObject('02')), + ] + + mocker.patch("ga4gh.core.is_pydantic_instance", return_value=True) + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt("SELECT COUNT(*) FROM information_schema.tables WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() AND UPPER(table_name) = UPPER('vrs_objects2')", None, [(1,)]) + # Batch 1 + .add_stmt(tmp_statement, None, [("Table created",)]) + .add_stmt(insert_statement, list((pair[0], pair[1].to_json()) for pair in vrs_id_object_pairs[0:2]), [(2,)], 5) + .add_stmt(merge_statement, None, [(2,)]) + .add_stmt(drop_statement, None, [("Table dropped",)]) + ) + + sf = SnowflakeObjectStore("snowflake://account/?param=value", 2, "vrs_objects2", None, None, SnowflakeBatchAddMode.insert_notin) + with sf.batch_manager(sf): + sf[vrs_id_object_pairs[0][0]] = vrs_id_object_pairs[0][1] + sf[vrs_id_object_pairs[1][0]] = vrs_id_object_pairs[1][1] + + sf.close() + assert mock_eng.return_value.were_all_execd() + +def test_batch_add_mode_insert(mocker): + tmp_statement = "CREATE TEMP TABLE IF NOT EXISTS tmp_vrs_objects (vrs_id VARCHAR(500) COLLATE 'utf8', vrs_object VARCHAR)" + insert_statement = "INSERT INTO tmp_vrs_objects (vrs_id, vrs_object) VALUES (?, ?)" + merge_statement = f""" + INSERT INTO vrs_objects2 (vrs_id, vrs_object) + SELECT vrs_id, PARSE_JSON(vrs_object) FROM tmp_vrs_objects + """ + drop_statement = "DROP TABLE tmp_vrs_objects" + + vrs_id_object_pairs = [ + ("ga4gh:VA.01", MockVRSObject('01')), + ("ga4gh:VA.02", MockVRSObject('02')), + ] + + mocker.patch("ga4gh.core.is_pydantic_instance", return_value=True) + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt("SELECT COUNT(*) FROM information_schema.tables WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() AND UPPER(table_name) = UPPER('vrs_objects2')", None, [(1,)]) + # Batch 1 + .add_stmt(tmp_statement, None, [("Table created",)]) + .add_stmt(insert_statement, list((pair[0], pair[1].to_json()) for pair in vrs_id_object_pairs[0:2]), [(2,)], 5) + .add_stmt(merge_statement, None, [(2,)]) + .add_stmt(drop_statement, None, [("Table dropped",)]) + ) + + sf = SnowflakeObjectStore("snowflake://account/?param=value", 2, "vrs_objects2", None, None, SnowflakeBatchAddMode.insert) + with sf.batch_manager(sf): + sf[vrs_id_object_pairs[0][0]] = vrs_id_object_pairs[0][1] + sf[vrs_id_object_pairs[1][0]] = vrs_id_object_pairs[1][1] + + sf.close() + assert mock_eng.return_value.were_all_execd() + +def test_insertion_count(mocker): + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt( + f""" + SELECT COUNT(*) FROM information_schema.tables + WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() + AND UPPER(table_name) = UPPER('{vrs_object_table_name}') + """, + None, [(1,)]) + .add_stmt( + f""" + SELECT COUNT(*) + FROM {vrs_object_table_name} + WHERE LENGTH(vrs_object:state:sequence) > 1 + """, + None, [(12,)]) + ) + sf = SnowflakeObjectStore("snowflake://account/?param=value") + assert sf.get_variation_count(VariationStatisticType.INSERTION) == 12 + sf.close() + assert mock_eng.were_all_execd() + +def test_substitution_count(mocker): + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt( + f""" + SELECT COUNT(*) FROM information_schema.tables + WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() + AND UPPER(table_name) = UPPER('{vrs_object_table_name}') + """, + None, [(1,)]) + .add_stmt( + f""" + SELECT COUNT(*) + FROM {vrs_object_table_name} + WHERE LENGTH(vrs_object:state:sequence) = 1 + """, + None, [(13,)]) + ) + sf = SnowflakeObjectStore("snowflake://account/?param=value") + assert sf.get_variation_count(VariationStatisticType.SUBSTITUTION) == 13 + sf.close() + assert mock_eng.were_all_execd() + +def test_deletion_count(mocker): + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt( + f""" + SELECT COUNT(*) FROM information_schema.tables + WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() + AND UPPER(table_name) = UPPER('{vrs_object_table_name}') + """, + None, [(1,)]) + .add_stmt( + f""" + SELECT COUNT(*) + FROM {vrs_object_table_name} + WHERE LENGTH(vrs_object:state:sequence) = 0 + """, + None, [(14,)]) + ) + sf = SnowflakeObjectStore("snowflake://account/?param=value") + assert sf.get_variation_count(VariationStatisticType.DELETION) == 14 + sf.close() + assert mock_eng.were_all_execd() + +def test_search_vrs_objects(mocker): + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt( + f""" + SELECT COUNT(*) FROM information_schema.tables + WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() + AND UPPER(table_name) = UPPER('{vrs_object_table_name}') + """, + None, [(1,)]) + .add_stmt( + f""" + SELECT vrs_object + FROM {vrs_object_table_name} + WHERE vrs_object:type = ? + AND vrs_object:location IN ( + SELECT vrs_id FROM {vrs_object_table_name} + WHERE vrs_object:start::INTEGER >= ? + AND vrs_object:end::INTEGER <= ? + AND vrs_object:sequenceReference:refgetAccession = ?) + """, + ("Allele", 123456, 123457, "MySQAccId"), [("{\"id\": 1}",), ("{\"id\": 2}",)]) + ) + sf = SnowflakeObjectStore("snowflake://account/?param=value") + vars = sf.search_variations("MySQAccId", 123456, 123457) + sf.close() + assert len(vars) == 2 + assert "id" in vars[0] and vars[0]["id"] == 1 + assert "id" in vars[1] and vars[1]["id"] == 2 + assert mock_eng.were_all_execd() \ No newline at end of file diff --git a/tests/storage/test_storage_mapping.py b/tests/storage/test_sql_storage_mapping.py similarity index 61% rename from tests/storage/test_storage_mapping.py rename to tests/storage/test_sql_storage_mapping.py index b8354a3..7c6c421 100644 --- a/tests/storage/test_storage_mapping.py +++ b/tests/storage/test_sql_storage_mapping.py @@ -1,6 +1,11 @@ -"""Tests the mutable mapping API of the storage backend""" +""" +Tests the SqlStorage methods that are NOT tested through the +REST API tests. To test against different SQL backends, this +test must be run with different ANYVAR_TEST_STORAGE_URI settings +and different ANYVAR_SQL_STORE_BATCH_ADD_MODE settings +""" from ga4gh.vrs import vrs_enref - +from anyvar.storage.snowflake import SnowflakeObjectStore, SnowflakeBatchAddMode from anyvar.translate.vrs_python import VrsPythonTranslator # pause for 5 seconds because Snowflake storage is an async write and @@ -11,12 +16,12 @@ def test_waitforsync(): # __getitem__ def test_getitem(storage, alleles): - for allele_id, allele in alleles.items(): + for allele_id, _ in alleles.items(): assert storage[allele_id] is not None # __contains__ def test_contains(storage, alleles): - for allele_id, allele in alleles.items(): + for allele_id, _ in alleles.items(): assert allele_id in storage # __len__ @@ -24,31 +29,31 @@ def test_len(storage): assert len(storage) > 0 # __iter__ -def test_iter(storage, alleles): +def test_iter(storage): obj_iter = iter(storage) count = 0 while True: try: - obj = next(obj_iter) + next(obj_iter) count += 1 except StopIteration: break - assert count == 14 + assert count == (18 if isinstance(storage, SnowflakeObjectStore) and storage.batch_add_mode == SnowflakeBatchAddMode.insert else 14) # keys def test_keys(storage, alleles): key_list = storage.keys() - for allele_id, allele in alleles.items(): + for allele_id, _ in alleles.items(): assert allele_id in key_list # __delitem__ def test_delitem(storage, alleles): - for allele_id, allele in alleles.items(): + for allele_id, _ in alleles.items(): del storage[allele_id] # __setitem__ def test_setitem(storage, alleles): - for allele_id, allele in alleles.items(): + for _, allele in alleles.items(): variation = allele["params"] definition = variation["definition"] translated_variation = VrsPythonTranslator().translate_variation(definition)