diff --git a/README.md b/README.md index aa985b7..1808454 100644 --- a/README.md +++ b/README.md @@ -57,40 +57,68 @@ In another terminal: curl http://localhost:8000/info -### Setting up Postgres - -A Postgres-backed *AnyVar* installation may use any Postgres instance, local -or remote. The following instructions are for using a docker-based -Postgres instance. - -First, run the commands in [README-pg.md](src/anyvar/storage/README-pg.md). This will create and start a local Postgres docker instance. - -Next, run the commands in [postgres_init.sql](src/anyvar/storage/postgres_init.sql). This will create the `anyvar` user with the appropriate permissions and create the `anyvar` database. - -### Setting up Snowflake -A Snowflake-backed *AnyVar* installation may use any Snowflake database schema. +### SQL Database Setup +A Postgres or Snowflake database may be used with *AnyVar*. The Postgres database +may be either local or remote. Use the `ANYVAR_STORAGE_URI` environment variable +to define the database connection URL. *AnyVar* uses [SQLAlchemy 1.4](https://docs.sqlalchemy.org/en/14/index.html) +to provide database connection management. The default database connection URL +is `"postgresql://postgres@localhost:5432/anyvar"`. + +The database integrations can be modified using the following parameters: +* `ANYVAR_SQL_STORE_BATCH_LIMIT` - in batch mode, limit VRS object upsert batches +to this number; defaults to `100,000` +* `ANYVAR_SQL_STORE_TABLE_NAME` - the name of the table that stores VRS objects; +defaults to `vrs_objects` +* `ANYVAR_SQL_STORE_MAX_PENDING_BATCHES` - the maximum number of pending batches +to allow before blocking; defaults to `50` +* `ANYVAR_SQL_STORE_FLUSH_ON_BATCHCTX_EXIT` - whether or not flush all pending +database writes when the batch manager exists; defaults to `True` + +The Postgres and Snowflake database connectors utilize a background thread +to write VRS objects to the database when operating in batch mode (e.g. annotating +a VCF file). Queries and statistics query only against the already committed database +state. Therefore, queries issued immediately after a batch operation may not reflect +all pending changes if the `ANYVAR_SQL_STORE_FLUSH_ON_BATCHCTX_EXIT` parameter is sett +to `False`. + +#### Setting up Postgres +The following instructions are for using a docker-based Postgres instance. + +First, run the commands in [README-pg.md](src/anyvar/storage/README-pg.md). +This will create and start a local Postgres docker instance. + +Next, run the commands in [postgres_init.sql](src/anyvar/storage/postgres_init.sql). +This will create the `anyvar` user with the appropriate permissions and create the +`anyvar` database. + +#### Setting up Snowflake The Snowflake database and schema must exist prior to starting *AnyVar*. To point *AnyVar* at Snowflake, specify a Snowflake URI in the `ANYVAR_STORAGE_URI` environment variable. For example: - - snowflake://my-sf-acct/?database=sf_db_name&schema=sd_schema_name&user=sf_username&password=sf_password - +``` +snowflake://sf_username:@sf_account_identifier/sf_db_name/sf_schema_name?password=sf_password +``` [Snowflake connection parameter reference](https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-api) -When running interactively and connecting to a Snowflake account that utilizes federated authentication or SSO, add -the parameter `authenticator=externalbrowser`. Non-interactive execution in a federated authentication or SSO environment -requires a service account to connect. Connections using an encrypted or unencrypted private key are also supported by -specifying the parameter `private_key=path/to/file.p8`. The key material may be URL-encoded and inlined in the connection URI, +When running interactively and connecting to a Snowflake account that utilizes +federated authentication or SSO, add the parameter `authenticator=externalbrowser`. +Non-interactive execution in a federated authentication or SSO environment +requires a service account to connect. Connections using an encrypted or unencrypted +private key are also supported by specifying the parameter `private_key=path/to/file.p8`. +The key material may be URL-encoded and inlined in the connection URI, for example: `private_key=-----BEGIN+PRIVATE+KEY-----%0AMIIEvAIBA...` - Environment variables that can be used to modify Snowflake database integration: -* `ANYVAR_SNOWFLAKE_STORE_BATCH_LIMIT` - in batch mode, limit VRS object upsert batches to this number; defaults to `100,000` -* `ANYVAR_SNOWFLAKE_STORE_TABLE_NAME` - the name of the table that stores VRS objects; defaults to `vrs_objects` -* `ANYVAR_SNOWFLAKE_STORE_MAX_PENDING_BATCHES` - the maximum number of pending batches to allow before blocking; defaults to `50` * `ANYVAR_SNOWFLAKE_STORE_PRIVATE_KEY_PASSPHRASE` - the passphrase for an encrypted private key - -NOTE: If you choose to create the VRS objects table in advance, the minimal table specification is as follows: +* `ANYVAR_SNOWFLAKE_BATCH_ADD_MODE` - the SQL statement type to use when adding new VRS objects, one of: + * `merge` (default) - use a MERGE statement. This guarantees that duplicate VRS IDs will + not be added, but also locks the VRS object table, limiting throughput. + * `insert_notin` - use INSERT INTO vrs_objects SELECT FROM tmp WHERE vrs_id NOT IN (...). + This narrows the chance of duplicates and does not require a table lock. + * `insert` - use INSERT INTO. This maximizes throughput at the cost of not checking for + duplicates at all. + +If you choose to create the VRS objects table in advance, the minimal table specification is as follows: ```sql CREATE TABLE ... ( vrs_id VARCHAR(500) COLLATE 'utf8', @@ -98,11 +126,6 @@ CREATE TABLE ... ( ) ``` -NOTE: The Snowflake database connector utilizes a background thread to write VRS objects to the database when operating in batch -mode (e.g. annotating a VCF file). Queries and statistics query only against the already committed database state. Therefore, -queries issued immediately after a batch operation may not reflect all pending changes. - - ## Deployment NOTE: The authoritative and sole source for version tags is the diff --git a/setup.cfg b/setup.cfg index c938cb6..1f7c81d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,7 +9,7 @@ install_requires = uvicorn ga4gh.vrs[extras]~=2.0.0a5 psycopg[binary] - snowflake-connector-python~=3.4.1 + snowflake-sqlalchemy~=1.5.1 [options.package_data] * = diff --git a/src/anyvar/anyvar.py b/src/anyvar/anyvar.py index f7fbf02..e5257ee 100644 --- a/src/anyvar/anyvar.py +++ b/src/anyvar/anyvar.py @@ -28,8 +28,7 @@ def create_storage(uri: Optional[str] = None) -> _Storage: * PostgreSQL `postgresql://[username]:[password]@[domain]/[database]` * Snowflake - `snowflake://[account_identifier].snowflakecomputing.com/?[param=value]&[param=value]...` - `snowflake://[account_identifier]/?[param=value]&[param=value]...` + `snowflake://[user]:@[account]/[database]/[schema]?[param=value]&[param=value]...` """ uri = uri or os.environ.get("ANYVAR_STORAGE_URI", DEFAULT_STORAGE_URI) diff --git a/src/anyvar/storage/postgres.py b/src/anyvar/storage/postgres.py index c28df48..7a4135a 100644 --- a/src/anyvar/storage/postgres.py +++ b/src/anyvar/storage/postgres.py @@ -1,265 +1,56 @@ +from io import StringIO import json -import logging -from typing import Any, Optional +from typing import Any, Optional, List -import psycopg -from ga4gh.core import is_pydantic_instance -from ga4gh.vrs import models +from sqlalchemy import text as sql_text +from sqlalchemy.engine import Connection -from anyvar.restapi.schema import VariationStatisticType - -from . import _BatchManager, _Storage +from .sql_storage import SqlStorage silos = "locations alleles haplotypes genotypes variationsets relations texts".split() -_logger = logging.getLogger(__name__) - - -class PostgresObjectStore(_Storage): +class PostgresObjectStore(SqlStorage): """PostgreSQL storage backend. Currently, this is our recommended storage approach. """ - def __init__(self, db_url: str, batch_limit: int = 65536): - """Initialize PostgreSQL DB handler. - - :param db_url: libpq connection info URL - :param batch_limit: max size of batch insert queue - """ - self.conn = psycopg.connect(db_url, autocommit=True) - self.ensure_schema_exists() - - self.batch_manager = PostgresBatchManager - self.batch_mode = False - self.batch_insert_values = [] - self.batch_limit = batch_limit - - def _create_schema(self): - """Add DB schema.""" - create_statement = """ - CREATE TABLE vrs_objects ( - vrs_id TEXT PRIMARY KEY, - vrs_object JSONB - ); - """ - with self.conn.cursor() as cur: - cur.execute(create_statement) - - def ensure_schema_exists(self): - """Check that DB schema is in place.""" - with self.conn.cursor() as cur: - cur.execute( - "SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = 'vrs_objects')" - ) # noqa: E501 - result = cur.fetchone() - if result and result[0]: - return - self._create_schema() - - def __repr__(self): - return str(self.conn) - - def __setitem__(self, name: str, value: Any): - """Add item to database. If batch mode is on, add item to queue and write only - if queue size is exceeded. - - :param name: value for `vrs_id` field - :param value: value for `vrs_object` field - """ - assert is_pydantic_instance(value), "ga4gh.vrs object value required" - name = str(name) # in case str-like - value_json = json.dumps(value.model_dump(exclude_none=True)) - if self.batch_mode: - self.batch_insert_values.append((name, value_json)) - if len(self.batch_insert_values) > self.batch_limit: - self.copy_insert() - else: - insert_query = "INSERT INTO vrs_objects (vrs_id, vrs_object) VALUES (%s, %s) ON CONFLICT DO NOTHING;" # noqa: E501 - with self.conn.cursor() as cur: - cur.execute(insert_query, [name, value_json]) - - def __getitem__(self, name: str) -> Optional[Any]: - """Fetch item from DB given key. - - Future issues: - * Remove reliance on VRS-Python models (requires rewriting the enderef module) - - :param name: key to retrieve VRS object for - :return: VRS object if available - :raise NotImplementedError: if unsupported VRS object type (this is WIP) - """ - with self.conn.cursor() as cur: - cur.execute("SELECT vrs_object FROM vrs_objects WHERE vrs_id = %s;", [name]) - result = cur.fetchone() - if result: - result = result[0] - object_type = result["type"] - if object_type == "Allele": - return models.Allele(**result) - elif object_type == "CopyNumberCount": - return models.CopyNumberCount(**result) - elif object_type == "CopyNumberChange": - return models.CopyNumberChange(**result) - elif object_type == "SequenceLocation": - return models.SequenceLocation(**result) - else: - raise NotImplementedError - else: - raise KeyError(name) - - def __contains__(self, name: str) -> bool: - """Check whether VRS objects table contains ID. - - :param name: VRS ID to look up - :return: True if ID is contained in vrs objects table - """ - with self.conn.cursor() as cur: - cur.execute("SELECT EXISTS (SELECT 1 FROM vrs_objects WHERE vrs_id = %s);", [name]) - result = cur.fetchone() - return result[0] if result else False - - def __delitem__(self, name: str) -> None: - """Delete item (not cascading -- doesn't delete referenced items) - - :param name: key to delete object for - """ - name = str(name) # in case str-like - with self.conn.cursor() as cur: - cur.execute("DELETE FROM vrs_objects WHERE vrs_id = %s;", [name]) - self.conn.commit() - - def wait_for_writes(self): - """Returns once any currently pending database modifications have been completed. - The PostgresObjectStore does not implement async writes, therefore this method is a no-op - and present only to maintain compatibility with the `_Storage` base class""" - - def close(self): - """Terminate connection if necessary.""" - if self.conn is not None: - self.conn.close() - - def __del__(self): - """Tear down DB instance.""" - self.close() - - def __len__(self): - with self.conn.cursor() as cur: - cur.execute( - """ - SELECT COUNT(*) AS c FROM vrs_objects - WHERE vrs_object ->> 'type' = 'Allele'; - """ - ) - result = cur.fetchone() - if result: - return result[0] - else: - return 0 - - def get_variation_count(self, variation_type: VariationStatisticType) -> int: - """Get total # of registered variations of requested type. - - :param variation_type: variation type to check - :return: total count - """ - if variation_type == VariationStatisticType.SUBSTITUTION: - return self._substitution_count() - elif variation_type == VariationStatisticType.INSERTION: - return self._insertion_count() - elif variation_type == VariationStatisticType.DELETION: - return self._deletion_count() - else: - return self._substitution_count() + self._deletion_count() + self._insertion_count() - - def _deletion_count(self) -> int: - with self.conn.cursor() as cur: - cur.execute( - """ - select count(*) as c from vrs_objects - where length(vrs_object -> 'state' ->> 'sequence') = 0; - """ - ) - result = cur.fetchone() - if result: - return result[0] - else: - return 0 - - def _substitution_count(self) -> int: - with self.conn.cursor() as cur: - cur.execute( - """ - select count(*) as c from vrs_objects - where length(vrs_object -> 'state' ->> 'sequence') = 1; - """ - ) - result = cur.fetchone() - if result: - return result[0] - else: - return 0 + def __init__( + self, + db_url: str, + batch_limit: Optional[int] = None, + table_name: Optional[str] = None, + max_pending_batches: Optional[int] = None, + flush_on_batchctx_exit: Optional[bool] = None, + ): + super().__init__( + db_url, + batch_limit, + table_name, + max_pending_batches, + flush_on_batchctx_exit, + ) - def _insertion_count(self): - with self.conn.cursor() as cur: - cur.execute( - """ - select count(*) as c from vrs_objects - where length(vrs_object -> 'state' ->> 'sequence') > 1 - """ + def create_schema(self, db_conn: Connection): + check_statement = f""" + SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = '{self.table_name}') + """ # nosec B608 + create_statement = f""" + CREATE TABLE {self.table_name} ( + vrs_id TEXT PRIMARY KEY, + vrs_object JSONB ) - result = cur.fetchone() - if result: - return result[0] - else: - return 0 - - def __iter__(self): - with self.conn.cursor() as cur: - cur.execute("SELECT * FROM vrs_objects;") - while True: - _next = cur.fetchone() - if _next is None: - break - yield _next - - def keys(self): - with self.conn.cursor() as cur: - cur.execute("SELECT vrs_id FROM vrs_objects;") - result = [row[0] for row in cur.fetchall()] - return result - - def search_variations(self, refget_accession: str, start: int, stop: int): - """Find all alleles that were registered that are in 1 genomic region + """ # nosec B608 + result = db_conn.execute(sql_text(check_statement)) + if not result or not result.scalar(): + db_conn.execute(sql_text(create_statement)) - Args: - refget_accession (str): refget accession (SQ. identifier) - start (int): Start genomic region to query - stop (iint): Stop genomic region to query - - Returns: - A list of VRS Alleles that have locations referenced as identifiers - """ - query_str = """ - SELECT vrs_object FROM vrs_objects - WHERE vrs_object->>'location' IN ( - SELECT vrs_id FROM vrs_objects - WHERE CAST (vrs_object->>'start' as INTEGER) >= %s - AND CAST (vrs_object->>'end' as INTEGER) <= %s - AND vrs_object->'sequenceReference'->>'refgetAccession' = %s - ); - """ - with self.conn.cursor() as cur: - cur.execute(query_str, [start, stop, refget_accession]) - results = cur.fetchall() - return [vrs_object[0] for vrs_object in results if vrs_object] - - def wipe_db(self): - """Remove all stored records from vrs_objects table.""" - with self.conn.cursor() as cur: - cur.execute("DELETE FROM vrs_objects;") + def add_one_item(self, db_conn: Connection, name: str, value: Any): + insert_query = f"INSERT INTO {self.table_name} (vrs_id, vrs_object) VALUES (:vrs_id, :vrs_object) ON CONFLICT DO NOTHING" # nosec B608 + value_json = json.dumps(value.model_dump(exclude_none=True)) + db_conn.execute(sql_text(insert_query), {"vrs_id": name, "vrs_object": value_json}) - def copy_insert(self): + def add_many_items(self, db_conn: Connection, items: list): """Perform copy-based insert, enabling much faster writes for large, repeated insert statements, using insert parameters stored in `self.batch_insert_values`. @@ -269,63 +60,72 @@ def copy_insert(self): conflicts when moving data over from that table to vrs_objects. """ tmp_statement = ( - "CREATE TEMP TABLE tmp_table (LIKE vrs_objects INCLUDING DEFAULTS);" # noqa: E501 + f"CREATE TEMP TABLE tmp_table (LIKE {self.table_name} INCLUDING DEFAULTS)" # noqa: E501 ) - copy_statement = "COPY tmp_table (vrs_id, vrs_object) FROM STDIN;" - insert_statement = ( - "INSERT INTO vrs_objects SELECT * FROM tmp_table ON CONFLICT DO NOTHING;" # noqa: E501 + insert_statement = f"INSERT INTO {self.table_name} SELECT * FROM tmp_table ON CONFLICT DO NOTHING" # nosec B608 + drop_statement = "DROP TABLE tmp_table" + db_conn.execute(sql_text(tmp_statement)) + with db_conn.connection.cursor() as cur: + row_data = [ + f"{name}\t{json.dumps(value.model_dump(exclude_none=True))}" + for name, value in items + ] + fl = StringIO("\n".join(row_data)) + cur.copy_from(fl, "tmp_table", columns=["vrs_id", "vrs_object"]) + fl.close() + db_conn.execute(sql_text(insert_statement)) + db_conn.execute(sql_text(drop_statement)) + + def deletion_count(self, db_conn: Connection) -> int: + result = db_conn.execute( + sql_text( + f""" + SELECT COUNT(*) AS c + FROM {self.table_name} + WHERE LENGTH(vrs_object -> 'state' ->> 'sequence') = 0 + """ # nosec B608 + ) ) - drop_statement = "DROP TABLE tmp_table;" - with self.conn.cursor() as cur: - cur.execute(tmp_statement) - with cur.copy(copy_statement) as copy: - for row in self.batch_insert_values: - copy.write_row(row) - cur.execute(insert_statement) - cur.execute(drop_statement) - self.conn.commit() - self.batch_insert_values = [] - - -class PostgresBatchManager(_BatchManager): - """Context manager enabling batch insertion statements via Postgres COPY command. - - Use in cases like VCF ingest when intaking large amounts of data at once. - """ - - def __init__(self, storage: PostgresObjectStore): - """Initialize context manager. - - :param storage: Postgres instance to manage. Should be taken from the active - AnyVar instance -- otherwise it won't be able to delay insertions. - :raise ValueError: if `storage` param is not a `PostgresObjectStore` instance - """ - if not isinstance(storage, PostgresObjectStore): - raise ValueError("PostgresBatchManager requires a PostgresObjectStore instance") - self._storage = storage - - def __enter__(self): - """Enter managed context.""" - self._storage.batch_insert_values = [] - self._storage.batch_mode = True - - def __exit__( - self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[Any] - ) -> bool: - """Handle exit from context management. This method is responsible for - committing or rolling back any staged inserts. - - :param exc_type: type of exception encountered, if any - :param exc_value: exception value - :param traceback: traceback for context of exception - :return: True if no exceptions encountered, False otherwise - """ - if exc_type is not None: - self._storage.conn.rollback() - self._storage.batch_insert_values = [] - self._storage.batch_mode = False - _logger.error(f"Postgres batch manager encountered exception {exc_type}: {exc_value}") - return False - self._storage.copy_insert() - self._storage.batch_mode = False - return True + return result.scalar() + + def substitution_count(self, db_conn: Connection) -> int: + result = db_conn.execute( + sql_text( + f""" + SELECT COUNT(*) AS c + FROM {self.table_name} + WHERE LENGTH(vrs_object -> 'state' ->> 'sequence') = 1 + """ # nosec B608 + ) + ) + return result.scalar() + + def insertion_count(self, db_conn: Connection): + result = db_conn.execute( + sql_text( + f""" + SELECT COUNT(*) AS c + FROM {self.table_name} + WHERE LENGTH(vrs_object -> 'state' ->> 'sequence') > 1 + """ # nosec B608 + ) + ) + return result.scalar() + + def search_vrs_objects( + self, db_conn: Connection, type: str, refget_accession: str, start: int, stop: int + ) -> List: + query_str = f""" + SELECT vrs_object + FROM {self.table_name} + WHERE vrs_object->>'type' = %s + AND vrs_object->>'location' IN ( + SELECT vrs_id FROM {self.table_name} + WHERE CAST (vrs_object->>'start' AS INTEGER) >= %s + AND CAST (vrs_object->>'end' AS INTEGER) <= %s + AND vrs_object->'sequenceReference'->>'refgetAccession' = %s) + """ # nosec B608 + with db_conn.connection.cursor() as cur: + cur.execute(query_str, [type, start, stop, refget_accession]) + results = cur.fetchall() + return [vrs_object[0] for vrs_object in results if vrs_object] diff --git a/src/anyvar/storage/snowflake.py b/src/anyvar/storage/snowflake.py index 54857c9..4085d20 100644 --- a/src/anyvar/storage/snowflake.py +++ b/src/anyvar/storage/snowflake.py @@ -1,25 +1,66 @@ +from enum import auto, StrEnum import json import logging import os -from threading import Condition, Thread -from typing import Any, List, Optional, Tuple -from urllib.parse import urlparse, parse_qs +import snowflake.connector +from snowflake.sqlalchemy.snowdialect import SnowflakeDialect +from typing import Any, List, Optional +from sqlalchemy import text as sql_text +from sqlalchemy.engine import Connection, URL +from urllib.parse import urlparse, parse_qs, urlunparse, urlencode from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization -import ga4gh.core -from ga4gh.vrs import models -import snowflake.connector -from snowflake.connector import SnowflakeConnection -from anyvar.restapi.schema import VariationStatisticType - -from . import _BatchManager, _Storage +from .sql_storage import SqlStorage _logger = logging.getLogger(__name__) +snowflake.connector.paramstyle = "qmark" + +# +# Monkey patch to workaround a bug in the Snowflake SQLAlchemy dialect +# https://github.com/snowflakedb/snowflake-sqlalchemy/issues/489 + +# Create a new pointer to the existing create_connect_args method +SnowflakeDialect._orig_create_connect_args = SnowflakeDialect.create_connect_args + + +# Define a new create_connect_args method that calls the original method +# and then fixes the result so that the account name is not mangled +# when using privatelink +def sf_create_connect_args_override(self, url: URL): + + # retval is tuple of empty array and dict ([], {}) + retval = self._orig_create_connect_args(url) + + # the dict has the options including the mangled account name + opts = retval[1] + if ( + "host" in opts + and "account" in opts + and opts["host"].endswith(".privatelink.snowflakecomputing.com") + ): + opts["account"] = opts["host"].split(".")[0] + + return retval + + +# Replace the create_connect_args method with the override +SnowflakeDialect.create_connect_args = sf_create_connect_args_override + +# +# End monkey patch +# -class SnowflakeObjectStore(_Storage): + +class SnowflakeBatchAddMode(StrEnum): + merge = auto() + insert_notin = auto() + insert = auto() + + +class SnowflakeObjectStore(SqlStorage): """Snowflake storage backend. Requires existing Snowflake database.""" def __init__( @@ -28,496 +69,185 @@ def __init__( batch_limit: Optional[int] = None, table_name: Optional[str] = None, max_pending_batches: Optional[int] = None, + flush_on_batchctx_exit: Optional[bool] = None, + batch_add_mode: Optional[SnowflakeBatchAddMode] = None, ): - """Initialize Snowflake DB handler. - - :param db_url: snowflake connection info URL, snowflake://[account_identifier]/?[param=value]&[param=value]... - :param batch_limit: max size of batch insert queue, defaults to 100000; can be set with - ANYVAR_SNOWFLAKE_STORE_BATCH_LIMIT environment variable - :param table_name: table name for storing VRS objects, defaults to `vrs_objects`; can be set with - ANYVAR_SNOWFLAKE_STORE_TABLE_NAME environment variable - :param max_pending_batches: maximum number of pending batches allowed before batch queueing blocks; can - be set with ANYVAR_SNOWFLAKE_STORE_MAX_PENDING_BATCHES environment variable - - See https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-api for full list - of database connection url parameters """ - # specify that bind variables in queries should be indicated with a question mark - snowflake.connector.paramstyle = "qmark" - - # get table name override from environment - self.table_name = table_name or os.environ.get( - "ANYVAR_SNOWFLAKE_STORE_TABLE_NAME", "vrs_objects" + :param batch_add_mode: what type of SQL statement to use when adding many items at one; one of `merge` + (no duplicates), `insert_notin` (try to avoid duplicates) or `insert` (don't worry about duplicates); + defaults to `merge`; can be set with the ANYVAR_SNOWFLAKE_BATCH_ADD_MODE + """ + prepared_db_url = self._preprocess_db_url(db_url) + super().__init__( + prepared_db_url, + batch_limit, + table_name, + max_pending_batches, + flush_on_batchctx_exit, + ) + self.batch_add_mode = batch_add_mode or os.environ.get( + "ANYVAR_SNOWFLAKE_BATCH_ADD_MODE", SnowflakeBatchAddMode.merge ) + if self.batch_add_mode not in SnowflakeBatchAddMode.__members__: + raise Exception("batch_add_mode must be one of 'merge', 'insert_notin', or 'insert'") - # parse the db url and extract the account name and conn params + def _preprocess_db_url(self, db_url: str) -> str: + db_url = db_url.replace(".snowflakecomputing.com", "") parsed_uri = urlparse(db_url) - account_name = parsed_uri.hostname.replace(".snowflakecomputing.com", "") conn_params = { key: value[0] if value else None for key, value in parse_qs(parsed_uri.query).items() } + if "private_key" in conn_params: + self.private_key_param = conn_params["private_key"] + del conn_params["private_key"] + parsed_uri = parsed_uri._replace(query=urlencode(conn_params)) + else: + self.private_key_param = None + + return urlunparse(parsed_uri) + def _get_connect_args(self, db_url: str) -> dict: # if there is a private_key param that is a file, read the contents of file - if "private_key" in conn_params: - pk_value = conn_params["private_key"] + if self.private_key_param: p_key = None pk_passphrase = None if "ANYVAR_SNOWFLAKE_STORE_PRIVATE_KEY_PASSPHRASE" in os.environ: pk_passphrase = os.environ["ANYVAR_SNOWFLAKE_STORE_PRIVATE_KEY_PASSPHRASE"].encode() - if os.path.isfile(pk_value): - with open(pk_value, "rb") as key: + if os.path.isfile(self.private_key_param): + with open(self.private_key_param, "rb") as key: p_key = serialization.load_pem_private_key( key.read(), password=pk_passphrase, backend=default_backend() ) else: p_key = serialization.load_pem_private_key( - pk_value.encode(), password=pk_passphrase, backend=default_backend() + self.private_key_param.encode(), + password=pk_passphrase, + backend=default_backend(), ) - conn_params["private_key"] = p_key.private_bytes( - encoding=serialization.Encoding.DER, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ) - - # log sanitized connection parameters - if _logger.isEnabledFor(logging.DEBUG): - sanitized_conn_params = conn_params.copy() - for secret_param in ["password", "private_key"]: - if secret_param in sanitized_conn_params: - sanitized_conn_params[secret_param] = "****sanitized****" - - _logger.debug( - "Connecting to Snowflake account %s with params %s", - account_name, - sanitized_conn_params, - ) - # log connection attempt + return { + "private_key": p_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + } else: - _logger.info( - "Connecting to Snowflake account %s", - account_name, - ) + return {} - # create the database connection and ensure it is setup - self.conn = snowflake.connector.connect(account=account_name, **conn_params) - self.ensure_schema_exists() - - # setup batch handling - self.batch_manager = SnowflakeBatchManager - self.batch_mode = False - self.batch_insert_values = [] - self.batch_limit = batch_limit or int( - os.environ.get("ANYVAR_SNOWFLAKE_STORE_BATCH_LIMIT", "100000") - ) - max_pending_batches = max_pending_batches or int( - os.environ.get("ANYVAR_SNOWFLAKE_STORE_MAX_PENDING_BATCHES", "50") - ) - self.batch_thread = SnowflakeBatchThread(self.conn, self.table_name, max_pending_batches) - self.batch_thread.start() - - def _create_schema(self): - """Add the VRS object table if it does not exist""" - # self.table_name is only modifiable via environment variable or direct instantiation of the SnowflakeObjectStore - create_statement = f""" - CREATE TABLE {self.table_name} ( - vrs_id VARCHAR(500) PRIMARY KEY COLLATE 'utf8', - vrs_object VARIANT - ); + def create_schema(self, db_conn: Connection): + check_statement = f""" + SELECT COUNT(*) FROM information_schema.tables + WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() + AND UPPER(table_name) = UPPER('{self.table_name}') """ # nosec B608 - _logger.info("Creating VRS object table %s", self.table_name) - with self.conn.cursor() as cur: - cur.execute(create_statement) - - def ensure_schema_exists(self): - """Check that VRS object table exists and create it if it does not""" - with self.conn.cursor() as cur: - cur.execute( - f""" - SELECT COUNT(*) FROM information_schema.tables - WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() - AND UPPER(table_name) = UPPER('{self.table_name}'); - """ # nosec B608 + create_statement = f""" + CREATE TABLE {self.table_name} ( + vrs_id VARCHAR(500) PRIMARY KEY COLLATE 'utf8', + vrs_object VARIANT ) - result = cur.fetchone() - - if result is None or result[0] <= 0: - self._create_schema() - - def __repr__(self): - return str(self.conn) - - def __setitem__(self, name: str, value: Any): - """Add item to database. If batch mode is on, add item to batch and submit batch - for write only if batch size is exceeded. - - :param name: value for `vrs_id` field - :param value: value for `vrs_object` field - """ - assert ga4gh.core.is_pydantic_instance(value), "ga4gh.vrs object value required" - name = str(name) # in case str-like - if self.batch_mode: - self.batch_insert_values.append((name, value)) - _logger.debug("Appended item %s to batch queue", name) - if len(self.batch_insert_values) >= self.batch_limit: - self.batch_thread.queue_batch(self.batch_insert_values) - _logger.info( - "Queued batch of %s VRS objects for write", len(self.batch_insert_values) - ) - self.batch_insert_values = [] - else: - value_json = json.dumps(value.model_dump(exclude_none=True)) - insert_query = f""" - MERGE INTO {self.table_name} t USING (SELECT ? AS vrs_id, ? AS vrs_object) s ON t.vrs_id = s.vrs_id - WHEN NOT MATCHED THEN INSERT (vrs_id, vrs_object) VALUES (s.vrs_id, PARSE_JSON(s.vrs_object)); - """ # nosec B608 - with self.conn.cursor() as cur: - cur.execute(insert_query, [name, value_json]) - _logger.debug("Inserted item %s to %s", name, self.table_name) - - def __getitem__(self, name: str) -> Optional[Any]: - """Fetch item from DB given key. - - Future issues: - * Remove reliance on VRS-Python models (requires rewriting the enderef module) + """ # nosec B608 + result = db_conn.execute(sql_text(check_statement)) + if result.scalar() < 1: + db_conn.execute(sql_text(create_statement)) + + def add_one_item(self, db_conn: Connection, name: str, value: Any): + insert_query = f""" + MERGE INTO {self.table_name} t USING (SELECT ? AS vrs_id, ? AS vrs_object) s ON t.vrs_id = s.vrs_id + WHEN NOT MATCHED THEN INSERT (vrs_id, vrs_object) VALUES (s.vrs_id, PARSE_JSON(s.vrs_object)) + """ # nosec B608 + value_json = json.dumps(value.model_dump(exclude_none=True)) + db_conn.execute(insert_query, (name, value_json)) + _logger.debug("Inserted item %s to %s", name, self.table_name) - :param name: key to retrieve VRS object for - :return: VRS object if available - :raise NotImplementedError: if unsupported VRS object type (this is WIP) - """ - with self.conn.cursor() as cur: - cur.execute( - f"SELECT vrs_object FROM {self.table_name} WHERE vrs_id = ?;", [name] # nosec B608 - ) - result = cur.fetchone() - if result: - result = json.loads(result[0]) - object_type = result["type"] - if object_type == "Allele": - return models.Allele(**result) - elif object_type == "CopyNumberCount": - return models.CopyNumberCount(**result) - elif object_type == "CopyNumberChange": - return models.CopyNumberChange(**result) - elif object_type == "SequenceLocation": - return models.SequenceLocation(**result) - else: - raise NotImplementedError + def add_many_items(self, db_conn: Connection, items: list): + """Bulk inserts the batch values into a TEMP table, then merges into the main {self.table_name} table""" + tmp_statement = "CREATE TEMP TABLE IF NOT EXISTS tmp_vrs_objects (vrs_id VARCHAR(500) COLLATE 'utf8', vrs_object VARCHAR)" + insert_statement = "INSERT INTO tmp_vrs_objects (vrs_id, vrs_object) VALUES (?, ?)" + if self.batch_add_mode == SnowflakeBatchAddMode.insert: + merge_statement = f""" + INSERT INTO {self.table_name} (vrs_id, vrs_object) + SELECT vrs_id, PARSE_JSON(vrs_object) FROM tmp_vrs_objects + """ # nosec B608 + elif self.batch_add_mode == SnowflakeBatchAddMode.insert_notin: + merge_statement = f""" + INSERT INTO {self.table_name} (vrs_id, vrs_object) + SELECT t.vrs_id, PARSE_JSON(t.vrs_object) + FROM tmp_vrs_objects t + LEFT OUTER JOIN {self.table_name} v ON v.vrs_id = t.vrs_id + WHERE v.vrs_id IS NULL + """ # nosec B608 else: - raise KeyError(name) - - def __contains__(self, name: str) -> bool: - """Check whether VRS objects table contains ID. - - :param name: VRS ID to look up - :return: True if ID is contained in vrs objects table - """ - with self.conn.cursor() as cur: - cur.execute( - f"SELECT COUNT(*) FROM {self.table_name} WHERE vrs_id = ?;", [name] # nosec B608 - ) - result = cur.fetchone() - return result[0] > 0 if result else False - - def __delitem__(self, name: str) -> None: - """Delete item (not cascading -- doesn't delete referenced items) - - :param name: key to delete object for - """ - name = str(name) # in case str-like - with self.conn.cursor() as cur: - cur.execute(f"DELETE FROM {self.table_name} WHERE vrs_id = ?;", [name]) # nosec B608 - self.conn.commit() - - def wait_for_writes(self): - """Returns once any currently pending database modifications have been completed.""" - if self.batch_thread is not None: - # short circuit if the queue is empty - with self.batch_thread.cond: - if not self.batch_thread.pending_batch_list: - return - - # queue an empty batch - batch = [] - self.batch_thread.queue_batch(batch) - # wait for the batch to be removed from the pending queue - while True: - with self.batch_thread.cond: - if list(filter(lambda x: x is batch, self.batch_thread.pending_batch_list)): - self.batch_thread.cond.wait() - else: - break - - def close(self): - """Stop the batch thread and wait for it to complete""" - if self.batch_thread is not None: - self.batch_thread.stop() - self.batch_thread.join() - self.batch_thread = None - # Terminate connection if necessary. - if self.conn is not None: - self.conn.close() - self.conn = None - - def __del__(self): - """Tear down DB instance.""" - self.close() - - def __len__(self): - with self.conn.cursor() as cur: - cur.execute( - f""" - SELECT COUNT(*) AS c FROM {self.table_name} - WHERE vrs_object:type = 'Allele'; + merge_statement = f""" + MERGE INTO {self.table_name} v USING tmp_vrs_objects s ON v.vrs_id = s.vrs_id + WHEN NOT MATCHED THEN INSERT (vrs_id, vrs_object) VALUES (s.vrs_id, PARSE_JSON(s.vrs_object)) """ # nosec B608 - ) - result = cur.fetchone() - if result: - return result[0] - else: - return 0 - - def get_variation_count(self, variation_type: VariationStatisticType) -> int: - """Get total # of registered variations of requested type. + drop_statement = "DROP TABLE tmp_vrs_objects" + + # create row data removing duplicates + # because if there are duplicates in the source of the merge + # Snowflake inserts duplicate rows + row_data = [] + row_keys = set() + for name, value in items: + if name not in row_keys: + row_keys.add(name) + row_data.append((name, json.dumps(value.model_dump(exclude_none=True)))) + _logger.info("Created row data for insert, first item is %s", row_data[0]) + + db_conn.execute(sql_text(tmp_statement)) + # NB - enclosing the insert statement in sql_text() + # causes a "Bind variable ? not set" error from Snowflake + # It is unclear why this is that case + db_conn.execute(insert_statement, row_data) + db_conn.execute(sql_text(merge_statement)) + db_conn.execute(sql_text(drop_statement)) + + def deletion_count(self, db_conn: Connection) -> int: + result = db_conn.execute( + f""" + SELECT COUNT(*) + FROM {self.table_name} + WHERE LENGTH(vrs_object:state:sequence) = 0 + """ # nosec B608 + ) + return result.scalar() + + def substitution_count(self, db_conn: Connection) -> int: + result = db_conn.execute( + f""" + SELECT COUNT(*) + FROM {self.table_name} + WHERE LENGTH(vrs_object:state:sequence) = 1 + """ # nosec B608 + ) + return result.scalar() + + def insertion_count(self, db_conn: Connection) -> int: + result = db_conn.execute( + f""" + SELECT COUNT(*) + FROM {self.table_name} + WHERE LENGTH(vrs_object:state:sequence) > 1 + """ # nosec B608 + ) + return result.scalar() - :param variation_type: variation type to check - :return: total count - """ - if variation_type == VariationStatisticType.SUBSTITUTION: - return self._substitution_count() - elif variation_type == VariationStatisticType.INSERTION: - return self._insertion_count() - elif variation_type == VariationStatisticType.DELETION: - return self._deletion_count() - else: - return self._substitution_count() + self._deletion_count() + self._insertion_count() - - def _deletion_count(self) -> int: - with self.conn.cursor() as cur: - cur.execute( - f""" - SELECT COUNT(*) FROM {self.table_name} - WHERE LENGTH(vrs_object:state:sequence) = 0; - """ # nosec B608 - ) - result = cur.fetchone() - if result: - return result[0] - else: - return 0 - - def _substitution_count(self) -> int: - with self.conn.cursor() as cur: - cur.execute( - f""" - SELECT COUNT(*) FROM {self.table_name} - WHERE LENGTH(vrs_object:state:sequence) = 1; - """ # nosec B608 - ) - result = cur.fetchone() - if result: - return result[0] - else: - return 0 - - def _insertion_count(self): - with self.conn.cursor() as cur: - cur.execute( - f""" - SELECT COUNT(*) FROM {self.table_name} - WHERE LENGTH(vrs_object:state:sequence) > 1 - """ # nosec B608 - ) - result = cur.fetchone() - if result: - return result[0] - else: - return 0 - - def __iter__(self): - with self.conn.cursor() as cur: - cur.execute(f"SELECT * FROM {self.table_name};") # nosec B608 - while True: - _next = cur.fetchone() - if _next is None: - break - yield _next - - def keys(self): - with self.conn.cursor() as cur: - cur.execute(f"SELECT vrs_id FROM {self.table_name};") # nosec B608 - result = [row[0] for row in cur.fetchall()] - return result - - def search_variations(self, refget_accession: str, start: int, stop: int): - """Find all alleles that were registered that are in 1 genomic region - - Args: - refget_accession (str): refget accession (SQ. identifier) - start (int): Start genomic region to query - stop (iint): Stop genomic region to query - - Returns: - A list of VRS Alleles that have locations referenced as identifiers - """ + def search_vrs_objects( + self, db_conn: Connection, type: str, refget_accession: str, start: int, stop: int + ) -> List[Any]: query_str = f""" - SELECT vrs_object FROM {self.table_name} - WHERE vrs_object:location IN ( + SELECT vrs_object + FROM {self.table_name} + WHERE vrs_object:type = ? + AND vrs_object:location IN ( SELECT vrs_id FROM {self.table_name} - WHERE vrs_object:start::INTEGER >= ? - AND vrs_object:end::INTEGER <= ? - AND vrs_object:sequenceReference:refgetAccession = ? - ); + WHERE vrs_object:start::INTEGER >= ? + AND vrs_object:end::INTEGER <= ? + AND vrs_object:sequenceReference:refgetAccession = ?) """ # nosec B608 - with self.conn.cursor() as cur: - cur.execute(query_str, [start, stop, refget_accession]) - results = cur.fetchall() - return [json.loads(vrs_object[0]) for vrs_object in results if vrs_object] - - def wipe_db(self): - """Remove all stored records from {self.table_name} table.""" - with self.conn.cursor() as cur: - cur.execute(f"DELETE FROM {self.table_name};") # nosec B608 - - def num_pending_batches(self): - if self.batch_thread: - return len(self.batch_thread.pending_batch_list) - else: - return 0 - - -class SnowflakeBatchManager(_BatchManager): - """Context manager enabling bulk insertion statements - - Use in cases like VCF ingest when intaking large amounts of data at once. - Insertion batches are processed by a background thread. - """ - - def __init__(self, storage: SnowflakeObjectStore): - """Initialize context manager. - - :param storage: Snowflake instance to manage. Should be taken from the active - AnyVar instance -- otherwise it won't be able to delay insertions. - :raise ValueError: if `storage` param is not a `SnowflakeObjectStore` instance - """ - if not isinstance(storage, SnowflakeObjectStore): - raise ValueError("SnowflakeBatchManager requires a SnowflakeObjectStore instance") - self._storage = storage - - def __enter__(self): - """Enter managed context.""" - self._storage.batch_insert_values = [] - self._storage.batch_mode = True - - def __exit__( - self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[Any] - ) -> bool: - """Handle exit from context management. Hands off final batch to background bulk insert processor. - - :param exc_type: type of exception encountered, if any - :param exc_value: exception value - :param traceback: traceback for context of exception - :return: True if no exceptions encountered, False otherwise - """ - if exc_type is not None: - self._storage.batch_insert_values = None - self._storage.batch_mode = False - _logger.error(f"Snowflake batch manager encountered exception {exc_type}: {exc_value}") - _logger.exception(exc_value) - return False - self._storage.batch_thread.queue_batch(self._storage.batch_insert_values) - self._storage.batch_mode = False - self._storage.batch_insert_values = None - return True - - -class SnowflakeBatchThread(Thread): - """Background thread that merges VRS objects into the database""" - - def __init__(self, conn: SnowflakeConnection, table_name: str, max_pending_batches: int): - """Constructs a new background thread - - :param conn: Snowflake connection - """ - super().__init__(daemon=True) - self.conn = conn - self.cond = Condition() - self.run_flag = True - self.pending_batch_list = [] - self.table_name = table_name - self.max_pending_batches = max_pending_batches - - def run(self): - """As long as run_flag is true, waits then processes pending batches""" - while self.run_flag: - with self.cond: - self.cond.wait() - self.process_pending_batches() - - def stop(self): - """Sets the run_flag to false and notifies""" - self.run_flag = False - with self.cond: - self.cond.notify() - - def queue_batch(self, batch_insert_values: List[Tuple]): - """Adds a batch to the pending list. If the pending batch list is already at its max size, waits until there is room - - :param batch_insert_values: list of tuples where each tuple consists of (vrs_id, vrs_object) - """ - with self.cond: - if batch_insert_values is not None: - _logger.info("Queueing batch of %s items", len(batch_insert_values)) - while len(self.pending_batch_list) >= self.max_pending_batches: - _logger.debug("Pending batch queue is full, waiting for space...") - self.cond.wait() - self.pending_batch_list.append(batch_insert_values) - _logger.info("Queued batch of %s items", len(batch_insert_values)) - self.cond.notify_all() - - def process_pending_batches(self): - """As long as batches are available for processing, merges them into the database""" - _logger.info("Processing %s queued batches", len(self.pending_batch_list)) - while True: - batch_insert_values = None - with self.cond: - if len(self.pending_batch_list) > 0: - batch_insert_values = self.pending_batch_list[0] - del self.pending_batch_list[0] - self.cond.notify_all() - else: - self.cond.notify_all() - break - - if batch_insert_values: - self._run_copy_insert(batch_insert_values) - _logger.info("Processed queued batch of %s items", len(batch_insert_values)) - - def _run_copy_insert(self, batch_insert_values): - """Bulk inserts the batch values into a TEMP table, then merges into the main {self.table_name} table""" - - try: - tmp_statement = "CREATE TEMP TABLE IF NOT EXISTS tmp_vrs_objects (vrs_id VARCHAR(500) COLLATE 'utf8', vrs_object VARCHAR);" - insert_statement = "INSERT INTO tmp_vrs_objects (vrs_id, vrs_object) VALUES (?, ?);" - merge_statement = f""" - MERGE INTO {self.table_name} v USING tmp_vrs_objects s ON v.vrs_id = s.vrs_id - WHEN NOT MATCHED THEN INSERT (vrs_id, vrs_object) VALUES (s.vrs_id, PARSE_JSON(s.vrs_object)); - """ # nosec B608 - drop_statement = "DROP TABLE tmp_vrs_objects;" - - row_data = [ - (name, json.dumps(value.model_dump(exclude_none=True))) - for name, value in batch_insert_values - ] - _logger.info("Created row data for insert, first item is %s", row_data[0]) - - with self.conn.cursor() as cur: - cur.execute(tmp_statement) - cur.executemany(insert_statement, row_data) - cur.execute(merge_statement) - cur.execute(drop_statement) - self.conn.commit() - except Exception: - _logger.exception("Failed to merge VRS object batch into database") - finally: - self.conn.rollback() + results = db_conn.execute( + query_str, + (type, start, stop, refget_accession), + ) + return [json.loads(row[0]) for row in results if row] diff --git a/src/anyvar/storage/sql_storage.py b/src/anyvar/storage/sql_storage.py new file mode 100644 index 0000000..fbeb78e --- /dev/null +++ b/src/anyvar/storage/sql_storage.py @@ -0,0 +1,499 @@ +from abc import abstractmethod +import json +import logging +import os +from threading import Condition, Thread +from typing import Any, Generator, List, Optional, Tuple + +import ga4gh.core +from ga4gh.vrs import models +from sqlalchemy import text as sql_text, create_engine +from sqlalchemy.engine import Connection + +from anyvar.restapi.schema import VariationStatisticType + +from . import _BatchManager, _Storage + +_logger = logging.getLogger(__name__) + + +class SqlStorage(_Storage): + """Relational database storage backend. Uses SQLAlchemy as a DB abstraction layer and pool. + Methods that utilize straightforward SQL are implemented in this class. Methods that require + specialized SQL statements must be implemented in a database specific subclass. + """ + + def __init__( + self, + db_url: str, + batch_limit: Optional[int] = None, + table_name: Optional[str] = None, + max_pending_batches: Optional[int] = None, + flush_on_batchctx_exit: Optional[bool] = None, + ): + """Initialize DB handler. + + :param db_url: db connection info URL + :param batch_limit: max size of batch insert queue, defaults to 100000; can be set with + ANYVAR_SQL_STORE_BATCH_LIMIT environment variable + :param table_name: table name for storing VRS objects, defaults to `vrs_objects`; can be set with + ANYVAR_SQL_STORE_TABLE_NAME environment variable + :param max_pending_batches: maximum number of pending batches allowed before batch queueing blocks; can + be set with ANYVAR_SQL_STORE_MAX_PENDING_BATCHES environment variable + :param flush_on_batchctx_exit: whether to call `wait_for_writes()` when exiting the batch manager context; + defaults to True; can be set with the ANYVAR_SQL_STORE_FLUSH_ON_BATCHCTX_EXIT environment variable + + See https://docs.sqlalchemy.org/en/20/core/connections.html for connection URL info + """ + + # get table name override from environment + self.table_name = table_name or os.environ.get("ANYVAR_SQL_STORE_TABLE_NAME", "vrs_objects") + + # create the database connection engine + self.conn_pool = create_engine( + db_url, + pool_size=1, + max_overflow=1, + pool_recycle=3600, + connect_args=self._get_connect_args(db_url), + ) + + # create the schema objects if necessary + with self._get_connection() as conn: + self.create_schema(conn) + + # setup batch handling + self.batch_manager = SqlStorageBatchManager + self.batch_mode = False + self.batch_insert_values = [] + self.batch_limit = batch_limit or int( + os.environ.get("ANYVAR_SQL_STORE_BATCH_LIMIT", "100000") + ) + self.flush_on_batchctx_exit = ( + bool(os.environ.get("ANYVAR_SQL_STORE_FLUSH_ON_BATCHCTX_EXIT", "True")) + if flush_on_batchctx_exit is None + else flush_on_batchctx_exit + ) + max_pending_batches = max_pending_batches or int( + os.environ.get("ANYVAR_SQL_STORE_MAX_PENDING_BATCHES", "50") + ) + self.batch_thread = SqlStorageBatchThread(self, max_pending_batches) + self.batch_thread.start() + + def _get_connection(self) -> Connection: + """Returns a database connection""" + return self.conn_pool.connect() + + def _get_connect_args(self, db_url: str) -> dict: + """Returns connect_args for the SQLAlchemy create_engine() call + The default implementation returns None""" + return {} + + @abstractmethod + def create_schema(self, db_conn: Connection): + """Add the VRS object table if it does not exist + + :param db_conn: a database connection + """ + + def __repr__(self): + return str(self.conn_pool) + + def __setitem__(self, name: str, value: Any): + """Add item to database. If batch mode is on, add item to batch and submit batch + for write only if batch size is exceeded. + + :param name: value for `vrs_id` field + :param value: value for `vrs_object` field + """ + assert ga4gh.core.is_pydantic_instance(value), "ga4gh.vrs object value required" + name = str(name) # in case str-like + if self.batch_mode: + self.batch_insert_values.append((name, value)) + _logger.debug("Appended item %s to batch queue", name) + if len(self.batch_insert_values) >= self.batch_limit: + self.batch_thread.queue_batch(self.batch_insert_values) + _logger.info( + "Queued batch of %s VRS objects for write", len(self.batch_insert_values) + ) + self.batch_insert_values = [] + else: + with self._get_connection() as db_conn: + with db_conn.begin(): + self.add_one_item(db_conn, name, value) + _logger.debug("Inserted item %s to %s", name, self.table_name) + + @abstractmethod + def add_one_item(self, db_conn: Connection, name: str, value: Any): + """Adds/merges a single item to the database + + :param db_conn: a database connection + :param name: value for `vrs_id` field + :param value: value for `vrs_object` field + """ + + @abstractmethod + def add_many_items(self, db_conn: Connection, items: list): + """Adds/merges many items to the database + + :param db_conn: a database connection + :param items: a list of (vrs_id, vrs_object) tuples + """ + + def __getitem__(self, name: str) -> Optional[Any]: + """Fetch item from DB given key. + + Future issues: + * Remove reliance on VRS-Python models (requires rewriting the enderef module) + + :param name: key to retrieve VRS object for + :return: VRS object if available + :raise NotImplementedError: if unsupported VRS object type (this is WIP) + """ + with self._get_connection() as conn: + result = self.fetch_vrs_object(conn, name) + if result: + object_type = result["type"] + if object_type == "Allele": + return models.Allele(**result) + elif object_type == "CopyNumberCount": + return models.CopyNumberCount(**result) + elif object_type == "CopyNumberChange": + return models.CopyNumberChange(**result) + elif object_type == "SequenceLocation": + return models.SequenceLocation(**result) + else: + raise NotImplementedError + else: + raise KeyError(name) + + def fetch_vrs_object(self, db_conn: Connection, vrs_id: str) -> Optional[Any]: + """Fetches a single VRS object from the database, return the value as a JSON object + + :param db_conn: a database connection + :param vrs_id: the VRS ID + :return: VRS object if available + """ + result = db_conn.execute( + sql_text( + f"SELECT vrs_object FROM {self.table_name} WHERE vrs_id = :vrs_id" # nosec B608 + ), + {"vrs_id": vrs_id}, + ) + if result: + value = result.scalar() + return json.loads(value) if value and isinstance(value, str) else value + else: + return None + + def __contains__(self, name: str) -> bool: + """Check whether VRS objects table contains ID. + + :param name: VRS ID to look up + :return: True if ID is contained in vrs objects table + """ + with self._get_connection() as conn: + return self.fetch_vrs_object(conn, name) is not None + + def __delitem__(self, name: str) -> None: + """Delete item (not cascading -- doesn't delete referenced items) + + :param name: key to delete object for + """ + name = str(name) # in case str-like + with self._get_connection() as conn: + with conn.begin(): + self.delete_vrs_object(conn, name) + + def delete_vrs_object(self, db_conn: Connection, vrs_id: str): + """Delete a single VRS object + + :param db_conn: a database connection + :param vrs_id: the VRS ID + """ + db_conn.execute( + sql_text(f"DELETE FROM {self.table_name} WHERE vrs_id = :vrs_id"), # nosec B608 + {"vrs_id": vrs_id}, + ) + + def wait_for_writes(self): + """Returns once any currently pending database modifications have been completed.""" + if hasattr(self, "batch_thread") and self.batch_thread is not None: + # short circuit if the queue is empty + with self.batch_thread.cond: + if not self.batch_thread.pending_batch_list: + return + + # queue an empty batch + batch = [] + self.batch_thread.queue_batch(batch) + # wait for the batch to be removed from the pending queue + while True: + with self.batch_thread.cond: + if list(filter(lambda x: x is batch, self.batch_thread.pending_batch_list)): + self.batch_thread.cond.wait() + else: + break + + def close(self): + """Stop the batch thread and wait for it to complete""" + if hasattr(self, "batch_thread") and self.batch_thread is not None: + self.batch_thread.stop() + self.batch_thread.join() + self.batch_thread = None + # Terminate connection if necessary. + if hasattr(self, "conn_pool") and self.conn_pool is not None: + self.conn_pool.dispose() + self.conn_pool = None + + def __del__(self): + """Tear down DB instance.""" + self.close() + + def __len__(self) -> int: + """Returns the total number of VRS objects""" + with self._get_connection() as conn: + return self.get_vrs_object_count(conn) + + def get_vrs_object_count(self, db_conn: Connection) -> int: + """Returns the total number of objects + + :param db_conn: a database connection + """ + result = db_conn.execute(sql_text(f"SELECT COUNT(*) FROM {self.table_name}")) # nosec B608 + return result.scalar() + + def get_variation_count(self, variation_type: VariationStatisticType) -> int: + """Get total # of registered variations of requested type. + + :param variation_type: variation type to check + :return: total count + """ + with self._get_connection() as conn: + if variation_type == VariationStatisticType.SUBSTITUTION: + return self.substitution_count(conn) + elif variation_type == VariationStatisticType.INSERTION: + return self.insertion_count(conn) + elif variation_type == VariationStatisticType.DELETION: + return self.deletion_count(conn) + else: + return ( + self.substitution_count(conn) + + self.deletion_count(conn) + + self.insertion_count(conn) + ) + + @abstractmethod + def deletion_count(self, db_conn: Connection) -> int: + """Returns the total number of deletions + + :param db_conn: a database connection + """ + + @abstractmethod + def substitution_count(self, db_conn: Connection) -> int: + """Returns the total number of substitutions + + :param db_conn: a database connection + """ + + @abstractmethod + def insertion_count(self, db_conn: Connection): + """Returns the total number of insertions + + :param db_conn: a database connection + """ + + def __iter__(self): + """Iterates over all VRS objects in the database""" + with self._get_connection() as conn: + iter = self.fetch_all_vrs_objects(conn) + for obj in iter: + yield obj + + def fetch_all_vrs_objects(self, db_conn: Connection) -> Generator[Any, Any, Any]: + """Returns a generator that iterates over all VRS objects in the database + in no specific order + + :param db_conn: a database connection + """ + result = db_conn.execute( + sql_text(f"SELECT vrs_object FROM {self.table_name}") # nosec B608 + ) + for row in result: + if row: + value = row["vrs_object"] + yield json.loads(value) if value and isinstance(value, str) else value + else: + yield None + + def keys(self): + """Returns a list of all VRS IDs in the database""" + with self._get_connection() as conn: + return self.fetch_all_vrs_ids(conn) + + def fetch_all_vrs_ids(self, db_conn: Connection) -> List: + """Returns a list of all VRS IDs in the database + + :param db_conn: a database connection + """ + result = db_conn.execute(sql_text(f"SELECT vrs_id FROM {self.table_name}")) # nosec B608 + return [row[0] for row in result] + + def search_variations(self, refget_accession: str, start: int, stop: int): + """Find all alleles that were registered that are in 1 genomic region + + :param refget_accession: refget accession (SQ. identifier) + :param start: Start genomic region to query + :param stop: Stop genomic region to query + + :return: a list of VRS Alleles that have locations referenced as identifiers + """ + with self._get_connection() as conn: + return self.search_vrs_objects(conn, "Allele", refget_accession, start, stop) + + @abstractmethod + def search_vrs_objects( + self, db_conn: Connection, type: str, refget_accession: str, start: int, stop: int + ) -> List[Any]: + """Find all VRS objects of the particular type and region + + :param type: the type of VRS object to search for + :param refget_accession: refget accession (SQ. identifier) + :param start: Start genomic region to query + :param stop: Stop genomic region to query + + :return: a list of VRS objects + """ + + def wipe_db(self): + """Remove all stored records from the database""" + with self._get_connection() as conn: + with conn.begin(): + conn.execute(sql_text(f"DELETE FROM {self.table_name}")) # nosec B608 + + def num_pending_batches(self) -> int: + """Returns the number of pending insert batches""" + if self.batch_thread: + return len(self.batch_thread.pending_batch_list) + else: + return 0 + + +class SqlStorageBatchManager(_BatchManager): + """Context manager enabling bulk insertion statements + + Use in cases like VCF ingest when intaking large amounts of data at once. + Insertion batches are processed by a background thread. + """ + + def __init__(self, storage: SqlStorage): + """Initialize context manager. + + :param storage: SqlStorage instance to manage. Should be taken from the active + AnyVar instance -- otherwise it won't be able to delay insertions. + :raise ValueError: if `storage` param is not a `SqlStorage` instance + """ + if not isinstance(storage, SqlStorage): + raise ValueError("SqlStorageBatchManager requires a SqlStorage instance") + self._storage = storage + + def __enter__(self): + """Enter managed context.""" + self._storage.batch_insert_values = [] + self._storage.batch_mode = True + + def __exit__( + self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[Any] + ) -> bool: + """Handle exit from context management. Hands off final batch to background bulk insert processor. + + :param exc_type: type of exception encountered, if any + :param exc_value: exception value + :param traceback: traceback for context of exception + :return: True if no exceptions encountered, False otherwise + """ + if exc_type is not None: + self._storage.batch_insert_values = None + self._storage.batch_mode = False + _logger.error( + f"Sql storage batch manager encountered exception {exc_type}: {exc_value}" + ) + _logger.exception(exc_value) + return False + self._storage.batch_thread.queue_batch(self._storage.batch_insert_values) + self._storage.batch_mode = False + self._storage.batch_insert_values = None + if self._storage.flush_on_batchctx_exit: + self._storage.wait_for_writes() + return True + + +class SqlStorageBatchThread(Thread): + """Background thread that merges VRS objects into the database""" + + def __init__(self, sql_store: SqlStorage, max_pending_batches: int): + """Constructs a new background thread + + :param conn_pool: SQLAlchemy connection pool + """ + super().__init__(daemon=True) + self.sql_store = sql_store + self.cond = Condition() + self.run_flag = True + self.pending_batch_list = [] + self.max_pending_batches = max_pending_batches + + def run(self): + """As long as run_flag is true, waits then processes pending batches""" + while self.run_flag: + with self.cond: + self.cond.wait() + self.process_pending_batches() + + def stop(self): + """Sets the run_flag to false and notifies""" + self.run_flag = False + with self.cond: + self.cond.notify() + + def queue_batch(self, batch_insert_values: List[Tuple]): + """Adds a batch to the pending list. If the pending batch list is already at its max size, waits until there is room + + :param batch_insert_values: list of tuples where each tuple consists of (vrs_id, vrs_object) + """ + with self.cond: + if batch_insert_values is not None: + _logger.info("Queueing batch of %s items", len(batch_insert_values)) + while len(self.pending_batch_list) >= self.max_pending_batches: + _logger.debug("Pending batch queue is full, waiting for space...") + self.cond.wait() + self.pending_batch_list.append(batch_insert_values) + _logger.info("Queued batch of %s items", len(batch_insert_values)) + self.cond.notify_all() + + def process_pending_batches(self): + """As long as batches are available for processing, merges them into the database""" + _logger.info("Processing %s queued batches", len(self.pending_batch_list)) + while True: + batch_insert_values = None + with self.cond: + if len(self.pending_batch_list) > 0: + batch_insert_values = self.pending_batch_list[0] + del self.pending_batch_list[0] + self.cond.notify_all() + else: + self.cond.notify_all() + break + + if batch_insert_values: + self._run_copy_insert(batch_insert_values) + _logger.info("Processed queued batch of %s items", len(batch_insert_values)) + + def _run_copy_insert(self, batch_insert_values): + try: + with self.sql_store._get_connection() as conn: + with conn.begin(): + self.sql_store.add_many_items(conn, batch_insert_values) + except Exception: + _logger.exception("Failed to merge VRS object batch into database") diff --git a/tests/conftest.py b/tests/conftest.py index 367eb13..7eace7d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,7 @@ def pytest_collection_modifyitems(items): """Modify test items in place to ensure test modules run in a given order.""" - MODULE_ORDER = ["test_lifespan", "test_variation", "test_general", "test_location", "test_search", "test_vcf", "test_storage_mapping", "test_snowflake"] + MODULE_ORDER = ["test_lifespan", "test_variation", "test_general", "test_location", "test_search", "test_vcf", "test_sql_storage_mapping", "test_postgres", "test_snowflake"] # remember to add new test modules to the order constant: assert len(MODULE_ORDER) == len(list(Path(__file__).parent.rglob("test_*.py"))) items.sort(key=lambda i: MODULE_ORDER.index(i.module.__name__)) diff --git a/tests/storage/sqlalchemy_mocks.py b/tests/storage/sqlalchemy_mocks.py new file mode 100644 index 0000000..e2b5201 --- /dev/null +++ b/tests/storage/sqlalchemy_mocks.py @@ -0,0 +1,153 @@ +import json +import re +import time + +class MockResult(list): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def fetchone(self): + if len(self) > 0: + retval = self[0] + del self[0] + return retval + + def fetchall(self): + new_list = list(self) + while len(self) > 0: + del self[0] + return new_list + + def scalar(self): + if len(self) > 0: + value = self[0][0] + while len(self) > 0: + del self[0] + try: + return int(value) + except: + return str(value) if value else None + else: + return None + +class MockStmt: + def __init__(self, sql: str, params, result, wait_for_secs: int = 0): + self.sql = re.sub(r'\s+', ' ', sql).strip() + self.params = params + self.result = result if not result or isinstance(result, Exception) else MockResult(result) + self.wait_for_secs = wait_for_secs + + def matches(self, sql: str, params): + norm_sql = re.sub(r'\s+', ' ', sql).strip() + if norm_sql == self.sql: + if self.params == True: + return True + elif (self.params is None or len(self.params) == 0) and (params is None or len(params) == 0): + return True + elif self.params == params: + return True + return False + +class MockStmtSequence(list): + def __init__(self): + self.execd = [] + + def add_stmt(self, sql: str, params, result, wait_for_secs: int = 0): + self.append(MockStmt(sql, params, result, wait_for_secs)) + return self + + def add_copy_from(self, table_name, data): + self.append(MockStmt(f"COPY FROM fd INTO {table_name}", data, [(1,)])) + return self + + def pop_if_matches(self, sql: str, params) -> list: + if len(self) > 0 and self[0].matches(sql, params): + self.execd.append(self[0]) + wait_for_secs = self[0].wait_for_secs + result = self[0].result + del self[0] + if wait_for_secs > 0: + time.sleep(wait_for_secs) + if isinstance(result, Exception): + raise result + else: + return result + return None + + def were_all_execd(self): + return len(self) <= 0 + +class MockConnection: + def __init__(self): + self.mock_stmt_sequences = [] + self.connection = self + self.last_result = None + + def close(self): + pass + + def cursor(self): + return self + + def fetchall(self): + return self.last_result.fetchall() if self.last_result else None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback) -> bool: + if exc_value: + raise exc_value + return True + + def begin(self): + return self + + def execute(self, cmd, params = None): + for seq in self.mock_stmt_sequences: + result = seq.pop_if_matches(str(cmd), params) + if result: + self.last_result = result + return result + + norm_sql = re.sub(r'\s+', ' ', str(cmd)).strip() + raise Exception(f"no mock statement found for {norm_sql} with params {params}") + + def copy_from(self, fd, table_name, columns=None): + data_as_str = str(fd.read()) + for seq in self.mock_stmt_sequences: + result = seq.pop_if_matches(f"COPY FROM fd INTO {table_name}", data_as_str) + if result: + self.last_result = result + return result + + raise Exception(f"no mock statement found for COPY FROM fd INTO {table_name} with data {data_as_str[:10]}...") + +class MockEngine: + def __init__(self): + self.conn = MockConnection() + + def connect(self): + return self.conn + + def dispose(self): + pass + + def add_mock_stmt_sequence(self, stmt_seq: MockStmtSequence): + self.conn.mock_stmt_sequences.append(stmt_seq) + + def were_all_execd(self): + for seq in self.conn.mock_stmt_sequences: + if not seq.were_all_execd(): + return False + return True + +class MockVRSObject: + def __init__(self, id: str): + self.id = id + + def model_dump(self, exclude_none: bool): + return { "id": self.id } + + def to_json(self): + return json.dumps(self.model_dump(exclude_none=True)) diff --git a/tests/storage/test_postgres.py b/tests/storage/test_postgres.py new file mode 100644 index 0000000..d5909e9 --- /dev/null +++ b/tests/storage/test_postgres.py @@ -0,0 +1,210 @@ +""" +Test Postgres specific storage integration methods +and the async batch insertion + +Uses mocks for database integration +""" +import os +from sqlalchemy_mocks import MockEngine, MockStmtSequence, MockVRSObject + +from anyvar.restapi.schema import VariationStatisticType +from anyvar.storage.postgres import PostgresObjectStore + +vrs_object_table_name = os.environ.get("ANYVAR_SQL_STORE_TABLE_NAME", "vrs_objects") + +def test_create_schema(mocker): + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt(f"SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = '{vrs_object_table_name}')", None, [(False,)]) + .add_stmt(f"CREATE TABLE {vrs_object_table_name} ( vrs_id TEXT PRIMARY KEY, vrs_object JSONB )", None, [("Table created",)]) + ) + sf = PostgresObjectStore("postgres://account/?param=value") + sf.close() + assert mock_eng.were_all_execd() + +def test_create_schema_exists(mocker): + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt(f"SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = '{vrs_object_table_name}')", None, [(True,)]) + ) + sf = PostgresObjectStore("postgres://account/?param=value") + sf.close() + assert mock_eng.were_all_execd() + +def test_add_one_item(mocker): + mocker.patch("ga4gh.core.is_pydantic_instance", return_value=True) + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt(f"SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = '{vrs_object_table_name}')", None, [(True,)]) + .add_stmt( + f""" + INSERT INTO {vrs_object_table_name} (vrs_id, vrs_object) VALUES (:vrs_id, :vrs_object) ON CONFLICT DO NOTHING + """, + {"vrs_id": "ga4gh:VA.01", "vrs_object": MockVRSObject('01').to_json()}, [(1,)]) + ) + sf = PostgresObjectStore("postgres://account/?param=value") + sf["ga4gh:VA.01"] = MockVRSObject('01') + sf.close() + assert mock_eng.were_all_execd() + +def test_add_many_items(mocker): + tmp_statement = "CREATE TEMP TABLE tmp_table (LIKE vrs_objects2 INCLUDING DEFAULTS)" + insert_statement = "INSERT INTO vrs_objects2 SELECT * FROM tmp_table ON CONFLICT DO NOTHING" + drop_statement = "DROP TABLE tmp_table" + + vrs_id_object_pairs = [ + ("ga4gh:VA.01", MockVRSObject('01')), + ("ga4gh:VA.02", MockVRSObject('02')), + ("ga4gh:VA.03", MockVRSObject('03')), + ("ga4gh:VA.04", MockVRSObject('04')), + ("ga4gh:VA.05", MockVRSObject('05')), + ("ga4gh:VA.06", MockVRSObject('06')), + ("ga4gh:VA.07", MockVRSObject('07')), + ("ga4gh:VA.08", MockVRSObject('08')), + ("ga4gh:VA.09", MockVRSObject('09')), + ("ga4gh:VA.10", MockVRSObject('10')), + ("ga4gh:VA.11", MockVRSObject('11')), + ] + + mocker.patch("ga4gh.core.is_pydantic_instance", return_value=True) + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt("SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = 'vrs_objects2')", None, [(True,)]) + # Batch 1 + .add_stmt(tmp_statement, None, [("Table created",)]) + .add_copy_from("tmp_table", "\n".join([f"{pair[0]}\t{pair[1].to_json()}" for pair in vrs_id_object_pairs[0:2]])) + .add_stmt(insert_statement, None, [(2,)], 5) + .add_stmt(drop_statement, None, [("Table dropped",)]) + # Batch 2 + .add_stmt(tmp_statement, None, [("Table created",)]) + .add_copy_from("tmp_table", "\n".join([f"{pair[0]}\t{pair[1].to_json()}" for pair in vrs_id_object_pairs[2:4]])) + .add_stmt(insert_statement, None, [(2,)], 4) + .add_stmt(drop_statement, None, [("Table dropped",)]) + # Batch 3 + .add_stmt(tmp_statement, None, [("Table created",)]) + .add_copy_from("tmp_table", "\n".join([f"{pair[0]}\t{pair[1].to_json()}" for pair in vrs_id_object_pairs[4:6]])) + .add_stmt(insert_statement, None, [(2,)], 3) + .add_stmt(drop_statement, None, [("Table dropped",)]) + # Batch 4 + .add_stmt(tmp_statement, None, [("Table created",)]) + .add_copy_from("tmp_table", "\n".join([f"{pair[0]}\t{pair[1].to_json()}" for pair in vrs_id_object_pairs[6:8]])) + .add_stmt(insert_statement, None, [(2,)], 5) + .add_stmt(drop_statement, None, [("Table dropped",)]) + # Batch 5 + .add_stmt(tmp_statement, None, [("Table created",)]) + .add_copy_from("tmp_table", "\n".join([f"{pair[0]}\t{pair[1].to_json()}" for pair in vrs_id_object_pairs[8:10]])) + .add_stmt(insert_statement, None, Exception("query timeout")) + # Batch 6 + .add_stmt(tmp_statement, None, [("Table created",)]) + .add_copy_from("tmp_table", "\n".join([f"{pair[0]}\t{pair[1].to_json()}" for pair in vrs_id_object_pairs[10:11]])) + .add_stmt(insert_statement, None, [(2,)], 2) + .add_stmt(drop_statement, None, [("Table dropped",)]) + ) + + sf = PostgresObjectStore("postgres://account/?param=value", 2, "vrs_objects2", 4, False) + with sf.batch_manager(sf): + sf.wait_for_writes() + assert sf.num_pending_batches() == 0 + sf[vrs_id_object_pairs[0][0]] = vrs_id_object_pairs[0][1] + sf[vrs_id_object_pairs[1][0]] = vrs_id_object_pairs[1][1] + assert sf.num_pending_batches() > 0 + sf.wait_for_writes() + assert sf.num_pending_batches() == 0 + sf[vrs_id_object_pairs[2][0]] = vrs_id_object_pairs[2][1] + sf[vrs_id_object_pairs[3][0]] = vrs_id_object_pairs[3][1] + sf[vrs_id_object_pairs[4][0]] = vrs_id_object_pairs[4][1] + sf[vrs_id_object_pairs[5][0]] = vrs_id_object_pairs[5][1] + sf[vrs_id_object_pairs[6][0]] = vrs_id_object_pairs[6][1] + sf[vrs_id_object_pairs[7][0]] = vrs_id_object_pairs[7][1] + sf[vrs_id_object_pairs[8][0]] = vrs_id_object_pairs[8][1] + sf[vrs_id_object_pairs[9][0]] = vrs_id_object_pairs[9][1] + sf[vrs_id_object_pairs[10][0]] = vrs_id_object_pairs[10][1] + + assert sf.num_pending_batches() > 0 + sf.close() + assert sf.num_pending_batches() == 0 + assert mock_eng.return_value.were_all_execd() + +def test_insertion_count(mocker): + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt(f"SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = '{vrs_object_table_name}')", None, [(True,)]) + .add_stmt( + f""" + SELECT COUNT(*) AS c + FROM {vrs_object_table_name} + WHERE LENGTH(vrs_object -> 'state' ->> 'sequence') > 1 + """, + None, [(12,)]) + ) + sf = PostgresObjectStore("postgres://account/?param=value") + assert sf.get_variation_count(VariationStatisticType.INSERTION) == 12 + sf.close() + assert mock_eng.were_all_execd() + +def test_substitution_count(mocker): + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt(f"SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = '{vrs_object_table_name}')", None, [(True,)]) + .add_stmt( + f""" + SELECT COUNT(*) AS c + FROM {vrs_object_table_name} + WHERE LENGTH(vrs_object -> 'state' ->> 'sequence') = 1 + """, + None, [(13,)]) + ) + sf = PostgresObjectStore("postgres://account/?param=value") + assert sf.get_variation_count(VariationStatisticType.SUBSTITUTION) == 13 + sf.close() + assert mock_eng.were_all_execd() + +def test_deletion_count(mocker): + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt(f"SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = '{vrs_object_table_name}')", None, [(True,)]) + .add_stmt( + f""" + SELECT COUNT(*) AS c + FROM {vrs_object_table_name} + WHERE LENGTH(vrs_object -> 'state' ->> 'sequence') = 0 + """, + None, [(14,)]) + ) + sf = PostgresObjectStore("postgres://account/?param=value") + assert sf.get_variation_count(VariationStatisticType.DELETION) == 14 + sf.close() + assert mock_eng.were_all_execd() + +def test_search_vrs_objects(mocker): + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt(f"SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = '{vrs_object_table_name}')", None, [(True,)]) + .add_stmt( + f""" + SELECT vrs_object + FROM {vrs_object_table_name} + WHERE vrs_object->>'type' = %s + AND vrs_object->>'location' IN ( + SELECT vrs_id FROM {vrs_object_table_name} + WHERE CAST (vrs_object->>'start' AS INTEGER) >= %s + AND CAST (vrs_object->>'end' AS INTEGER) <= %s + AND vrs_object->'sequenceReference'->>'refgetAccession' = %s) + """, + [ "Allele", 123456, 123457, "MySQAccId" ], [({"id": 1},), ({"id": 2},)]) + ) + sf = PostgresObjectStore("postgres://account/?param=value") + vars = sf.search_variations("MySQAccId", 123456, 123457) + sf.close() + assert len(vars) == 2 + assert "id" in vars[0] and vars[0]["id"] == 1 + assert "id" in vars[1] and vars[1]["id"] == 2 + assert mock_eng.were_all_execd() \ No newline at end of file diff --git a/tests/storage/test_snowflake.py b/tests/storage/test_snowflake.py index d6123e9..49d5c23 100644 --- a/tests/storage/test_snowflake.py +++ b/tests/storage/test_snowflake.py @@ -1,160 +1,82 @@ """ -Test Snowflake storage async batch management and write feature +Test Snowflake specific storage integration methods +and the async batch insertion -To run an integration test with a real Snowflake connection, run all - the tests with the ANYVAR_TEST_STORAGE_URI env variable set to - a Snowflake URI +Uses mocks for database integration """ -import json -import logging -import re -import time - -from anyvar.storage.snowflake import SnowflakeObjectStore - -class MockStmt: - def __init__(self, sql: str, params: list, result: list, wait_for_secs: int = 0): - self.sql = re.sub(r'\s+', ' ', sql).strip() - self.params = params - self.result = result - self.wait_for_secs = wait_for_secs - - def matches(self, sql: str, params: list): - norm_sql = re.sub(r'\s+', ' ', sql).strip() - if norm_sql == self.sql: - if self.params == True: - return True - elif (self.params is None or len(self.params) == 0) and (params is None or len(params) == 0): - return True - elif self.params == params: - return True - return False - -class MockStmtSequence(list): - def __init__(self): - self.execd = [] - - def add_stmt(self, sql: str, params: list, result: list, wait_for_secs: int = 0): - self.append(MockStmt(sql, params, result, wait_for_secs)) - return self - - def pop_if_matches(self, sql: str, params: list) -> list: - if len(self) > 0 and self[0].matches(sql, params): - self.execd.append(self[0]) - wait_for_secs = self[0].wait_for_secs - result = self[0].result - del self[0] - if wait_for_secs > 0: - time.sleep(wait_for_secs) - if isinstance(result, Exception): - raise result - else: - return result - return None - - def were_all_execd(self): - return len(self) <= 0 - -class MockConnectionCursor: - def __init__(self): - self.mock_stmt_sequences = [] - self.current_result = None - self.closed = False - - def close(self): - self.closed = True - - def cursor(self): - if self.closed: - raise Exception("connection is closed") - return self - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback) -> bool: - if exc_value: - raise exc_value - return True - - def execute(self, cmd: str, params: list = None): - if self.closed: - raise Exception("connection is closed") - self.current_result = None - for seq in self.mock_stmt_sequences: - result = seq.pop_if_matches(cmd, params) - if result: - self.current_result = result - if not self.current_result: - raise Exception(f"no mock statement found for {cmd} with params {params}") - - def executemany(self, cmd: str, params: list = None): - self.execute(cmd, params) - - def fetchone(self): - if self.closed: - raise Exception("connection is closed") - return self.current_result[0] if self.current_result and len(self.current_result) > 0 else None - - def commit(self): - self.execute("COMMIT;", None) - - def rollback(self): - self.execute("ROLLBACK;", None) - - def add_mock_stmt_sequence(self, stmt_seq: MockStmtSequence): - self.mock_stmt_sequences.append(stmt_seq) - - def were_all_execd(self): - for seq in self.mock_stmt_sequences: - if not seq.were_all_execd(): - return False - return True - -class MockVRSObject: - def __init__(self, id: str): - self.id = id - - def model_dump(self, exclude_none: bool): - return { "id": self.id } - - def to_json(self): - return json.dumps(self.model_dump(exclude_none=True)) - -def test_create_schema(caplog, mocker): - caplog.set_level(logging.DEBUG) - sf_conn = mocker.patch('snowflake.connector.connect') - sf_conn.return_value = MockConnectionCursor() - sf_conn.return_value.add_mock_stmt_sequence(MockStmtSequence() - .add_stmt("SELECT COUNT(*) FROM information_schema.tables WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() AND UPPER(table_name) = UPPER('vrs_objects');", None, [(0,)]) - .add_stmt("CREATE TABLE vrs_objects ( vrs_id VARCHAR(500) PRIMARY KEY COLLATE 'utf8', vrs_object VARIANT );", None, [("Table created",)]) - ) +import os +from sqlalchemy_mocks import MockEngine, MockStmtSequence, MockVRSObject - sf = SnowflakeObjectStore("snowflake://account/?param=value") - sf.close() - assert sf_conn.return_value.were_all_execd() +from anyvar.restapi.schema import VariationStatisticType +from anyvar.storage.snowflake import SnowflakeObjectStore, SnowflakeBatchAddMode +vrs_object_table_name = os.environ.get("ANYVAR_SQL_STORE_TABLE_NAME", "vrs_objects") -def test_schema_exists(mocker): - sf_conn = mocker.patch('snowflake.connector.connect') - sf_conn.return_value = MockConnectionCursor() - sf_conn.return_value.add_mock_stmt_sequence(MockStmtSequence() - .add_stmt("SELECT COUNT(*) FROM information_schema.tables WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() AND UPPER(table_name) = UPPER('vrs_objects');", None, [(1,)]) +def test_create_schema(mocker): + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt( + f""" + SELECT COUNT(*) FROM information_schema.tables + WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() + AND UPPER(table_name) = UPPER('{vrs_object_table_name}') + """, + None, [(0,)]) + .add_stmt(f"CREATE TABLE {vrs_object_table_name} ( vrs_id VARCHAR(500) PRIMARY KEY COLLATE 'utf8', vrs_object VARIANT )", None, [("Table created",)]) ) + sf = SnowflakeObjectStore("snowflake://account/?param=value") + sf.close() + assert mock_eng.were_all_execd() +def test_create_schema_exists(mocker): + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt( + f""" + SELECT COUNT(*) FROM information_schema.tables + WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() + AND UPPER(table_name) = UPPER('{vrs_object_table_name}') + """, + None, [(1,)]) + ) sf = SnowflakeObjectStore("snowflake://account/?param=value") sf.close() - assert sf_conn.return_value.were_all_execd() + assert mock_eng.were_all_execd() +def test_add_one_item(mocker): + mocker.patch("ga4gh.core.is_pydantic_instance", return_value=True) + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt( + f""" + SELECT COUNT(*) FROM information_schema.tables + WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() + AND UPPER(table_name) = UPPER('{vrs_object_table_name}') + """, + None, [(1,)]) + .add_stmt( + f""" + MERGE INTO {vrs_object_table_name} t USING (SELECT ? AS vrs_id, ? AS vrs_object) s ON t.vrs_id = s.vrs_id + WHEN NOT MATCHED THEN INSERT (vrs_id, vrs_object) VALUES (s.vrs_id, PARSE_JSON(s.vrs_object)) + """, + ("ga4gh:VA.01", MockVRSObject('01').to_json()), [(1,)]) + ) + sf = SnowflakeObjectStore("snowflake://account/?param=value") + sf["ga4gh:VA.01"] = MockVRSObject('01') + sf.close() + assert mock_eng.were_all_execd() -def test_batch_mgmt_and_async_write_single_thread(mocker): - tmp_statement = "CREATE TEMP TABLE IF NOT EXISTS tmp_vrs_objects (vrs_id VARCHAR(500) COLLATE 'utf8', vrs_object VARCHAR);" - insert_statement = "INSERT INTO tmp_vrs_objects (vrs_id, vrs_object) VALUES (?, ?);" +def test_add_many_items(mocker): + tmp_statement = "CREATE TEMP TABLE IF NOT EXISTS tmp_vrs_objects (vrs_id VARCHAR(500) COLLATE 'utf8', vrs_object VARCHAR)" + insert_statement = "INSERT INTO tmp_vrs_objects (vrs_id, vrs_object) VALUES (?, ?)" merge_statement = f""" MERGE INTO vrs_objects2 v USING tmp_vrs_objects s ON v.vrs_id = s.vrs_id - WHEN NOT MATCHED THEN INSERT (vrs_id, vrs_object) VALUES (s.vrs_id, PARSE_JSON(s.vrs_object)); + WHEN NOT MATCHED THEN INSERT (vrs_id, vrs_object) VALUES (s.vrs_id, PARSE_JSON(s.vrs_object)) """ - drop_statement = "DROP TABLE tmp_vrs_objects;" + drop_statement = "DROP TABLE tmp_vrs_objects" vrs_id_object_pairs = [ ("ga4gh:VA.01", MockVRSObject('01')), @@ -170,54 +92,43 @@ def test_batch_mgmt_and_async_write_single_thread(mocker): ("ga4gh:VA.11", MockVRSObject('11')), ] - mocker.patch('ga4gh.core.is_pydantic_instance', return_value=True) - sf_conn = mocker.patch('snowflake.connector.connect') - sf_conn.return_value = MockConnectionCursor() - sf_conn.return_value.add_mock_stmt_sequence(MockStmtSequence() - .add_stmt("SELECT COUNT(*) FROM information_schema.tables WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() AND UPPER(table_name) = UPPER('vrs_objects2');", None, [(1,)]) + mocker.patch("ga4gh.core.is_pydantic_instance", return_value=True) + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt("SELECT COUNT(*) FROM information_schema.tables WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() AND UPPER(table_name) = UPPER('vrs_objects2')", None, [(1,)]) # Batch 1 .add_stmt(tmp_statement, None, [("Table created",)]) .add_stmt(insert_statement, list((pair[0], pair[1].to_json()) for pair in vrs_id_object_pairs[0:2]), [(2,)], 5) .add_stmt(merge_statement, None, [(2,)]) .add_stmt(drop_statement, None, [("Table dropped",)]) - .add_stmt("COMMIT;", None, [("Committed", )]) - .add_stmt("ROLLBACK;", None, [("Rolled back", )]) # Batch 2 .add_stmt(tmp_statement, None, [("Table created",)]) .add_stmt(insert_statement, list((pair[0], pair[1].to_json()) for pair in vrs_id_object_pairs[2:4]), [(2,)], 4) .add_stmt(merge_statement, None, [(2,)]) .add_stmt(drop_statement, None, [("Table dropped",)]) - .add_stmt("COMMIT;", None, [("Committed", )]) - .add_stmt("ROLLBACK;", None, [("Rolled back", )]) # Batch 3 .add_stmt(tmp_statement, None, [("Table created",)]) .add_stmt(insert_statement, list((pair[0], pair[1].to_json()) for pair in vrs_id_object_pairs[4:6]), [(2,)], 3) .add_stmt(merge_statement, None, [(2,)]) .add_stmt(drop_statement, None, [("Table dropped",)]) - .add_stmt("COMMIT;", None, [("Committed", )]) - .add_stmt("ROLLBACK;", None, [("Rolled back", )]) # Batch 4 .add_stmt(tmp_statement, None, [("Table created",)]) .add_stmt(insert_statement, list((pair[0], pair[1].to_json()) for pair in vrs_id_object_pairs[6:8]), [(2,)], 5) .add_stmt(merge_statement, None, [(2,)]) .add_stmt(drop_statement, None, [("Table dropped",)]) - .add_stmt("COMMIT;", None, [("Committed", )]) - .add_stmt("ROLLBACK;", None, [("Rolled back", )]) # Batch 5 .add_stmt(tmp_statement, None, [("Table created",)]) .add_stmt(insert_statement, list((pair[0], pair[1].to_json()) for pair in vrs_id_object_pairs[8:10]), [(2,)], 3) .add_stmt(merge_statement, None, Exception("query timeout")) - .add_stmt("ROLLBACK;", None, [("Rolled back", )]) # Batch 6 .add_stmt(tmp_statement, None, [("Table created",)]) .add_stmt(insert_statement, list((pair[0], pair[1].to_json()) for pair in vrs_id_object_pairs[10:11]), [(2,)], 2) .add_stmt(merge_statement, None, [(2,)]) .add_stmt(drop_statement, None, [("Table dropped",)]) - .add_stmt("COMMIT;", None, [("Committed", )]) - .add_stmt("ROLLBACK;", None, [("Rolled back", )]) ) - sf = SnowflakeObjectStore("snowflake://account/?param=value", 2, "vrs_objects2", 4) + sf = SnowflakeObjectStore("snowflake://account/?param=value", 2, "vrs_objects2", 4, False, SnowflakeBatchAddMode.merge) with sf.batch_manager(sf): sf.wait_for_writes() assert sf.num_pending_batches() == 0 @@ -239,4 +150,179 @@ def test_batch_mgmt_and_async_write_single_thread(mocker): assert sf.num_pending_batches() > 0 sf.close() assert sf.num_pending_batches() == 0 - assert sf_conn.return_value.were_all_execd() \ No newline at end of file + assert mock_eng.return_value.were_all_execd() + +def test_batch_add_mode_insert_notin(mocker): + tmp_statement = "CREATE TEMP TABLE IF NOT EXISTS tmp_vrs_objects (vrs_id VARCHAR(500) COLLATE 'utf8', vrs_object VARCHAR)" + insert_statement = "INSERT INTO tmp_vrs_objects (vrs_id, vrs_object) VALUES (?, ?)" + merge_statement = f""" + INSERT INTO vrs_objects2 (vrs_id, vrs_object) + SELECT t.vrs_id, PARSE_JSON(t.vrs_object) + FROM tmp_vrs_objects t + LEFT OUTER JOIN vrs_objects2 v ON v.vrs_id = t.vrs_id + WHERE v.vrs_id IS NULL + """ + drop_statement = "DROP TABLE tmp_vrs_objects" + + vrs_id_object_pairs = [ + ("ga4gh:VA.01", MockVRSObject('01')), + ("ga4gh:VA.02", MockVRSObject('02')), + ] + + mocker.patch("ga4gh.core.is_pydantic_instance", return_value=True) + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt("SELECT COUNT(*) FROM information_schema.tables WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() AND UPPER(table_name) = UPPER('vrs_objects2')", None, [(1,)]) + # Batch 1 + .add_stmt(tmp_statement, None, [("Table created",)]) + .add_stmt(insert_statement, list((pair[0], pair[1].to_json()) for pair in vrs_id_object_pairs[0:2]), [(2,)], 5) + .add_stmt(merge_statement, None, [(2,)]) + .add_stmt(drop_statement, None, [("Table dropped",)]) + ) + + sf = SnowflakeObjectStore("snowflake://account/?param=value", 2, "vrs_objects2", None, None, SnowflakeBatchAddMode.insert_notin) + with sf.batch_manager(sf): + sf[vrs_id_object_pairs[0][0]] = vrs_id_object_pairs[0][1] + sf[vrs_id_object_pairs[1][0]] = vrs_id_object_pairs[1][1] + + sf.close() + assert mock_eng.return_value.were_all_execd() + +def test_batch_add_mode_insert(mocker): + tmp_statement = "CREATE TEMP TABLE IF NOT EXISTS tmp_vrs_objects (vrs_id VARCHAR(500) COLLATE 'utf8', vrs_object VARCHAR)" + insert_statement = "INSERT INTO tmp_vrs_objects (vrs_id, vrs_object) VALUES (?, ?)" + merge_statement = f""" + INSERT INTO vrs_objects2 (vrs_id, vrs_object) + SELECT vrs_id, PARSE_JSON(vrs_object) FROM tmp_vrs_objects + """ + drop_statement = "DROP TABLE tmp_vrs_objects" + + vrs_id_object_pairs = [ + ("ga4gh:VA.01", MockVRSObject('01')), + ("ga4gh:VA.02", MockVRSObject('02')), + ] + + mocker.patch("ga4gh.core.is_pydantic_instance", return_value=True) + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt("SELECT COUNT(*) FROM information_schema.tables WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() AND UPPER(table_name) = UPPER('vrs_objects2')", None, [(1,)]) + # Batch 1 + .add_stmt(tmp_statement, None, [("Table created",)]) + .add_stmt(insert_statement, list((pair[0], pair[1].to_json()) for pair in vrs_id_object_pairs[0:2]), [(2,)], 5) + .add_stmt(merge_statement, None, [(2,)]) + .add_stmt(drop_statement, None, [("Table dropped",)]) + ) + + sf = SnowflakeObjectStore("snowflake://account/?param=value", 2, "vrs_objects2", None, None, SnowflakeBatchAddMode.insert) + with sf.batch_manager(sf): + sf[vrs_id_object_pairs[0][0]] = vrs_id_object_pairs[0][1] + sf[vrs_id_object_pairs[1][0]] = vrs_id_object_pairs[1][1] + + sf.close() + assert mock_eng.return_value.were_all_execd() + +def test_insertion_count(mocker): + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt( + f""" + SELECT COUNT(*) FROM information_schema.tables + WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() + AND UPPER(table_name) = UPPER('{vrs_object_table_name}') + """, + None, [(1,)]) + .add_stmt( + f""" + SELECT COUNT(*) + FROM {vrs_object_table_name} + WHERE LENGTH(vrs_object:state:sequence) > 1 + """, + None, [(12,)]) + ) + sf = SnowflakeObjectStore("snowflake://account/?param=value") + assert sf.get_variation_count(VariationStatisticType.INSERTION) == 12 + sf.close() + assert mock_eng.were_all_execd() + +def test_substitution_count(mocker): + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt( + f""" + SELECT COUNT(*) FROM information_schema.tables + WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() + AND UPPER(table_name) = UPPER('{vrs_object_table_name}') + """, + None, [(1,)]) + .add_stmt( + f""" + SELECT COUNT(*) + FROM {vrs_object_table_name} + WHERE LENGTH(vrs_object:state:sequence) = 1 + """, + None, [(13,)]) + ) + sf = SnowflakeObjectStore("snowflake://account/?param=value") + assert sf.get_variation_count(VariationStatisticType.SUBSTITUTION) == 13 + sf.close() + assert mock_eng.were_all_execd() + +def test_deletion_count(mocker): + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt( + f""" + SELECT COUNT(*) FROM information_schema.tables + WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() + AND UPPER(table_name) = UPPER('{vrs_object_table_name}') + """, + None, [(1,)]) + .add_stmt( + f""" + SELECT COUNT(*) + FROM {vrs_object_table_name} + WHERE LENGTH(vrs_object:state:sequence) = 0 + """, + None, [(14,)]) + ) + sf = SnowflakeObjectStore("snowflake://account/?param=value") + assert sf.get_variation_count(VariationStatisticType.DELETION) == 14 + sf.close() + assert mock_eng.were_all_execd() + +def test_search_vrs_objects(mocker): + mock_eng = mocker.patch("anyvar.storage.sql_storage.create_engine") + mock_eng.return_value = MockEngine() + mock_eng.return_value.add_mock_stmt_sequence(MockStmtSequence() + .add_stmt( + f""" + SELECT COUNT(*) FROM information_schema.tables + WHERE table_catalog = CURRENT_DATABASE() AND table_schema = CURRENT_SCHEMA() + AND UPPER(table_name) = UPPER('{vrs_object_table_name}') + """, + None, [(1,)]) + .add_stmt( + f""" + SELECT vrs_object + FROM {vrs_object_table_name} + WHERE vrs_object:type = ? + AND vrs_object:location IN ( + SELECT vrs_id FROM {vrs_object_table_name} + WHERE vrs_object:start::INTEGER >= ? + AND vrs_object:end::INTEGER <= ? + AND vrs_object:sequenceReference:refgetAccession = ?) + """, + ("Allele", 123456, 123457, "MySQAccId"), [("{\"id\": 1}",), ("{\"id\": 2}",)]) + ) + sf = SnowflakeObjectStore("snowflake://account/?param=value") + vars = sf.search_variations("MySQAccId", 123456, 123457) + sf.close() + assert len(vars) == 2 + assert "id" in vars[0] and vars[0]["id"] == 1 + assert "id" in vars[1] and vars[1]["id"] == 2 + assert mock_eng.were_all_execd() \ No newline at end of file diff --git a/tests/storage/test_storage_mapping.py b/tests/storage/test_sql_storage_mapping.py similarity index 61% rename from tests/storage/test_storage_mapping.py rename to tests/storage/test_sql_storage_mapping.py index b8354a3..7c6c421 100644 --- a/tests/storage/test_storage_mapping.py +++ b/tests/storage/test_sql_storage_mapping.py @@ -1,6 +1,11 @@ -"""Tests the mutable mapping API of the storage backend""" +""" +Tests the SqlStorage methods that are NOT tested through the +REST API tests. To test against different SQL backends, this +test must be run with different ANYVAR_TEST_STORAGE_URI settings +and different ANYVAR_SQL_STORE_BATCH_ADD_MODE settings +""" from ga4gh.vrs import vrs_enref - +from anyvar.storage.snowflake import SnowflakeObjectStore, SnowflakeBatchAddMode from anyvar.translate.vrs_python import VrsPythonTranslator # pause for 5 seconds because Snowflake storage is an async write and @@ -11,12 +16,12 @@ def test_waitforsync(): # __getitem__ def test_getitem(storage, alleles): - for allele_id, allele in alleles.items(): + for allele_id, _ in alleles.items(): assert storage[allele_id] is not None # __contains__ def test_contains(storage, alleles): - for allele_id, allele in alleles.items(): + for allele_id, _ in alleles.items(): assert allele_id in storage # __len__ @@ -24,31 +29,31 @@ def test_len(storage): assert len(storage) > 0 # __iter__ -def test_iter(storage, alleles): +def test_iter(storage): obj_iter = iter(storage) count = 0 while True: try: - obj = next(obj_iter) + next(obj_iter) count += 1 except StopIteration: break - assert count == 14 + assert count == (18 if isinstance(storage, SnowflakeObjectStore) and storage.batch_add_mode == SnowflakeBatchAddMode.insert else 14) # keys def test_keys(storage, alleles): key_list = storage.keys() - for allele_id, allele in alleles.items(): + for allele_id, _ in alleles.items(): assert allele_id in key_list # __delitem__ def test_delitem(storage, alleles): - for allele_id, allele in alleles.items(): + for allele_id, _ in alleles.items(): del storage[allele_id] # __setitem__ def test_setitem(storage, alleles): - for allele_id, allele in alleles.items(): + for _, allele in alleles.items(): variation = allele["params"] definition = variation["definition"] translated_variation = VrsPythonTranslator().translate_variation(definition)