diff --git a/nifi/user-scripts/annotation_manager.py b/nifi/user-scripts/annotation_manager.py index 1048b8a9..bede97ee 100644 --- a/nifi/user-scripts/annotation_manager.py +++ b/nifi/user-scripts/annotation_manager.py @@ -2,7 +2,8 @@ import json import traceback import sys -from utils.sqlite_query import connect_and_query,check_db_exists,create_db_from_file +import sqlite3 +from utils.sqlite_query import connect_and_query,check_db_exists,create_db_from_file,create_connection global DOCUMENT_ID_FIELD_NAME global DOCUMENT_TEXT_FIELD_NAME @@ -70,12 +71,16 @@ def main(): inserted_doc_ids = [] + _sqlite_connection_ro = None + if OPERATION_MODE == "check": + _sqlite_connection_ro = create_connection(db_file_path, read_only_mode=True) + + for record in records: if OPERATION_MODE == "check": document_id = str(record[DOCUMENT_ID_FIELD_NAME]) - - query = "SELECT * FROM annotations WHERE elasticsearch_id LIKE '%" + document_id + "%' LIMIT 1" - result = connect_and_query(query, db_file_path) + query = "SELECT id, elasticsearch_id FROM annotations WHERE elasticsearch_id LIKE '%" + document_id + "%' LIMIT 1" + result = connect_and_query(query, db_file_path, sqlite_connection=_sqlite_connection_ro) if len(result) < 1: output_stream["content"].append(record) diff --git a/nifi/user-scripts/utils/sqlite_query.py b/nifi/user-scripts/utils/sqlite_query.py index 64e427dd..0354f26e 100644 --- a/nifi/user-scripts/utils/sqlite_query.py +++ b/nifi/user-scripts/utils/sqlite_query.py @@ -1,17 +1,30 @@ +from ast import List import sqlite3 -def connect_and_query(query: str, db_file_path: str, sql_script_mode: bool = False): - """ - Executes whatever query. + +def connect_and_query(query: str, db_file_path: str, sqlite_connection: sqlite3.Connection, sql_script_mode: bool = False) -> List: + """ Executes whatever query. Args: - query (string): your SQL query. - """ + query (str): your SQL query. + db_file_path (str): file path to sqlite db + sql_script_mode (bool, optional): if it is transactional or just a fetch query . Defaults to False. + + Raises: + sqlite3.Error: sqlite error. + + Returns: + List: List of results + """ + result = [] - sqlite_connection = None try: - sqlite_connection = sqlite3.connect(db_file_path) + if sqlite_connection: + sqlite_connection = sqlite_connection + else: + sqlite_connection = create_connection(db_file_path) + cursor = sqlite_connection.cursor() if not sql_script_mode: cursor.execute(query) @@ -28,12 +41,36 @@ def connect_and_query(query: str, db_file_path: str, sql_script_mode: bool = Fal return result + +def create_connection(db_file_path: str, read_only_mode=False) -> sqlite3.Connection: + + connection_str = "file:/" + str(db_file_path) + + if read_only_mode: + connection_str += "?mode=ro" + + return sqlite3.connect(connection_str) + + +def query_with_connection(query: str, sqlite_connection: sqlite3.Connection) -> List: + result = [] + try: + cursor = sqlite_connection.cursor() + cursor.execute(query) + result = cursor.fetchall() + cursor.close() + except sqlite3.Error as error: + raise sqlite3.Error(error) + return result + + def check_db_exists(table_name: str, db_file_path: str): query = "PRAGMA table_info(" + table_name + ");" return connect_and_query(query=query, db_file_path=db_file_path) + def create_db_from_file(sqlite_file_path: str, db_file_path: str) -> sqlite3.Cursor: - """ + """ Creates db from .sqlite schema/query file Args: sqlite_file_path (str): sqlite db folder @@ -46,4 +83,4 @@ def create_db_from_file(sqlite_file_path: str, db_file_path: str) -> sqlite3.Cur with open(sqlite_file_path, mode="r") as sql_file: query = sql_file.read() return connect_and_query(query=query, db_file_path=db_file_path, - sql_script_mode=True) + sql_script_mode=True)