From b3d6e9345ab0692fa77a3acaf429f4e35fcc5602 Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Mon, 14 Dec 2020 19:12:04 -0800 Subject: [PATCH 001/294] Updated fin models --- wbia/algo/detect/densenet.py | 3 +-- wbia/algo/detect/lightnet.py | 5 +++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/wbia/algo/detect/densenet.py b/wbia/algo/detect/densenet.py index f49cc88663..e80ccd333e 100644 --- a/wbia/algo/detect/densenet.py +++ b/wbia/algo/detect/densenet.py @@ -47,8 +47,7 @@ 'humpback_dorsal': 'https://wildbookiarepository.azureedge.net/models/labeler.whale_humpback.dorsal.v0.zip', 'orca_v0': 'https://wildbookiarepository.azureedge.net/models/labeler.whale_orca.v0.zip', 'fins_v0': 'https://wildbookiarepository.azureedge.net/models/labeler.fins.v0.zip', - 'fins_v1': 'https://wildbookiarepository.azureedge.net/models/labeler.fins.v1.zip', - 'fins_v1-1': 'https://wildbookiarepository.azureedge.net/models/labeler.fins.v1.1.zip', + 'fins_v1': 'https://wildbookiarepository.azureedge.net/models/labeler.fins.v1.1.zip', 'wilddog_v0': 'https://wildbookiarepository.azureedge.net/models/labeler.wild_dog.v0.zip', 'wilddog_v1': 'https://wildbookiarepository.azureedge.net/models/labeler.wild_dog.v1.zip', 'wilddog_v2': 'https://wildbookiarepository.azureedge.net/models/labeler.wild_dog.v2.zip', diff --git a/wbia/algo/detect/lightnet.py b/wbia/algo/detect/lightnet.py index cc0655ede2..790539aedd 100644 --- a/wbia/algo/detect/lightnet.py +++ b/wbia/algo/detect/lightnet.py @@ -52,8 +52,9 @@ 'humpback_dorsal': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.whale_humpback.dorsal.v0.py', 'orca_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.whale_orca.v0.py', 'fins_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.fins.v0.py', - 'fins_v1': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.fins.v1.py', - 'fins_v1-1': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.fins.v1.1.py', + 'fins_v1_fluke': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.fins.v1.py', + 'fins_v1_dorsal': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.fins.v1.1.py', + 'fins_v1': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.fins.v1.1.py', 'nassau_grouper_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.grouper_nassau.v0.py', 'spotted_dolphin_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.dolphin_spotted.v0.py', 'spotted_skunk_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.skunk_spotted.v0.py', From 50cf5d4d0fbec67c249db865546271510bd368b8 Mon Sep 17 00:00:00 2001 From: karen chan Date: Mon, 7 Sep 2020 20:59:49 -0700 Subject: [PATCH 002/294] Use sqlalchemy to create and manage database connections Notes: - How to do `register_converter` with sqlalchemy? `register_adapter` appears to be working though - With sqlalchemy, we need to do: ``` result = connection.execute('...') result.fetchall() ``` instead of: ``` cursor.execute('...') cursor.fetchall() ``` and: ``` transaction = connection.begin() transaction.commit() ``` instead of: ``` connection.commit() ``` - We can't do fetch when there's no results e.g. CREATE TABLE statements, sqlalchemy returns an error: ``` sqlalchemy.exc.ResourceClosedError: This result object does not return rows. It has been closed automatically. ``` The way to check is by doing `result.returns_rows`. --- wbia/annotmatch_funcs.py | 4 +- wbia/control/IBEISControl.py | 4 +- wbia/control/_sql_helpers.py | 8 +- wbia/control/manual_annot_funcs.py | 2 +- wbia/control/manual_image_funcs.py | 4 +- wbia/control/manual_name_funcs.py | 2 +- wbia/dtool/depcache_control.py | 2 +- wbia/dtool/sql_control.py | 151 +++++++++++++++-------------- 8 files changed, 91 insertions(+), 86 deletions(-) diff --git a/wbia/annotmatch_funcs.py b/wbia/annotmatch_funcs.py index 6739bc858e..579ebaefb2 100644 --- a/wbia/annotmatch_funcs.py +++ b/wbia/annotmatch_funcs.py @@ -52,7 +52,7 @@ def get_annotmatch_rowids_from_aid1(ibs, aid1_list, eager=True, nInput=None): ANNOTMATCH_TABLE=ibs.const.ANNOTMATCH_TABLE, annot_rowid1=manual_annotmatch_funcs.ANNOT_ROWID1, ) - ).fetchall() + ) where_colnames = [manual_annotmatch_funcs.ANNOT_ROWID1] annotmatch_rowid_list = ibs.db.get_where_eq( ibs.const.ANNOTMATCH_TABLE, @@ -89,7 +89,7 @@ def get_annotmatch_rowids_from_aid2( ANNOTMATCH_TABLE=ibs.const.ANNOTMATCH_TABLE, annot_rowid2=manual_annotmatch_funcs.ANNOT_ROWID2, ) - ).fetchall() + ) colnames = (manual_annotmatch_funcs.ANNOTMATCH_ROWID,) # FIXME: col_rowid is not correct params_iter = zip(aid2_list) diff --git a/wbia/control/IBEISControl.py b/wbia/control/IBEISControl.py index 0c8d446a78..2f5651dc6d 100644 --- a/wbia/control/IBEISControl.py +++ b/wbia/control/IBEISControl.py @@ -682,7 +682,7 @@ def _init_sqldbcore(ibs, request_dbversion=None): readonly = None else: readonly = True - db_uri = 'file://{}'.format(realpath(sqldb_fpath)) + db_uri = 'sqlite:///{}'.format(realpath(sqldb_fpath)) ibs.db = dtool.SQLDatabaseController.from_uri(db_uri, readonly=readonly) ibs.readonly = ibs.db.readonly @@ -768,7 +768,7 @@ def _init_sqldbstaging(ibs, request_stagingversion=None): readonly = None else: readonly = True - db_uri = 'file://{}'.format(realpath(sqlstaging_fpath)) + db_uri = 'sqlite:///{}'.format(realpath(sqlstaging_fpath)) ibs.staging = dtool.SQLDatabaseController.from_uri( db_uri, readonly=readonly, diff --git a/wbia/control/_sql_helpers.py b/wbia/control/_sql_helpers.py index db19f22ecb..6d44b5a030 100644 --- a/wbia/control/_sql_helpers.py +++ b/wbia/control/_sql_helpers.py @@ -47,7 +47,7 @@ def _devcheck_backups(): sorted(ut.glob(join(dbdir, '_wbia_backups'), '*staging_back*.sqlite3')) fpaths = sorted(ut.glob(join(dbdir, '_wbia_backups'), '*database_back*.sqlite3')) for fpath in fpaths: - db_uri = 'file://{}'.format(realpath(fpath)) + db_uri = 'sqlite:///{}'.format(realpath(fpath)) db = dt.SQLDatabaseController.from_uri(db_uri) logger.info('fpath = %r' % (fpath,)) num_edges = len(db.executeone('SELECT rowid from annotmatch')) @@ -185,7 +185,7 @@ def copy_database(src_fpath, dst_fpath): # blocked lock for all processes potentially writing to the database timeout = 12 * 60 * 60 # Allow a lock of up to 12 hours for a database backup routine if not src_fpath.startswith('file:'): - src_fpath = 'file://{}'.format(realpath(src_fpath)) + src_fpath = 'sqlite:///{}'.format(realpath(src_fpath)) db = dtool.SQLDatabaseController.from_uri(src_fpath, timeout=timeout) db.backup(dst_fpath) @@ -379,7 +379,7 @@ def _check_superkeys(): ), 'ERROR UPDATING DATABASE, SUPERKEYS of %s DROPPED!' % (tablename,) logger.info('[_SQL] update_schema_version') - db_fpath = db.uri.replace('file://', '') + db_fpath = db.uri.replace('sqlite://', '') if dobackup: db_dpath, db_fname = split(db_fpath) db_fname_noext, ext = splitext(db_fname) @@ -540,7 +540,7 @@ def get_nth_test_schema_version(schema_spec, n=-1): cachedir = ut.ensure_app_resource_dir('wbia_test') db_fname = 'test_%s.sqlite3' % dbname ut.delete(join(cachedir, db_fname)) - db_uri = 'file://{}'.format(realpath(join(cachedir, db_fname))) + db_uri = 'sqlite:///{}'.format(realpath(join(cachedir, db_fname))) db = SQLDatabaseController.from_uri(db_uri) ensure_correct_version(None, db, version_expected, schema_spec, dobackup=False) return db diff --git a/wbia/control/manual_annot_funcs.py b/wbia/control/manual_annot_funcs.py index d58b963c78..c238ca8f16 100644 --- a/wbia/control/manual_annot_funcs.py +++ b/wbia/control/manual_annot_funcs.py @@ -2486,7 +2486,7 @@ def get_annot_part_rowids(ibs, aid_list, is_staged=False): """ CREATE INDEX IF NOT EXISTS aid_to_part_rowids ON parts (annot_rowid); """ - ).fetchall() + ) # The index maxes the following query very efficient part_rowids_list = ibs.db.get( ibs.const.PART_TABLE, diff --git a/wbia/control/manual_image_funcs.py b/wbia/control/manual_image_funcs.py index e655d49ea5..bb13020355 100644 --- a/wbia/control/manual_image_funcs.py +++ b/wbia/control/manual_image_funcs.py @@ -2178,7 +2178,7 @@ def get_image_imgsetids(ibs, gid_list): """.format( GSG_RELATION_TABLE=const.GSG_RELATION_TABLE, IMAGE_ROWID=IMAGE_ROWID ) - ).fetchall() + ) colnames = ('imageset_rowid',) imgsetids_list = ibs.db.get( const.GSG_RELATION_TABLE, @@ -2280,7 +2280,7 @@ def get_image_aids(ibs, gid_list, is_staged=False, __check_staged__=True): """ CREATE INDEX IF NOT EXISTS gid_to_aids ON annotations (image_rowid); """ - ).fetchall() + ) # The index maxes the following query very efficient if __check_staged__: diff --git a/wbia/control/manual_name_funcs.py b/wbia/control/manual_name_funcs.py index cee28dbca6..4afdc5a0fb 100644 --- a/wbia/control/manual_name_funcs.py +++ b/wbia/control/manual_name_funcs.py @@ -493,7 +493,7 @@ def get_name_aids(ibs, nid_list, enable_unknown_fix=True, is_staged=False): """ CREATE INDEX IF NOT EXISTS nid_to_aids ON annotations (name_rowid); """ - ).fetchall() + ) aids_list = ibs.db.get( const.ANNOTATION_TABLE, (ANNOT_ROWID,), diff --git a/wbia/dtool/depcache_control.py b/wbia/dtool/depcache_control.py index 5a1fb1fb7e..9ec9e52c0a 100644 --- a/wbia/dtool/depcache_control.py +++ b/wbia/dtool/depcache_control.py @@ -235,7 +235,7 @@ def initialize(depc, _debug=None): fpath = ut.unixjoin(depc.cache_dpath, fname_) # if ut.get_argflag('--clear-all-depcache'): # ut.delete(fpath) - db_uri = 'file://{}'.format(os.path.realpath(fpath)) + db_uri = 'sqlite:///{}'.format(os.path.realpath(fpath)) db = sql_control.SQLDatabaseController.from_uri(db_uri) depcache_table.ensure_config_table(db) depc.fname_to_db[fname] = db diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 0d6597b6b2..d0486ca367 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -10,6 +10,7 @@ import os import parse import re +import sqlite3 import threading import uuid from collections.abc import Mapping, MutableMapping @@ -17,10 +18,11 @@ from os.path import join, exists import six +import sqlalchemy import utool as ut from deprecated import deprecated -from wbia.dtool import sqlite3 as lite +from wbia.dtool import lite from wbia.dtool.dump import dumps @@ -156,8 +158,8 @@ def __init__( context.verbose = verbose context.is_insert = context.operation_type.startswith('INSERT') context.keepwrap = keepwrap - context.cur = None context.connection = None + context.transaction = None def __enter__(context): """ Checks to see if the operating will change the database """ @@ -174,21 +176,12 @@ def __enter__(context): # Start SQL Transaction context.connection = context.db.connection - try: - context.cur = context.connection.cursor() # HACK in a new cursor - except lite.ProgrammingError: - # Get connection for new thread - context.connection = context.db.thread_connection() - context.cur = context.connection.cursor() - - # context.cur = context.db.cur # OR USE DB CURSOR?? if context.start_transaction: - # context.cur.execute('BEGIN', ()) try: - context.cur.execute('BEGIN') - except lite.OperationalError: + context.transaction = context.connection.begin() + except sqlalchemy.exc.OperationalError: context.connection.rollback() - context.cur.execute('BEGIN') + context.transaction = context.connection.begin() if context.verbose or VERBOSE_SQL: logger.info(context.operation_lbl) if context.verbose: @@ -203,8 +196,8 @@ def __enter__(context): def execute_and_generate_results(context, params): """ helper for context statment """ try: - context.cur.execute(context.operation, params) - except lite.Error as ex: + result = context.connection.execute(context.operation, params) + except sqlalchemy.exc.OperationalError as ex: logger.info('Reporting SQLite Error') logger.info('params = ' + ut.repr2(params, truncate=not ut.VERBOSE)) ut.printex(ex, 'sql.Error', keys=['params']) @@ -232,10 +225,10 @@ def execute_and_generate_results(context, params): except Exception: pass raise - return context._results_gen() + return context._results_gen(result) # @profile - def _results_gen(context): + def _results_gen(context, result): """HELPER - Returns as many results as there are. Careful. Overwrites the results once you call it. Basically: Dont call this twice. @@ -244,19 +237,22 @@ def _results_gen(context): # The sqlite3_last_insert_rowid(D) interface returns the # rowid of the most recent successful INSERT # into a rowid table in D - context.cur.execute('SELECT last_insert_rowid()', ()) + result = context.connection.execute('SELECT last_insert_rowid()', ()) # Wraping fetchone in a generator for some pretty tight calls. while True: - result = context.cur.fetchone() - if not result: + if not result.returns_rows: + # Doesn't have any results, e.g. CREATE TABLE statements + return + row = result.fetchone() + if row is None: return if context.keepwrap: # Results are always returned wraped in a tuple - yield result + yield row else: # Here unpacking is conditional # FIXME: can this if be removed? - yield result[0] if len(result) == 1 else result + yield row[0] if len(row) == 1 else row def __exit__(context, type_, value, trace): """ Finalization of an SQLController call """ @@ -269,8 +265,8 @@ def __exit__(context, type_, value, trace): return False else: # Commit the transaction - if context.auto_commit: - context.connection.commit() + if context.auto_commit and context.transaction: + context.transaction.commit() else: logger.info('no commit %r' % context.operation_lbl) @@ -637,7 +633,7 @@ def from_uri(cls, uri, readonly=READ_ONLY, timeout=TIMEOUT): >>> sqldb_dpath = ut.ensure_app_resource_dir('dtool') >>> sqldb_fname = u'test_database.sqlite3' >>> path = os.path.join(sqldb_dpath, sqldb_fname) - >>> db_uri = 'file://{}'.format(os.path.realpath(path)) + >>> db_uri = 'sqlite:///{}'.format(os.path.realpath(path)) >>> db = SQLDatabaseController.from_uri(db_uri) >>> db.print_schema() >>> print(db) @@ -660,9 +656,6 @@ def from_uri(cls, uri, readonly=READ_ONLY, timeout=TIMEOUT): # FIXME (31-Jul-12020) rename to private attribute self.thread_connections = {} self._connection = None - # FIXME (31-Jul-12020) rename to private attribute, no direct access to the connection - # Initialize a cursor - self.cur = self.connection.cursor() if not self.readonly: # Ensure the metadata table is initialized. @@ -677,9 +670,14 @@ def from_uri(cls, uri, readonly=READ_ONLY, timeout=TIMEOUT): def connect(self): """Create a connection for the instance or use the existing connection""" - self._connection = lite.connect( - self.uri, detect_types=lite.PARSE_DECLTYPES, timeout=self.timeout, uri=True + # The echo flag is a shortcut to set up SQLAlchemy logging + self._engine = sqlalchemy.create_engine( + self.uri, + # FIXME (27-Aug-12020) Hardcoded for sqlite + connect_args={'detect_types': sqlite3.PARSE_DECLTYPES}, + echo=False, ) + self._connection = self._engine.connect() return self._connection @property @@ -693,7 +691,7 @@ def connection(self): return conn def _create_connection(self): - path = self.uri.replace('file://', '') + path = self.uri.replace('sqlite://', '') if not exists(path): logger.info('[sql] Initializing new database: %r' % (self.uri,)) if self.readonly: @@ -707,9 +705,13 @@ def _create_connection(self): uri = self.uri if self.readonly: uri += '?mode=ro' - connection = lite.connect( - uri, uri=True, detect_types=lite.PARSE_DECLTYPES, timeout=self.timeout + engine = sqlalchemy.create_engine( + uri, + # FIXME (27-Aug-12020) Hardcoded for sqlite + connect_args={'detect_types': sqlite3.PARSE_DECLTYPES}, + echo=False, ) + connection = engine.connect() # Keep track of what thead this was started in threadid = threading.current_thread() @@ -718,7 +720,6 @@ def _create_connection(self): return connection, uri def close(self): - self.cur = None self.connection.close() self.thread_connections = {} @@ -747,7 +748,7 @@ def _ensure_metadata_table(self): """ try: orig_table_kw = self.get_table_autogen_dict(METADATA_TABLE_NAME) - except (lite.OperationalError, NameError): + except (sqlalchemy.exc.OperationalError, NameError): orig_table_kw = None meta_table_kw = ut.odict( @@ -813,7 +814,7 @@ def get_db_init_uuid(self, ensure=True): >>> import os >>> from wbia.dtool.sql_control import * # NOQA >>> # Check random database gets new UUID on init - >>> db = SQLDatabaseController.from_uri(':memory:') + >>> db = SQLDatabaseController.from_uri('sqlite:///') >>> uuid_ = db.get_db_init_uuid() >>> print('New Database: %r is valid' % (uuid_, )) >>> assert isinstance(uuid_, uuid.UUID) @@ -821,7 +822,7 @@ def get_db_init_uuid(self, ensure=True): >>> sqldb_dpath = ut.ensure_app_resource_dir('dtool') >>> sqldb_fname = u'test_database.sqlite3' >>> path = os.path.join(sqldb_dpath, sqldb_fname) - >>> db_uri = 'file://{}'.format(os.path.realpath(path)) + >>> db_uri = 'sqlite:///{}'.format(os.path.realpath(path)) >>> db1 = SQLDatabaseController.from_uri(db_uri) >>> uuid_1 = db1.get_db_init_uuid() >>> db2 = SQLDatabaseController.from_uri(db_uri) @@ -837,14 +838,15 @@ def get_db_init_uuid(self, ensure=True): def reboot(self): logger.info('[sql] reboot') - self.cur.close() - del self.cur self.connection.close() del self.connection - self.connection = lite.connect( - self.uri, detect_types=lite.PARSE_DECLTYPES, timeout=self.timeout, uri=True + self._engine = sqlalchemy.create_engine( + self.uri, + # FIXME (27-Aug-12020) Hardcoded for sqlite + connect_args={'detect_types': sqlite3.PARSE_DECLTYPES}, + echo=False, ) - self.cur = self.connection.cursor() + self.connection = self._engine.connect() def backup(self, backup_filepath): """ @@ -853,10 +855,11 @@ def backup(self, backup_filepath): # Create a brand new conenction to lock out current thread and any others connection, uri = self._create_connection() # Start Exclusive transaction, lock out all other writers from making database changes + transaction = connection.begin() connection.isolation_level = 'EXCLUSIVE' connection.execute('BEGIN EXCLUSIVE') # Assert the database file exists, and copy to backup path - path = self.uri.replace('file://', '') + path = self.uri.replace('sqlite://', '') if exists(path): ut.copy(path, backup_filepath) else: @@ -864,7 +867,7 @@ def backup(self, backup_filepath): 'Could not backup the database as the URI does not exist: %r' % (uri,) ) # Commit the transaction, releasing the lock - connection.commit() + transaction.commit() # Close the connection connection.close() @@ -873,35 +876,35 @@ def optimize(self): # http://web.utk.edu/~jplyon/sqlite/SQLite_optimization_FAQ.html if VERBOSE_SQL: logger.info('[sql] running sql pragma optimizions') - # self.cur.execute('PRAGMA cache_size = 0;') - # self.cur.execute('PRAGMA cache_size = 1024;') - # self.cur.execute('PRAGMA page_size = 1024;') + # self.connection.execute('PRAGMA cache_size = 0;') + # self.connection.execute('PRAGMA cache_size = 1024;') + # self.connection.execute('PRAGMA page_size = 1024;') # logger.info('[sql] running sql pragma optimizions') - self.cur.execute('PRAGMA cache_size = 10000;') # Default: 2000 - self.cur.execute('PRAGMA temp_store = MEMORY;') - self.cur.execute('PRAGMA synchronous = OFF;') - # self.cur.execute('PRAGMA synchronous = NORMAL;') - # self.cur.execute('PRAGMA synchronous = FULL;') # Default - # self.cur.execute('PRAGMA parser_trace = OFF;') - # self.cur.execute('PRAGMA busy_timeout = 1;') - # self.cur.execute('PRAGMA default_cache_size = 0;') + self.connection.execute('PRAGMA cache_size = 10000;') # Default: 2000 + self.connection.execute('PRAGMA temp_store = MEMORY;') + self.connection.execute('PRAGMA synchronous = OFF;') + # self.connection.execute('PRAGMA synchronous = NORMAL;') + # self.connection.execute('PRAGMA synchronous = FULL;') # Default + # self.connection.execute('PRAGMA parser_trace = OFF;') + # self.connection.execute('PRAGMA busy_timeout = 1;') + # self.connection.execute('PRAGMA default_cache_size = 0;') def shrink_memory(self): logger.info('[sql] shrink_memory') self.connection.commit() - self.cur.execute('PRAGMA shrink_memory;') + self.connection.execute('PRAGMA shrink_memory;') self.connection.commit() def vacuum(self): logger.info('[sql] vaccum') - self.connection.commit() - self.cur.execute('VACUUM;') - self.connection.commit() + transaction = self.connection.begin() + self.connection.execute('VACUUM;') + transaction.commit() def integrity(self): logger.info('[sql] vaccum') self.connection.commit() - self.cur.execute('PRAGMA integrity_check;') + self.connection.execute('PRAGMA integrity_check;') self.connection.commit() def squeeze(self): @@ -1015,7 +1018,7 @@ def add_cleanly( Example: >>> # ENABLE_DOCTEST >>> from wbia.dtool.sql_control import * # NOQA - >>> db = SQLDatabaseController.from_uri(':memory:') + >>> db = SQLDatabaseController.from_uri('sqlite:///') >>> db.add_table('dummy_table', ( >>> ('rowid', 'INTEGER PRIMARY KEY'), >>> ('key', 'TEXT'), @@ -1106,7 +1109,7 @@ def rows_exist(self, tblname, rowids): """ operation = 'SELECT count(1) FROM {tblname} WHERE rowid=?'.format(tblname=tblname) for rowid in rowids: - yield bool(self.cur.execute(operation, (rowid,)).fetchone()[0]) + yield bool(self.connection.execute(operation, (rowid,)).fetchone()[0]) def get_where_eq( self, @@ -1354,7 +1357,7 @@ def get( 'id_repr': ','.join(map(str, id_iter)), } operation = operation_fmt.format(**fmtdict) - results = self.cur.execute(operation).fetchall() + results = self.connection.execute(operation).fetchall() import numpy as np sortx = np.argsort(np.argsort(id_iter)) @@ -2249,7 +2252,7 @@ def get_table_autogen_dict(self, tablename): Example: >>> # ENABLE_DOCTEST >>> from wbia.dtool.sql_control import * # NOQA - >>> db = SQLDatabaseController.from_uri(':memory:') + >>> db = SQLDatabaseController.from_uri('sqlite:///') >>> tablename = 'dummy_table' >>> db.add_table(tablename, ( >>> ('rowid', 'INTEGER PRIMARY KEY'), @@ -2284,7 +2287,7 @@ def get_table_autogen_str(self, tablename): Example: >>> # ENABLE_DOCTEST >>> from wbia.dtool.sql_control import * # NOQA - >>> db = SQLDatabaseController.from_uri(':memory:') + >>> db = SQLDatabaseController.from_uri('sqlite:///') >>> tablename = 'dummy_table' >>> db.add_table(tablename, ( >>> ('rowid', 'INTEGER PRIMARY KEY'), @@ -2381,8 +2384,10 @@ def dump_schema(self): def get_table_names(self, lazy=False): """ Conveinience: """ if not lazy or self._tablenames is None: - self.cur.execute("SELECT name FROM sqlite_master WHERE type='table'") - tablename_list = self.cur.fetchall() + result = self.connection.execute( + "SELECT name FROM sqlite_master WHERE type='table'" + ) + tablename_list = result.fetchall() self._tablenames = {str(tablename[0]) for tablename in tablename_list} return self._tablenames @@ -2504,9 +2509,9 @@ def get_columns(self, tablename): ] """ # check if the table exists first. Throws an error if it does not exist. - self.cur.execute('SELECT 1 FROM ' + tablename + ' LIMIT 1') - self.cur.execute("PRAGMA TABLE_INFO('" + tablename + "')") - colinfo_list = self.cur.fetchall() + self.connection.execute('SELECT 1 FROM ' + tablename + ' LIMIT 1') + result = self.connection.execute("PRAGMA TABLE_INFO('" + tablename + "')") + colinfo_list = result.fetchall() colrichinfo_list = [SQLColumnRichInfo(*colinfo) for colinfo in colinfo_list] return colrichinfo_list @@ -3278,8 +3283,8 @@ def set_db_version(self, version): def get_sql_version(self): """ Conveinience """ - self.cur.execute('SELECT sqlite_version()') - sql_version = self.cur.fetchone() + self.connection.execute('SELECT sqlite_version()') + sql_version = self.connection.fetchone() logger.info('[sql] SELECT sqlite_version = %r' % (sql_version,)) # The version number sqlite3 module. NOT the version of SQLite library. logger.info('[sql] sqlite3.version = %r' % (lite.version,)) From 29a3ab5154619e3deada9163c689f471c39b734b Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Tue, 8 Sep 2020 10:24:26 -0700 Subject: [PATCH 003/294] Fix sql_control unit tests --- wbia/tests/dtool/test_sql_control.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index 5cffef0dd4..55e51d3fb4 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- -import sqlite3 import uuid import pytest +from sqlalchemy.engine import Connection from wbia.dtool.sql_control import ( METADATA_TABLE_COLUMNS, @@ -13,17 +13,17 @@ @pytest.fixture def ctrlr(): - return SQLDatabaseController.from_uri(':memory:') + return SQLDatabaseController.from_uri('sqlite:///:memory:') def test_instantiation(ctrlr): # Check for basic connection information - assert ctrlr.uri == ':memory:' + assert ctrlr.uri == 'sqlite:///:memory:' assert ctrlr.timeout == TIMEOUT # Check for a connection, that would have been made during instantiation - assert isinstance(ctrlr.connection, sqlite3.Connection) - assert isinstance(ctrlr.cur, sqlite3.Cursor) + assert isinstance(ctrlr.connection, Connection) + assert not ctrlr.connection.closed def test_safely_get_db_version(ctrlr): From 0886af274c1b3bc65f461fe7f223d2a002432b35 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Mon, 7 Sep 2020 22:46:21 -0700 Subject: [PATCH 004/294] Remove sqlite type integrations Remove the sqlite adapter and converter type integration. Remove sqlite dbapi type initialization from the engine creation. --- wbia/dtool/__init__.py | 4 +- wbia/dtool/_integrate_sqlite3.py | 115 ----------------- wbia/dtool/sql_control.py | 7 -- wbia/tests/dtool/test__integrate_sqlite3.py | 132 -------------------- 4 files changed, 3 insertions(+), 255 deletions(-) delete mode 100644 wbia/dtool/_integrate_sqlite3.py delete mode 100644 wbia/tests/dtool/test__integrate_sqlite3.py diff --git a/wbia/dtool/__init__.py b/wbia/dtool/__init__.py index dcaec258f3..00818297c6 100644 --- a/wbia/dtool/__init__.py +++ b/wbia/dtool/__init__.py @@ -6,7 +6,9 @@ # See `_integrate_sqlite3` module for details. import sqlite3 -from wbia.dtool import _integrate_sqlite3 as lite +# BBB (7-Sept-12020) +import sqlite3 as lite + from wbia.dtool import base from wbia.dtool import sql_control from wbia.dtool import depcache_control diff --git a/wbia/dtool/_integrate_sqlite3.py b/wbia/dtool/_integrate_sqlite3.py deleted file mode 100644 index a260370824..0000000000 --- a/wbia/dtool/_integrate_sqlite3.py +++ /dev/null @@ -1,115 +0,0 @@ -# -*- coding: utf-8 -*- -"""Integrates numpy types into sqlite3""" -import io -import uuid -from sqlite3 import register_adapter, register_converter - -import numpy as np -import utool as ut - - -__all__ = () - - -def _read_numpy_from_sqlite3(blob): - # INVESTIGATE: Is memory freed up correctly here? - out = io.BytesIO(blob) - out.seek(0) - # return np.load(out) - # Is this better? - arr = np.load(out) - out.close() - return arr - - -def _read_bool(b): - return None if b is None else bool(b) - - -def _write_bool(b): - return b - - -def _write_numpy_to_sqlite3(arr): - out = io.BytesIO() - np.save(out, arr) - out.seek(0) - return memoryview(out.read()) - - -def _read_uuid_from_sqlite3(blob): - try: - return uuid.UUID(bytes_le=blob) - except ValueError as ex: - ut.printex(ex, keys=['blob']) - raise - - -def _read_dict_from_sqlite3(blob): - return ut.from_json(blob) - # return uuid.UUID(bytes_le=blob) - - -def _write_dict_to_sqlite3(dict_): - return ut.to_json(dict_) - - -def _write_uuid_to_sqlite3(uuid_): - return memoryview(uuid_.bytes_le) - - -def register_numpy_dtypes(): - py_int_type = int - for dtype in ( - np.int8, - np.int16, - np.int32, - np.int64, - np.uint8, - np.uint16, - np.uint32, - np.uint64, - ): - register_adapter(dtype, py_int_type) - register_adapter(np.float32, float) - register_adapter(np.float64, float) - - -def register_numpy(): - """ - Tell SQL how to deal with numpy arrays - Utility function allowing numpy arrays to be stored as raw blob data - """ - register_converter('NUMPY', _read_numpy_from_sqlite3) - register_converter('NDARRAY', _read_numpy_from_sqlite3) - register_adapter(np.ndarray, _write_numpy_to_sqlite3) - - -def register_uuid(): - """ Utility function allowing uuids to be stored in sqlite """ - register_converter('UUID', _read_uuid_from_sqlite3) - register_adapter(uuid.UUID, _write_uuid_to_sqlite3) - - -def register_dict(): - register_converter('DICT', _read_dict_from_sqlite3) - register_adapter(dict, _write_dict_to_sqlite3) - - -def register_list(): - register_converter('LIST', ut.from_json) - register_adapter(list, ut.to_json) - - -# def register_bool(): -# # FIXME: ensure this works -# register_converter('BOOL', _read_bool) -# register_adapter(bool, _write_bool) - - -register_numpy_dtypes() -register_numpy() -register_uuid() -register_dict() -register_list() -# register_bool() # TODO diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index d0486ca367..235b75096f 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -10,7 +10,6 @@ import os import parse import re -import sqlite3 import threading import uuid from collections.abc import Mapping, MutableMapping @@ -673,8 +672,6 @@ def connect(self): # The echo flag is a shortcut to set up SQLAlchemy logging self._engine = sqlalchemy.create_engine( self.uri, - # FIXME (27-Aug-12020) Hardcoded for sqlite - connect_args={'detect_types': sqlite3.PARSE_DECLTYPES}, echo=False, ) self._connection = self._engine.connect() @@ -707,8 +704,6 @@ def _create_connection(self): uri += '?mode=ro' engine = sqlalchemy.create_engine( uri, - # FIXME (27-Aug-12020) Hardcoded for sqlite - connect_args={'detect_types': sqlite3.PARSE_DECLTYPES}, echo=False, ) connection = engine.connect() @@ -842,8 +837,6 @@ def reboot(self): del self.connection self._engine = sqlalchemy.create_engine( self.uri, - # FIXME (27-Aug-12020) Hardcoded for sqlite - connect_args={'detect_types': sqlite3.PARSE_DECLTYPES}, echo=False, ) self.connection = self._engine.connect() diff --git a/wbia/tests/dtool/test__integrate_sqlite3.py b/wbia/tests/dtool/test__integrate_sqlite3.py deleted file mode 100644 index 1aca23abd0..0000000000 --- a/wbia/tests/dtool/test__integrate_sqlite3.py +++ /dev/null @@ -1,132 +0,0 @@ -# -*- coding: utf-8 -*- -import sqlite3 -import uuid - -import numpy as np -import pytest - -# We do not explicitly call code in this module because -# importing the following module is execution of the code. -import wbia.dtool._integrate_sqlite3 # noqa - - -@pytest.fixture -def db(): - with sqlite3.connect(':memory:', detect_types=sqlite3.PARSE_DECLTYPES) as con: - yield con - - -np_number_types = ( - np.int8, - np.int16, - np.int32, - np.int64, - np.uint8, - np.uint16, - np.uint32, - np.uint64, -) - - -@pytest.mark.parametrize('num_type', np_number_types) -def test_register_numpy_dtypes_ints(db, num_type): - # The magic takes place in the register_numpy_dtypes function, - # which is implicitly called on module import. - - # Create a table that uses the type - db.execute('create table test(x integer)') - - # Insert a uuid value into the table - insert_value = num_type(8) - db.execute('insert into test(x) values (?)', (insert_value,)) - # Query for the value - cur = db.execute('select x from test') - selected_value = cur.fetchone()[0] - assert selected_value == insert_value - - -@pytest.mark.parametrize('num_type', (np.float32, np.float64)) -def test_register_numpy_dtypes_floats(db, num_type): - # The magic takes place in the register_numpy_dtypes function, - # which is implicitly called on module import. - - # Create a table that uses the type - db.execute('create table test(x real)') - - # Insert a uuid value into the table - insert_value = num_type(8.0000008) - db.execute('insert into test(x) values (?)', (insert_value,)) - # Query for the value - cur = db.execute('select x from test') - selected_value = cur.fetchone()[0] - assert selected_value == insert_value - - -@pytest.mark.parametrize('type_name', ('numpy', 'ndarray')) -def test_register_numpy(db, type_name): - # The magic takes place in the register_numpy function, - # which is implicitly called on module import. - - # Create a table that uses the type - db.execute(f'create table test(x {type_name})') - - # Insert a numpy array value into the table - insert_value = np.array([[1, 2, 3], [4, 5, 6]], np.int32) - db.execute('insert into test(x) values (?)', (insert_value,)) - # Query for the value - cur = db.execute('select x from test') - selected_value = cur.fetchone()[0] - assert (selected_value == insert_value).all() - - -def test_register_uuid(db): - # The magic takes place in the register_uuid function, - # which is implicitly called on module import. - - # Create a table that uses the type - db.execute('create table test(x uuid)') - - # Insert a uuid value into the table - insert_value = uuid.uuid4() - db.execute('insert into test(x) values (?)', (insert_value,)) - # Query for the value - cur = db.execute('select x from test') - selected_value = cur.fetchone()[0] - assert selected_value == insert_value - - -def test_register_dict(db): - # The magic takes place in the register_dict function, - # which is implicitly called on module import. - - # Create a table that uses the type - db.execute('create table test(x dict)') - - # Insert a dict value into the table - insert_value = { - 'a': 1, - 'b': 2.2, - 'c': [[1, 2, 3], [4, 5, 6]], - } - db.execute('insert into test(x) values (?)', (insert_value,)) - # Query for the value - cur = db.execute('select x from test') - selected_value = cur.fetchone()[0] - for k, v in selected_value.items(): - assert v == insert_value[k] - - -def test_register_list(db): - # The magic takes place in the register_list function, - # which is implicitly called on module import. - - # Create a table that uses the type - db.execute('create table test(x list)') - - # Insert a list of list value into the table - insert_value = [[1, 2, 3], [4, 5, 6]] - db.execute('insert into test(x) values (?)', (insert_value,)) - # Query for the value - cur = db.execute('select x from test') - selected_value = cur.fetchone()[0] - assert selected_value == insert_value From c6b77e0b882beb05d01fd4cd76f5cde139a38f6e Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Fri, 11 Sep 2020 22:23:47 -0700 Subject: [PATCH 005/294] Revise and test delete & delete_rowids methods Revised to use sqlalchemy. Also delete_rowids is now a submethod of delete. Sidenote, this is a terribly inefficient way to do a delete, one at a time. But I can't seem to get the in operator with an array to work. For example, `where id in (1, 2, 3)` doesn't work because python's sqlite3 lib fails to translate a list to an array. And sqlalchemy won't push it through either. So we're stuck doing what it was doing before. --- wbia/dtool/sql_control.py | 37 ++++------------- wbia/tests/dtool/test_sql_control.py | 61 ++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 28 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 235b75096f..8f3d9938e1 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -20,6 +20,7 @@ import sqlalchemy import utool as ut from deprecated import deprecated +from sqlalchemy.sql import text from wbia.dtool import lite from wbia.dtool.dump import dumps @@ -1503,38 +1504,18 @@ def set( ) def delete(self, tblname, id_list, id_colname='rowid', **kwargs): + """Deletes rows from a SQL table (``tblname``) by ID, + given a sequence of IDs (``id_list``). + Optionally a different ID column can be specified via ``id_colname``. + """ - deleter. USE delete_rowids instead - """ - fmtdict = { - 'tblname': tblname, - 'rowid_str': (id_colname + '=?'), - } - operation_fmt = """ - DELETE - FROM {tblname} - WHERE {rowid_str} - """ - params_iter = ((_rowid,) for _rowid in id_list) - return self._executemany_operation_fmt( - operation_fmt, fmtdict, params_iter=params_iter, **kwargs - ) + stmt = text(f'DELETE FROM {tblname} WHERE {id_colname} = :id') + for id in id_list: + self.connection.execute(stmt, id=id) def delete_rowids(self, tblname, rowid_list, **kwargs): """ deletes the the rows in rowid_list """ - fmtdict = { - 'tblname': tblname, - 'rowid_str': ('rowid=?'), - } - operation_fmt = """ - DELETE - FROM {tblname} - WHERE {rowid_str} - """ - params_iter = ((_rowid,) for _rowid in rowid_list) - return self._executemany_operation_fmt( - operation_fmt, fmtdict, params_iter=params_iter, **kwargs - ) + self.delete(tblname, rowid_list, id_colname='rowid', **kwargs) # ============== # CORE WRAPPERS diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index 55e51d3fb4..c3c574da72 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -3,6 +3,7 @@ import pytest from sqlalchemy.engine import Connection +from sqlalchemy.sql import text from wbia.dtool.sql_control import ( METADATA_TABLE_COLUMNS, @@ -261,3 +262,63 @@ def test_delitem_for_database(self): self.ctrlr.metadata.database['init_uuid'] = None # Check the value is still a uuid.UUID assert isinstance(self.ctrlr.metadata.database.init_uuid, uuid.UUID) + + +class TestAPI: + """Testing the primary *usage* API""" + + @pytest.fixture(autouse=True) + def fixture(self, ctrlr): + self.ctrlr = ctrlr + + def make_table(self, name): + self.ctrlr.connection.execute( + f'CREATE TABLE IF NOT EXISTS {name} ' + '(id INTEGER PRIMARY KEY, x TEXT, y INTEGER, z REAL)' + ) + + def test_delete(self): + # Make a table for records + table_name = 'test_delete' + self.make_table(table_name) + + # Create some dummy records + insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') + for i in range(0, 10): + x, y, z = (str(i), i, i * 2.01) + self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + + results = self.ctrlr.connection.execute(f'SELECT id, CAST((y % 2) AS BOOL) FROM {table_name}') + rows = results.fetchall() + del_ids = [row[0] for row in rows if row[1]] + remaining_ids = sorted([row[0] for row in rows if not row[1]]) + + # Call the testing target + self.ctrlr.delete(table_name, del_ids, 'id') + + # Verify the deletion + results = self.ctrlr.connection.execute(f'SELECT id FROM {table_name}') + assert sorted([r[0] for r in results]) == remaining_ids + + def test_delete_rowid(self): + # Make a table for records + table_name = 'test_delete_rowid' + self.make_table(table_name) + + # Create some dummy records + insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') + for i in range(0, 10): + x, y, z = (str(i), i, i * 2.01) + self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + + results = self.ctrlr.connection.execute(f'SELECT rowid, CAST((y % 2) AS BOOL) FROM {table_name}') + rows = results.fetchall() + del_ids = [row[0] for row in rows if row[1]] + remaining_ids = sorted([row[0] for row in rows if not row[1]]) + + # Call the testing target + self.ctrlr.delete_rowids(table_name, del_ids) + + # Verify the deletion + results = self.ctrlr.connection.execute(f'SELECT rowid FROM {table_name}') + assert sorted([r[0] for r in results]) == remaining_ids From 805792b0639263b7d532fb923632aa8cbbb57e2e Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sat, 12 Sep 2020 09:48:19 -0700 Subject: [PATCH 006/294] Add test for sql controller set method Note, this is passing as expected, but it'll only pass with with sqlite. --- wbia/tests/dtool/test_sql_control.py | 36 ++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index c3c574da72..1c2916919e 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -277,6 +277,42 @@ def make_table(self, name): '(id INTEGER PRIMARY KEY, x TEXT, y INTEGER, z REAL)' ) + + def test_setting(self): + # Note, this is not a comprehensive test. It only attempts to test the SQL logic. + # Make a table for records + table_name = 'test_setting' + self.make_table(table_name) + + # Create some dummy records + insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') + for i in range(0, 10): + x, y, z = (str(i), i, i * 2.01) + self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + + results = self.ctrlr.connection.execute(f'SELECT id, CAST((y%2) AS BOOL) FROM {table_name}') + rows = results.fetchall() + ids = [row[0] for row in rows if row[1]] + + # Call the testing target + self.ctrlr.set( + table_name, + ['x', 'z'], + [('even', 0.0)] * len(ids), + ids, + id_colname='id' + ) + + # Verify setting + sql_array = ', '.join([str(id) for id in ids]) + results = self.ctrlr.connection.execute( + f'SELECT id, x, z FROM {table_name} ' + f"WHERE id in ({sql_array})" + ) + expected = sorted(map(lambda a: tuple([a] + ['even', 0.0]), ids)) + set_rows = sorted(results) + assert set_rows == expected + def test_delete(self): # Make a table for records table_name = 'test_delete' From 7dc0aa88ccf16b95a7e73dbb1e8a5d075d31b674 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sat, 12 Sep 2020 10:07:17 -0700 Subject: [PATCH 007/294] Update sql controller set method to use sqlalchemy params This should now work between sql implementation. I'm not sure if it's easier or more difficult to read. But I did the best I could without changing the method's signature and usage. --- wbia/dtool/sql_control.py | 65 ++++++--------------------------------- 1 file changed, 9 insertions(+), 56 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 8f3d9938e1..4cd9fd3272 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -81,45 +81,6 @@ def tuplize(list_): return tup_list -def flattenize(list_): - """ - maps flatten to a tuplized list - - Weird function. DEPRICATE - - Example: - >>> # DISABLE_DOCTEST - >>> list_ = [[1, 2, 3], [2, 3, [4, 2, 1]], [3, 2], [[1, 2], [3, 4]]] - >>> import utool - >>> from itertools import zip - >>> val_list1 = [(1, 2), (2, 4), (5, 3)] - >>> id_list1 = [(1,), (2,), (3,)] - >>> out_list1 = utool.flattenize(zip(val_list1, id_list1)) - - >>> val_list2 = [1, 4, 5] - >>> id_list2 = [(1,), (2,), (3,)] - >>> out_list2 = utool.flattenize(zip(val_list2, id_list2)) - - >>> val_list3 = [1, 4, 5] - >>> id_list3 = [1, 2, 3] - >>> out_list3 = utool.flattenize(zip(val_list3, id_list3)) - - out_list4 = list(zip(val_list3, id_list3)) - %timeit utool.flattenize(zip(val_list1, id_list1)) - %timeit utool.flattenize(zip(val_list2, id_list2)) - %timeit utool.flattenize(zip(val_list3, id_list3)) - %timeit list(zip(val_list3, id_list3)) - - 100000 loops, best of 3: 14 us per loop - 100000 loops, best of 3: 16.5 us per loop - 100000 loops, best of 3: 18 us per loop - 1000000 loops, best of 3: 1.18 us per loop - """ - tuplized_iter = map(tuplize, list_) - flatenized_list = list(map(ut.flatten, tuplized_iter)) - return flatenized_list - - # ======================= # SQL Context Class # ======================= @@ -1477,6 +1438,7 @@ def set( % (duplicate_behavior,) ) + # Check for incongruity between values and identifiers try: num_val = len(val_list) num_id = len(id_list) @@ -1484,24 +1446,15 @@ def set( except AssertionError as ex: ut.printex(ex, key_list=['num_val', 'num_id']) raise - fmtdict = { - 'tblname_str': tblname, - 'assign_str': ',\n'.join(['%s=?' % name for name in colnames]), - 'where_clause': (id_colname + '=?'), - } - operation_fmt = """ - UPDATE {tblname_str} - SET {assign_str} - WHERE {where_clause} - """ - - # TODO: The flattenize can be removed if we pass in val_lists instead - params_iter = flattenize(list(zip(val_list, id_list))) - # params_iter = list(zip(val_list, id_list)) - return self._executemany_operation_fmt( - operation_fmt, fmtdict, params_iter=params_iter, **kwargs - ) + # Execute the SQL updates for each set of values + assignments = ', '.join([f'{col} = :e{i}' for i, col in enumerate(colnames)]) + where_condition = f'{id_colname} = :id' + stmt = text(f'UPDATE {tblname} SET {assignments} WHERE {where_condition}') + for i, id in enumerate(id_list): + params = {'id': id} + params.update({f'e{e}': p for e, p in enumerate(val_list[i])}) + self.connection.execute(stmt, **params) def delete(self, tblname, id_list, id_colname='rowid', **kwargs): """Deletes rows from a SQL table (``tblname``) by ID, From 14d4582bea7a9928374b1bf1cfb6e48ea7ba1083 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sat, 12 Sep 2020 10:33:02 -0700 Subject: [PATCH 008/294] Raise a TypeError rather than Assertion This is checking that the method received the correct data type. Let's raise a TypeError when it is not. --- wbia/dtool/sql_control.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 4cd9fd3272..9d13dcbd37 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1153,8 +1153,8 @@ def get_where( eager=True, **kwargs, ): - """""" - assert isinstance(colnames, tuple), 'colnames must be a tuple' + if not isinstance(colnames, (tuple, list)): + raise TypeError('colnames must be a sequence type of strings') if where_clause is None: operation_fmt = """ @@ -1289,9 +1289,8 @@ def get( + ut.get_caller_name(list(range(1, 4))) + ' db.get(%r, %r, ...)' % (tblname, colnames) ) - assert isinstance(colnames, tuple), 'must specify column names TUPLE to get from' - # if isinstance(colnames, six.string_types): - # colnames = (colnames,) + if not isinstance(colnames, (tuple, list)): + raise TypeError('colnames must be a sequence type of strings') if ( assume_unique @@ -1369,9 +1368,9 @@ def set( >>> table.print_csv() >>> depc.clear_all() """ - assert isinstance(colnames, tuple) - # if isinstance(colnames, six.string_types): - # colnames = (colnames,) + if not isinstance(colnames, (tuple, list)): + raise TypeError('colnames must be a sequence type of strings') + val_list = list(val_iter) # eager evaluation id_list = list(id_iter) # eager evaluation From d9331e1185c477b4a51186ad9aec4b236062dc56 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sat, 12 Sep 2020 10:34:45 -0700 Subject: [PATCH 009/294] Adjust logging to debug This info is debug oriented rather than informationational. Or it'd probably be too much detailed information for the logs. --- wbia/dtool/sql_control.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 9d13dcbd37..358e39da3e 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1374,12 +1374,11 @@ def set( val_list = list(val_iter) # eager evaluation id_list = list(id_iter) # eager evaluation - if VERBOSE_SQL or (NOT_QUIET and VERYVERBOSE): - logger.info('[sql] SETTER: ' + ut.get_caller_name()) - logger.info('[sql] * tblname=%r' % (tblname,)) - logger.info('[sql] * val_list=%r' % (val_list,)) - logger.info('[sql] * id_list=%r' % (id_list,)) - logger.info('[sql] * id_colname=%r' % (id_colname,)) + logger.debug('[sql] SETTER: ' + ut.get_caller_name()) + logger.debug('[sql] * tblname=%r' % (tblname,)) + logger.debug('[sql] * val_list=%r' % (val_list,)) + logger.debug('[sql] * id_list=%r' % (id_list,)) + logger.debug('[sql] * id_colname=%r' % (id_colname,)) if duplicate_behavior == 'error': try: @@ -1406,7 +1405,7 @@ def set( for index in sorted(pop_list, reverse=True): del id_list[index] del val_list[index] - logger.info( + logger.debug( '[!set] Auto Resolution: Removed %d duplicate (id, value) pairs from the database operation' % (len(pop_list),) ) From f55403e6f415db95e72e83755a0b91aa2d6e968f Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sat, 12 Sep 2020 14:26:35 -0700 Subject: [PATCH 010/294] Testing sql controller's get_where method --- wbia/tests/dtool/test_sql_control.py | 90 ++++++++++++++++++++++++---- 1 file changed, 80 insertions(+), 10 deletions(-) diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index 1c2916919e..17eb5971a7 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -277,6 +277,75 @@ def make_table(self, name): '(id INTEGER PRIMARY KEY, x TEXT, y INTEGER, z REAL)' ) + def test_get_where_without_where_condition(self): + table_name = 'test_get_where' + self.make_table(table_name) + + # Create some dummy records + insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') + for i in range(0, 10): + x, y, z = ( + (i % 2) and 'odd' or 'even', + i, + i * 2.01, + ) + self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + + # Call the testing target + results = self.ctrlr.get_where(table_name, ['id', 'y'], tuple(), None,) + + # Verify query + assert results == [(i + 1, i) for i in range(0, 10)] + + def test_scalar_get_where(self): + table_name = 'test_get_where' + self.make_table(table_name) + + # Create some dummy records + insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') + for i in range(0, 10): + x, y, z = ( + (i % 2) and 'odd' or 'even', + i, + i * 2.01, + ) + self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + + # Call the testing target + results = self.ctrlr.get_where(table_name, ['id', 'y'], ([1]), 'id = ?',) + evens = results[0] + + # Verify query + assert evens == (1, 0) + + def test_multi_row_get_where(self): + table_name = 'test_get_where' + self.make_table(table_name) + + # Create some dummy records + insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') + for i in range(0, 10): + x, y, z = ( + (i % 2) and 'odd' or 'even', + i, + i * 2.01, + ) + self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + + # Call the testing target + results = self.ctrlr.get_where( + table_name, + ['id', 'y'], + (['even'], ['odd']), + 'x = ?', + unpack_scalars=False, # this makes it more than one row of results + ) + evens = results[0] + odds = results[1] + + # Verify query + assert evens == [(i + 1, i) for i in range(0, 10) if not i % 2] + assert odds == [(i + 1, i) for i in range(0, 10) if i % 2] def test_setting(self): # Note, this is not a comprehensive test. It only attempts to test the SQL logic. @@ -290,24 +359,21 @@ def test_setting(self): x, y, z = (str(i), i, i * 2.01) self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) - results = self.ctrlr.connection.execute(f'SELECT id, CAST((y%2) AS BOOL) FROM {table_name}') + results = self.ctrlr.connection.execute( + f'SELECT id, CAST((y%2) AS BOOL) FROM {table_name}' + ) rows = results.fetchall() ids = [row[0] for row in rows if row[1]] # Call the testing target self.ctrlr.set( - table_name, - ['x', 'z'], - [('even', 0.0)] * len(ids), - ids, - id_colname='id' + table_name, ['x', 'z'], [('even', 0.0)] * len(ids), ids, id_colname='id' ) # Verify setting sql_array = ', '.join([str(id) for id in ids]) results = self.ctrlr.connection.execute( - f'SELECT id, x, z FROM {table_name} ' - f"WHERE id in ({sql_array})" + f'SELECT id, x, z FROM {table_name} ' f'WHERE id in ({sql_array})' ) expected = sorted(map(lambda a: tuple([a] + ['even', 0.0]), ids)) set_rows = sorted(results) @@ -324,7 +390,9 @@ def test_delete(self): x, y, z = (str(i), i, i * 2.01) self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) - results = self.ctrlr.connection.execute(f'SELECT id, CAST((y % 2) AS BOOL) FROM {table_name}') + results = self.ctrlr.connection.execute( + f'SELECT id, CAST((y % 2) AS BOOL) FROM {table_name}' + ) rows = results.fetchall() del_ids = [row[0] for row in rows if row[1]] remaining_ids = sorted([row[0] for row in rows if not row[1]]) @@ -347,7 +415,9 @@ def test_delete_rowid(self): x, y, z = (str(i), i, i * 2.01) self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) - results = self.ctrlr.connection.execute(f'SELECT rowid, CAST((y % 2) AS BOOL) FROM {table_name}') + results = self.ctrlr.connection.execute( + f'SELECT rowid, CAST((y % 2) AS BOOL) FROM {table_name}' + ) rows = results.fetchall() del_ids = [row[0] for row in rows if row[1]] remaining_ids = sorted([row[0] for row in rows if not row[1]]) From 65851153e7578ae4a19a9ee4c8ecc2812bdfd77a Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sat, 12 Sep 2020 22:41:40 -0700 Subject: [PATCH 011/294] Write documentation of SQLExecutionContext Attempting to document this so that I understand it. --- wbia/dtool/sql_control.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 358e39da3e..133bbd52fc 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -88,10 +88,20 @@ def tuplize(list_): class SQLExecutionContext(object): """ - Context manager for transactional database calls - - FIXME: hash out details. I don't think anybody who programmed this - knows what is going on here. So much for fine grained control. + Context manager for + 1. handling database transactions + 2. returning id information after inserts + 3. unwrapping single value queries + + Args: + db (SQLDatabaseController): invoking instance of the database controller + operation (str): sql operation + nInput (int): Number of parameter inputs (only used for informational purposes) + auto_commit (bool): [deprecated] do not use (default: True) + start_transaction (bool): flag to start a transaction (default: False) + keepwrap (bool): flag to unwrap single value queries + verbose (bool): verbosity + tablename (str): [deprecated] name of the table the operation is running on Referencs: http://stackoverflow.com/questions/9573768/understand-sqlite-multi-module-envs From aa6417887179eac410595bc5321c8482c80cf523 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sun, 13 Sep 2020 10:19:29 -0700 Subject: [PATCH 012/294] Add tests for sql controller's executeone method --- wbia/tests/dtool/test_sql_control.py | 67 ++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index 17eb5971a7..3b043f55b0 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -277,6 +277,73 @@ def make_table(self, name): '(id INTEGER PRIMARY KEY, x TEXT, y INTEGER, z REAL)' ) + def test_executeone(self): + table_name = 'test_executeone' + self.make_table(table_name) + + # Create some dummy records + insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') + for i in range(0, 10): + x, y, z = ( + (i % 2) and 'odd' or 'even', + i, + i * 2.01, + ) + self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + + # Call the testing target + result = self.ctrlr.executeone(f'SELECT id, y FROM {table_name}') + + assert result == [(i + 1, i) for i in range(0, 10)] + + def test_executeone_on_insert(self): + # Should return id after an insert + table_name = 'test_executeone' + self.make_table(table_name) + + # Create some dummy records + insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') + for i in range(0, 10): + x, y, z = ( + (i % 2) and 'odd' or 'even', + i, + i * 2.01, + ) + self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + + # Call the testing target + result = self.ctrlr.executeone(f'INSERT INTO {table_name} (y) VALUES (?)', (10,)) + + # Cursory check that the result is a single int value + assert result == [11] # the result list with one unwrapped value + + # Check for the actual value associated with the resulting id + inserted_value = self.ctrlr.connection.execute( + f'SELECT id, y FROM {table_name} WHERE rowid = :rowid', rowid=result[0], + ).fetchone() + assert inserted_value == (11, 10,) + + def test_executeone_for_single_column(self): + # Should unwrap the resulting query value (no tuple wrapping) + table_name = 'test_executeone' + self.make_table(table_name) + + # Create some dummy records + insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') + for i in range(0, 10): + x, y, z = ( + (i % 2) and 'odd' or 'even', + i, + i * 2.01, + ) + self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + + # Call the testing target + result = self.ctrlr.executeone(f'SELECT y FROM {table_name}') + + # Note the unwrapped values, rather than [(i,) ...] + assert result == [i for i in range(0, 10)] + def test_get_where_without_where_condition(self): table_name = 'test_get_where' self.make_table(table_name) From 94ce0982bb59df7e0e4527ef926053a0a67f47f4 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sun, 13 Sep 2020 10:43:05 -0700 Subject: [PATCH 013/294] Rewrite sql controller's executeone method I attempted to retain the existing behavior. I'm not all together certain I retained all of it. I'd guess that if anything is using the `keepwrap` parameter flag, this might break that usage. However, I don't think that's a bad thing. The strange unwrapping behavior needs to be eliminated anyway. I say strange, because intuitive usage would suggest a constistent return structure no matter what you're doing. --- wbia/dtool/sql_control.py | 35 ++++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 133bbd52fc..01a5c4b4e2 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1514,17 +1514,30 @@ def _executemany_operation_fmt( # SQLDB CORE # ========= - def executeone(db, operation, params=(), eager=True, verbose=VERBOSE_SQL): - contextkw = dict(nInput=1, verbose=verbose) - with SQLExecutionContext(db, operation, **contextkw) as context: - try: - result_iter = context.execute_and_generate_results(params) - result_list = list(result_iter) - except Exception as ex: - ut.printex(ex, key_list=[(str, 'operation'), 'params']) - # ut.sys.exit(1) - raise - return result_list + def executeone(self, operation, params=(), eager=True, verbose=VERBOSE_SQL): + """Executes the given ``operation`` once with the given set of ``params``""" + # FIXME (12-Sept-12020) Allows passing through '?' (question mark) parameters. + results = self.connection.execute(operation, params) + + # BBB (12-Sept-12020) Retaining insertion rowid result + # FIXME postgresql (12-Sept-12020) This won't work in postgres. + # Maybe see if ResultProxy.inserted_primary_key will work + if 'insert' in operation.lower(): + # BBB (12-Sept-12020) Retaining behavior to unwrap single value rows. + return [results.lastrowid] + elif not results.returns_rows: + return None + else: + values = list( + [ + # BBB (12-Sept-12020) Retaining behavior to unwrap single value rows. + row[0] if len(row) == 1 else row + for row in results + ] + ) + if not values: # empty list + values = None + return values @profile def executemany( From 1353dcb52658cfb18127e72e9475779f71cc8e17 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sun, 13 Sep 2020 17:52:41 -0700 Subject: [PATCH 014/294] Add tests for sql controller's executemany method This concentrates on the unpacking and transaction behaviors of the method. It's not obvious, but sqlalchemy transactions are nested. So the call to executeone is indeed using the transaction. --- wbia/tests/dtool/test_sql_control.py | 47 ++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index 3b043f55b0..6f16e1b429 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -2,6 +2,7 @@ import uuid import pytest +import sqlalchemy.exc from sqlalchemy.engine import Connection from sqlalchemy.sql import text @@ -323,6 +324,52 @@ def test_executeone_on_insert(self): ).fetchone() assert inserted_value == (11, 10,) + def test_executemany(self): + table_name = 'test_executemany' + self.make_table(table_name) + + # Create some dummy records + insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') + for i in range(0, 10): + x, y, z = ( + (i % 2) and 'odd' or 'even', + i, + i * 2.01, + ) + self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + + # Call the testing target + results = self.ctrlr.executemany( + f'SELECT id, y FROM {table_name} where x = ?', + (['even'], ['odd']), + unpack_scalars=False, + ) + + # Check for results + evens = [(i + 1, i) for i in range(0, 10) if not i % 2] + odds = [(i + 1, i) for i in range(0, 10) if i % 2] + assert results == [evens, odds] + + def test_executemany_transaction(self): + table_name = 'test_executemany' + self.make_table(table_name) + + # Test a failure to execute in the transaction to test the transaction boundary. + insert = f'INSERT INTO {table_name} (x, y, z) VALUES (?, ?, ?)' + params = [ + ('even', 0, 0.0), + ('odd', 1, 1.01), + ('oops', 2.02), # error + ('odd', 3, 3.03), + ] + with pytest.raises(sqlalchemy.exc.ProgrammingError): + # Call the testing target + results = self.ctrlr.executemany(insert, params) + + # Check for results + results = self.ctrlr.connection.execute(f'select count(*) from {table_name}') + assert results.fetchone()[0] == 0 + def test_executeone_for_single_column(self): # Should unwrap the resulting query value (no tuple wrapping) table_name = 'test_executeone' From 7711fe99241f4d4f633bc8f07efed2074ebc011c Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sun, 13 Sep 2020 18:06:18 -0700 Subject: [PATCH 015/294] Rewrite the sql controller's executemany method This has been rewritten to simplify things. This gets us one step forward by making the logic sqlalchemy compatible. We're sacrificing the logging bits for simplicity. I'm looking to deprecate this method in a later commit. This will be kept around to provide backwards compatiblity with the current users of it. The new method will be a generator, which will give us the ability to bring back the logging facilities if necessary. --- wbia/dtool/sql_control.py | 104 ++++++++------------------------------ 1 file changed, 21 insertions(+), 83 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 01a5c4b4e2..a668f22f03 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -13,7 +13,6 @@ import threading import uuid from collections.abc import Mapping, MutableMapping -from functools import partial from os.path import join, exists import six @@ -1539,90 +1538,29 @@ def executeone(self, operation, params=(), eager=True, verbose=VERBOSE_SQL): values = None return values - @profile - def executemany( - self, - operation, - params_iter, - verbose=VERBOSE_SQL, - unpack_scalars=True, - nInput=None, - eager=True, - keepwrap=False, - showprog=False, - ): - """ - if unpack_scalars is True only a single result must be returned for each query. - """ - # --- ARGS PREPROC --- - # Aggresively compute iterator if the nInput is not given - if nInput is None: - if isinstance(params_iter, (list, tuple)): - nInput = len(params_iter) - else: - if VERBOSE_SQL: - logger.info( - '[sql!] WARNING: aggressive eval of params_iter because nInput=None' - ) - params_iter = list(params_iter) - nInput = len(params_iter) - else: - if VERBOSE_SQL: - logger.info('[sql] Taking params_iter as iterator') + def executemany(self, operation, params_iter, unpack_scalars=True, **kwargs): + """Executes the given ``operation`` once for each item in ``params_iter`` - # Do not compute executemany without params - if nInput == 0: - if VERBOSE_SQL: - logger.info( - '[sql!] WARNING: dont use executemany' - 'with no params use executeone instead.' - ) - return [] - # --- SQL EXECUTION --- - contextkw = { - 'nInput': nInput, - 'start_transaction': True, - 'verbose': verbose, - 'keepwrap': keepwrap, - } - with SQLExecutionContext(self, operation, **contextkw) as context: - if eager: - if showprog: - if isinstance(showprog, six.string_types): - lbl = showprog - else: - lbl = 'sqlread' - prog = ut.ProgPartial( - adjust=True, length=nInput, freq=1, lbl=lbl, bs=True - ) - params_iter = prog(params_iter) - results_iter = [ - list(context.execute_and_generate_results(params)) - for params in params_iter - ] - if unpack_scalars: - # list of iterators - _unpacker_ = partial(_unpacker) - results_iter = list(map(_unpacker_, results_iter)) - # Eager evaluation - results_list = list(results_iter) - else: + Args: + operation (str): SQL operation + params_iter (sequence): a sequence of sequences + containing parameters in the sql operation + unpack_scalars (bool): [deprecated] use to unpack a single result from each query + only use with operations that return a single result for each query + (default: True) - def _tmpgen(context): - # Temporary hack to turn off eager_evaluation - for params in params_iter: - # Eval results per query yeild per iter - results = list(context.execute_and_generate_results(params)) - if unpack_scalars: - yield _unpacker(results) - else: - yield results - - results_list = _tmpgen(context) - return results_list - - # def commit(db): - # db.connection.commit() + """ + results = [] + with self.connection.begin(): + for params in params_iter: + value = self.executeone(operation, params) + # Should only be used when the user wants back on value. + # Let the error bubble up if used wrong. + # Deprecated... Do not depend on the unpacking behavior. + if unpack_scalars: + value = _unpacker(value) + results.append(value) + return results def print_dbg_schema(self): logger.info( From 577768bb6d8c0ea1354e976f10137f6a9e40096a Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Mon, 14 Sep 2020 09:13:50 -0700 Subject: [PATCH 016/294] Remove SQLExecutionContext It's no longer used. The code that did use it is now using sqlalchemy for transactions. --- wbia/dtool/sql_control.py | 187 -------------------------------------- 1 file changed, 187 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index a668f22f03..e236177df1 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -80,193 +80,6 @@ def tuplize(list_): return tup_list -# ======================= -# SQL Context Class -# ======================= - - -class SQLExecutionContext(object): - """ - Context manager for - 1. handling database transactions - 2. returning id information after inserts - 3. unwrapping single value queries - - Args: - db (SQLDatabaseController): invoking instance of the database controller - operation (str): sql operation - nInput (int): Number of parameter inputs (only used for informational purposes) - auto_commit (bool): [deprecated] do not use (default: True) - start_transaction (bool): flag to start a transaction (default: False) - keepwrap (bool): flag to unwrap single value queries - verbose (bool): verbosity - tablename (str): [deprecated] name of the table the operation is running on - - Referencs: - http://stackoverflow.com/questions/9573768/understand-sqlite-multi-module-envs - - """ - - def __init__( - context, - db, - operation, - nInput=None, - auto_commit=True, - start_transaction=False, - keepwrap=False, - verbose=VERBOSE_SQL, - tablename=None, - ): - context.tablename = None - context.auto_commit = auto_commit - context.db = db - context.operation = operation - context.nInput = nInput - context.start_transaction = start_transaction - context.operation_type = get_operation_type(operation) - context.verbose = verbose - context.is_insert = context.operation_type.startswith('INSERT') - context.keepwrap = keepwrap - context.connection = None - context.transaction = None - - def __enter__(context): - """ Checks to see if the operating will change the database """ - # ut.printif(lambda: '[sql] Callers: ' + ut.get_caller_name(range(3, 6)), DEBUG) - if context.nInput is not None: - context.operation_lbl = '[sql] execute nInput=%d optype=%s: ' % ( - context.nInput, - context.operation_type, - ) - else: - context.operation_lbl = '[sql] executeone optype=%s: ' % ( - context.operation_type - ) - # Start SQL Transaction - - context.connection = context.db.connection - if context.start_transaction: - try: - context.transaction = context.connection.begin() - except sqlalchemy.exc.OperationalError: - context.connection.rollback() - context.transaction = context.connection.begin() - if context.verbose or VERBOSE_SQL: - logger.info(context.operation_lbl) - if context.verbose: - logger.info('[sql] operation=\n' + context.operation) - # Comment out timeing code - # if __debug__: - # if NOT_QUIET and (VERBOSE_SQL or context.verbose): - # context.tt = ut.tic(context.operation_lbl) - return context - - # @profile - def execute_and_generate_results(context, params): - """ helper for context statment """ - try: - result = context.connection.execute(context.operation, params) - except sqlalchemy.exc.OperationalError as ex: - logger.info('Reporting SQLite Error') - logger.info('params = ' + ut.repr2(params, truncate=not ut.VERBOSE)) - ut.printex(ex, 'sql.Error', keys=['params']) - if ( - hasattr(ex, 'message') - and ex.message.find('probably unsupported type') > -1 - ): - logger.info( - 'ERR REPORT: given param types = ' + ut.repr2(ut.lmap(type, params)) - ) - if context.tablename is None: - if context.operation_type.startswith('SELECT'): - tablename = ut.str_between( - context.operation, 'FROM', 'WHERE' - ).strip() - else: - tablename = context.operation_type.split(' ')[-1] - else: - tablename = context.tablename - try: - coldef_list = context.db.get_coldef_list(tablename) - logger.info( - 'ERR REPORT: expected types = %s' % (ut.repr4(coldef_list),) - ) - except Exception: - pass - raise - return context._results_gen(result) - - # @profile - def _results_gen(context, result): - """HELPER - Returns as many results as there are. - Careful. Overwrites the results once you call it. - Basically: Dont call this twice. - """ - if context.is_insert: - # The sqlite3_last_insert_rowid(D) interface returns the - # rowid of the most recent successful INSERT - # into a rowid table in D - result = context.connection.execute('SELECT last_insert_rowid()', ()) - # Wraping fetchone in a generator for some pretty tight calls. - while True: - if not result.returns_rows: - # Doesn't have any results, e.g. CREATE TABLE statements - return - row = result.fetchone() - if row is None: - return - if context.keepwrap: - # Results are always returned wraped in a tuple - yield row - else: - # Here unpacking is conditional - # FIXME: can this if be removed? - yield row[0] if len(row) == 1 else row - - def __exit__(context, type_, value, trace): - """ Finalization of an SQLController call """ - if trace is not None: - # An SQLError is a serious offence. - logger.info('[sql] FATAL ERROR IN QUERY CONTEXT') - logger.info('[sql] operation=\n' + context.operation) - logger.info('[sql] Error in context manager!: ' + str(value)) - # return a falsey value on error - return False - else: - # Commit the transaction - if context.auto_commit and context.transaction: - context.transaction.commit() - else: - logger.info('no commit %r' % context.operation_lbl) - - -def get_operation_type(operation): - """ - Parses the operation_type from an SQL operation - """ - operation = ' '.join(operation.split('\n')).strip() - operation_type = operation.split(' ')[0].strip() - if operation_type.startswith('SELECT'): - operation_args = ut.str_between(operation, operation_type, 'FROM').strip() - elif operation_type.startswith('INSERT'): - operation_args = ut.str_between(operation, operation_type, '(').strip() - elif operation_type.startswith('DROP'): - operation_args = '' - elif operation_type.startswith('ALTER'): - operation_args = '' - elif operation_type.startswith('UPDATE'): - operation_args = ut.str_between(operation, operation_type, 'FROM').strip() - elif operation_type.startswith('DELETE'): - operation_args = ut.str_between(operation, operation_type, 'FROM').strip() - elif operation_type.startswith('CREATE'): - operation_args = ut.str_between(operation, operation_type, '(').strip() - else: - operation_args = None - operation_type += ' ' + operation_args.replace('\n', ' ') - return operation_type.upper() - - def sanitize_sql(db, tablename_, columns=None): """ Sanatizes an sql tablename and column. Use sparingly """ tablename = re.sub('[^a-zA-Z_0-9]', '', tablename_) From 63c5ad0a09bb66da14a038d01284f1f85310e0b9 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Mon, 14 Sep 2020 10:14:08 -0700 Subject: [PATCH 017/294] Fix failing set method example This was failing due to the `val_iter` values not being wrapped. That might have worked before, but IMHO that's bad practice. Let's error if the structure doesn't match that of the column names. In other words, don't make a special snowflake for single value setting; use the method consistently everywhere. Most occurences of this method in the codebase that do a single column set are wrapping the values already. For example, in `wbia.control.manual_part_funcs:set_part_metadata`: ``` val_list = ((metadata_str,) for metadata_str in metadata_str_list) ibs.db.set(const.PART_TABLE, ('part_metadata_json',), val_list, id_iter) ``` --- wbia/dtool/sql_control.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index e236177df1..360cfb9422 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1181,11 +1181,11 @@ def set( >>> table.print_csv() >>> # Break things to test set >>> colnames = ('dummy_annot_rowid',) - >>> val_iter = [9003, 9001, 9002] + >>> val_iter = [(9003,), (9001,), (9002,)] >>> orig_data = db.get('notch', colnames, id_iter=rowids) >>> db.set('notch', colnames, val_iter, id_iter=rowids) >>> new_data = db.get('notch', colnames, id_iter=rowids) - >>> assert new_data == val_iter + >>> assert new_data == [x[0] for x in val_iter] >>> assert new_data != orig_data >>> table.print_csv() >>> depc.clear_all() From 37fda94e4d0687347ca79cad792513ecd521864b Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Mon, 14 Sep 2020 21:33:10 -0700 Subject: [PATCH 018/294] Require SQLAlchemy text() for all execute* methods This is a breaking change. All existing queries will fail due to this change. We must use `sqlalchemy.sql:text` in order to fully achieve parameter agnostic syntax. This change enforces the requirement. --- wbia/dtool/sql_control.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 360cfb9422..22191a649b 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -19,7 +19,7 @@ import sqlalchemy import utool as ut from deprecated import deprecated -from sqlalchemy.sql import text +from sqlalchemy.sql import text, ClauseElement from wbia.dtool import lite from wbia.dtool.dump import dumps @@ -1328,13 +1328,19 @@ def _executemany_operation_fmt( def executeone(self, operation, params=(), eager=True, verbose=VERBOSE_SQL): """Executes the given ``operation`` once with the given set of ``params``""" + if not isinstance(operation, ClauseElement): + raise TypeError( + "'operation' needs to be a sqlalchemy textual sql instance " + "see docs on 'sqlalchemy.sql:text' factory function; " + f"'operation' is a '{type(operation)}'" + ) # FIXME (12-Sept-12020) Allows passing through '?' (question mark) parameters. results = self.connection.execute(operation, params) # BBB (12-Sept-12020) Retaining insertion rowid result # FIXME postgresql (12-Sept-12020) This won't work in postgres. # Maybe see if ResultProxy.inserted_primary_key will work - if 'insert' in operation.lower(): + if 'insert' in operation.text.lower(): # BBB (12-Sept-12020) Retaining behavior to unwrap single value rows. return [results.lastrowid] elif not results.returns_rows: @@ -1363,6 +1369,13 @@ def executemany(self, operation, params_iter, unpack_scalars=True, **kwargs): (default: True) """ + if not isinstance(operation, ClauseElement): + raise TypeError( + "'operation' needs to be a sqlalchemy textual sql instance " + "see docs on 'sqlalchemy.sql:text' factory function; " + f"'operation' is a '{type(operation)}'" + ) + results = [] with self.connection.begin(): for params in params_iter: From 75fb8a0b1f372c69af73983b920e0d42b288e2c1 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Mon, 14 Sep 2020 21:48:17 -0700 Subject: [PATCH 019/294] Test and revise sql controller _add method It retains all the previous behavior. The only change is in how the logic is executed. --- wbia/dtool/sql_control.py | 24 ++++++++---------------- wbia/tests/dtool/test_sql_control.py | 24 ++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 22191a649b..dca49c089c 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -740,23 +740,15 @@ def check_rowid_exists(self, tablename, rowid_iter, eager=True, **kwargs): exists_list = [rowid is not None for rowid in rowid_list1] return exists_list - def _add(self, tblname, colnames, params_iter, **kwargs): + def _add(self, tblname, colnames, params_iter, unpack_scalars=True, **kwargs): """ ADDER NOTE: use add_cleanly """ - fmtdict = { - 'tblname': tblname, - 'erotemes': ', '.join(['?'] * len(colnames)), - 'params': ',\n'.join(colnames), - } - operation_fmt = """ - INSERT INTO {tblname}( - rowid, - {params} - ) VALUES (NULL, {erotemes}) - """ - rowid_list = self._executemany_operation_fmt( - operation_fmt, fmtdict, params_iter=params_iter, **kwargs - ) - return rowid_list + columns = ', '.join(colnames) + column_params = ', '.join(f':{col}' for col in colnames) + parameterized_values = [ + {col: val for col, val in zip(colnames, params)} for params in params_iter + ] + stmt = text(f'INSERT INTO {tblname} ({columns}) VALUES ({column_params})') + return self.executemany(stmt, parameterized_values, unpack_scalars=unpack_scalars) def add_cleanly( self, diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index 6f16e1b429..726afbdc58 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -391,6 +391,30 @@ def test_executeone_for_single_column(self): # Note the unwrapped values, rather than [(i,) ...] assert result == [i for i in range(0, 10)] + def test_add(self): + table_name = 'test_add' + self.make_table(table_name) + + insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') + parameter_values = [] + for i in range(0, 10): + x, y, z = ( + (i % 2) and 'odd' or 'even', + i, + i * 2.01, + ) + parameter_values.append((x, y, z)) + + # Call the testing target + ids = self.ctrlr._add(table_name, ['x', 'y', 'z'], parameter_values) + + # Verify the resulting ids + assert ids == [i + 1 for i in range(0, len(parameter_values))] + # Verify addition of records + results = self.ctrlr.connection.execute(f'SELECT id, x, y, z FROM {table_name}') + expected = [(i + 1, x, y, z) for i, (x, y, z) in enumerate(parameter_values)] + assert results.fetchall() == expected + def test_get_where_without_where_condition(self): table_name = 'test_get_where' self.make_table(table_name) From b5adfd7c1de3147baf74f86e02f9e3766a222c99 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Wed, 16 Sep 2020 08:21:25 -0700 Subject: [PATCH 020/294] Fix tests to use SA text() for textual sql All these unittests pass with this in place. --- wbia/tests/dtool/test_sql_control.py | 33 ++++++++++++++-------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index 726afbdc58..b96a8c7925 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -293,7 +293,7 @@ def test_executeone(self): self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) # Call the testing target - result = self.ctrlr.executeone(f'SELECT id, y FROM {table_name}') + result = self.ctrlr.executeone(text(f'SELECT id, y FROM {table_name}')) assert result == [(i + 1, i) for i in range(0, 10)] @@ -313,14 +313,16 @@ def test_executeone_on_insert(self): self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) # Call the testing target - result = self.ctrlr.executeone(f'INSERT INTO {table_name} (y) VALUES (?)', (10,)) + result = self.ctrlr.executeone( + text(f'INSERT INTO {table_name} (y) VALUES (:y)'), {'y': 10} + ) # Cursory check that the result is a single int value assert result == [11] # the result list with one unwrapped value # Check for the actual value associated with the resulting id inserted_value = self.ctrlr.connection.execute( - f'SELECT id, y FROM {table_name} WHERE rowid = :rowid', rowid=result[0], + text(f'SELECT id, y FROM {table_name} WHERE rowid = :rowid'), rowid=result[0], ).fetchone() assert inserted_value == (11, 10,) @@ -340,8 +342,8 @@ def test_executemany(self): # Call the testing target results = self.ctrlr.executemany( - f'SELECT id, y FROM {table_name} where x = ?', - (['even'], ['odd']), + text(f'SELECT id, y FROM {table_name} where x = :x'), + ({'x': 'even'}, {'x': 'odd'}), unpack_scalars=False, ) @@ -355,14 +357,14 @@ def test_executemany_transaction(self): self.make_table(table_name) # Test a failure to execute in the transaction to test the transaction boundary. - insert = f'INSERT INTO {table_name} (x, y, z) VALUES (?, ?, ?)' + insert = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, y:, :z)') params = [ - ('even', 0, 0.0), - ('odd', 1, 1.01), - ('oops', 2.02), # error - ('odd', 3, 3.03), + dict(x='even', y=0, z=0.0), + dict(x='odd', y=1, z=1.01), + dict(x='oops', z=2.02), # error + dict(x='odd', y=3, z=3.03), ] - with pytest.raises(sqlalchemy.exc.ProgrammingError): + with pytest.raises(sqlalchemy.exc.OperationalError): # Call the testing target results = self.ctrlr.executemany(insert, params) @@ -386,7 +388,7 @@ def test_executeone_for_single_column(self): self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) # Call the testing target - result = self.ctrlr.executeone(f'SELECT y FROM {table_name}') + result = self.ctrlr.executeone(text(f'SELECT y FROM {table_name}')) # Note the unwrapped values, rather than [(i,) ...] assert result == [i for i in range(0, 10)] @@ -395,7 +397,6 @@ def test_add(self): table_name = 'test_add' self.make_table(table_name) - insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') parameter_values = [] for i in range(0, 10): x, y, z = ( @@ -450,7 +451,7 @@ def test_scalar_get_where(self): self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) # Call the testing target - results = self.ctrlr.get_where(table_name, ['id', 'y'], ([1]), 'id = ?',) + results = self.ctrlr.get_where(table_name, ['id', 'y'], ({'id': 1},), 'id = :id',) evens = results[0] # Verify query @@ -474,8 +475,8 @@ def test_multi_row_get_where(self): results = self.ctrlr.get_where( table_name, ['id', 'y'], - (['even'], ['odd']), - 'x = ?', + ({'x': 'even'}, {'x': 'odd'}), + 'x = :x', unpack_scalars=False, # this makes it more than one row of results ) evens = results[0] From 626fb46184ecccd822c68ad929e94e201de740ee Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Wed, 16 Sep 2020 21:33:05 -0700 Subject: [PATCH 021/294] Fix sql controller tests that use executeone Changing over these tests to use the raw sqlalchemy api, so that the tests are not dependent on the controller implementation. --- wbia/tests/dtool/test_sql_control.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index b96a8c7925..5142f6a66f 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -55,14 +55,14 @@ def fixture(self, ctrlr, monkeypatch): self.ctrlr._ensure_metadata_table() # Create metadata in the table + insert_stmt = text( + 'INSERT INTO metadata (metadata_key, metadata_value) VALUES (:key, :value)' + ) for key, value in self.data.items(): unprefixed_name = key.split('_')[-1] if METADATA_TABLE_COLUMNS[unprefixed_name]['is_coded_data']: value = repr(value) - self.ctrlr.executeone( - 'INSERT INTO metadata (metadata_key, metadata_value) VALUES (?, ?)', - (key, value), - ) + self.ctrlr.connection.execute(insert_stmt, key=key, value=value) def monkey_get_table_names(self, *args, **kwargs): return ['foo', 'metadata'] @@ -120,9 +120,9 @@ def test_setting_to_none(self): assert new_value == value # Also check the table does not have the record - assert not self.ctrlr.executeone( + assert not self.ctrlr.connection.execute( f"SELECT * FROM metadata WHERE metadata_key = 'foo_{key}'" - ) + ).fetchone() def test_setting_unknown_key(self): # Check setting of an unknown metadata key @@ -141,9 +141,9 @@ def test_deleter(self): assert self.ctrlr.metadata.foo.docstr is None # Also check the table does not have the record - assert not self.ctrlr.executeone( + assert not self.ctrlr.connection.execute( f"SELECT * FROM metadata WHERE metadata_key = 'foo_{key}'" - ) + ).fetchone() def test_database_attributes(self): # Check the database version @@ -244,9 +244,9 @@ def test_delitem_for_table(self): assert self.ctrlr.metadata.foo.docstr is None # Also check the table does not have the record - assert not self.ctrlr.executeone( + assert not self.ctrlr.connection.execute( f"SELECT * FROM metadata WHERE metadata_key = 'foo_{key}'" - ) + ).fetchone() def test_delitem_for_database(self): # You cannot delete database version metadata From 523480afa7f331d7e6638d4da308361d86270148 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Wed, 16 Sep 2020 21:43:20 -0700 Subject: [PATCH 022/294] Update SQL execution to use ':name' parameters This brings the existing unittests to a passing state. Essentially we want all the existing queries to be dialect abstract. --- wbia/dtool/sql_control.py | 65 +++++++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 27 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index dca49c089c..26c4d197da 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -192,29 +192,37 @@ def __init__(self, ctrlr): @property def version(self): - stmt = f'SELECT metadata_value FROM {METADATA_TABLE_NAME} WHERE metadata_key = ?' + stmt = text( + f'SELECT metadata_value FROM {METADATA_TABLE_NAME} WHERE metadata_key = :key' + ) try: - return self.ctrlr.executeone(stmt, ('database_version',))[0] - except IndexError: # No result + return self.ctrlr.executeone(stmt, {'key': 'database_version'})[0] + except TypeError: # NoneType return None @version.setter def version(self, value): if not value: raise ValueError(value) - self.ctrlr.executeone( - f'INSERT OR REPLACE INTO {METADATA_TABLE_NAME} (metadata_key, metadata_value) VALUES (?, ?)', - ('database_version', value), + stmt = text( + f'INSERT OR REPLACE INTO {METADATA_TABLE_NAME} (metadata_key, metadata_value)' + 'VALUES (:key, :value)' ) + params = {'key': 'database_version', 'value': value} + self.ctrlr.executeone(stmt, params) @property def init_uuid(self): - stmt = f'SELECT metadata_value FROM {METADATA_TABLE_NAME} WHERE metadata_key = ?' + stmt = text( + f'SELECT metadata_value FROM {METADATA_TABLE_NAME} WHERE metadata_key = :key' + ) try: - value = self.ctrlr.executeone(stmt, ('database_init_uuid',))[0] - except IndexError: # No result + value = self.ctrlr.executeone(stmt, {'key': 'database_init_uuid'})[0] + except TypeError: # NoneType return None - return uuid.UUID(value) + if value is not None: + value = uuid.UUID(value) + return value @init_uuid.setter def init_uuid(self, value): @@ -222,10 +230,12 @@ def init_uuid(self, value): raise ValueError(value) elif isinstance(value, uuid.UUID): value = str(value) - self.ctrlr.executeone( - f'INSERT OR REPLACE INTO {METADATA_TABLE_NAME} (metadata_key, metadata_value) VALUES (?, ?)', - ('database_init_uuid', value), + stmt = text( + f'INSERT OR REPLACE INTO {METADATA_TABLE_NAME} (metadata_key, metadata_value) ' + 'VALUES (:key, :value)' ) + params = {'key': 'database_init_uuid', 'value': value} + self.ctrlr.executeone(stmt, params) # collections.abc.MutableMapping abstract methods @@ -272,15 +282,14 @@ def update(self, **kwargs): def __getattr__(self, name): # Query the database for the value represented as name key = '_'.join([self.table_name, name]) - statement = ( + statement = text( 'SELECT metadata_value ' f'FROM {METADATA_TABLE_NAME} ' - 'WHERE metadata_key = ?' + 'WHERE metadata_key = :key' ) try: - value = self.ctrlr.executeone(statement, (key,))[0] - except IndexError: - # No value for the requested metadata_key + value = self.ctrlr.executeone(statement, {'key': key})[0] + except TypeError: # NoneType return None if METADATA_TABLE_COLUMNS[name]['is_coded_data']: value = eval(value) @@ -307,14 +316,14 @@ def __setattr__(self, name, value): # Insert or update the record # FIXME postgresql (4-Aug-12020) 'insert or replace' is not valid for postgresql - statement = ( + statement = text( f'INSERT OR REPLACE INTO {METADATA_TABLE_NAME} ' - f'(metadata_key, metadata_value) VALUES (?, ?)' - ) - params = ( - key, - value, + f'(metadata_key, metadata_value) VALUES (:key, :value)' ) + params = { + 'key': key, + 'value': value, + } self.ctrlr.executeone(statement, params) def __delattr__(self, name): @@ -323,8 +332,10 @@ def __delattr__(self, name): raise AttributeError # Insert or update the record - statement = f'DELETE FROM {METADATA_TABLE_NAME} where metadata_key = ?' - params = (self._get_key_name(name),) + statement = text( + f'DELETE FROM {METADATA_TABLE_NAME} where metadata_key = :key' + ) + params = {'key': self._get_key_name(name)} self.ctrlr.executeone(statement, params) def __dir__(self): @@ -1553,7 +1564,7 @@ def _make_add_table_sqlstr( comma = ',' + sep table_body = comma.join(body_list + constraint_list) - return f'CREATE TABLE IF NOT EXISTS {tablename} ({sep}{table_body}{sep})' + return text(f'CREATE TABLE IF NOT EXISTS {tablename} ({sep}{table_body}{sep})') def add_table(self, tablename=None, coldef_list=None, **metadata_keyval): """ From c9f9140bfcf7b3c5cc338ff1f4bd6ae58105bee0 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Wed, 16 Sep 2020 21:50:43 -0700 Subject: [PATCH 023/294] Correct get_where unittests This does not change the behavior of the method, only the style of execution. --- wbia/dtool/sql_control.py | 46 +++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 26c4d197da..243067b891 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -967,7 +967,6 @@ def get_where_eq_set( } return self._executeone_operation_fmt(operation_fmt, fmtdict, **kwargs) - @profile def get_where( self, tblname, @@ -978,34 +977,33 @@ def get_where( eager=True, **kwargs, ): + """ + Interface to do a SQL select with a where clause + + Args: + tblname (str): table name + colnames (tuple[str]): sequence of column names + params_iter (list[dict]): a sequence of dicts with parameters, + where each item in the sequence is used in a SQL execution + where_clause (str): conditional statement used in the where clause + unpack_scalars (bool): [deprecated] use to unpack a single result from each query + only use with operations that return a single result for each query + (default: True) + + """ if not isinstance(colnames, (tuple, list)): raise TypeError('colnames must be a sequence type of strings') + # Build and execute the query + columns = ', '.join(colnames) + stmt = f'SELECT {columns} FROM {tblname}' if where_clause is None: - operation_fmt = """ - SELECT {colnames} - FROM {tblname} - """ - fmtdict = { - 'tblname': tblname, - 'colnames': ', '.join(colnames), - } - val_list = self._executeone_operation_fmt(operation_fmt, fmtdict, **kwargs) + val_list = self.executeone(text(stmt), **kwargs) else: - operation_fmt = """ - SELECT {colnames} - FROM {tblname} - WHERE {where_clauses} - """ - fmtdict = { - 'tblname': tblname, - 'colnames': ', '.join(colnames), - 'where_clauses': where_clause, - } - val_list = self._executemany_operation_fmt( - operation_fmt, - fmtdict, - params_iter=params_iter, + stmt += f' WHERE {where_clause}' + val_list = self.executemany( + text(stmt), + params_iter, unpack_scalars=unpack_scalars, eager=eager, **kwargs, From 495f8e92b39dd56aaf0b82543a7918d68223cd12 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Wed, 16 Sep 2020 23:57:37 -0700 Subject: [PATCH 024/294] Testing and rewriting get_where_eq with SQLAlchemy Corrects the following traceback: ``` Traceback (most recent call last): File "/usr/local/lib/python3.6/dist-packages/xdoctest/doctest_example.py", line 571, in run exec(code, test_globals) File "", line rel: 28, abs: 833, in >>> tblname, colnames, params_iter, get_rowid_from_superkey, superkey_paramx) File "/code/wbia/dtool/sql_control.py", line 853, in add_cleanly rowid_list_ = get_rowid_from_superkey(*superkey_lists) File "", line rel: 26, abs: 831, in get_rowid_from_superkey >>> return db.get_where_eq(tblname, ('rowid',), zip(superkeys_list), superkey_colnames) File "/code/wbia/dtool/sql_control.py", line 930, in get_where_eq **kwargs, File "/code/wbia/dtool/sql_control.py", line 1016, in get_where **kwargs, File "/code/wbia/dtool/sql_control.py", line 1390, in executemany value = self.executeone(operation, params) File "/code/wbia/dtool/sql_control.py", line 1346, in executeone results = self.connection.execute(operation, params) File "/usr/local/lib/python3.6/dist-packages/sqlalchemy/engine/base.py", line 1011, in execute return meth(self, multiparams, params) File "/usr/local/lib/python3.6/dist-packages/sqlalchemy/sql/elements.py", line 298, in _execute_on_connection return connection._execute_clauseelement(self, multiparams, params) File "/usr/local/lib/python3.6/dist-packages/sqlalchemy/engine/base.py", line 1090, in _execute_clauseelement keys = list(distilled_params[0].keys()) AttributeError: 'tuple' object has no attribute 'keys' ``` Problem resides within `get_where_eq`, because the parameters are old-style tuple. And the query is build with '?' parameters. Backing up to test this method outside the scope of `add_cleanly`. --- wbia/dtool/sql_control.py | 30 +++++++++++++++++----------- wbia/tests/dtool/test_sql_control.py | 27 +++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 12 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 243067b891..9eb943bc10 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -898,28 +898,34 @@ def get_where_eq( params_iter, where_colnames, unpack_scalars=True, - eager=True, op='AND', **kwargs, ): - """hacked in function for nicer templates + """Executes a SQL select where the given parameters match/equal + the specified where columns. - unpack_scalars = True - kwargs = {} + Args: + tblname (str): table name + colnames (tuple[str]): sequence of column names + params_iter (list[list]): a sequence of a sequence with parameters, + where each item in the sequence is used in a SQL execution + where_colnames (list[str]): column names to match for equality against the same index + of the param_iter values + op (str): SQL boolean operator (e.g. AND, OR) + unpack_scalars (bool): [deprecated] use to unpack a single result from each query + only use with operations that return a single result for each query + (default: True) - Kwargs: - verbose: """ - andwhere_clauses = [colname + '=?' for colname in where_colnames] - logicop_ = ' %s ' % (op,) - where_clause = logicop_.join(andwhere_clauses) + equal_conditions = [f'{c}=:{c}' for c in where_colnames] + where_conditions = f' {op} '.upper().join(equal_conditions) + params = [dict(zip(where_colnames, p)) for p in params_iter] return self.get_where( tblname, colnames, - params_iter, - where_clause, + params, + where_conditions, unpack_scalars=unpack_scalars, - eager=eager, **kwargs, ) diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index 5142f6a66f..25c31358f1 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -486,6 +486,33 @@ def test_multi_row_get_where(self): assert evens == [(i + 1, i) for i in range(0, 10) if not i % 2] assert odds == [(i + 1, i) for i in range(0, 10) if i % 2] + def test_get_where_eq(self): + table_name = 'test_get_where_eq' + self.make_table(table_name) + + # Create some dummy records + insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') + for i in range(0, 10): + x, y, z = ( + (i % 2) and 'odd' or 'even', + i, + i * 2.01, + ) + self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + + # Call the testing target + results = self.ctrlr.get_where_eq( + table_name, + ['id', 'y'], + (['even', 8], ['odd', 7]), # params_iter + ('x', 'y'), # where_colnames + op='AND', + unpack_scalars=True, + ) + + # Verify query + assert results == [(9, 8), (8, 7)] + def test_setting(self): # Note, this is not a comprehensive test. It only attempts to test the SQL logic. # Make a table for records From 496dde90fbf66db54137625724ac75f47b59941d Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Thu, 17 Sep 2020 00:11:51 -0700 Subject: [PATCH 025/294] Fix scalar unpacking to check for None Let's use python's builtin ability to translate None or an empty list to False. This corrects the following error: ``` Traceback (most recent call last): File "/usr/local/lib/python3.6/dist-packages/xdoctest/doctest_example.py", line 571, in run exec(code, test_globals) File "", line rel: 28, abs: 833, in >>> tblname, colnames, params_iter, get_rowid_from_superkey, superkey_paramx) File "/code/wbia/dtool/sql_control.py", line 853, in add_cleanly rowid_list_ = get_rowid_from_superkey(*superkey_lists) File "", line rel: 26, abs: 831, in get_rowid_from_superkey >>> return db.get_where_eq(tblname, ('rowid',), zip(superkeys_list), superkey_colnames) File "/code/wbia/dtool/sql_control.py", line 936, in get_where_eq **kwargs, File "/code/wbia/dtool/sql_control.py", line 1022, in get_where **kwargs, File "/code/wbia/dtool/sql_control.py", line 1401, in executemany value = _unpacker(value) File "/code/wbia/dtool/sql_control.py", line 70, in _unpacker if len(results_) == 0: TypeError: object of type 'NoneType' has no len() ``` --- wbia/dtool/sql_control.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 9eb943bc10..779eed3ddf 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -64,13 +64,13 @@ METADATA_TABLE_COLUMN_NAMES = list(METADATA_TABLE_COLUMNS.keys()) -def _unpacker(results_): +def _unpacker(results): """ HELPER: Unpacks results if unpack_scalars is True. """ - if len(results_) == 0: + if not results: # Check for None or empty list results = None else: - assert len(results_) <= 1, 'throwing away results! { %r }' % (results_,) - results = results_[0] + assert len(results) <= 1, 'throwing away results! { %r }' % (results,) + results = results[0] return results From 62bc737b13ff77b80d2adc7487bec34a8e4ce943 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Thu, 24 Sep 2020 15:19:09 -0700 Subject: [PATCH 026/294] Updating code-style with new version of black I upgraded to black 20.8b1, which causes some of the prior changesets to need updated. This upgrade came mid-way through this branch, otherwise it wouldn't be a problem. --- wbia/tests/dtool/test_sql_control.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index 25c31358f1..4b5ab1d2f7 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -322,9 +322,13 @@ def test_executeone_on_insert(self): # Check for the actual value associated with the resulting id inserted_value = self.ctrlr.connection.execute( - text(f'SELECT id, y FROM {table_name} WHERE rowid = :rowid'), rowid=result[0], + text(f'SELECT id, y FROM {table_name} WHERE rowid = :rowid'), + rowid=result[0], ).fetchone() - assert inserted_value == (11, 10,) + assert inserted_value == ( + 11, + 10, + ) def test_executemany(self): table_name = 'test_executemany' @@ -431,7 +435,12 @@ def test_get_where_without_where_condition(self): self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) # Call the testing target - results = self.ctrlr.get_where(table_name, ['id', 'y'], tuple(), None,) + results = self.ctrlr.get_where( + table_name, + ['id', 'y'], + tuple(), + None, + ) # Verify query assert results == [(i + 1, i) for i in range(0, 10)] @@ -451,7 +460,12 @@ def test_scalar_get_where(self): self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) # Call the testing target - results = self.ctrlr.get_where(table_name, ['id', 'y'], ({'id': 1},), 'id = :id',) + results = self.ctrlr.get_where( + table_name, + ['id', 'y'], + ({'id': 1},), + 'id = :id', + ) evens = results[0] # Verify query From 7463bdf3407327276bcfe20a1d6a627a8fa428a4 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Thu, 24 Sep 2020 15:23:59 -0700 Subject: [PATCH 027/294] Add test for making the SQL create table statement This is a simple test to check for correct output at the statement level. --- wbia/tests/dtool/test_sql_control.py | 48 ++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index 4b5ab1d2f7..2738c135d2 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -28,6 +28,54 @@ def test_instantiation(ctrlr): assert not ctrlr.connection.closed +class TestSchemaModifiers: + """Testing the API that creates, modifies or deletes schema elements""" + + @pytest.fixture(autouse=True) + def fixture(self, ctrlr): + self.ctrlr = ctrlr + + def make_table_definition(self, name): + """Creates a table definition for use with the controller's add_table method""" + definition = { + 'tablename': name, + 'coldef_list': [ + (f'{name}_id', 'INTEGER PRIMARY KEY'), + ('meta_labeler_id', 'INTEGER NOT NULL'), + ('indexer_id', 'INTEGER NOT NULL'), + ('config_id', 'INTEGER DEFAULT 0'), + ('data', 'TEXT'), + ], + 'docstr': f'docstr for {name}', + 'superkeys': [ + ('meta_labeler_id', 'indexer_id', 'config_id'), + ], + 'dependson': [ + 'meta_labelers', + 'indexers', + ], + } + return definition + + def test_make_add_table_sqlstr(self): + table_definition = self.make_table_definition('foobars') + + # Call the target + sql = self.ctrlr._make_add_table_sqlstr(**table_definition) + + expected = ( + 'CREATE TABLE IF NOT EXISTS foobars ( ' + 'foobars_id INTEGER PRIMARY KEY, ' + 'meta_labeler_id INTEGER NOT NULL, ' + 'indexer_id INTEGER NOT NULL, ' + 'config_id INTEGER DEFAULT 0, ' + 'data TEXT, ' + 'CONSTRAINT superkey ' + 'UNIQUE (meta_labeler_id,indexer_id,config_id) )' + ) + assert sql.text == expected + + def test_safely_get_db_version(ctrlr): v = ctrlr.get_db_version(ensure=True) assert v == '0.0.0' From 9a9830807bfe57e2e776e5cfb18364910d3ae8e2 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Thu, 24 Sep 2020 23:25:50 -0700 Subject: [PATCH 028/294] Adjust testing method to accept depends_on parameter Not really necessary for the existing code, which does not use foreign key relationships, but may be useful in the future. It'll at least make me feel better. --- wbia/tests/dtool/test_sql_control.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index 2738c135d2..746f38242c 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -35,7 +35,7 @@ class TestSchemaModifiers: def fixture(self, ctrlr): self.ctrlr = ctrlr - def make_table_definition(self, name): + def make_table_definition(self, name, depends_on=[]): """Creates a table definition for use with the controller's add_table method""" definition = { 'tablename': name, @@ -50,15 +50,14 @@ def make_table_definition(self, name): 'superkeys': [ ('meta_labeler_id', 'indexer_id', 'config_id'), ], - 'dependson': [ - 'meta_labelers', - 'indexers', - ], + 'dependson': depends_on, } return definition def test_make_add_table_sqlstr(self): - table_definition = self.make_table_definition('foobars') + table_definition = self.make_table_definition( + 'foobars', depends_on=['meta_labelers', 'indexers'] + ) # Call the target sql = self.ctrlr._make_add_table_sqlstr(**table_definition) From cca4e32104c364c021fd1cb50060cd318499e84d Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Thu, 24 Sep 2020 15:35:12 -0700 Subject: [PATCH 029/294] Update _make_add_table_sqlstr for sqlalchemy Updating some of the internals to be consistent and readable at a glance. Also, documenting the code blocks. I'm removing the docstr test because it doesn't test for anything, but maybe that the method output something. So one less test to have hanging around. Besides the previous commit submits a unittest for this method. --- wbia/dtool/sql_control.py | 63 ++++++++++------------------ wbia/tests/dtool/test_sql_control.py | 4 +- 2 files changed, 24 insertions(+), 43 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 779eed3ddf..e5f0984ac2 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1489,56 +1489,39 @@ def add_column(self, tablename, colname, coltype): operation = op_fmtstr.format(**fmtkw) self.executeone(operation, [], verbose=False) - def __make_superkey_constraints(self, superkeys: list) -> list: - """Creates SQL for the 'superkey' constraint. - A 'superkey' is one or more columns that make up a unique constraint on the table. - - """ - has_superkeys = superkeys is not None and len(superkeys) > 0 - constraints = [] - if has_superkeys: - # Create a superkey statement for each superkey item - # superkeys = [(col), (col1, col2, ...), ...], - for columns in superkeys: - columns = ','.join(columns) - constraints.append(f'CONSTRAINT superkey UNIQUE ({columns})') - return constraints + def __make_unique_constraint(self, column_or_columns): + """Creates a SQL ``CONSTRAINT`` clause for ``UNIQUE`` column data""" + if not isinstance(column_or_columns, (list, tuple)): + columns = [column_or_columns] + else: + # Cast as list incase it's a tuple, b/c tuple + list = error + columns = list(column_or_columns) + constraint_name = '_'.join(['unique'] + columns) + columns_listing = ', '.join(columns) + return f'CONSTRAINT {constraint_name} UNIQUE ({columns_listing})' def __make_column_definition(self, name: str, definition: str) -> str: """Creates SQL for the given column `name` and type, default & constraint (i.e. `definition`).""" if not name: raise ValueError(f'name cannot be an empty string paired with {definition}') - if not definition: + elif not definition: raise ValueError(f'definition cannot be an empty string paired with {name}') return f'{name} {definition}' def _make_add_table_sqlstr( self, tablename: str, coldef_list: list, sep=' ', **metadata_keyval ): - r"""Creates the SQL for a CREATE TABLE statement + """Creates the SQL for a CREATE TABLE statement Args: tablename (str): table name coldef_list (list): list of tuples (name, type definition) + sep (str): clause separation character(s) (default: space) + kwargs: metadata specifications Returns: str: operation - CommandLine: - python -m dtool.sql_control _make_add_table_sqlstr - - Example: - >>> # ENABLE_DOCTEST - >>> from wbia.dtool.sql_control import * # NOQA - >>> from wbia.dtool.example_depcache import testdata_depc - >>> depc = testdata_depc() - >>> tablename = 'keypoint' - >>> db = depc[tablename].db - >>> autogen_dict = db.get_table_autogen_dict(tablename) - >>> coldef_list = autogen_dict['coldef_list'] - >>> operation = db._make_add_table_sqlstr(tablename, coldef_list) - >>> print(operation) - """ if not coldef_list: raise ValueError(f'empty coldef_list specified for {tablename}') @@ -1548,22 +1531,20 @@ def _make_add_table_sqlstr( if len(bad_kwargs) > 0: raise TypeError(f'got unexpected keyword arguments: {bad_kwargs}') - if ut.DEBUG2: - logger.info('[sql] schema ensuring tablename=%r' % tablename) - if ut.VERBOSE: - logger.info('') - _args = [tablename, coldef_list] - logger.info(ut.func_str(self.add_table, _args, metadata_keyval)) - logger.info('') + logger.debug('[sql] schema ensuring tablename=%r' % tablename) + logger.debug( + ut.func_str(self.add_table, [tablename, coldef_list], metadata_keyval) + ) # Create the main body of the CREATE TABLE statement with column definitions # coldef_list = [(, ,), ...] body_list = [self.__make_column_definition(c, d) for c, d in coldef_list] # Make a list of constraints to place on the table - constraint_list = self.__make_superkey_constraints( - metadata_keyval.get('superkeys', []) - ) + # superkeys = [(, ...), ...] + constraint_list = [ + self.__make_unique_constraint(x) for x in metadata_keyval.get('superkeys', []) + ] constraint_list = ut.unique_ordered(constraint_list) comma = ',' + sep diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index 746f38242c..352fd9ef57 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -69,8 +69,8 @@ def test_make_add_table_sqlstr(self): 'indexer_id INTEGER NOT NULL, ' 'config_id INTEGER DEFAULT 0, ' 'data TEXT, ' - 'CONSTRAINT superkey ' - 'UNIQUE (meta_labeler_id,indexer_id,config_id) )' + 'CONSTRAINT unique_meta_labeler_id_indexer_id_config_id ' + 'UNIQUE (meta_labeler_id, indexer_id, config_id) )' ) assert sql.text == expected From af0a65ba99c3818c43b952399e0ee333265eea9c Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Thu, 24 Sep 2020 23:00:21 -0700 Subject: [PATCH 030/294] Add a test for the sql controller's add_table method This tests for the creation of the table. It also verifies the metadata has been entered into the metadata table. --- wbia/tests/dtool/test_sql_control.py | 67 +++++++++++++++++++++++++++- 1 file changed, 66 insertions(+), 1 deletion(-) diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index 352fd9ef57..4cffc94f1d 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -1,10 +1,12 @@ # -*- coding: utf-8 -*- import uuid +from functools import partial import pytest import sqlalchemy.exc +from sqlalchemy import MetaData, Table from sqlalchemy.engine import Connection -from sqlalchemy.sql import text +from sqlalchemy.sql import select, text from wbia.dtool.sql_control import ( METADATA_TABLE_COLUMNS, @@ -74,6 +76,69 @@ def test_make_add_table_sqlstr(self): ) assert sql.text == expected + def test_add_table(self): + # Two tables... + # .. used in the creation of bars table + foos_definition = self.make_table_definition('foos') + # .. bars table depends on foos table + bars_definition = self.make_table_definition('bars', depends_on=['foos']) + # We test against bars table and basically neglect foos table + + # Call the target + self.ctrlr.add_table(**foos_definition) + self.ctrlr.add_table(**bars_definition) + + # Check the table has been added and verify details + # Use sqlalchemy's reflection + table_factory = partial(Table, autoload=True, autoload_with=self.ctrlr._engine) + md = MetaData() + bars = table_factory('bars', md) + metadata = table_factory('metadata', md) + + # Check the table's column definitions + expected_bars_columns = [ + ('bars_id', 'INTEGER'), + ('config_id', 'INTEGER'), + ('data', 'TEXT'), + ('indexer_id', 'INTEGER'), + ('meta_labeler_id', 'INTEGER'), + ] + found_bars_columns = [ + ( + c.name, + c.type.__class__.__name__, + ) + for c in bars.columns + ] + assert sorted(found_bars_columns) == expected_bars_columns + # Check the table's constraints + expected_constraint_info = [ + ('PrimaryKeyConstraint', None, ['bars_id']), + ( + 'UniqueConstraint', + 'unique_meta_labeler_id_indexer_id_config_id', + ['meta_labeler_id', 'indexer_id', 'config_id'], + ), + ] + found_constraint_info = [ + (x.__class__.__name__, x.name, [c.name for c in x.columns]) + for x in bars.constraints + ] + assert sorted(found_constraint_info) == expected_constraint_info + + # Check for metadata entries + results = self.ctrlr.connection.execute( + select([metadata.c.metadata_key, metadata.c.metadata_value]).where( + metadata.c.metadata_key.like('bars_%') + ) + ) + expected_metadata_rows = [ + ('bars_docstr', 'docstr for bars'), + ('bars_superkeys', "[('meta_labeler_id', 'indexer_id', 'config_id')]"), + ('bars_dependson', "['foos']"), + ] + assert results.fetchall() == expected_metadata_rows + def test_safely_get_db_version(ctrlr): v = ctrlr.get_db_version(ensure=True) From 00f4fe70ae6bcbeebf53ef4fa0d354c812fe63e6 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Thu, 24 Sep 2020 23:54:49 -0700 Subject: [PATCH 031/294] Refactor table reflect in tests --- wbia/tests/dtool/test_sql_control.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index 4cffc94f1d..7a223f7dcc 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -56,6 +56,19 @@ def make_table_definition(self, name, depends_on=[]): } return definition + @property + def _table_factory(self): + return partial(Table, autoload=True, autoload_with=self.ctrlr._engine) + + def reflect_table(self, name, metadata=None): + """Using SQLAlchemy to reflect the table at the given ``name`` + to return a SQLAlchemy Table object + + """ + if metadata is None: + metadata = MetaData() + return self._table_factory(name, metadata) + def test_make_add_table_sqlstr(self): table_definition = self.make_table_definition( 'foobars', depends_on=['meta_labelers', 'indexers'] @@ -90,10 +103,9 @@ def test_add_table(self): # Check the table has been added and verify details # Use sqlalchemy's reflection - table_factory = partial(Table, autoload=True, autoload_with=self.ctrlr._engine) md = MetaData() - bars = table_factory('bars', md) - metadata = table_factory('metadata', md) + bars = self.reflect_table('bars', md) + metadata = self.reflect_table('metadata', md) # Check the table's column definitions expected_bars_columns = [ From c1115e0eb36f7e6d2bae64cb8e0510661d4064a8 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Fri, 25 Sep 2020 00:27:36 -0700 Subject: [PATCH 032/294] Test and fix sql controller's rename_table method This tests the controller can rename a table. It was broken on the `set` method, where the table's `id` was given as a tuple value. --- wbia/dtool/sql_control.py | 20 ++++++----------- wbia/tests/dtool/test_sql_control.py | 32 ++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 14 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index e5f0984ac2..e9ad67d781 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1778,21 +1778,14 @@ def get_rowid_from_superkey(x): self.rename_table(tablename_temp, tablename_new) def rename_table(self, tablename_old, tablename_new): - if ut.VERBOSE: - logger.info( - '[sql] schema renaming tablename=%r -> %r' - % (tablename_old, tablename_new) - ) + logger.info( + '[sql] schema renaming tablename=%r -> %r' % (tablename_old, tablename_new) + ) # Technically insecure call, but all entries are statically inputted by # the database's owner, who could delete or alter the entire database # anyway. - fmtkw = { - 'tablename_old': tablename_old, - 'tablename_new': tablename_new, - } - op_fmtstr = 'ALTER TABLE {tablename_old} RENAME TO {tablename_new}' - operation = op_fmtstr.format(**fmtkw) - self.executeone(operation, [], verbose=False) + operation = text(f'ALTER TABLE {tablename_old} RENAME TO {tablename_new}') + self.executeone(operation, []) # Rename table's metadata key_old_list = [ @@ -1801,10 +1794,9 @@ def rename_table(self, tablename_old, tablename_new): key_new_list = [ tablename_new + '_' + suffix for suffix in METADATA_TABLE_COLUMN_NAMES ] - id_iter = [(key,) for key in key_old_list] + id_iter = [key for key in key_old_list] val_iter = [(key,) for key in key_new_list] colnames = ('metadata_key',) - # logger.info('Setting metadata_key from %s to %s' % (ut.repr2(id_iter), ut.repr2(val_iter))) self.set( METADATA_TABLE_NAME, colnames, val_iter, id_iter, id_colname='metadata_key' ) diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index 7a223f7dcc..a20f5fe4d5 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -151,6 +151,38 @@ def test_add_table(self): ] assert results.fetchall() == expected_metadata_rows + def test_rename_table(self): + # Assumes `add_table` passes to reduce this test's complexity. + table_name = 'cookies' + self.ctrlr.add_table(**self.make_table_definition(table_name)) + + # Call the target + new_table_name = 'deserts' + self.ctrlr.rename_table(table_name, new_table_name) + + # Check the table has been renamed use sqlalchemy's reflection + md = MetaData() + metadata = self.reflect_table('metadata', md) + + # Reflecting the table is enough to check that it's been renamed. + self.reflect_table(new_table_name, md) + + # Check for metadata entries have been renamed. + results = self.ctrlr.connection.execute( + select([metadata.c.metadata_key, metadata.c.metadata_value]).where( + metadata.c.metadata_key.like(f'{new_table_name}_%') + ) + ) + expected_metadata_rows = [ + (f'{new_table_name}_docstr', f'docstr for {table_name}'), + ( + f'{new_table_name}_superkeys', + "[('meta_labeler_id', 'indexer_id', 'config_id')]", + ), + (f'{new_table_name}_dependson', '[]'), + ] + assert results.fetchall() == expected_metadata_rows + def test_safely_get_db_version(ctrlr): v = ctrlr.get_db_version(ensure=True) From d623f64ba5964424a3de6c9e110d7aba90d0d8a7 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Fri, 25 Sep 2020 07:25:20 -0700 Subject: [PATCH 033/294] Test and fix the sql controller's drop_table method This mostly unittests the behavior and changes the statement to be sqlalchemy compatible. --- wbia/dtool/sql_control.py | 11 +++-------- wbia/tests/dtool/test_sql_control.py | 24 ++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index e9ad67d781..f5fffb5361 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1802,17 +1802,12 @@ def rename_table(self, tablename_old, tablename_new): ) def drop_table(self, tablename): - if VERBOSE_SQL: - logger.info('[sql] schema dropping tablename=%r' % tablename) + logger.info('[sql] schema dropping tablename=%r' % tablename) # Technically insecure call, but all entries are statically inputted by # the database's owner, who could delete or alter the entire database # anyway. - fmtkw = { - 'tablename': tablename, - } - op_fmtstr = 'DROP TABLE IF EXISTS {tablename}' - operation = op_fmtstr.format(**fmtkw) - self.executeone(operation, [], verbose=False) + operation = text(f'DROP TABLE IF EXISTS {tablename}') + self.executeone(operation, []) # Delete table's metadata key_list = [tablename + '_' + suffix for suffix in METADATA_TABLE_COLUMN_NAMES] diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index a20f5fe4d5..4fe0c2cd33 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -183,6 +183,30 @@ def test_rename_table(self): ] assert results.fetchall() == expected_metadata_rows + def test_drop_table(self): + # Assumes `add_table` passes to reduce this test's complexity. + table_name = 'cookies' + self.ctrlr.add_table(**self.make_table_definition(table_name)) + + # Call the target + self.ctrlr.drop_table(table_name) + + # Check the table using sqlalchemy's reflection + md = MetaData() + metadata = self.reflect_table('metadata', md) + + # This error in the attempt to reflect indicates the table has been removed. + with pytest.raises(sqlalchemy.exc.NoSuchTableError): + self.reflect_table(table_name, md) + + # Check for metadata entries have been renamed. + results = self.ctrlr.connection.execute( + select([metadata.c.metadata_key, metadata.c.metadata_value]).where( + metadata.c.metadata_key.like(f'{table_name}_%') + ) + ) + assert results.fetchall() == [] + def test_safely_get_db_version(ctrlr): v = ctrlr.get_db_version(ensure=True) From 332afb71b156948878c00d2b08a2c5195edaf6bf Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Fri, 25 Sep 2020 07:41:01 -0700 Subject: [PATCH 034/294] Test the sql controller's drop_all_tables method --- wbia/tests/dtool/test_sql_control.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index 4fe0c2cd33..fad64fd669 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -207,6 +207,34 @@ def test_drop_table(self): ) assert results.fetchall() == [] + def test_drop_all_table(self): + # Assumes `add_table` passes to reduce this test's complexity. + table_names = ['cookies', 'pies', 'cakes'] + for name in table_names: + self.ctrlr.add_table(**self.make_table_definition(name)) + + # Call the target + self.ctrlr.drop_all_tables() + + # Check the table using sqlalchemy's reflection + md = MetaData() + metadata = self.reflect_table('metadata', md) + + # This error in the attempt to reflect indicates the table has been removed. + for name in table_names: + with pytest.raises(sqlalchemy.exc.NoSuchTableError): + self.reflect_table(name, md) + + # Check for the absents of metadata for the removed tables. + results = self.ctrlr.connection.execute(select([metadata.c.metadata_key])) + expected_metadata_rows = [ + ('database_init_uuid',), + ('database_version',), + ('metadata_docstr',), + ('metadata_superkeys',), + ] + assert results.fetchall() == expected_metadata_rows + def test_safely_get_db_version(ctrlr): v = ctrlr.get_db_version(ensure=True) From bcf7cf09a71a1f6b59ee16d90de19ee347d3c4be Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Fri, 25 Sep 2020 14:25:04 -0700 Subject: [PATCH 035/294] Test and correct the sql controller's get method Note, the method's doctest is still failing, because it depends on other moving parts that aren't fixed yet. I'm tempted to drop the doctest because it's not as comprehensive as the unittests I've added in this changeset, but I'll leave it for now. The changes in the method were made to make the method compatible with sqlalchemy. IMO the `assume_unique` bit of this function should be an entirely new method, but if I understand correctly, it may have been put in as an optimization. --- wbia/dtool/sql_control.py | 57 +++++++-------------- wbia/tests/dtool/test_sql_control.py | 76 ++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 40 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index f5fffb5361..1f474dd4af 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1069,7 +1069,7 @@ def get( assume_unique=False, **kwargs, ): - """getter + """Get rows of data by ID Args: tblname (str): table name to get from @@ -1077,26 +1077,8 @@ def get( id_iter (iterable): iterable of search keys id_colname (str): column to be used as the search key (default: rowid) eager (bool): use eager evaluation + assume_unique (bool): default False. Experimental feature that could result in a 10x speedup unpack_scalars (bool): default True - id_colname (bool): default False. Experimental feature that could result in a 10x speedup - - CommandLine: - python -m dtool.sql_control get - - Ignore: - tblname = 'annotations' - colnames = ('name_rowid',) - id_iter = aid_list - #id_iter = id_iter[0:20] - id_colname = 'rowid' - eager = True - db = ibs.db - - x1 = db.get(tblname, colnames, id_iter, assume_unique=True) - x2 = db.get(tblname, colnames, id_iter, assume_unique=False) - x1 == x2 - %timeit db.get(tblname, colnames, id_iter, assume_unique=True) - %timeit db.get(tblname, colnames, id_iter, assume_unique=False) Example: >>> # ENABLE_DOCTEST @@ -1112,15 +1094,18 @@ def get( >>> got_data = db.get('notch', colnames, id_iter=rowids) >>> assert got_data == [1, 2, 3] """ - if VERBOSE_SQL: - logger.info( - '[sql]' - + ut.get_caller_name(list(range(1, 4))) - + ' db.get(%r, %r, ...)' % (tblname, colnames) - ) + logger.debug( + '[sql]' + + ut.get_caller_name(list(range(1, 4))) + + ' db.get(%r, %r, ...)' % (tblname, colnames) + ) if not isinstance(colnames, (tuple, list)): raise TypeError('colnames must be a sequence type of strings') + # ??? Getting a single column of unique values that is matched on rowid? + # And sorts the results after the query? + # ??? This seems oddly specific for a generic method. + # Perhaps the logic should be in its own method? if ( assume_unique and id_iter is not None @@ -1128,21 +1113,13 @@ def get( and len(colnames) == 1 ): id_iter = list(id_iter) - operation_fmt = """ - SELECT {colnames} - FROM {tblname} - WHERE rowid in ({id_repr}) - ORDER BY rowid ASC - """ - fmtdict = { - 'tblname': tblname, - 'colnames': ', '.join(colnames), - 'id_repr': ','.join(map(str, id_iter)), - } - operation = operation_fmt.format(**fmtdict) + columns = ', '.join(colnames) + ids_listing = ', '.join(map(str, id_iter)) + operation = f'SELECT {columns} FROM {tblname} WHERE rowid in ({ids_listing}) ORDER BY rowid ASC' results = self.connection.execute(operation).fetchall() import numpy as np + # ??? Why order the results if they are going to be sorted here? sortx = np.argsort(np.argsort(id_iter)) results = ut.take(results, sortx) if kwargs.get('unpack_scalars', True): @@ -1153,8 +1130,8 @@ def get( where_clause = None params_iter = [] else: - where_clause = id_colname + '=?' - params_iter = [(_rowid,) for _rowid in id_iter] + where_clause = id_colname + ' = :id' + params_iter = [{'id': id} for id in id_iter] return self.get_where( tblname, colnames, params_iter, where_clause, eager=eager, **kwargs diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index fad64fd669..ecb8a922cd 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -735,6 +735,82 @@ def test_get_where_eq(self): # Verify query assert results == [(9, 8), (8, 7)] + def test_get_all(self): + # Make a table for records + table_name = 'test_getting' + self.make_table(table_name) + + # Create some dummy records + insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') + for i in range(0, 10): + x, y, z = (str(i), i, i * 2.01) + self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + + # Build the expect results of the testing target + results = self.ctrlr.connection.execute(f'SELECT id, x, z FROM {table_name}') + rows = results.fetchall() + row_mapping = {row[0]: row[1:] for row in rows if row[1]} + + # Call the testing target + data = self.ctrlr.get(table_name, ['x', 'z']) + + # Verify getting + assert data == [r for r in row_mapping.values()] + + def test_get_by_id(self): + # Make a table for records + table_name = 'test_getting' + self.make_table(table_name) + + # Create some dummy records + insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') + for i in range(0, 10): + x, y, z = (str(i), i, i * 2.01) + self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + + # Call the testing target + requested_ids = [2, 4, 6] + data = self.ctrlr.get(table_name, ['x', 'z'], requested_ids) + + # Build the expect results of the testing target + sql_array = ', '.join([str(id) for id in requested_ids]) + results = self.ctrlr.connection.execute( + f'SELECT x, z FROM {table_name} WHERE id in ({sql_array})' + ) + expected = results.fetchall() + # Verify getting + assert data == expected + + def test_get_as_unique(self): + # This test could be inaccurate, because this logical path appears + # to be bolted on the side. Usage of this path's feature is unknown. + + # Make a table for records + table_name = 'test_getting' + self.make_table(table_name) + + # Create some dummy records + insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') + for i in range(0, 10): + x, y, z = (str(i), i, i * 2.01) + self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + + # Call the testing target + # The table has a INTEGER PRIMARY KEY, which essentially maps to the rowid + # in SQLite. So, we need not change the default `id_colname` param. + requested_ids = [2, 4, 6] + data = self.ctrlr.get(table_name, ['x'], requested_ids, assume_unique=True) + + # Build the expect results of the testing target + sql_array = ', '.join([str(id) for id in requested_ids]) + results = self.ctrlr.connection.execute( + f'SELECT x FROM {table_name} WHERE id in ({sql_array})' + ) + # ... recall that the controller unpacks single values + expected = [row[0] for row in results] + # Verify getting + assert data == expected + def test_setting(self): # Note, this is not a comprehensive test. It only attempts to test the SQL logic. # Make a table for records From dc17738907247a28734c665d08f55a33f57c8612 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Fri, 25 Sep 2020 14:39:38 -0700 Subject: [PATCH 036/294] Fix get_all_rowids method This allow the `wbia/dtool/sql_control.py::SQLDatabaseController.get_metadata_items:0` doctest to pass. Changes the query to be compatible with sqlalchemy. --- wbia/dtool/sql_control.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 1f474dd4af..cc280d192a 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -714,11 +714,8 @@ def get_row_count(self, tblname): def get_all_rowids(self, tblname, **kwargs): """ returns a list of all rowids from a table in ascending order """ - fmtdict = { - 'tblname': tblname, - } - operation_fmt = 'SELECT rowid FROM {tblname} ORDER BY rowid ASC' - return self._executeone_operation_fmt(operation_fmt, fmtdict, **kwargs) + operation = text(f'SELECT rowid FROM {tblname} ORDER BY rowid ASC') + return self.executeone(operation, **kwargs) def get_all_col_rows(self, tblname, colname): """ returns a list of all rowids from a table in ascending order """ From 15f3a03c8e5ff27e43cf94dcbc7fa7e679d38913 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Fri, 25 Sep 2020 15:02:13 -0700 Subject: [PATCH 037/294] Fix the sql controller's get_column method This allows the following tests to pass: `wbia/dtool/sql_control.py::SQLDatabaseController.get:0`, `wbia/dtool/sql_control.py::SQLDatabaseController.set:0`, `wbia/dtool/sql_control.py::SQLDatabaseController.get_table_column_data:0`, `wbia/dtool/sql_control.py::SQLDatabaseController.make_json_table_definition:0`, `wbia/dtool/sql_control.py::SQLDatabaseController.get_table_new_transferdata:0`, and `wbia/dtool/sql_control.py::SQLDatabaseController.get_table_csv:0`. The change is to be compatible with SQLAlchemy. --- wbia/dtool/sql_control.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index cc280d192a..728ee66f16 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -2184,16 +2184,8 @@ def get_column_names(self, tablename): def get_column(self, tablename, name): """ Conveinience: """ - _table, (_column,) = sanitize_sql(self, tablename, (name,)) - column_vals = self.executeone( - operation=""" - SELECT %s - FROM %s - ORDER BY rowid ASC - """ - % (_column, _table) - ) - return column_vals + table, (column,) = sanitize_sql(self, tablename, (name,)) + return self.executeone(text(f'SELECT {column} FROM {table} ORDER BY rowid ASC')) def get_table_as_pandas( self, tablename, rowids=None, columns=None, exclude_columns=[] @@ -2223,15 +2215,14 @@ def get_table_as_pandas( df = pd.DataFrame(ut.dzip(column_names, column_list), index=index) return df + # TODO (25-Sept-12020) Deprecate once ResultProxy can be exposed, + # because it will allow result access by index or column name. def get_table_column_data( self, tablename, columns=None, exclude_columns=[], rowids=None ): """ Grabs a table of information - CommandLine: - python -m dtool.sql_control --test-get_table_column_data - Example: >>> # ENABLE_DOCTEST >>> from wbia.dtool.sql_control import * # NOQA From 7a219af1de7bc3b25c512ba6ed924b9a02cb73b3 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Fri, 25 Sep 2020 16:05:44 -0700 Subject: [PATCH 038/294] Raise exception to catch '?' parameterization This `where_clause` often has a '?' parameter in it. This is no longer valid with sqlalchemy named parameters in practice. This is mostly to help in the efforts of debugging unknown users of this method (e.g. plugins). --- wbia/dtool/sql_control.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 728ee66f16..befeffb82e 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1002,6 +1002,11 @@ def get_where( stmt = f'SELECT {columns} FROM {tblname}' if where_clause is None: val_list = self.executeone(text(stmt), **kwargs) + elif '?' in where_clause: + raise ValueError( + "Statements cannot use '?' parameterization, " + "use ':name' parameters instead." + ) else: stmt += f' WHERE {where_clause}' val_list = self.executemany( From 2c0333c80443988aac523251fde29fe888f9b40e Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Fri, 25 Sep 2020 20:04:50 -0700 Subject: [PATCH 039/294] Remove debug property from depcache logic The `_debug` property is used to conditionally print logging messages. It's no longer useful since we've switched to a logger pattern. --- wbia/dtool/depcache_control.py | 193 ++++++++++++++------------------- wbia/dtool/depcache_table.py | 157 ++++++++++++--------------- 2 files changed, 148 insertions(+), 202 deletions(-) diff --git a/wbia/dtool/depcache_control.py b/wbia/dtool/depcache_control.py index 9ec9e52c0a..d7af3e017b 100644 --- a/wbia/dtool/depcache_control.py +++ b/wbia/dtool/depcache_control.py @@ -110,9 +110,8 @@ def _register_prop( SEE: dtool.REG_PREPROC_DOC """ - if depc._debug: - logger.info('[depc] Registering tablename=%r' % (tablename,)) - logger.info('[depc] * preproc_func=%r' % (preproc_func,)) + logger.debug('[depc] Registering tablename=%r' % (tablename,)) + logger.debug('[depc] * preproc_func=%r' % (preproc_func,)) # ---------- # Sanitize inputs if isinstance(tablename, six.string_types): @@ -198,20 +197,15 @@ def initialize(depc, _debug=None): logger.info( '[depc] Initialize %s depcache in %r' % (depc.root.upper(), depc.cache_dpath) ) - _debug = depc._debug if _debug is None else _debug if depc._use_globals: reg_preproc = PREPROC_REGISTER[depc.root] reg_subprop = SUBPROP_REGISTER[depc.root] - if ut.VERBOSE: - logger.info( - '[depc.init] Registering %d global preproc funcs' % len(reg_preproc) - ) + logger.info( + '[depc.init] Registering %d global preproc funcs' % len(reg_preproc) + ) for args_, _kwargs in reg_preproc: depc._register_prop(*args_, **_kwargs) - if ut.VERBOSE: - logger.info( - '[depc.init] Registering %d global subprops ' % len(reg_subprop) - ) + logger.info('[depc.init] Registering %d global subprops ' % len(reg_subprop)) for args_, _kwargs in reg_subprop: depc._register_subprop(*args_, **_kwargs) @@ -239,11 +233,10 @@ def initialize(depc, _debug=None): db = sql_control.SQLDatabaseController.from_uri(db_uri) depcache_table.ensure_config_table(db) depc.fname_to_db[fname] = db - if ut.VERBOSE: - logger.info('[depc] Finished initialization') + logger.info('[depc] Finished initialization') for table in depc.cachetable_dict.values(): - table.initialize(_debug=_debug) + table.initialize() # HACKS: # Define injected functions for autocomplete convinience @@ -365,11 +358,9 @@ def _ensure_config(depc, tablekey, config, _debug=False): if config_ is None: # Preferable way to get configs with explicit # configs - if _debug: - logger.info(' **config = %r' % (config,)) + logger.debug(' **config = %r' % (config,)) config_ = configclass(**config) - if _debug: - logger.info(' config_ = %r' % (config_,)) + logger.debug(' config_ = %r' % (config_,)) return config_ def get_config_trail(depc, tablename, config): @@ -411,7 +402,6 @@ def _get_parent_input( nInput=nInput, recompute=recompute, recompute_all=recompute_all, - _debug=ut.countdown_flag(_debug), levels_up=1, ) parent_rowids = depc._get_parent_rowids(table, rowid_dict) @@ -494,18 +484,15 @@ def get_parent_rowids(depc, target_tablename, input_tuple, config=None, **kwargs """ _kwargs = kwargs.copy() _recompute = _kwargs.pop('recompute_all', False) - _debug = _kwargs.get('_debug', False) _hack_rootmost = _kwargs.pop('_hack_rootmost', False) - _debug = depc._debug if _debug is None else _debug if config is None: config = {} - with ut.Indenter('[GetParentID-%s]' % (target_tablename,), enabled=_debug): - if _debug: - logger.info(ut.color_text('Enter get_parent_rowids', 'blue')) - logger.info(' * target_tablename = %r' % (target_tablename,)) - logger.info(' * input_tuple=%s' % (ut.trunc_repr(input_tuple),)) - logger.info(' * config = %r' % (config,)) + with ut.Indenter('[GetParentID-%s]' % (target_tablename,)): + logger.debug('Enter get_parent_rowids') + logger.debug(' * target_tablename = %r' % (target_tablename,)) + logger.debug(' * input_tuple=%s' % (ut.trunc_repr(input_tuple),)) + logger.debug(' * config = %r' % (config,)) target_table = depc[target_tablename] # TODO: Expand to the appropriate given inputs @@ -515,8 +502,7 @@ def get_parent_rowids(depc, target_tablename, input_tuple, config=None, **kwargs else: # otherwise we are given inputs in totalroot form exi_inputs = target_table.rootmost_inputs.total_expand() - if _debug: - logger.info(' * exi_inputs=%s' % (exi_inputs,)) + logger.debug(' * exi_inputs=%s' % (exi_inputs,)) rectified_input = depc.rectify_input_tuple(exi_inputs, input_tuple) @@ -525,25 +511,21 @@ def get_parent_rowids(depc, target_tablename, input_tuple, config=None, **kwargs rowid_dict[rmi] = rowids compute_edges = exi_inputs.flat_compute_rmi_edges() - if _debug: - logger.info(' * rectified_input=%s' % ut.trunc_repr(rectified_input)) - logger.info(' * compute_edges=%s' % ut.repr2(compute_edges, nl=2)) + logger.debug(' * rectified_input=%s' % ut.trunc_repr(rectified_input)) + logger.debug(' * compute_edges=%s' % ut.repr2(compute_edges, nl=2)) for count, (input_nodes, output_node) in enumerate(compute_edges, start=1): - if _debug: - ut.cprint( - ' * COMPUTING %d/%d EDGE %r -- %r' - % (count, len(compute_edges), input_nodes, output_node), - 'blue', - ) + logger.debug( + ' * COMPUTING %d/%d EDGE %r -- %r' + % (count, len(compute_edges), input_nodes, output_node), + ) tablekey = output_node.tablename table = depc[tablekey] input_nodes_ = input_nodes - if _debug: - logger.info( - 'table.parent_id_tablenames = %r' % (table.parent_id_tablenames,) - ) - logger.info('input_nodes_ = %r' % (input_nodes_,)) + logger.debug( + 'table.parent_id_tablenames = %r' % (table.parent_id_tablenames,) + ) + logger.debug('input_nodes_ = %r' % (input_nodes_,)) input_multi_flags = [ node.ismulti and node in exi_inputs.rmi_list for node in input_nodes_ ] @@ -581,42 +563,39 @@ def get_parent_rowids(depc, target_tablename, input_tuple, config=None, **kwargs _parent_rowids = list(zip(*parent_rowids2_)) # _parent_rowids = list(ut.product(*parent_rowids_)) - if _debug: - logger.info( - 'parent_rowids_ = %s' - % ( - ut.repr4( - [ut.trunc_repr(ids_) for ids_ in parent_rowids_], - strvals=True, - ) + logger.debug( + 'parent_rowids_ = %s' + % ( + ut.repr4( + [ut.trunc_repr(ids_) for ids_ in parent_rowids_], + strvals=True, ) ) - logger.info( - 'parent_rowids2_ = %s' - % ( - ut.repr4( - [ut.trunc_repr(ids_) for ids_ in parent_rowids2_], - strvals=True, - ) + ) + logger.debug( + 'parent_rowids2_ = %s' + % ( + ut.repr4( + [ut.trunc_repr(ids_) for ids_ in parent_rowids2_], + strvals=True, ) ) - logger.info( - '_parent_rowids = %s' - % ( - ut.truncate_str( - ut.repr4( - [ut.trunc_repr(ids_) for ids_ in _parent_rowids], - strvals=True, - ) + ) + logger.debug( + '_parent_rowids = %s' + % ( + ut.truncate_str( + ut.repr4( + [ut.trunc_repr(ids_) for ids_ in _parent_rowids], + strvals=True, ) ) ) + ) - if _debug: - ut.cprint('-------------', 'blue') if output_node.tablename != target_tablename: # Get table configuration - config_ = depc._ensure_config(tablekey, config, _debug) + config_ = depc._ensure_config(tablekey, config) output_rowids = table.get_rowid( _parent_rowids, config=config_, recompute=_recompute, **_kwargs @@ -660,12 +639,11 @@ def get_rowids(depc, tablename, input_tuple, **rowid_kw): >>> root_rowids = [1, 2, 3] >>> root_rowids2 = [(4, 5, 6, 7)] >>> root_rowids3 = root_rowids2 - >>> _debug = True >>> tablename = 'smk_match' >>> input_tuple = (root_rowids, root_rowids2, root_rowids3) >>> target_table = depc[tablename] >>> inputs = target_table.rootmost_inputs.total_expand() - >>> depc.get_rowids(tablename, input_tuple, _debug=_debug) + >>> depc.get_rowids(tablename, input_tuple) >>> depc.print_all_tables() Example: @@ -696,8 +674,6 @@ def get_rowids(depc, tablename, input_tuple, **rowid_kw): >>> assert recomp_rowids == initial_rowids, 'rowids should not change due to recompute' """ target_tablename = tablename - _debug = rowid_kw.get('_debug', False) - _debug = depc._debug if _debug is None else _debug _kwargs = rowid_kw.copy() config = _kwargs.pop('config', {}) _hack_rootmost = _kwargs.pop('_hack_rootmost', False) @@ -713,8 +689,8 @@ def get_rowids(depc, tablename, input_tuple, **rowid_kw): **_kwargs, ) - with ut.Indenter('[GetRowId-%s]' % (target_tablename,), enabled=_debug): - config_ = depc._ensure_config(target_tablename, config, _debug) + with ut.Indenter('[GetRowId-%s]' % (target_tablename,)): + config_ = depc._ensure_config(target_tablename, config) rowids = table.get_rowid( parent_rowids, config=config_, recompute=recompute, **_kwargs ) @@ -768,7 +744,6 @@ def get( >>> depc = testdata_depc3(True) >>> exec(ut.execstr_funckw(depc.get), globals()) >>> aids = [1, 2, 3] - >>> _debug = True >>> tablename = 'labeler' >>> root_rowids = aids >>> prop_list = depc.get( @@ -785,7 +760,6 @@ def get( >>> depc = testdata_depc3(True) >>> exec(ut.execstr_funckw(depc.get), globals()) >>> aids = [1, 2, 3] - >>> _debug = True >>> tablename = 'smk_match' >>> tablename = 'vocab' >>> table = depc[tablename] @@ -804,7 +778,6 @@ def get( >>> depc = testdata_depc3(True) >>> exec(ut.execstr_funckw(depc.get), globals()) >>> aids = [1, 2, 3] - >>> _debug = True >>> depc = testdata_depc() >>> tablename = 'chip' >>> table = depc[tablename] @@ -821,13 +794,11 @@ def get( if tablename == depc.root_tablename: return depc.root_getters[colnames](root_rowids) # pass - _debug = depc._debug if _debug is None else _debug - with ut.Indenter('[GetProp-%s]' % (tablename,), enabled=_debug): - if _debug: - logger.info(' * tablename=%s' % (tablename)) - logger.info(' * root_rowids=%s' % (ut.trunc_repr(root_rowids))) - logger.info(' * colnames = %r' % (colnames,)) - logger.info(' * config = %r' % (config,)) + with ut.Indenter('[GetProp-%s]' % (tablename,)): + logger.debug(' * tablename=%s' % (tablename)) + logger.debug(' * root_rowids=%s' % (ut.trunc_repr(root_rowids))) + logger.debug(' * colnames = %r' % (colnames,)) + logger.debug(' * config = %r' % (config,)) if hack_paths and not ensure and not read_extern: # HACK: should be able to not compute rows to get certain properties @@ -839,17 +810,15 @@ def get( root_rowids, config=config, ensure=True, - _debug=None, recompute_all=False, eager=True, nInput=None, ) config_ = depc._ensure_config(tablename, config) - if _debug: - logger.info(' * (ensured) config_ = %r' % (config_,)) + logger.debug(' * (ensured) config_ = %r' % (config_,)) table = depc[tablename] extern_dpath = table.extern_dpath - ut.ensuredir(extern_dpath, verbose=False or table.depc._debug) + ut.ensuredir(extern_dpath) fname_list = table.get_extern_fnames( parent_rowids, config=config_, extern_col_index=0 ) @@ -866,12 +835,10 @@ def get( ensure=ensure, recompute=recompute, recompute_all=recompute_all, - _debug=_debug, ) rowdata_kw = dict( read_extern=read_extern, - _debug=_debug, num_retries=num_retries, eager=eager, ensure=ensure, @@ -885,10 +852,9 @@ def get( table = depc[tablename] # Vectorized get of properties tbl_rowids = depc.get_rowids(tablename, input_tuple, **rowid_kw) - if _debug: - logger.info( - '[depc.get] tbl_rowids = %s' % (ut.trunc_repr(tbl_rowids),) - ) + logger.debug( + '[depc.get] tbl_rowids = %s' % (ut.trunc_repr(tbl_rowids),) + ) prop_list = table.get_row_data(tbl_rowids, colnames, **rowdata_kw) except depcache_table.ExternalStorageException: logger.info('!!* Hit ExternalStorageException') @@ -896,8 +862,7 @@ def get( raise else: break - if _debug: - logger.info('* return prop_list=%s' % (ut.trunc_repr(prop_list),)) + logger.debug('* return prop_list=%s' % (ut.trunc_repr(prop_list),)) return prop_list def get_native( @@ -939,24 +904,20 @@ def get_native( >>> print('chips = %r' % (chips,)) """ tbl_rowids = list(tbl_rowids) - _debug = depc._debug if _debug is None else _debug - with ut.Indenter('[GetNative %s]' % (tablename,), enabled=_debug): - if _debug: - logger.info(' * tablename = %r' % (tablename,)) - logger.info(' * colnames = %r' % (colnames,)) - logger.info(' * tbl_rowids=%s' % (ut.trunc_repr(tbl_rowids))) + with ut.Indenter('[GetNative %s]' % (tablename,)): + logger.debug(' * tablename = %r' % (tablename,)) + logger.debug(' * colnames = %r' % (colnames,)) + logger.debug(' * tbl_rowids=%s' % (ut.trunc_repr(tbl_rowids))) table = depc[tablename] # import utool # with utool.embed_on_exception_context: # try: - prop_list = table.get_row_data( - tbl_rowids, colnames, _debug=_debug, read_extern=read_extern - ) + prop_list = table.get_row_data(tbl_rowids, colnames, read_extern=read_extern) # except depcache_table.ExternalStorageException: # # This code is a bit rendant and would probably live better elsewhere # # Also need to fix issues if more than one column specified # extern_uris = table.get_row_data( - # tbl_rowids, colnames, _debug=_debug, read_extern=False, + # tbl_rowids, colnames, read_extern=False, # delete_on_fail=True, ensure=False) # from os.path import exists # error_flags = [exists(uri) for uri in extern_uris] @@ -969,7 +930,7 @@ def get_native( # table.get_rowid(parent_rowids, recompute=True, config=config) # # TRY ONE MORE TIME - # prop_list = table.get_row_data(tbl_rowids, colnames, _debug=_debug, + # prop_list = table.get_row_data(tbl_rowids, colnames, # read_extern=read_extern, # delete_on_fail=False) return prop_list @@ -1036,7 +997,10 @@ def delete_property(depc, tablename, root_rowids, config=None, _debug=False): FIXME: make this work for all configs """ rowid_list = depc.get_rowids( - tablename, root_rowids, config=config, ensure=False, _debug=_debug + tablename, + root_rowids, + config=config, + ensure=False, ) table = depc[tablename] num_deleted = table.delete_rows(rowid_list) @@ -1112,7 +1076,8 @@ def __init__( get_root_uuid = ut.identity depc.get_root_uuid = get_root_uuid depc.delete_exclude_tables = {} - depc._debug = ut.get_argflag(('--debug-depcache', '--debug-depc')) + # BBB (25-Sept-12020) `_debug` remains around to be backwards compatible + depc._debug = False def get_tablenames(depc): return list(depc.cachetable_dict.keys()) @@ -1486,9 +1451,9 @@ def delete_root( >>> depc = testdata_depc() >>> exec(ut.execstr_funckw(depc.delete_root), globals()) >>> root_rowids = [1] - >>> depc.delete_root(root_rowids, _debug=0) + >>> depc.delete_root(root_rowids) >>> depc.get('fgweight', [1]) - >>> depc.delete_root(root_rowids, _debug=0) + >>> depc.delete_root(root_rowids) """ # graph = depc.make_graph(implicit=False) # hack @@ -1498,7 +1463,7 @@ def delete_root( ) # children = [child for child in graph.succ[depc.root_tablename] # if sum([len(e) for e in graph.pred[child].values()]) == 1] - # depc.delete_property(tablename, root_rowids, _debug=_debug) + # depc.delete_property(tablename, root_rowids) num_deleted = 0 for tablename, table_rowids in rowid_dict.items(): if tablename == depc.root: diff --git a/wbia/dtool/depcache_table.py b/wbia/dtool/depcache_table.py index 55680120f4..50d75e2987 100644 --- a/wbia/dtool/depcache_table.py +++ b/wbia/dtool/depcache_table.py @@ -389,7 +389,7 @@ def get_config_rowid(table, config=None, _debug=None): if isinstance(config, int): config_rowid = config else: - config_rowid = table.add_config(config, _debug) + config_rowid = table.add_config(config) return config_rowid def get_config_hashid(table, config_rowid_list): @@ -430,9 +430,8 @@ def add_config(table, config, _debug=None): except AttributeError: config_strid = ut.to_json(config) config_hashid = ut.hashstr27(config_strid) - if table.depc._debug or _debug: - logger.info('config_strid = %r' % (config_strid,)) - logger.info('config_hashid = %r' % (config_hashid,)) + logger.debug('config_strid = %r' % (config_strid,)) + logger.debug('config_hashid = %r' % (config_hashid,)) get_rowid_from_superkey = table.get_config_rowid_from_hashid colnames = (CONFIG_HASHID, CONFIG_TABLENAME, CONFIG_STRID, CONFIG_DICT) if hasattr(config, 'config'): @@ -444,9 +443,7 @@ def add_config(table, config, _debug=None): CONFIG_TABLE, colnames, param_list, get_rowid_from_superkey ) config_rowid = config_rowid_list[0] - if table.depc._debug: - logger.info('config_rowid_list = %r' % (config_rowid_list,)) - # logger.info('config_rowid = %r' % (config_rowid,)) + logger.debug('config_rowid_list = %r' % (config_rowid_list,)) return config_rowid @@ -1342,11 +1339,11 @@ def prepare_storage( >>> tablename = 'labeler' >>> tablename = 'indexer' >>> config = {tablename + '_param': None, 'foo': 'bar'} - >>> data = depc.get('labeler', [1, 2, 3], 'data', _debug=0) - >>> data = depc.get('labeler', [1, 2, 3], 'data', config=config, _debug=0) - >>> data = depc.get('indexer', [[1, 2, 3]], 'data', _debug=0) - >>> data = depc.get('indexer', [[1, 2, 3]], 'data', config=config, _debug=0) - >>> rowids = depc.get_rowids('indexer', [[1, 2, 3]], config=config, _debug=0) + >>> data = depc.get('labeler', [1, 2, 3], 'data') + >>> data = depc.get('labeler', [1, 2, 3], 'data', config=config) + >>> data = depc.get('indexer', [[1, 2, 3]], 'data') + >>> data = depc.get('indexer', [[1, 2, 3]], 'data', config=config) + >>> rowids = depc.get_rowids('indexer', [[1, 2, 3]], config=config) >>> table = depc[tablename] >>> model_uuid_list = table.get_internal_columns(rowids, ('model_uuid',)) >>> model_uuid = model_uuid_list[0] @@ -1544,7 +1541,7 @@ def _prepare_storage_extern( ) # get extern cache directory and fpaths extern_dpath = table.extern_dpath - ut.ensuredir(extern_dpath, verbose=False or table.depc._debug) + ut.ensuredir(extern_dpath) # extern_fpaths_list = [ # [join(extern_dpath, fname) for fname in fnames] # for fnames in extern_fnames_list @@ -1722,19 +1719,18 @@ def _chunk_compute_dirty_rows( >>> from wbia.dtool.example_depcache2 import * # NOQA >>> depc = testdata_depc3(in_memory=False) >>> depc.clear_all() - >>> data = depc.get('labeler', [1, 2, 3], 'data', _debug=True) - >>> data = depc.get('indexer', [[1, 2, 3]], 'data', _debug=True) + >>> data = depc.get('labeler', [1, 2, 3], 'data') + >>> data = depc.get('indexer', [[1, 2, 3]], 'data') >>> depc.print_all_tables() """ nInput = len(dirty_parent_ids) chunksize = nInput if table.chunksize is None else table.chunksize - if verbose: - logger.info( - '[deptbl.compute] nInput={}, chunksize={}, tbl={}'.format( - nInput, table.chunksize, table.tablename - ) + logger.info( + '[deptbl.compute] nInput={}, chunksize={}, tbl={}'.format( + nInput, table.chunksize, table.tablename ) + ) # Report computation progress dirty_iter = list(zip(dirty_parent_ids, dirty_preproc_args)) @@ -1929,8 +1925,7 @@ def initialize(table, _debug=None): table.db = table.depc.fname_to_db[table.fname] # logger.info('Checking sql for table=%r' % (table.tablename,)) if not table.db.has_table(table.tablename): - if _debug or ut.VERBOSE: - logger.info('Initializing table=%r' % (table.tablename,)) + logger.debug('Initializing table=%r' % (table.tablename,)) new_state = table._get_addtable_kw() table.db.add_table(**new_state) else: @@ -2028,18 +2023,16 @@ def ensure_rows( >>> table = depc['vsone'] >>> exec(ut.execstr_funckw(table.get_rowid), globals()) >>> config = table.configclass() - >>> _debug = 5 >>> verbose = True >>> # test duplicate inputs are detected and accounted for >>> parent_rowids = [(i, i) for i in list(range(100))] * 100 >>> rectify_tup = table._rectify_ids(parent_rowids) >>> (parent_ids_, preproc_args, idxs1, idxs2) = rectify_tup - >>> rowids = table.ensure_rows(parent_ids_, preproc_args, config=config, _debug=_debug) + >>> rowids = table.ensure_rows(parent_ids_, preproc_args, config=config) >>> result = ('rowids = %r' % (rowids,)) >>> print(result) """ try: - _debug = table.depc._debug if _debug is None else _debug # Get requested configuration id config_rowid = table.get_config_rowid(config) @@ -2047,12 +2040,11 @@ def ensure_rows( initial_rowid_list = table._get_rowid(parent_ids_, config=config) initial_rowid_list = list(initial_rowid_list) - if table.depc._debug: - logger.info( - '[deptbl.ensure] initial_rowid_list = %s' - % (ut.trunc_repr(initial_rowid_list),) - ) - logger.info('[deptbl.ensure] config_rowid = %r' % (config_rowid,)) + logger.debug( + '[deptbl.ensure] initial_rowid_list = %s' + % (ut.trunc_repr(initial_rowid_list),) + ) + logger.debug('[deptbl.ensure] config_rowid = %r' % (config_rowid,)) # Get corresponding "dirty" parent rowids isdirty_list = ut.flag_None_items(initial_rowid_list) @@ -2061,16 +2053,15 @@ def ensure_rows( if num_dirty > 0: with ut.Indenter('[ADD]', enabled=_debug): - if verbose or _debug: - logger.info( - 'Add %d / %d new rows to %r' - % (num_dirty, num_total, table.tablename) - ) - logger.info( - '[deptbl.add] * config_rowid = {}, config={}'.format( - config_rowid, str(config) - ) + logger.debug( + 'Add %d / %d new rows to %r' + % (num_dirty, num_total, table.tablename) + ) + logger.debug( + '[deptbl.add] * config_rowid = {}, config={}'.format( + config_rowid, str(config) ) + ) dirty_parent_ids_ = ut.compress(parent_ids_, isdirty_list) dirty_preproc_args_ = ut.compress(preproc_args, isdirty_list) @@ -2117,16 +2108,14 @@ def ensure_rows( # Remove cache when main add is done table._hack_chunk_cache = None - if verbose or _debug: - logger.info('[deptbl.add] finished add') + logger.debug('[deptbl.add] finished add') # # The requested data is clean and must now exist in the parent # database, do a lookup to ensure the correct order. rowid_list = table._get_rowid(parent_ids_, config=config) else: rowid_list = initial_rowid_list - if _debug: - logger.info('[deptbl.add] rowid_list = %s' % ut.trunc_repr(rowid_list)) + logger.debug('[deptbl.add] rowid_list = %s' % ut.trunc_repr(rowid_list)) except lite.IntegrityError: if retry <= 0: raise @@ -2145,7 +2134,6 @@ def ensure_rows( preproc_args, config=config, verbose=verbose, - _debug=_debug, retry=retry_, retry_delay_min=retry_delay_min, retry_delay_max=retry_delay_max, @@ -2279,7 +2267,7 @@ def get_rowid( eager (bool): (default = True) nInput (int): (default = None) recompute (bool): (default = False) - _debug (None): (default = None) + _debug (None): (default = None) deprecated; no-op Returns: list: rowid_list @@ -2295,22 +2283,18 @@ def get_rowid( >>> table = depc['labeler'] >>> exec(ut.execstr_funckw(table.get_rowid), globals()) >>> config = table.configclass() - >>> _debug = True >>> parent_rowids = list(zip([1, None, None, 2])) - >>> rowids = table.get_rowid(parent_rowids, config=config, _debug=_debug) + >>> rowids = table.get_rowid(parent_rowids, config=config) >>> result = ('rowids = %r' % (rowids,)) >>> print(result) rowids = [1, None, None, 2] """ - _debug = table.depc._debug if _debug is None else _debug - if _debug: - logger.info( - '[deptbl.get_rowid] Get %s rowids via %d parent superkeys' - % (table.tablename, len(parent_rowids)) - ) - if _debug > 1: - logger.info('[deptbl.get_rowid] config = %r' % (config,)) - logger.info('[deptbl.get_rowid] ensure = %r' % (ensure,)) + logger.debug( + '[deptbl.get_rowid] Get %s rowids via %d parent superkeys' + % (table.tablename, len(parent_rowids)) + ) + logger.debug('[deptbl.get_rowid] config = %r' % (config,)) + logger.debug('[deptbl.get_rowid] ensure = %r' % (ensure,)) # Ensure inputs are in the correct format / remove Nones # Collapse multi-inputs into a UUID hash @@ -2321,7 +2305,10 @@ def get_rowid( logger.info('REQUESTED RECOMPUTE') # get existing rowids, delete them, recompute the request rowid_list_ = table._get_rowid( - parent_ids_, config=config, eager=True, nInput=None, _debug=_debug + parent_ids_, + config=config, + eager=True, + nInput=None, ) rowid_list_ = list(rowid_list_) needs_recompute_rowids = ut.filter_Nones(rowid_list_) @@ -2336,35 +2323,38 @@ def get_rowid( for try_num in range(num_retries): try: rowid_list_ = table.ensure_rows( - parent_ids_, preproc_args, config=config, _debug=_debug + parent_ids_, + preproc_args, + config=config, ) except ExternalStorageException: if try_num == num_retries - 1: raise else: rowid_list_ = table._get_rowid( - parent_ids_, config=config, eager=eager, nInput=nInput, _debug=_debug + parent_ids_, + config=config, + eager=eager, + nInput=nInput, ) # Map outputs to correspond with inputs rowid_list = table._unrectify_ids(rowid_list_, parent_rowids, idxs1, idxs2) return rowid_list # @profile - def _get_rowid(table, parent_ids_, config=None, eager=True, nInput=None, _debug=None): + def _get_rowid(table, parent_ids_, config=None, eager=True, nInput=None): """ Returns rowids using parent superkeys. Does not add non-existing properties. """ colnames = (table.rowid_colname,) config_rowid = table.get_config_rowid(config=config) - _debug = table.depc._debug if _debug is None else _debug - if _debug: - logger.info('_get_rowid') - logger.info('_get_rowid table.tablename = %r ' % (table.tablename,)) - logger.info('_get_rowid parent_ids_ = %s' % (ut.trunc_repr(parent_ids_))) - logger.info('_get_rowid config = %s' % (config)) - logger.info('_get_rowid table.rowid_colname = %s' % (table.rowid_colname)) - logger.info('_get_rowid config_rowid = %s' % (config_rowid)) + logger.debug('_get_rowid') + logger.debug('_get_rowid table.tablename = %r ' % (table.tablename,)) + logger.debug('_get_rowid parent_ids_ = %s' % (ut.trunc_repr(parent_ids_))) + logger.debug('_get_rowid config = %s' % (config)) + logger.debug('_get_rowid table.rowid_colname = %s' % (table.rowid_colname)) + logger.debug('_get_rowid config_rowid = %s' % (config_rowid)) andwhere_colnames = table.superkey_colnames params_iter = (ids_ + (config_rowid,) for ids_ in parent_ids_) # TODO: make sure things that call this can accept a generator @@ -2379,8 +2369,7 @@ def _get_rowid(table, parent_ids_, config=None, eager=True, nInput=None, _debug= eager=eager, nInput=nInput, ) - if _debug: - logger.info('_get_rowid rowid_list = %s' % (ut.trunc_repr(rowid_list))) + logger.debug('_get_rowid rowid_list = %s' % (ut.trunc_repr(rowid_list))) return rowid_list def clear_table(table): @@ -2633,12 +2622,10 @@ def get_row_data( >>> data = table.get_row_data(tbl_rowids, 'chip', read_extern=False, ensure=False) >>> data = table.get_row_data(tbl_rowids, 'chip', read_extern=False, ensure=True) """ - _debug = table.depc._debug if _debug is None else _debug - if _debug: - logger.info( - ('Get col of tablename=%r, colnames=%r with ' 'tbl_rowids=%s') - % (table.tablename, colnames, ut.trunc_repr(tbl_rowids)) - ) + logger.debug( + ('Get col of tablename=%r, colnames=%r with ' 'tbl_rowids=%s') + % (table.tablename, colnames, ut.trunc_repr(tbl_rowids)) + ) #### # Resolve requested column names if unpack_columns is None: @@ -2652,16 +2639,13 @@ def get_row_data( else: requested_colnames = colnames - if _debug: - logger.info('requested_colnames = %r' % (requested_colnames,)) + logger.debug('requested_colnames = %r' % (requested_colnames,)) tup = table._resolve_requested_columns(requested_colnames) nesting_xs, extern_resolve_tups, flat_intern_colnames = tup - if _debug: - logger.info( - '[deptbl.get_row_data] flat_intern_colnames = %r' - % (flat_intern_colnames,) - ) + logger.debug( + '[deptbl.get_row_data] flat_intern_colnames = %r' % (flat_intern_colnames,) + ) nonNone_flags = ut.flag_not_None_items(tbl_rowids) nonNone_tbl_rowids = ut.compress(tbl_rowids, nonNone_flags) @@ -2759,7 +2743,6 @@ def _generator_resolve_all(): read_extern, delete_on_fail, tries_left, - _debug, ) except ExternalStorageException: if tries_left == 0: @@ -2798,7 +2781,6 @@ def _resolve_any_external_data( read_extern, delete_on_fail, tries_left, - _debug, ): #### # Read data specified by any external columns @@ -2810,8 +2792,7 @@ def _resolve_any_external_data( raise for extern_colx, read_func in extern_resolve_tups: - if _debug: - logger.info('[deptbl.get_row_data] read_func = %r' % (read_func,)) + logger.debug('[deptbl.get_row_data] read_func = %r' % (read_func,)) data_list = [] failed_list = [] for uri in prop_listT[extern_colx]: From 1a4650cbcfd6ef1e0e2b9deae8165b3194012387 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Fri, 25 Sep 2020 20:46:25 -0700 Subject: [PATCH 040/294] Remove use of ut.Indenter from depcache We're no longer using `print` for logging, so the use of `ut.Indenter` does nothing. --- wbia/dtool/depcache_control.py | 396 ++++++++++++++++----------------- wbia/dtool/depcache_table.py | 110 +++++---- 2 files changed, 247 insertions(+), 259 deletions(-) diff --git a/wbia/dtool/depcache_control.py b/wbia/dtool/depcache_control.py index d7af3e017b..bad112b692 100644 --- a/wbia/dtool/depcache_control.py +++ b/wbia/dtool/depcache_control.py @@ -488,126 +488,123 @@ def get_parent_rowids(depc, target_tablename, input_tuple, config=None, **kwargs if config is None: config = {} - with ut.Indenter('[GetParentID-%s]' % (target_tablename,)): - logger.debug('Enter get_parent_rowids') - logger.debug(' * target_tablename = %r' % (target_tablename,)) - logger.debug(' * input_tuple=%s' % (ut.trunc_repr(input_tuple),)) - logger.debug(' * config = %r' % (config,)) - target_table = depc[target_tablename] - - # TODO: Expand to the appropriate given inputs - if _hack_rootmost: - # Hack: if true, we are given inputs in rootmost form - exi_inputs = target_table.rootmost_inputs - else: - # otherwise we are given inputs in totalroot form - exi_inputs = target_table.rootmost_inputs.total_expand() - logger.debug(' * exi_inputs=%s' % (exi_inputs,)) - - rectified_input = depc.rectify_input_tuple(exi_inputs, input_tuple) + logger.debug('Enter get_parent_rowids') + logger.debug(' * target_tablename = %r' % (target_tablename,)) + logger.debug(' * input_tuple=%s' % (ut.trunc_repr(input_tuple),)) + logger.debug(' * config = %r' % (config,)) + target_table = depc[target_tablename] + + # TODO: Expand to the appropriate given inputs + if _hack_rootmost: + # Hack: if true, we are given inputs in rootmost form + exi_inputs = target_table.rootmost_inputs + else: + # otherwise we are given inputs in totalroot form + exi_inputs = target_table.rootmost_inputs.total_expand() + logger.debug(' * exi_inputs=%s' % (exi_inputs,)) - rowid_dict = {} - for rmi, rowids in zip(exi_inputs.rmi_list, rectified_input): - rowid_dict[rmi] = rowids + rectified_input = depc.rectify_input_tuple(exi_inputs, input_tuple) - compute_edges = exi_inputs.flat_compute_rmi_edges() - logger.debug(' * rectified_input=%s' % ut.trunc_repr(rectified_input)) - logger.debug(' * compute_edges=%s' % ut.repr2(compute_edges, nl=2)) + rowid_dict = {} + for rmi, rowids in zip(exi_inputs.rmi_list, rectified_input): + rowid_dict[rmi] = rowids - for count, (input_nodes, output_node) in enumerate(compute_edges, start=1): - logger.debug( - ' * COMPUTING %d/%d EDGE %r -- %r' - % (count, len(compute_edges), input_nodes, output_node), - ) - tablekey = output_node.tablename - table = depc[tablekey] - input_nodes_ = input_nodes - logger.debug( - 'table.parent_id_tablenames = %r' % (table.parent_id_tablenames,) - ) - logger.debug('input_nodes_ = %r' % (input_nodes_,)) - input_multi_flags = [ - node.ismulti and node in exi_inputs.rmi_list for node in input_nodes_ - ] + compute_edges = exi_inputs.flat_compute_rmi_edges() + logger.debug(' * rectified_input=%s' % ut.trunc_repr(rectified_input)) + logger.debug(' * compute_edges=%s' % ut.repr2(compute_edges, nl=2)) - # Args currently go in like this: - # args = [..., (pid_{i,1}, pid_{i,2}, ..., pid_{i,M}), ...] - # They get converted into - # argsT = [... (pid_{1,j}, ... pid_{N,j}) ...] - # i = row, j = col - sig_multi_flags = table.get_parent_col_attr('ismulti') - parent_rowidsT = ut.take(rowid_dict, input_nodes_) - parent_rowids_ = [] - # TODO: will need to figure out which columns to zip and which - # columns to product (ie take product over ones that have 1 - # item, and zip ones that have equal amount of items) - for flag1, flag2, rowidsT in zip( - sig_multi_flags, input_multi_flags, parent_rowidsT - ): - if flag1 and flag2: - parent_rowids_.append(rowidsT) - elif flag1 and not flag2: - parent_rowids_.append([rowidsT]) - elif not flag1 and flag2: - assert len(rowidsT) == 1 - parent_rowids_.append(rowidsT[0]) - else: - parent_rowids_.append(rowidsT) - # Assume that we are either given corresponding lists or single values - # that must be broadcast. - rowlens = list(map(len, parent_rowids_)) - maxlen = max(rowlens) - parent_rowids2_ = [ - r * maxlen if len(r) == 1 else r for r in parent_rowids_ - ] - _parent_rowids = list(zip(*parent_rowids2_)) - # _parent_rowids = list(ut.product(*parent_rowids_)) + for count, (input_nodes, output_node) in enumerate(compute_edges, start=1): + logger.debug( + ' * COMPUTING %d/%d EDGE %r -- %r' + % (count, len(compute_edges), input_nodes, output_node), + ) + tablekey = output_node.tablename + table = depc[tablekey] + input_nodes_ = input_nodes + logger.debug( + 'table.parent_id_tablenames = %r' % (table.parent_id_tablenames,) + ) + logger.debug('input_nodes_ = %r' % (input_nodes_,)) + input_multi_flags = [ + node.ismulti and node in exi_inputs.rmi_list for node in input_nodes_ + ] - logger.debug( - 'parent_rowids_ = %s' - % ( - ut.repr4( - [ut.trunc_repr(ids_) for ids_ in parent_rowids_], - strvals=True, - ) + # Args currently go in like this: + # args = [..., (pid_{i,1}, pid_{i,2}, ..., pid_{i,M}), ...] + # They get converted into + # argsT = [... (pid_{1,j}, ... pid_{N,j}) ...] + # i = row, j = col + sig_multi_flags = table.get_parent_col_attr('ismulti') + parent_rowidsT = ut.take(rowid_dict, input_nodes_) + parent_rowids_ = [] + # TODO: will need to figure out which columns to zip and which + # columns to product (ie take product over ones that have 1 + # item, and zip ones that have equal amount of items) + for flag1, flag2, rowidsT in zip( + sig_multi_flags, input_multi_flags, parent_rowidsT + ): + if flag1 and flag2: + parent_rowids_.append(rowidsT) + elif flag1 and not flag2: + parent_rowids_.append([rowidsT]) + elif not flag1 and flag2: + assert len(rowidsT) == 1 + parent_rowids_.append(rowidsT[0]) + else: + parent_rowids_.append(rowidsT) + # Assume that we are either given corresponding lists or single values + # that must be broadcast. + rowlens = list(map(len, parent_rowids_)) + maxlen = max(rowlens) + parent_rowids2_ = [r * maxlen if len(r) == 1 else r for r in parent_rowids_] + _parent_rowids = list(zip(*parent_rowids2_)) + # _parent_rowids = list(ut.product(*parent_rowids_)) + + logger.debug( + 'parent_rowids_ = %s' + % ( + ut.repr4( + [ut.trunc_repr(ids_) for ids_ in parent_rowids_], + strvals=True, ) ) - logger.debug( - 'parent_rowids2_ = %s' - % ( - ut.repr4( - [ut.trunc_repr(ids_) for ids_ in parent_rowids2_], - strvals=True, - ) + ) + logger.debug( + 'parent_rowids2_ = %s' + % ( + ut.repr4( + [ut.trunc_repr(ids_) for ids_ in parent_rowids2_], + strvals=True, ) ) - logger.debug( - '_parent_rowids = %s' - % ( - ut.truncate_str( - ut.repr4( - [ut.trunc_repr(ids_) for ids_ in _parent_rowids], - strvals=True, - ) + ) + logger.debug( + '_parent_rowids = %s' + % ( + ut.truncate_str( + ut.repr4( + [ut.trunc_repr(ids_) for ids_ in _parent_rowids], + strvals=True, ) ) ) + ) - if output_node.tablename != target_tablename: - # Get table configuration - config_ = depc._ensure_config(tablekey, config) + if output_node.tablename != target_tablename: + # Get table configuration + config_ = depc._ensure_config(tablekey, config) - output_rowids = table.get_rowid( - _parent_rowids, config=config_, recompute=_recompute, **_kwargs - ) - rowid_dict[output_node] = output_rowids - # table.get_model_inputs(table.get_model_uuid(output_rowids)[0]) - else: - # We are only computing up to the parents of the table here. - parent_rowids = _parent_rowids - break - # rowids = rowid_dict[output_node] - return parent_rowids + output_rowids = table.get_rowid( + _parent_rowids, config=config_, recompute=_recompute, **_kwargs + ) + rowid_dict[output_node] = output_rowids + # table.get_model_inputs(table.get_model_uuid(output_rowids)[0]) + else: + # We are only computing up to the parents of the table here. + parent_rowids = _parent_rowids + break + # rowids = rowid_dict[output_node] + return parent_rowids def check_rowids(depc, tablename, input_tuple, config={}): """ @@ -689,11 +686,10 @@ def get_rowids(depc, tablename, input_tuple, **rowid_kw): **_kwargs, ) - with ut.Indenter('[GetRowId-%s]' % (target_tablename,)): - config_ = depc._ensure_config(target_tablename, config) - rowids = table.get_rowid( - parent_rowids, config=config_, recompute=recompute, **_kwargs - ) + config_ = depc._ensure_config(target_tablename, config) + rowids = table.get_rowid( + parent_rowids, config=config_, recompute=recompute, **_kwargs + ) return rowids @ut.accepts_scalar_input2(argx_list=[1]) @@ -794,75 +790,72 @@ def get( if tablename == depc.root_tablename: return depc.root_getters[colnames](root_rowids) # pass - with ut.Indenter('[GetProp-%s]' % (tablename,)): - logger.debug(' * tablename=%s' % (tablename)) - logger.debug(' * root_rowids=%s' % (ut.trunc_repr(root_rowids))) - logger.debug(' * colnames = %r' % (colnames,)) - logger.debug(' * config = %r' % (config,)) - - if hack_paths and not ensure and not read_extern: - # HACK: should be able to not compute rows to get certain properties - from os.path import join - - # recompute_ = recompute or recompute_all - parent_rowids = depc.get_parent_rowids( - tablename, - root_rowids, - config=config, - ensure=True, - recompute_all=False, - eager=True, - nInput=None, - ) - config_ = depc._ensure_config(tablename, config) - logger.debug(' * (ensured) config_ = %r' % (config_,)) - table = depc[tablename] - extern_dpath = table.extern_dpath - ut.ensuredir(extern_dpath) - fname_list = table.get_extern_fnames( - parent_rowids, config=config_, extern_col_index=0 - ) - fpath_list = [join(extern_dpath, fname) for fname in fname_list] - return fpath_list + logger.debug(' * tablename=%s' % (tablename)) + logger.debug(' * root_rowids=%s' % (ut.trunc_repr(root_rowids))) + logger.debug(' * colnames = %r' % (colnames,)) + logger.debug(' * config = %r' % (config,)) - if nInput is None and ut.is_listlike(root_rowids): - nInput = len(root_rowids) + if hack_paths and not ensure and not read_extern: + # HACK: should be able to not compute rows to get certain properties + from os.path import join - rowid_kw = dict( + # recompute_ = recompute or recompute_all + parent_rowids = depc.get_parent_rowids( + tablename, + root_rowids, config=config, - nInput=nInput, - eager=eager, - ensure=ensure, - recompute=recompute, - recompute_all=recompute_all, + ensure=True, + recompute_all=False, + eager=True, + nInput=None, ) - - rowdata_kw = dict( - read_extern=read_extern, - num_retries=num_retries, - eager=eager, - ensure=ensure, - nInput=nInput, + config_ = depc._ensure_config(tablename, config) + logger.debug(' * (ensured) config_ = %r' % (config_,)) + table = depc[tablename] + extern_dpath = table.extern_dpath + ut.ensuredir(extern_dpath) + fname_list = table.get_extern_fnames( + parent_rowids, config=config_, extern_col_index=0 ) + fpath_list = [join(extern_dpath, fname) for fname in fname_list] + return fpath_list - input_tuple = root_rowids + if nInput is None and ut.is_listlike(root_rowids): + nInput = len(root_rowids) - for trynum in range(num_retries + 1): - try: - table = depc[tablename] - # Vectorized get of properties - tbl_rowids = depc.get_rowids(tablename, input_tuple, **rowid_kw) - logger.debug( - '[depc.get] tbl_rowids = %s' % (ut.trunc_repr(tbl_rowids),) - ) - prop_list = table.get_row_data(tbl_rowids, colnames, **rowdata_kw) - except depcache_table.ExternalStorageException: - logger.info('!!* Hit ExternalStorageException') - if trynum == num_retries: - raise - else: - break - logger.debug('* return prop_list=%s' % (ut.trunc_repr(prop_list),)) + rowid_kw = dict( + config=config, + nInput=nInput, + eager=eager, + ensure=ensure, + recompute=recompute, + recompute_all=recompute_all, + ) + + rowdata_kw = dict( + read_extern=read_extern, + num_retries=num_retries, + eager=eager, + ensure=ensure, + nInput=nInput, + ) + + input_tuple = root_rowids + + for trynum in range(num_retries + 1): + try: + table = depc[tablename] + # Vectorized get of properties + tbl_rowids = depc.get_rowids(tablename, input_tuple, **rowid_kw) + logger.debug('[depc.get] tbl_rowids = %s' % (ut.trunc_repr(tbl_rowids),)) + prop_list = table.get_row_data(tbl_rowids, colnames, **rowdata_kw) + except depcache_table.ExternalStorageException: + logger.info('!!* Hit ExternalStorageException') + if trynum == num_retries: + raise + else: + break + logger.debug('* return prop_list=%s' % (ut.trunc_repr(prop_list),)) return prop_list def get_native( @@ -904,35 +897,34 @@ def get_native( >>> print('chips = %r' % (chips,)) """ tbl_rowids = list(tbl_rowids) - with ut.Indenter('[GetNative %s]' % (tablename,)): - logger.debug(' * tablename = %r' % (tablename,)) - logger.debug(' * colnames = %r' % (colnames,)) - logger.debug(' * tbl_rowids=%s' % (ut.trunc_repr(tbl_rowids))) - table = depc[tablename] - # import utool - # with utool.embed_on_exception_context: - # try: - prop_list = table.get_row_data(tbl_rowids, colnames, read_extern=read_extern) - # except depcache_table.ExternalStorageException: - # # This code is a bit rendant and would probably live better elsewhere - # # Also need to fix issues if more than one column specified - # extern_uris = table.get_row_data( - # tbl_rowids, colnames, read_extern=False, - # delete_on_fail=True, ensure=False) - # from os.path import exists - # error_flags = [exists(uri) for uri in extern_uris] - # redo_rowids = ut.compress(tbl_rowids, ut.not_list(error_flags)) - # parent_rowids = table.get_parent_rowids(redo_rowids) - # # config_rowids = table.get_row_cfgid(redo_rowids) - # configs = table.get_row_configs(redo_rowids) - # assert ut.allsame(list(map(id, configs))), 'more than one config not yet supported' - # config = configs[0] - # table.get_rowid(parent_rowids, recompute=True, config=config) - - # # TRY ONE MORE TIME - # prop_list = table.get_row_data(tbl_rowids, colnames, - # read_extern=read_extern, - # delete_on_fail=False) + logger.debug(' * tablename = %r' % (tablename,)) + logger.debug(' * colnames = %r' % (colnames,)) + logger.debug(' * tbl_rowids=%s' % (ut.trunc_repr(tbl_rowids))) + table = depc[tablename] + # import utool + # with utool.embed_on_exception_context: + # try: + prop_list = table.get_row_data(tbl_rowids, colnames, read_extern=read_extern) + # except depcache_table.ExternalStorageException: + # # This code is a bit rendant and would probably live better elsewhere + # # Also need to fix issues if more than one column specified + # extern_uris = table.get_row_data( + # tbl_rowids, colnames, read_extern=False, + # delete_on_fail=True, ensure=False) + # from os.path import exists + # error_flags = [exists(uri) for uri in extern_uris] + # redo_rowids = ut.compress(tbl_rowids, ut.not_list(error_flags)) + # parent_rowids = table.get_parent_rowids(redo_rowids) + # # config_rowids = table.get_row_cfgid(redo_rowids) + # configs = table.get_row_configs(redo_rowids) + # assert ut.allsame(list(map(id, configs))), 'more than one config not yet supported' + # config = configs[0] + # table.get_rowid(parent_rowids, recompute=True, config=config) + + # # TRY ONE MORE TIME + # prop_list = table.get_row_data(tbl_rowids, colnames, + # read_extern=read_extern, + # delete_on_fail=False) return prop_list def get_config_history(depc, tablename, root_rowids, config=None): diff --git a/wbia/dtool/depcache_table.py b/wbia/dtool/depcache_table.py index 50d75e2987..752f46e444 100644 --- a/wbia/dtool/depcache_table.py +++ b/wbia/dtool/depcache_table.py @@ -2052,67 +2052,63 @@ def ensure_rows( num_total = len(parent_ids_) if num_dirty > 0: - with ut.Indenter('[ADD]', enabled=_debug): - logger.debug( - 'Add %d / %d new rows to %r' - % (num_dirty, num_total, table.tablename) - ) - logger.debug( - '[deptbl.add] * config_rowid = {}, config={}'.format( - config_rowid, str(config) - ) + logger.debug( + 'Add %d / %d new rows to %r' % (num_dirty, num_total, table.tablename) + ) + logger.debug( + '[deptbl.add] * config_rowid = {}, config={}'.format( + config_rowid, str(config) ) + ) + + dirty_parent_ids_ = ut.compress(parent_ids_, isdirty_list) + dirty_preproc_args_ = ut.compress(preproc_args, isdirty_list) + + # Process only unique items + unique_flags = ut.flag_unique_items(dirty_parent_ids_) + dirty_parent_ids = ut.compress(dirty_parent_ids_, unique_flags) + dirty_preproc_args = ut.compress(dirty_preproc_args_, unique_flags) + + # Break iterator into chunks + if False and verbose: + # check parent configs we are working with + for x, parname in enumerate(table.parents()): + if parname == table.depc.root: + continue + parent_table = table.depc[parname] + ut.take_column(parent_ids_, x) + rowid_list = ut.take_column(parent_ids_, x) + try: + parent_history = parent_table.get_config_history(rowid_list) + logger.info('parent_history = %r' % (parent_history,)) + except KeyError: + logger.info( + '[depcache_table] WARNING: config history is having troubles... says Jon' + ) - dirty_parent_ids_ = ut.compress(parent_ids_, isdirty_list) - dirty_preproc_args_ = ut.compress(preproc_args, isdirty_list) - - # Process only unique items - unique_flags = ut.flag_unique_items(dirty_parent_ids_) - dirty_parent_ids = ut.compress(dirty_parent_ids_, unique_flags) - dirty_preproc_args = ut.compress(dirty_preproc_args_, unique_flags) - - # Break iterator into chunks - if False and verbose: - # check parent configs we are working with - for x, parname in enumerate(table.parents()): - if parname == table.depc.root: - continue - parent_table = table.depc[parname] - ut.take_column(parent_ids_, x) - rowid_list = ut.take_column(parent_ids_, x) - try: - parent_history = parent_table.get_config_history( - rowid_list - ) - logger.info('parent_history = %r' % (parent_history,)) - except KeyError: - logger.info( - '[depcache_table] WARNING: config history is having troubles... says Jon' - ) - - # Gives the function a hacky cache to use between chunks - table._hack_chunk_cache = {} - gen = table._chunk_compute_dirty_rows( - dirty_parent_ids, dirty_preproc_args, config_rowid, config + # Gives the function a hacky cache to use between chunks + table._hack_chunk_cache = {} + gen = table._chunk_compute_dirty_rows( + dirty_parent_ids, dirty_preproc_args, config_rowid, config + ) + """ + colnames, dirty_params_iter, nChunkInput = next(gen) + """ + for colnames, dirty_params_iter, nChunkInput in gen: + table.db._add( + table.tablename, + colnames, + dirty_params_iter, + nInput=nChunkInput, ) - """ - colnames, dirty_params_iter, nChunkInput = next(gen) - """ - for colnames, dirty_params_iter, nChunkInput in gen: - table.db._add( - table.tablename, - colnames, - dirty_params_iter, - nInput=nChunkInput, - ) - # Remove cache when main add is done - table._hack_chunk_cache = None - logger.debug('[deptbl.add] finished add') - # - # The requested data is clean and must now exist in the parent - # database, do a lookup to ensure the correct order. - rowid_list = table._get_rowid(parent_ids_, config=config) + # Remove cache when main add is done + table._hack_chunk_cache = None + logger.debug('[deptbl.add] finished add') + # + # The requested data is clean and must now exist in the parent + # database, do a lookup to ensure the correct order. + rowid_list = table._get_rowid(parent_ids_, config=config) else: rowid_list = initial_rowid_list logger.debug('[deptbl.add] rowid_list = %s' % ut.trunc_repr(rowid_list)) From 2c6bfd363ff719f13e687a7518da95628b91e179 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sat, 26 Sep 2020 10:14:03 -0700 Subject: [PATCH 041/294] Refactor testing case Refactor how the dummy data is set up. --- wbia/tests/dtool/test_sql_control.py | 80 +++++++--------------------- 1 file changed, 19 insertions(+), 61 deletions(-) diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index ecb8a922cd..c1bde82086 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -486,12 +486,12 @@ def make_table(self, name): '(id INTEGER PRIMARY KEY, x TEXT, y INTEGER, z REAL)' ) - def test_executeone(self): - table_name = 'test_executeone' - self.make_table(table_name) + def populate_table(self, name): + """To be used in conjunction with ``make_table`` to populate the table + with records from 0 to 9. - # Create some dummy records - insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') + """ + insert_stmt = text(f'INSERT INTO {name} (x, y, z) VALUES (:x, :y, :z)') for i in range(0, 10): x, y, z = ( (i % 2) and 'odd' or 'even', @@ -500,6 +500,13 @@ def test_executeone(self): ) self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + def test_executeone(self): + table_name = 'test_executeone' + self.make_table(table_name) + + # Create some dummy records + self.populate_table(table_name) + # Call the testing target result = self.ctrlr.executeone(text(f'SELECT id, y FROM {table_name}')) @@ -511,14 +518,7 @@ def test_executeone_on_insert(self): self.make_table(table_name) # Create some dummy records - insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') - for i in range(0, 10): - x, y, z = ( - (i % 2) and 'odd' or 'even', - i, - i * 2.01, - ) - self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + self.populate_table(table_name) # Call the testing target result = self.ctrlr.executeone( @@ -543,14 +543,7 @@ def test_executemany(self): self.make_table(table_name) # Create some dummy records - insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') - for i in range(0, 10): - x, y, z = ( - (i % 2) and 'odd' or 'even', - i, - i * 2.01, - ) - self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + self.populate_table(table_name) # Call the testing target results = self.ctrlr.executemany( @@ -590,14 +583,7 @@ def test_executeone_for_single_column(self): self.make_table(table_name) # Create some dummy records - insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') - for i in range(0, 10): - x, y, z = ( - (i % 2) and 'odd' or 'even', - i, - i * 2.01, - ) - self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + self.populate_table(table_name) # Call the testing target result = self.ctrlr.executeone(text(f'SELECT y FROM {table_name}')) @@ -633,14 +619,7 @@ def test_get_where_without_where_condition(self): self.make_table(table_name) # Create some dummy records - insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') - for i in range(0, 10): - x, y, z = ( - (i % 2) and 'odd' or 'even', - i, - i * 2.01, - ) - self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + self.populate_table(table_name) # Call the testing target results = self.ctrlr.get_where( @@ -658,14 +637,7 @@ def test_scalar_get_where(self): self.make_table(table_name) # Create some dummy records - insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') - for i in range(0, 10): - x, y, z = ( - (i % 2) and 'odd' or 'even', - i, - i * 2.01, - ) - self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + self.populate_table(table_name) # Call the testing target results = self.ctrlr.get_where( @@ -684,14 +656,7 @@ def test_multi_row_get_where(self): self.make_table(table_name) # Create some dummy records - insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') - for i in range(0, 10): - x, y, z = ( - (i % 2) and 'odd' or 'even', - i, - i * 2.01, - ) - self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + self.populate_table(table_name) # Call the testing target results = self.ctrlr.get_where( @@ -713,14 +678,7 @@ def test_get_where_eq(self): self.make_table(table_name) # Create some dummy records - insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') - for i in range(0, 10): - x, y, z = ( - (i % 2) and 'odd' or 'even', - i, - i * 2.01, - ) - self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + self.populate_table(table_name) # Call the testing target results = self.ctrlr.get_where_eq( From 8af6bfc9a18d92ab5c19ed6214675e19984574ad Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sat, 26 Sep 2020 10:25:34 -0700 Subject: [PATCH 042/294] Add class based division for the api tests Hopefully helps with readability as well as focused testing usage. --- wbia/tests/dtool/test_sql_control.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index c1bde82086..4b4452bc5e 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -473,7 +473,7 @@ def test_delitem_for_database(self): assert isinstance(self.ctrlr.metadata.database.init_uuid, uuid.UUID) -class TestAPI: +class BaseAPITestCase: """Testing the primary *usage* API""" @pytest.fixture(autouse=True) @@ -500,6 +500,8 @@ def populate_table(self, name): ) self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + +class TestExecutionAPI(BaseAPITestCase): def test_executeone(self): table_name = 'test_executeone' self.make_table(table_name) @@ -591,6 +593,8 @@ def test_executeone_for_single_column(self): # Note the unwrapped values, rather than [(i,) ...] assert result == [i for i in range(0, 10)] + +class TestAdditionAPI(BaseAPITestCase): def test_add(self): table_name = 'test_add' self.make_table(table_name) @@ -614,6 +618,8 @@ def test_add(self): expected = [(i + 1, x, y, z) for i, (x, y, z) in enumerate(parameter_values)] assert results.fetchall() == expected + +class TestGettingAPI(BaseAPITestCase): def test_get_where_without_where_condition(self): table_name = 'test_get_where' self.make_table(table_name) @@ -769,6 +775,8 @@ def test_get_as_unique(self): # Verify getting assert data == expected + +class TestSettingAPI(BaseAPITestCase): def test_setting(self): # Note, this is not a comprehensive test. It only attempts to test the SQL logic. # Make a table for records @@ -801,6 +809,8 @@ def test_setting(self): set_rows = sorted(results) assert set_rows == expected + +class TestDeletionAPI(BaseAPITestCase): def test_delete(self): # Make a table for records table_name = 'test_delete' From 8a85d70e8fec5dfcaeea853d811b016283bae4a3 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sun, 27 Sep 2020 15:16:08 -0700 Subject: [PATCH 043/294] Refactor SQLAlchemy engine creation logic Pull this up into a reuseable method since it's the same twice. Also move the engine's creation into the class initialization method. --- wbia/dtool/sql_control.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index befeffb82e..1c68b928d4 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -413,6 +413,14 @@ def __iter__(self): def __len__(self): return len(self.ctrlr.get_table_names()) + 1 # for 'database' + def __init_engine(self): + """Create the SQLAlchemy Engine""" + self._engine = sqlalchemy.create_engine( + self.uri, + # The echo flag is a shortcut to set up SQLAlchemy logging + echo=False, + ) + @classmethod def from_uri(cls, uri, readonly=READ_ONLY, timeout=TIMEOUT): """Creates a controller instance from a connection URI @@ -446,6 +454,8 @@ def from_uri(cls, uri, readonly=READ_ONLY, timeout=TIMEOUT): self.metadata = self.Metadata(self) self.readonly = readonly + self.__init_engine() + self._tablenames = None # FIXME (31-Jul-12020) rename to private attribute self.thread_connections = {} @@ -464,11 +474,6 @@ def from_uri(cls, uri, readonly=READ_ONLY, timeout=TIMEOUT): def connect(self): """Create a connection for the instance or use the existing connection""" - # The echo flag is a shortcut to set up SQLAlchemy logging - self._engine = sqlalchemy.create_engine( - self.uri, - echo=False, - ) self._connection = self._engine.connect() return self._connection @@ -630,10 +635,9 @@ def reboot(self): logger.info('[sql] reboot') self.connection.close() del self.connection - self._engine = sqlalchemy.create_engine( - self.uri, - echo=False, - ) + # Re-initialize the engine + # ??? May be better to use the `dispose()` method? + self.__init_engine() self.connection = self._engine.connect() def backup(self, backup_filepath): From e45b91b4a8427bb537ee9fb34196d61540ee55b4 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sun, 27 Sep 2020 15:25:56 -0700 Subject: [PATCH 044/294] Initialize a SQLAlchemy MetaData object for internal use Setting up to use the metadata object for table reflection. Table reflection is going to allow us to infer the column types. This means we'll be able to use the existing calls to the controller without requesting the users of the controller specify type information. --- wbia/dtool/sql_control.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 1c68b928d4..30aba8bdf0 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -455,6 +455,10 @@ def from_uri(cls, uri, readonly=READ_ONLY, timeout=TIMEOUT): self.readonly = readonly self.__init_engine() + # Create a _private_ SQLAlchemy metadata instance + # TODO (27-Sept-12020) Develop API to expose elements of SQLAlchemy. + # The MetaData is unbound to ensure we don't accidentally misuse it. + self._sa_metadata = sqlalchemy.MetaData() self._tablenames = None # FIXME (31-Jul-12020) rename to private attribute From 477ba3254ef81869e81ba92f4d1607242e5c9660 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sun, 27 Sep 2020 19:55:26 -0700 Subject: [PATCH 045/294] Reflect tables on sql controller initialization This pre-populates the SQLAlchemy metadata object's tables through reflection. --- wbia/dtool/sql_control.py | 3 ++ wbia/tests/dtool/test_sql_control.py | 70 +++++++++++++++++++++------- 2 files changed, 55 insertions(+), 18 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 30aba8bdf0..2b7e83b595 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -460,6 +460,9 @@ def from_uri(cls, uri, readonly=READ_ONLY, timeout=TIMEOUT): # The MetaData is unbound to ensure we don't accidentally misuse it. self._sa_metadata = sqlalchemy.MetaData() + # Reflect all known tables + self._sa_metadata.reflect(bind=self._engine) + self._tablenames = None # FIXME (31-Jul-12020) rename to private attribute self.thread_connections = {} diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index 4b4452bc5e..25fba1f2e4 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -20,6 +20,26 @@ def ctrlr(): return SQLDatabaseController.from_uri('sqlite:///:memory:') +def make_table_definition(name, depends_on=[]): + """Creates a table definition for use with the controller's add_table method""" + definition = { + 'tablename': name, + 'coldef_list': [ + (f'{name}_id', 'INTEGER PRIMARY KEY'), + ('meta_labeler_id', 'INTEGER NOT NULL'), + ('indexer_id', 'INTEGER NOT NULL'), + ('config_id', 'INTEGER DEFAULT 0'), + ('data', 'TEXT'), + ], + 'docstr': f'docstr for {name}', + 'superkeys': [ + ('meta_labeler_id', 'indexer_id', 'config_id'), + ], + 'dependson': depends_on, + } + return definition + + def test_instantiation(ctrlr): # Check for basic connection information assert ctrlr.uri == 'sqlite:///:memory:' @@ -30,6 +50,37 @@ def test_instantiation(ctrlr): assert not ctrlr.connection.closed +def test_instantiation_with_table_reflection(tmp_path): + db_file = (tmp_path / 'testing.db').resolve() + creating_ctrlr = SQLDatabaseController.from_uri(f'sqlite:///{db_file}') + # Assumes `add_table` is functional. If you run into failing problems + # check for failures around this method first. + created_tables = [] + table_names = map( + 'table_{}'.format, + ( + 'a', + 'b', + 'c', + ), + ) + for t in table_names: + creating_ctrlr.add_table(**make_table_definition(t, depends_on=created_tables)) + # Build up of denpendence + created_tables.append(t) + + # Delete the controller + del creating_ctrlr + + # Create the controller again for reflection (testing target) + ctrlr = SQLDatabaseController.from_uri(f'sqlite:///{db_file}') + # Verify the tables are loaded on instantiation + assert list(ctrlr._sa_metadata.tables.keys()) == ['metadata'] + created_tables + # Note, we don't have to check for the contents of the tables, + # because that's machinery within SQLAlchemy, + # which will have been tested by SQLAlchemy. + + class TestSchemaModifiers: """Testing the API that creates, modifies or deletes schema elements""" @@ -37,24 +88,7 @@ class TestSchemaModifiers: def fixture(self, ctrlr): self.ctrlr = ctrlr - def make_table_definition(self, name, depends_on=[]): - """Creates a table definition for use with the controller's add_table method""" - definition = { - 'tablename': name, - 'coldef_list': [ - (f'{name}_id', 'INTEGER PRIMARY KEY'), - ('meta_labeler_id', 'INTEGER NOT NULL'), - ('indexer_id', 'INTEGER NOT NULL'), - ('config_id', 'INTEGER DEFAULT 0'), - ('data', 'TEXT'), - ], - 'docstr': f'docstr for {name}', - 'superkeys': [ - ('meta_labeler_id', 'indexer_id', 'config_id'), - ], - 'dependson': depends_on, - } - return definition + make_table_definition = staticmethod(make_table_definition) @property def _table_factory(self): From e2197da98dcec4e954a669328a6960613699f594 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sun, 27 Sep 2020 21:13:45 -0700 Subject: [PATCH 046/294] Add a method to reflect a table by name This is to be selectively used by internal methods. The key benefit here is acquiring Column objects for transparent insertion and updating of user-defined types, without the developer directly specifing how to translate those types from Python to SQL and vice versa. --- wbia/dtool/sql_control.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 2b7e83b595..a58fa19651 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -19,6 +19,7 @@ import sqlalchemy import utool as ut from deprecated import deprecated +from sqlalchemy.schema import Table from sqlalchemy.sql import text, ClauseElement from wbia.dtool import lite @@ -711,6 +712,14 @@ def squeeze(self): self.shrink_memory() self.vacuum() + def _reflect_table(self, table_name): + """Produces a SQLAlchemy Table object from the given ``table_name``""" + # Note, this on introspects once. Repeated calls will pull the Table object + # from the MetaData object. + return Table( + table_name, self._sa_metadata, autoload=True, autoload_with=self._engine + ) + # ============== # API INTERFACE # ============== From 877eb5fe974a31b39a82b028f40481411c7b0e84 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sun, 27 Sep 2020 21:48:10 -0700 Subject: [PATCH 047/294] Enable the our SQLAlchemy event listeners Importing the module is enough to register any listeners for use. --- wbia/dtool/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/wbia/dtool/__init__.py b/wbia/dtool/__init__.py index 00818297c6..76706f7d05 100644 --- a/wbia/dtool/__init__.py +++ b/wbia/dtool/__init__.py @@ -26,3 +26,4 @@ from wbia.dtool.base import * # NOQA from wbia.dtool.sql_control import SQLDatabaseController from wbia.dtool.types import TYPE_TO_SQLTYPE +import wbia.dtool.events From 45319e6055d76f04275ae2204ad3871436fff937 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sun, 27 Sep 2020 21:49:23 -0700 Subject: [PATCH 048/294] Insert using the SQLAlchemy Table object This does a reflection of the table to acquire column type information, which will then enable the logic to correctly translate the data-types from Python to SQL. --- wbia/dtool/sql_control.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index a58fa19651..f21aa651da 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -770,13 +770,28 @@ def check_rowid_exists(self, tablename, rowid_iter, eager=True, **kwargs): def _add(self, tblname, colnames, params_iter, unpack_scalars=True, **kwargs): """ ADDER NOTE: use add_cleanly """ - columns = ', '.join(colnames) - column_params = ', '.join(f':{col}' for col in colnames) parameterized_values = [ {col: val for col, val in zip(colnames, params)} for params in params_iter ] - stmt = text(f'INSERT INTO {tblname} ({columns}) VALUES ({column_params})') - return self.executemany(stmt, parameterized_values, unpack_scalars=unpack_scalars) + table = self._reflect_table(tblname) + + # It would be possible to do one insert, + # but SQLite is not capable of returning the primary key value after a multi-value insert. + # Thus, we are stuck doing several inserts... ineffecient. + insert_stmt = sqlalchemy.insert(table) + + primary_keys = [] + with self.connection.begin(): # new nested database transaction + for vals in parameterized_values: + result = self.connection.execute(insert_stmt.values(vals)) + + pk = result.inserted_primary_key + if unpack_scalars: + # Assumption at the time of writing this is that the primary key is the SQLite rowid. + # Therefore, we can assume the primary key is a single column value. + pk = pk[0] + primary_keys.append(pk) + return primary_keys def add_cleanly( self, From e75f0c710c5633ff8c5d1a7a12942fa074e81d49 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sun, 27 Sep 2020 22:42:02 -0700 Subject: [PATCH 049/294] Adjust sql controller's get_where_eq method to use SQLAlchemy This pulls in the table so that we can use the Column type information when binding parameters in the conditions. --- wbia/dtool/sql_control.py | 43 +++++++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index f21aa651da..2d893995a6 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -20,7 +20,7 @@ import utool as ut from deprecated import deprecated from sqlalchemy.schema import Table -from sqlalchemy.sql import text, ClauseElement +from sqlalchemy.sql import bindparam, text, ClauseElement from wbia.dtool import lite from wbia.dtool.dump import dumps @@ -949,14 +949,20 @@ def get_where_eq( (default: True) """ - equal_conditions = [f'{c}=:{c}' for c in where_colnames] - where_conditions = f' {op} '.upper().join(equal_conditions) + table = self._reflect_table(tblname) + # Build the equality conditions using column type information. + # This allows us to bind the parameter with the correct type. + equal_conditions = [ + (table.c[c] == bindparam(c, type_=table.c[c].type)) for c in where_colnames + ] + gate_func = {'and': sqlalchemy.and_, 'or': sqlalchemy.or_}[op.lower()] + where_clause = gate_func(*equal_conditions) params = [dict(zip(where_colnames, p)) for p in params_iter] return self.get_where( tblname, colnames, params, - where_conditions, + where_clause, unpack_scalars=unpack_scalars, **kwargs, ) @@ -1023,7 +1029,7 @@ def get_where( colnames (tuple[str]): sequence of column names params_iter (list[dict]): a sequence of dicts with parameters, where each item in the sequence is used in a SQL execution - where_clause (str): conditional statement used in the where clause + where_clause (str|Operation): conditional statement used in the where clause unpack_scalars (bool): [deprecated] use to unpack a single result from each query only use with operations that return a single result for each query (default: True) @@ -1031,21 +1037,24 @@ def get_where( """ if not isinstance(colnames, (tuple, list)): raise TypeError('colnames must be a sequence type of strings') + elif where_clause is not None: + if '?' in str(where_clause): # cast in case it's an SQLAlchemy object + raise ValueError( + "Statements cannot use '?' parameterization, " + "use ':name' parameters instead." + ) + elif isinstance(where_clause, str): + where_clause = text(where_clause) + + table = self._reflect_table(tblname) + stmt = sqlalchemy.select([table.c[c] for c in colnames]) - # Build and execute the query - columns = ', '.join(colnames) - stmt = f'SELECT {columns} FROM {tblname}' if where_clause is None: - val_list = self.executeone(text(stmt), **kwargs) - elif '?' in where_clause: - raise ValueError( - "Statements cannot use '?' parameterization, " - "use ':name' parameters instead." - ) + val_list = self.executeone(stmt, **kwargs) else: - stmt += f' WHERE {where_clause}' + stmt = stmt.where(where_clause) val_list = self.executemany( - text(stmt), + stmt, params_iter, unpack_scalars=unpack_scalars, eager=eager, @@ -1361,7 +1370,7 @@ def executeone(self, operation, params=(), eager=True, verbose=VERBOSE_SQL): # BBB (12-Sept-12020) Retaining insertion rowid result # FIXME postgresql (12-Sept-12020) This won't work in postgres. # Maybe see if ResultProxy.inserted_primary_key will work - if 'insert' in operation.text.lower(): + if 'insert' in str(operation).lower(): # cast in case it's an SQLAlchemy object # BBB (12-Sept-12020) Retaining behavior to unwrap single value rows. return [results.lastrowid] elif not results.returns_rows: From 62e55dec39cd1eae9bd8a1a5593c3306692ca017 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Mon, 28 Sep 2020 00:02:21 -0700 Subject: [PATCH 050/294] Test executeone's no results expectation As noted in the test's comments, this breaks backwards compatiblity, where no results returned an empty list (i.e. `[]`). IMO returning None is correct, because that's the expectation from `fetchone`'s DBAPI spec. That is I see `executeone` as the shortcut for calling `execute` and `fetchone`, which I believe is a fairly correct assumption. Therefore, the returned value expectations should match that of `fetchone` (ignoring the weird `unpack_scalars` behavior). --- wbia/tests/dtool/test_sql_control.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index 25fba1f2e4..107de2bbdc 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -548,6 +548,19 @@ def test_executeone(self): assert result == [(i + 1, i) for i in range(0, 10)] + def test_executeone_without_results(self): + table_name = 'test_executeone' + self.make_table(table_name) + + # Call the testing target + result = self.ctrlr.executeone(text(f'SELECT id, y FROM {table_name}')) + + # Note, this breaks backwards compatiblity, + # where no results returned an empty list (i.e. `[]`). + # IMO returning None is correct, + # because that's the expectation from `fetchone`'s DBAPI spec. + assert result is None + def test_executeone_on_insert(self): # Should return id after an insert table_name = 'test_executeone' From 9b66bf60ec43c1e01f6cc2728c93719c481f7642 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Mon, 28 Sep 2020 00:17:12 -0700 Subject: [PATCH 051/294] Fix get_table_column_data to make it backwards compat Fixed the doctest so that it actually tests something... I've fixed the logic so that it's outputting backwards compatible results for empty column data. This goes back to the change made to `executeone` where the empty result behavior now matches that of `fetchone`'s DBAPI spec. --- wbia/dtool/sql_control.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 2d893995a6..7008044568 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -2280,6 +2280,10 @@ def get_table_column_data( >>> tablename = 'keypoint' >>> db = depc[tablename].db >>> column_list, column_names = db.get_table_column_data(tablename) + >>> column_list + [[], [], [], [], []] + >>> column_names + ['keypoint_rowid', 'chip_rowid', 'config_rowid', 'kpts', 'num'] """ if columns is None: all_column_names = self.get_column_names(tablename) @@ -2293,6 +2297,9 @@ def get_table_column_data( ] else: column_list = [self.get_column(tablename, name) for name in column_names] + # BBB (28-Sept-12020) The previous implementation of `executeone` returned [] + # rather than None for empty rows. + column_list = [x and x or [] for x in column_list] return column_list, column_names def make_json_table_definition(self, tablename): From 986991a148b0644669df988580766370f8515f9b Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Mon, 28 Sep 2020 09:59:29 -0700 Subject: [PATCH 052/294] Fix problematic interpretation of executeone's behavior My interpretation of the `executeone` method's behavior has been slightly incorrect. I believed we were shortcutting the procedure to `execute` and `fetchone`. This turns out to not exactly be true. I've corrected this problem by retaining the, IMO wrong, results of an empty list. But I've added the ability through `use_fetchone_behavior` to correspond with expectations defined in the DBAPI v2 spec. --- wbia/dtool/sql_control.py | 38 +++++++++++++++++++++++----- wbia/tests/dtool/test_sql_control.py | 17 ++++++++++--- 2 files changed, 46 insertions(+), 9 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 7008044568..a55f609d27 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -197,7 +197,9 @@ def version(self): f'SELECT metadata_value FROM {METADATA_TABLE_NAME} WHERE metadata_key = :key' ) try: - return self.ctrlr.executeone(stmt, {'key': 'database_version'})[0] + return self.ctrlr.executeone( + stmt, {'key': 'database_version'}, use_fetchone_behavior=True + )[0] except TypeError: # NoneType return None @@ -218,7 +220,9 @@ def init_uuid(self): f'SELECT metadata_value FROM {METADATA_TABLE_NAME} WHERE metadata_key = :key' ) try: - value = self.ctrlr.executeone(stmt, {'key': 'database_init_uuid'})[0] + value = self.ctrlr.executeone( + stmt, {'key': 'database_init_uuid'}, use_fetchone_behavior=True + )[0] except TypeError: # NoneType return None if value is not None: @@ -289,7 +293,9 @@ def __getattr__(self, name): 'WHERE metadata_key = :key' ) try: - value = self.ctrlr.executeone(statement, {'key': key})[0] + value = self.ctrlr.executeone( + statement, {'key': key}, use_fetchone_behavior=True + )[0] except TypeError: # NoneType return None if METADATA_TABLE_COLUMNS[name]['is_coded_data']: @@ -1356,8 +1362,24 @@ def _executemany_operation_fmt( # SQLDB CORE # ========= - def executeone(self, operation, params=(), eager=True, verbose=VERBOSE_SQL): - """Executes the given ``operation`` once with the given set of ``params``""" + def executeone( + self, + operation, + params=(), + eager=True, + verbose=VERBOSE_SQL, + use_fetchone_behavior=False, + ): + """Executes the given ``operation`` once with the given set of ``params`` + + Args: + operation (str|TextClause): SQL statement + params (sequence|dict): parameters to pass in with SQL execution + eager: [deprecated] no-op + verbose: [deprecated] no-op + use_fetchone_behavior (bool): Use DBAPI ``fetchone`` behavior when outputing no rows (i.e. None) + + """ if not isinstance(operation, ClauseElement): raise TypeError( "'operation' needs to be a sqlalchemy textual sql instance " @@ -1383,7 +1405,11 @@ def executeone(self, operation, params=(), eager=True, verbose=VERBOSE_SQL): for row in results ] ) - if not values: # empty list + # FIXME (28-Sept-12020) No rows results in an empty list. This behavior does not + # match the resulting expectations of `fetchone`'s DBAPI spec. + # If executeone is the shortcut of `execute` and `fetchone`, + # the expectation should be to return according to DBAPI spec. + if use_fetchone_behavior and not values: # empty list values = None return values diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index 107de2bbdc..856716d8d7 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -548,6 +548,19 @@ def test_executeone(self): assert result == [(i + 1, i) for i in range(0, 10)] + def test_executeone_using_fetchone_behavior(self): + table_name = 'test_executeone' + self.make_table(table_name) + + # Call the testing target with `fetchone` method's returning behavior. + result = self.ctrlr.executeone( + text(f'SELECT id, y FROM {table_name}'), use_fetchone_behavior=True + ) + + # IMO returning None is correct, + # because that's the expectation from `fetchone`'s DBAPI spec. + assert result is None + def test_executeone_without_results(self): table_name = 'test_executeone' self.make_table(table_name) @@ -555,11 +568,9 @@ def test_executeone_without_results(self): # Call the testing target result = self.ctrlr.executeone(text(f'SELECT id, y FROM {table_name}')) - # Note, this breaks backwards compatiblity, - # where no results returned an empty list (i.e. `[]`). # IMO returning None is correct, # because that's the expectation from `fetchone`'s DBAPI spec. - assert result is None + assert result == [] def test_executeone_on_insert(self): # Should return id after an insert From 7b8fa1105c1edaa5f9b5b402741595c997f19d81 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Mon, 28 Sep 2020 12:28:59 -0700 Subject: [PATCH 053/294] Deprecate set_db_version in favor of the metadata property --- wbia/dtool/sql_control.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index a55f609d27..7295e25d0d 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -2995,25 +2995,9 @@ def view_db_in_external_reader(self): # ut.cmd(sqlite3_reader, sqlite3_db_fpath) pass + @deprecated("Use 'self.metadata.database.version = version' instead") def set_db_version(self, version): - # Do things properly, get the metadata_rowid (best because we want to assert anyway) - metadata_key_list = ['database_version'] - params_iter = ((metadata_key,) for metadata_key in metadata_key_list) - where_clause = 'metadata_key=?' - # list of relationships for each image - metadata_rowid_list = self.get_where( - METADATA_TABLE_NAME, - ('metadata_rowid',), - params_iter, - where_clause, - unpack_scalars=True, - ) - assert ( - len(metadata_rowid_list) == 1 - ), 'duplicate database_version keys in database' - id_iter = ((metadata_rowid,) for metadata_rowid in metadata_rowid_list) - val_list = ((_,) for _ in [version]) - self.set(METADATA_TABLE_NAME, ('metadata_value',), val_list, id_iter) + self.metadata.database.version = version def get_sql_version(self): """ Conveinience """ From a1ea05a77f83ee4ac9da8230fa3c4b46e9c7aaf9 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Mon, 28 Sep 2020 19:37:30 -0700 Subject: [PATCH 054/294] Fix value to be items in a sequence This error doesn't result in error in the pre-sqlalchemy implementation because the controller only joined the sequence with a `', '.join(...)` anyway. However, the current changes use the sequence of safely named items to programatically create the constraint name, which cannot have spaces. --- wbia/control/DB_SCHEMA.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/wbia/control/DB_SCHEMA.py b/wbia/control/DB_SCHEMA.py index 22d3a7a468..de4c5643bd 100644 --- a/wbia/control/DB_SCHEMA.py +++ b/wbia/control/DB_SCHEMA.py @@ -224,7 +224,12 @@ def update_1_0_0(db, ibs=None): ('feature_keypoints', 'NUMPY'), ('feature_sifts', 'NUMPY'), ), - superkeys=[('chip_rowid, config_rowid',)], + superkeys=[ + ( + 'chip_rowid', + 'config_rowid', + ) + ], docstr=""" Used to store individual chip features (ellipses)""", ) From 9958383b8e4ab85691ef1d04508a27f0ed087108 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Mon, 28 Sep 2020 19:43:36 -0700 Subject: [PATCH 055/294] Add a method to invalidate the table cache We're ending up a stale cache after schema operations. The change ensures we invalidate the cache. --- wbia/dtool/sql_control.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 7295e25d0d..1a2fc384ad 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1817,16 +1817,18 @@ def get_rowid_from_superkey(x): return [None] * len(x) self.add_cleanly(tablename_temp, dst_list, data_list, get_rowid_from_superkey) - if tablename_new is None: + if tablename_new is None: # i.e. not renaming the table # Drop original table - self.drop_table(tablename) + self.drop_table(tablename, invalidate_cache=False) # Rename temp table to original table name - self.rename_table(tablename_temp, tablename) + self.rename_table(tablename_temp, tablename, invalidate_cache=False) else: # Rename new table to new name - self.rename_table(tablename_temp, tablename_new) + self.rename_table(tablename_temp, tablename_new, invalidate_cache=False) + # Any modifications are going to invalidate the cached tables. + self.invalidate_tables_cache() - def rename_table(self, tablename_old, tablename_new): + def rename_table(self, tablename_old, tablename_new, invalidate_cache=True): logger.info( '[sql] schema renaming tablename=%r -> %r' % (tablename_old, tablename_new) ) @@ -1849,8 +1851,10 @@ def rename_table(self, tablename_old, tablename_new): self.set( METADATA_TABLE_NAME, colnames, val_iter, id_iter, id_colname='metadata_key' ) + if invalidate_cache: + self.invalidate_tables_cache() - def drop_table(self, tablename): + def drop_table(self, tablename, invalidate_cache=True): logger.info('[sql] schema dropping tablename=%r' % tablename) # Technically insecure call, but all entries are statically inputted by # the database's owner, who could delete or alter the entire database @@ -1861,6 +1865,8 @@ def drop_table(self, tablename): # Delete table's metadata key_list = [tablename + '_' + suffix for suffix in METADATA_TABLE_COLUMN_NAMES] self.delete(METADATA_TABLE_NAME, key_list, id_colname='metadata_key') + if invalidate_cache: + self.invalidate_tables_cache() def drop_all_tables(self): """ @@ -1869,8 +1875,8 @@ def drop_all_tables(self): self._tablenames = None for tablename in self.get_table_names(): if tablename != 'metadata': - self.drop_table(tablename) - self._tablenames = None + self.drop_table(tablename, invalidate_cache=False) + self.invalidate_tables_cache() # ============== # CONVINENCE @@ -2117,6 +2123,15 @@ def dump_schema(self): file_.write('\t%s%s%s%s%s\n' % col) ut.view_directory(app_resource_dir) + def invalidate_tables_cache(self): + """Invalidates the controller's cache of table names and objects + Resets the caches and/or repopulates them. + + """ + self._tablenames = None + self._sa_metadata = sqlalchemy.MetaData() + self.get_table_names() + def get_table_names(self, lazy=False): """ Conveinience: """ if not lazy or self._tablenames is None: From ea5c7b5207e25f1825b497be2f551f11cd2819e0 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Mon, 28 Sep 2020 19:56:44 -0700 Subject: [PATCH 056/294] Invalidate the tables cache after each migration step This corrects errors in migration process. By running this after each migration step we make certain the controller has no stale table data around that would interfere with the next migration step. --- wbia/control/_sql_helpers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/wbia/control/_sql_helpers.py b/wbia/control/_sql_helpers.py index 6d44b5a030..bc8439d655 100644 --- a/wbia/control/_sql_helpers.py +++ b/wbia/control/_sql_helpers.py @@ -426,10 +426,13 @@ def _check_superkeys(): pre, update, post = db_versions[next_version] if pre is not None: pre(db, ibs=ibs) + db.invalidate_tables_cache() if update is not None: update(db, ibs=ibs) + db.invalidate_tables_cache() if post is not None: post(db, ibs=ibs) + db.invalidate_tables_cache() _check_superkeys() except Exception as ex: if dobackup: From 1b7232980150476a424efa77e718618763febb7e Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Mon, 28 Sep 2020 20:20:22 -0700 Subject: [PATCH 057/294] Make the where clause aware of its parameter type We're using the reflected table to get the correct column type information so that the passed in parameter can be properly adapted to SQL. The problem circumstance is SQLite's rowid column, which doesn't really exist as a column, so a special case is put in. --- wbia/dtool/sql_control.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 1a2fc384ad..e1feaf2b8f 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1182,7 +1182,12 @@ def get( where_clause = None params_iter = [] else: - where_clause = id_colname + ' = :id' + where_clause = text(id_colname + ' = :id') + if id_colname != 'rowid': # b/c rowid doesn't really exist as a column + column = self._reflect_table(tblname).c[id_colname] + where_clause = where_clause.bindparams( + bindparam('id', type_=column.type) + ) params_iter = [{'id': id} for id in id_iter] return self.get_where( From 2087a630450598b75d1ee76f2bb01f56d03498eb Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Mon, 28 Sep 2020 22:34:15 -0700 Subject: [PATCH 058/294] Build the update statment using with typed columns This updates/rewrites the set method's SQL to build from the Table instance. This solves the issue where the id is something other than a simple integer. For example, this was erroring when the id was a UUID object. With this change in place, the id parameter gets the type treatment of the column it is trying to compare with. --- wbia/dtool/sql_control.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index e1feaf2b8f..3a9021887d 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1309,9 +1309,15 @@ def set( raise # Execute the SQL updates for each set of values - assignments = ', '.join([f'{col} = :e{i}' for i, col in enumerate(colnames)]) - where_condition = f'{id_colname} = :id' - stmt = text(f'UPDATE {tblname} SET {assignments} WHERE {where_condition}') + table = self._reflect_table(tblname) + stmt = table.update().values( + **{col: bindparam(f'e{i}') for i, col in enumerate(colnames)} + ) + where_clause = text(id_colname + ' = :id') + if id_colname != 'rowid': # b/c rowid doesn't really exist as a column + id_column = table.c[id_colname] + where_clause = where_clause.bindparams(bindparam('id', type_=id_column.type)) + stmt = stmt.where(where_clause) for i, id in enumerate(id_list): params = {'id': id} params.update({f'e{e}': p for e, p in enumerate(val_list[i])}) From 1b9f5f86b3de1e5eb3b738a47d092d5e767683fa Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Mon, 28 Sep 2020 22:37:07 -0700 Subject: [PATCH 059/294] Fix sql controller's set method to accept mistyped input The problem is that sometimes the method is given inputs that are sequence of sequences, while other times it gets a sequence of item values. It's a terribly inconsistent usage from many places. This puts in a backwards compatiblity set, while at the same time letting the code that follows read as if the problem didn't exist. The other weird thing is that this method sometimes gets called with an empty list of values or ids. This is why the conditions here check the variable for _somethingness_ before looking inside it. --- wbia/dtool/sql_control.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 3a9021887d..7b71aec0bf 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1308,6 +1308,17 @@ def set( ut.printex(ex, key_list=['num_val', 'num_id']) raise + # BBB (28-Sept-12020) This method's usage throughout the codebase allows + # for items in `val_iter` to be a non-sequence value. + has_unsequenced_values = val_list and not isinstance(val_list[0], (tuple, list)) + if has_unsequenced_values: + val_list = [(v,) for v in val_list] + # BBB (28-Sept-12020) This method's usage throughout the codebase allows + # for items in `id_iter` to be a tuple of one value. + has_sequenced_ids = id_list and isinstance(id_list[0], (tuple, list)) + if has_sequenced_ids: + id_list = [x[0] for x in id_list] + # Execute the SQL updates for each set of values table = self._reflect_table(tblname) stmt = table.update().values( From 1c5774016cc2d7d331719b634ce0407f7f986c90 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Mon, 28 Sep 2020 23:11:32 -0700 Subject: [PATCH 060/294] Adjust id parameter name to _identifier SQLAlchemy complains when there is a column named 'id'. Name it something generic, but not going to be found as a column name. --- wbia/dtool/sql_control.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 7b71aec0bf..29dcc76a07 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1182,13 +1182,14 @@ def get( where_clause = None params_iter = [] else: - where_clause = text(id_colname + ' = :id') + id_param_name = '_identifier' + where_clause = text(id_colname + f' = :{id_param_name}') if id_colname != 'rowid': # b/c rowid doesn't really exist as a column column = self._reflect_table(tblname).c[id_colname] where_clause = where_clause.bindparams( - bindparam('id', type_=column.type) + bindparam(id_param_name, type_=column.type) ) - params_iter = [{'id': id} for id in id_iter] + params_iter = [{id_param_name: id} for id in id_iter] return self.get_where( tblname, colnames, params_iter, where_clause, eager=eager, **kwargs @@ -1320,17 +1321,20 @@ def set( id_list = [x[0] for x in id_list] # Execute the SQL updates for each set of values + id_param_name = '_identifier' table = self._reflect_table(tblname) stmt = table.update().values( **{col: bindparam(f'e{i}') for i, col in enumerate(colnames)} ) - where_clause = text(id_colname + ' = :id') + where_clause = text(id_colname + f' = :{id_param_name}') if id_colname != 'rowid': # b/c rowid doesn't really exist as a column id_column = table.c[id_colname] - where_clause = where_clause.bindparams(bindparam('id', type_=id_column.type)) + where_clause = where_clause.bindparams( + bindparam(id_param_name, type_=id_column.type) + ) stmt = stmt.where(where_clause) for i, id in enumerate(id_list): - params = {'id': id} + params = {id_param_name: id} params.update({f'e{e}': p for e, p in enumerate(val_list[i])}) self.connection.execute(stmt, **params) From c25b0c38b95d764621bbf5740bf89430f6cb2ced Mon Sep 17 00:00:00 2001 From: karen chan Date: Thu, 1 Oct 2020 21:52:03 +0100 Subject: [PATCH 061/294] Change representation of uuid in database to bytes UUIDs were being stored like `d1d7ac5067d641f487f09ad5066d98e3` instead of `b'\xa2\xcd\xac\xc9\xa0\x04\x8fC\x8f\xed7\x90w\x15\xf4\x83'` like the original code. --- wbia/dtool/types.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/wbia/dtool/types.py b/wbia/dtool/types.py index ffbe568c09..fa81bf12c4 100644 --- a/wbia/dtool/types.py +++ b/wbia/dtool/types.py @@ -161,10 +161,10 @@ def process(value): return value else: if not isinstance(value, uuid.UUID): - return '%.32x' % uuid.UUID(value).int + return uuid.UUID(value).bytes else: # hexstring - return '%.32x' % value.int + return value.bytes return process @@ -174,7 +174,7 @@ def process(value): return value else: if not isinstance(value, uuid.UUID): - return uuid.UUID(value) + return uuid.UUID(bytes=value) else: return value From 72d07901be84ae16932b4e6fdac63d22079a7dc4 Mon Sep 17 00:00:00 2001 From: karen chan Date: Thu, 1 Oct 2020 23:21:34 +0100 Subject: [PATCH 062/294] Force rowid to be int in SQLDatabaseController.set Sometimes `SQLDatabaseController.set` receives iterator generated from `np.array` as `id_iter`, causing the individual ids to be `numpy.int64`. When this is used in the where clause, it doesn't match any rows and so doesn't update anything. --- wbia/dtool/sql_control.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 29dcc76a07..d3766c2dbd 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1334,6 +1334,8 @@ def set( ) stmt = stmt.where(where_clause) for i, id in enumerate(id_list): + if id_colname == 'rowid': + id = int(id) params = {id_param_name: id} params.update({f'e{e}': p for e, p in enumerate(val_list[i])}) self.connection.execute(stmt, **params) From 8c6c5ab8d61e2c686ff5733620a6f810222fcdba Mon Sep 17 00:00:00 2001 From: karen chan Date: Fri, 2 Oct 2020 15:49:09 +0100 Subject: [PATCH 063/294] Load all test databases in wbia-init-testdbs `wbia-init-testdbs` was only loading testdb1. Add testdb2, wd_peter2, NAUT_test and PZ_MTEST. --- wbia/cli/testdbs.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/wbia/cli/testdbs.py b/wbia/cli/testdbs.py index 1575a1ae17..20dbb4cef9 100644 --- a/wbia/cli/testdbs.py +++ b/wbia/cli/testdbs.py @@ -6,7 +6,13 @@ import click from wbia.dbio import ingest_database -from wbia.init.sysres import get_workdir +from wbia.init.sysres import ( + ensure_nauts, + ensure_pz_mtest, + ensure_testdb2, + ensure_wilddogs, + get_workdir, +) @click.command() @@ -18,6 +24,10 @@ def main(force_replace): dbs = { # : 'testdb1': lambda: ingest_database.ingest_standard_database('testdb1'), + 'PZ_MTEST': ensure_pz_mtest, + 'NAUT_test': ensure_nauts, + 'wd_peter2': ensure_wilddogs, + 'testdb2': ensure_testdb2, } for db in dbs: From 29efc51a13ee00c400f641d41cbc97d02feb28f0 Mon Sep 17 00:00:00 2001 From: karen chan Date: Fri, 2 Oct 2020 17:08:07 +0100 Subject: [PATCH 064/294] Change sql statement from ? to : --- wbia/control/manual_gsgrelate_funcs.py | 9 ++++++--- wbia/control/manual_review_funcs.py | 10 ++++------ wbia/dtool/sql_control.py | 6 +++--- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/wbia/control/manual_gsgrelate_funcs.py b/wbia/control/manual_gsgrelate_funcs.py index 9310832a6a..09c8384175 100644 --- a/wbia/control/manual_gsgrelate_funcs.py +++ b/wbia/control/manual_gsgrelate_funcs.py @@ -173,10 +173,13 @@ def get_gsgr_rowid_from_superkey(ibs, gid_list, imgsetid_list): Returns: gsgrid_list (list): eg-relate-ids from info constrained to be unique (imgsetid, gid)""" colnames = ('image_rowid',) + where_colnames = ('image_rowid', 'imageset_rowid') params_iter = zip(gid_list, imgsetid_list) - where_clause = 'image_rowid=? AND imageset_rowid=?' - gsgrid_list = ibs.db.get_where( - const.GSG_RELATION_TABLE, colnames, params_iter, where_clause + gsgrid_list = ibs.db.get_where_eq( + const.GSG_RELATION_TABLE, + colnames, + params_iter, + where_colnames, ) return gsgrid_list diff --git a/wbia/control/manual_review_funcs.py b/wbia/control/manual_review_funcs.py index db1ebdc00e..b600842f8c 100644 --- a/wbia/control/manual_review_funcs.py +++ b/wbia/control/manual_review_funcs.py @@ -612,12 +612,11 @@ def get_review_rowids_from_single(ibs, aid_list, eager=True, nInput=None): def get_review_rowids_from_aid1(ibs, aid_list, eager=True, nInput=None): colnames = (REVIEW_ROWID,) params_iter = [(aid,) for aid in aid_list] - where_clause = '%s=?' % (REVIEW_AID1,) - review_rowids = ibs.staging.get_where( + review_rowids = ibs.staging.get_where_eq( const.REVIEW_TABLE, colnames, params_iter, - where_clause=where_clause, + (REVIEW_AID1,), unpack_scalars=False, ) return review_rowids @@ -627,12 +626,11 @@ def get_review_rowids_from_aid1(ibs, aid_list, eager=True, nInput=None): def get_review_rowids_from_aid2(ibs, aid_list, eager=True, nInput=None): colnames = (REVIEW_ROWID,) params_iter = [(aid,) for aid in aid_list] - where_clause = '%s=?' % (REVIEW_AID2,) - review_rowids = ibs.staging.get_where( + review_rowids = ibs.staging.get_where_eq( const.REVIEW_TABLE, colnames, params_iter, - where_clause=where_clause, + (REVIEW_AID2,), unpack_scalars=False, ) return review_rowids diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index d3766c2dbd..e6a996770b 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1512,9 +1512,9 @@ def set_metadata_val(self, key, val): 'tablename': METADATA_TABLE_NAME, 'columns': 'metadata_key, metadata_value', } - op_fmtstr = 'INSERT OR REPLACE INTO {tablename} ({columns}) VALUES (?, ?)' - operation = op_fmtstr.format(**fmtkw) - params = [key, val] + op_fmtstr = 'INSERT OR REPLACE INTO {tablename} ({columns}) VALUES (:key, :val)' + operation = text(op_fmtstr.format(**fmtkw)) + params = {'key': key, 'val': val} self.executeone(operation, params, verbose=False) @deprecated('Use metadata property instead') From d96bbb1865617ae2aed38382ede34d524ae61306 Mon Sep 17 00:00:00 2001 From: karen chan Date: Mon, 5 Oct 2020 23:49:43 +0100 Subject: [PATCH 065/294] Add keepwrap to SQLDatabaseController.executemany In `DependencyCacheTable.get_row_data`, `raw_prop_list` is expected to keep the tuples in the results, so: ``` [('chips_img_id=1_vxqakzqofnfttbuh.png',), ('chips_img_id=2_vxqakzqofnfttbuh.png',), ('chips_img_id=3_vxqakzqofnfttbuh.png',)] ``` but the code returned: ``` ['chips_img_id=1_vxqakzqofnfttbuh.png', 'chips_img_id=2_vxqakzqofnfttbuh.png', 'chips_img_id=3_vxqakzqofnfttbuh.png'] ``` It seems we removed `keepwrap` which changes whether the tuples are returned. --- wbia/dtool/sql_control.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index e6a996770b..2227b4d073 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1397,6 +1397,7 @@ def executeone( eager=True, verbose=VERBOSE_SQL, use_fetchone_behavior=False, + keepwrap=False, ): """Executes the given ``operation`` once with the given set of ``params`` @@ -1429,7 +1430,7 @@ def executeone( values = list( [ # BBB (12-Sept-12020) Retaining behavior to unwrap single value rows. - row[0] if len(row) == 1 else row + row[0] if not keepwrap and len(row) == 1 else row for row in results ] ) @@ -1441,7 +1442,9 @@ def executeone( values = None return values - def executemany(self, operation, params_iter, unpack_scalars=True, **kwargs): + def executemany( + self, operation, params_iter, unpack_scalars=True, keepwrap=False, **kwargs + ): """Executes the given ``operation`` once for each item in ``params_iter`` Args: @@ -1463,7 +1466,7 @@ def executemany(self, operation, params_iter, unpack_scalars=True, **kwargs): results = [] with self.connection.begin(): for params in params_iter: - value = self.executeone(operation, params) + value = self.executeone(operation, params, keepwrap=keepwrap) # Should only be used when the user wants back on value. # Let the error bubble up if used wrong. # Deprecated... Do not depend on the unpacking behavior. From 2b294493cd232e6bba262452febfa5b5eb445d6a Mon Sep 17 00:00:00 2001 From: karen chan Date: Wed, 7 Oct 2020 19:57:58 +0100 Subject: [PATCH 066/294] Remove fetchall() from "create index" sql statements When running the tests we get this error: ``` Traceback (most recent call last): File "/virtualenv/env3/lib/python3.7/site-packages/xdoctest/doctest_example.py", line 598, in run exec(code, test_globals) File "", line rel: 5, abs: 51, in >>> autoinit='staging', verbose=4) File "/wbia/wildbook-ia/wbia/algo/graph/core.py", line 1378, in __init__ infr.reset_feedback(autoinit) File "/wbia/wildbook-ia/wbia/algo/graph/core.py", line 470, in reset_feedback infr.external_feedback = infr.read_wbia_staging_feedback() File "/wbia/wildbook-ia/wbia/algo/graph/mixin_wbia.py", line 554, in read_wbia_staging_feedback hack_create_aidpair_index(ibs) File "/wbia/wildbook-ia/wbia/control/manual_review_funcs.py", line 65, in hack_create_aidpair_index ibs.staging.connection.execute(sqlcmd).fetchall() File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/engine/result.py", line 1289, in fetchall e, None, None, self.cursor, self.context File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/engine/base.py", line 1514, in _handle_dbapi_exception util.raise_(exc_info[1], with_traceback=exc_info[2]) File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/util/compat.py", line 182, in raise_ raise exception File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/engine/result.py", line 1284, in fetchall l = self.process_rows(self._fetchall_impl()) File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/engine/result.py", line 1232, in _fetchall_impl return self._non_result([], err) File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/engine/result.py", line 1241, in _non_result replace_context=err, File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/util/compat.py", line 182, in raise_ raise exception sqlalchemy.exc.ResourceClosedError: This result object does not return rows. It has been closed automatically. DOCTEST REPRODUCTION CommandLine: pytest /wbia/wildbook-ia/wbia/algo/graph/mixin_loops.py::InfrLoops.main_gen:0 ``` --- wbia/control/manual_review_funcs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/wbia/control/manual_review_funcs.py b/wbia/control/manual_review_funcs.py index b600842f8c..e585ddfe65 100644 --- a/wbia/control/manual_review_funcs.py +++ b/wbia/control/manual_review_funcs.py @@ -62,19 +62,19 @@ def hack_create_aidpair_index(ibs): table=ibs.const.REVIEW_TABLE, index_cols=','.join([REVIEW_AID1, REVIEW_AID2]), ) - ibs.staging.connection.execute(sqlcmd).fetchall() + ibs.staging.connection.execute(sqlcmd) sqlcmd = sqlfmt.format( index_name='aid1_to_rowids', table=ibs.const.REVIEW_TABLE, index_cols=','.join([REVIEW_AID1]), ) - ibs.staging.connection.execute(sqlcmd).fetchall() + ibs.staging.connection.execute(sqlcmd) sqlcmd = sqlfmt.format( index_name='aid2_to_rowids', table=ibs.const.REVIEW_TABLE, index_cols=','.join([REVIEW_AID2]), ) - ibs.staging.connection.execute(sqlcmd).fetchall() + ibs.staging.connection.execute(sqlcmd) @register_ibs_method From 36fd28a8dbd2709ab36bbbd42a997fb2f6bc56d0 Mon Sep 17 00:00:00 2001 From: karen chan Date: Thu, 8 Oct 2020 15:55:48 +0100 Subject: [PATCH 067/294] Register Integer columns in SQL_TYPE_TO_SA_TYPE Without this change, the `Integer` class doesn't actually get used. --- wbia/dtool/types.py | 4 +++- wbia/tests/dtool/test_sql_control.py | 15 ++++++--------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/wbia/dtool/types.py b/wbia/dtool/types.py index fa81bf12c4..e373710a5f 100644 --- a/wbia/dtool/types.py +++ b/wbia/dtool/types.py @@ -119,7 +119,8 @@ class Integer(TypeDecorator): impl = SAInteger def process_bind_param(self, value, dialect): - return int(value) + if value is not None: + return int(value) class List(JSONCodeableType): @@ -184,3 +185,4 @@ def process(value): _USER_DEFINED_TYPES = (Dict, List, NDArray, Number, UUID) # SQL type (e.g. 'DICT') to SQLAlchemy type: SQL_TYPE_TO_SA_TYPE = {cls().get_col_spec(): cls for cls in _USER_DEFINED_TYPES} +SQL_TYPE_TO_SA_TYPE['INTEGER'] = Integer diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index 856716d8d7..ca051cd6b2 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -143,17 +143,14 @@ def test_add_table(self): # Check the table's column definitions expected_bars_columns = [ - ('bars_id', 'INTEGER'), - ('config_id', 'INTEGER'), - ('data', 'TEXT'), - ('indexer_id', 'INTEGER'), - ('meta_labeler_id', 'INTEGER'), + ('bars_id', 'wbia.dtool.types.Integer'), + ('config_id', 'wbia.dtool.types.Integer'), + ('data', 'sqlalchemy.sql.sqltypes.TEXT'), + ('indexer_id', 'wbia.dtool.types.Integer'), + ('meta_labeler_id', 'wbia.dtool.types.Integer'), ] found_bars_columns = [ - ( - c.name, - c.type.__class__.__name__, - ) + (c.name, '.'.join([c.type.__class__.__module__, c.type.__class__.__name__])) for c in bars.columns ] assert sorted(found_bars_columns) == expected_bars_columns From c00187ab944ca00305ffe557e0435da4308b700f Mon Sep 17 00:00:00 2001 From: karen chan Date: Thu, 8 Oct 2020 16:56:17 +0100 Subject: [PATCH 068/294] Fix TypeError in _TableComputeHelper.prepare_storage ``` DOCTEST TRACEBACK Traceback (most recent call last): File "/virtualenv/env3/lib/python3.7/site-packages/xdoctest/doctest_example.py", line 598, in run exec(code, test_globals) File "", line rel: 10, abs: 656, in >>> rowids = depc.get_rowids('keypoint', aids, ensure=True, config=config) File "/wbia/wildbook-ia/wbia/dtool/depcache_control.py", line 691, in get_rowids parent_rowids, config=config_, recompute=recompute, **_kwargs File "/wbia/wildbook-ia/wbia/dtool/depcache_table.py", line 2324, in get_rowid config=config, File "/wbia/wildbook-ia/wbia/dtool/depcache_table.py", line 2097, in ensure_rows for colnames, dirty_params_iter, nChunkInput in gen: File "/wbia/wildbook-ia/wbia/dtool/depcache_table.py", line 1770, in _chunk_compute_dirty_rows config, File "/wbia/wildbook-ia/wbia/dtool/depcache_table.py", line 1688, in _compute_dirty_rows proptup_gen = list(proptup_gen) File "/wbia/wildbook-ia/wbia/dtool/example_depcache.py", line 316, in dummy_preproc_kpts chip_fpath_list = depc.get_native('chip', chip_rowids, 'chip', read_extern=False) File "/wbia/wildbook-ia/wbia/dtool/depcache_control.py", line 907, in get_native prop_list = table.get_row_data(tbl_rowids, colnames, read_extern=read_extern) File "/wbia/wildbook-ia/wbia/dtool/depcache_table.py", line 2741, in get_row_data tries_left, File "/wbia/wildbook-ia/wbia/dtool/depcache_table.py", line 2833, in _resolve_any_external_data table._recompute_external_storage(failed_rowids) File "/wbia/wildbook-ia/wbia/dtool/depcache_table.py", line 2865, in _recompute_external_storage parent_ids, parent_args, config_rowid=cfgid, config=config File "/wbia/wildbook-ia/wbia/dtool/depcache_table.py", line 1701, in _compute_dirty_rows dirty_params_iter = list(dirty_params_iter) File "/wbia/wildbook-ia/wbia/dtool/depcache_table.py", line 1429, in prepare_storage + parent_extra TypeError: unsupported operand type(s) for +: 'RowProxy' and 'tuple' DOCTEST REPRODUCTION CommandLine: pytest /wbia/wildbook-ia/wbia/dtool/depcache_control.py::_CoreDependencyCache.get_rowids:1 ``` --- wbia/dtool/depcache_table.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wbia/dtool/depcache_table.py b/wbia/dtool/depcache_table.py index 752f46e444..a48fb6868e 100644 --- a/wbia/dtool/depcache_table.py +++ b/wbia/dtool/depcache_table.py @@ -1422,7 +1422,7 @@ def prepare_storage( # fname in zip(multi_args, # multi_fpaths)])) row_tup = ( - ids_ + tuple(ids_) + (config_rowid,) + quick_access_tup + data_cols From 253303fad12ffcae955d1cd0a7874764d96c4e6c Mon Sep 17 00:00:00 2001 From: karen chan Date: Thu, 8 Oct 2020 17:00:35 +0100 Subject: [PATCH 069/294] Wrap sql statement in text() in SQLDatabaseController._executeone_operation_fmt ``` DOCTEST TRACEBACK Traceback (most recent call last): File "/virtualenv/env3/lib/python3.7/site-packages/xdoctest/doctest_example.py", line 598, in run exec(code, test_globals) File "", line rel: 7, abs: 1446, in >>> depc.delete_root(root_rowids) File "/wbia/wildbook-ia/wbia/dtool/depcache_control.py", line 1454, in delete_root root_rowids, table_config_filter File "/wbia/wildbook-ia/wbia/dtool/depcache_control.py", line 1518, in get_allconfig_descendant_rowids op='AND', File "/wbia/wildbook-ia/wbia/dtool/sql_control.py", line 1018, in get_where_eq_set return self._executeone_operation_fmt(operation_fmt, fmtdict, **kwargs) File "/wbia/wildbook-ia/wbia/dtool/sql_control.py", line 1367, in _executeone_operation_fmt return self.executeone(operation, params, eager=eager, **kwargs) File "/wbia/wildbook-ia/wbia/dtool/sql_control.py", line 1414, in executeone "'operation' needs to be a sqlalchemy textual sql instance " TypeError: 'operation' needs to be a sqlalchemy textual sql instance see docs on 'sqlalchemy.sql:text' factory function; 'operation' is a '' DOCTEST REPRODUCTION CommandLine: pytest /wbia/wildbook-ia/wbia/dtool/depcache_control.py::DependencyCache.delete_root:0 ``` --- wbia/dtool/sql_control.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 2227b4d073..b9ffc71ddf 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1364,7 +1364,7 @@ def _executeone_operation_fmt( if params is None: params = [] operation = operation_fmt.format(**fmtdict) - return self.executeone(operation, params, eager=eager, **kwargs) + return self.executeone(text(operation), params, eager=eager, **kwargs) @profile def _executemany_operation_fmt( From fc8c666f5250ed34ce0f7dc3f2cc57c0bda22cae Mon Sep 17 00:00:00 2001 From: karen chan Date: Fri, 9 Oct 2020 00:17:35 +0100 Subject: [PATCH 070/294] Handle duplicate columns in SQLDatabaseController.executeone Specifically, this test was failing: ``` DOCTEST TRACEBACK Traceback (most recent call last): File "/virtualenv/env3/lib/python3.7/site-packages/xdoctest/doctest_example.py", line 598, in run exec(code, test_globals) File "", line rel: 10, abs: 2585, in >>> prop_list = table.get_row_data(tbl_rowids, colnames, **kwargs) File "/wbia/wildbook-ia/wbia/dtool/depcache_table.py", line 2741, in get_row_data tries_left, File "/wbia/wildbook-ia/wbia/dtool/depcache_table.py", line 2794, in _resolve_any_external_data for uri in prop_listT[extern_colx]: IndexError: list index out of range DOCTEST REPRODUCTION CommandLine: pytest /wbia/wildbook-ia/wbia/dtool/depcache_table.py::DependencyCacheTable.get_row_data:0 /wbia/wildbook-ia/wbia/dtool/depcache_table.py:2585: IndexError ``` The test was getting `size_1`, `size_0`, `size_1`, `chip_extern_uri`, `chip_extern_uri` from the `chip` table. SQLAlchemy generated sql that looks like: ``` SELECT chip.size_1, chip.size_0, chip.chip_extern_uri FROM chip WHERE chip_rowid = :_identifier ``` removing all the duplicate columns, causing the `IndexError` above. --- wbia/dtool/sql_control.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index b9ffc71ddf..235b6652be 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1066,6 +1066,7 @@ def get_where( eager=eager, **kwargs, ) + return val_list def exists_where_eq( @@ -1427,6 +1428,24 @@ def executeone( elif not results.returns_rows: return None else: + if isinstance(operation, sqlalchemy.sql.selectable.Select): + # This code is specifically for handling duplication in colnames + # because sqlalchemy removes them. + # e.g. select field1, field1, field2 from table; + # becomes + # select field1, field2 from table; + # so the items in val_list only have 2 values + # but the caller isn't expecting it so it causes problems + returned_columns = tuple([c.name for c in operation.columns]) + raw_columns = tuple([c.name for c in operation._raw_columns]) + if raw_columns != returned_columns: + results_ = [] + for r in results: + results_.append( + tuple(r[returned_columns.index(c)] for c in raw_columns) + ) + results = results_ + values = list( [ # BBB (12-Sept-12020) Retaining behavior to unwrap single value rows. From 13bcb46a76e9ec9f9e1f982b3fcf15ac9028584b Mon Sep 17 00:00:00 2001 From: karen chan Date: Fri, 9 Oct 2020 16:43:34 +0100 Subject: [PATCH 071/294] Change db.cur to db.connection in other/ibsfuncs.py --- wbia/other/ibsfuncs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/wbia/other/ibsfuncs.py b/wbia/other/ibsfuncs.py index e411acb0d3..54a566bd9b 100644 --- a/wbia/other/ibsfuncs.py +++ b/wbia/other/ibsfuncs.py @@ -4875,7 +4875,7 @@ def filter_aids_to_quality(ibs, aid_list, minqual, unknown_ok=True, speedhack=Tr else: operation = 'SELECT rowid from annotations WHERE annot_quality NOTNULL AND annot_quality>={minqual_int} AND rowid IN ({aids})' operation = operation.format(aids=list_repr, minqual_int=minqual_int) - aid_list_ = ut.take_column(ibs.db.cur.execute(operation).fetchall(), 0) + aid_list_ = ut.take_column(ibs.db.connection.execute(operation).fetchall(), 0) else: qual_flags = list( ibs.get_quality_filterflags(aid_list, minqual, unknown_ok=unknown_ok) @@ -4965,7 +4965,7 @@ def filter_aids_without_name(ibs, aid_list, invert=False, speedhack=True): 'SELECT rowid from annotations WHERE name_rowid>0 AND rowid IN (%s)' % (list_repr,) ) - aid_list_ = ut.take_column(ibs.db.cur.execute(operation).fetchall(), 0) + aid_list_ = ut.take_column(ibs.db.connection.execute(operation).fetchall(), 0) else: flag_list = ibs.is_aid_unknown(aid_list) if not invert: @@ -5101,7 +5101,7 @@ def filter_aids_to_species(ibs, aid_list, species, speedhack=True): list_repr = ','.join(map(str, aid_list)) operation = 'SELECT rowid from annotations WHERE (species_rowid == {species_rowid}) AND rowid IN ({aids})' operation = operation.format(aids=list_repr, species_rowid=species_rowid) - aid_list_ = ut.take_column(ibs.db.cur.execute(operation).fetchall(), 0) + aid_list_ = ut.take_column(ibs.db.connection.execute(operation).fetchall(), 0) else: species_rowid_list = ibs.get_annot_species_rowids(aid_list) is_valid_species = [sid == species_rowid for sid in species_rowid_list] From 77a5bf99caf167126657cb66205ff9dffdf2e76b Mon Sep 17 00:00:00 2001 From: karen chan Date: Fri, 9 Oct 2020 22:17:58 +0100 Subject: [PATCH 072/294] Fix superkey constraint error in create table sql Fix error if superkeys is None when creating "create table" sql ``` DOCTEST TRACEBACK Traceback (most recent call last): File "/virtualenv/env3/lib/python3.7/site-packages/xdoctest/doctest_example.py", line 598, in run exec(code, test_globals) File "", line rel: 4, abs: 1442, in >>> ibs = wbia.opendb('testdb2') File "/wbia/wildbook-ia/wbia/entry_points.py", line 573, in opendb ibs = _init_wbia(dbdir, verbose=verbose, use_cache=use_cache, web=web, **kwargs) File "/wbia/wildbook-ia/wbia/entry_points.py", line 92, in _init_wbia force_serial=force_serial, File "/wbia/wildbook-ia/wbia/control/IBEISControl.py", line 279, in request_IBEISController request_stagingversion=request_stagingversion, File "/wbia/wildbook-ia/wbia/control/IBEISControl.py", line 369, in __init__ request_stagingversion=request_stagingversion, File "/wbia/wildbook-ia/wbia/control/IBEISControl.py", line 601, in _init_sql ibs._init_sqldbcore(request_dbversion=request_dbversion) File "/wbia/wildbook-ia/wbia/control/IBEISControl.py", line 696, in _init_sqldbcore dobackup=not ibs.readonly, File "/wbia/wildbook-ia/wbia/control/_sql_helpers.py", line 340, in ensure_correct_version ibs, db, schema_spec, version, version_expected, dobackup=dobackup File "/wbia/wildbook-ia/wbia/control/_sql_helpers.py", line 431, in update_schema_version update(db, ibs=ibs) File "/wbia/wildbook-ia/wbia/control/DB_SCHEMA.py", line 1332, in update_1_4_3 (None, 'annotmatch_is_nondistinct', 'INTEGER', None), File "/wbia/wildbook-ia/wbia/dtool/sql_control.py", line 1839, in modify_table self.add_table(tablename_temp, coldef_list, **metadata_keyval2) File "/wbia/wildbook-ia/wbia/dtool/sql_control.py", line 1665, in add_table operation = self._make_add_table_sqlstr(tablename, coldef_list, **metadata_keyval) File "/wbia/wildbook-ia/wbia/dtool/sql_control.py", line 1645, in _make_add_table_sqlstr self.__make_unique_constraint(x) for x in metadata_keyval.get('superkeys', []) TypeError: 'NoneType' object is not iterable DOCTEST REPRODUCTION CommandLine: pytest /wbia/wildbook-ia/wbia/control/manual_annot_funcs.py::get_annot_contact_aids:1 ``` Fix error superkey constraint in create table sql When modifying tables in a migration, the create table sql generated returns an error: ``` sqlalchemy.exc.OperationalError: (sqlite3.OperationalError) near "encounter_rowid": syntax error [SQL: CREATE TABLE IF NOT EXISTS encounter_image_relationship_temp1fedefaa ( egr_rowid INTEGER PRIMARY KEY, image_rowid INTEGER NOT NULL, encounter_rowid INTEGER, CONSTRAINT unique_image_rowid, encounter_rowid UNIQUE (image_rowid, encounter_rowid) )] (Background on this error at: http://sqlalche.me/e/13/e3q8) DOCTEST REPRODUCTION CommandLine: pytest /wbia/wildbook-ia/wbia/control/manual_annot_funcs.py::get_annot_thetas:0 ``` This was caused by table metadata returning `[('image_rowid, encounter_rowid',)]` as the superkeys value instead of `[('image_rowid',), ('encounter_rowid',)]`. --- wbia/dtool/sql_control.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 235b6652be..6f0635904a 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -300,6 +300,11 @@ def __getattr__(self, name): return None if METADATA_TABLE_COLUMNS[name]['is_coded_data']: value = eval(value) + if name == 'superkeys' and isinstance(value, list): + # superkeys looks like [('image_rowid, encounter_rowid',)] + # instead of [('image_rowid',), ('encounter_rowid',)] + if len(value) == 1 and len(value[0]) == 1: + value = [tuple(value[0][0].split(', '))] return value def __getattribute__(self, name): @@ -1642,7 +1647,8 @@ def _make_add_table_sqlstr( # Make a list of constraints to place on the table # superkeys = [(, ...), ...] constraint_list = [ - self.__make_unique_constraint(x) for x in metadata_keyval.get('superkeys', []) + self.__make_unique_constraint(x) + for x in metadata_keyval.get('superkeys') or [] ] constraint_list = ut.unique_ordered(constraint_list) From cffdca9a69e35d1957bbdae8c0d534fb2ca6f3c7 Mon Sep 17 00:00:00 2001 From: karen chan Date: Sat, 10 Oct 2020 15:49:17 +0100 Subject: [PATCH 073/294] Remove .fetchall() from create index sql statement ``` sqlalchemy.exc.ResourceClosedError: This result object does not return rows. It has been closed automatically. DOCTEST REPRODUCTION CommandLine: pytest /wbia/wildbook-ia/wbia/control/manual_gsgrelate_funcs.py::delete_empty_imgsetids:0 ``` --- wbia/control/manual_imageset_funcs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wbia/control/manual_imageset_funcs.py b/wbia/control/manual_imageset_funcs.py index 48c498128d..cb870002bb 100644 --- a/wbia/control/manual_imageset_funcs.py +++ b/wbia/control/manual_imageset_funcs.py @@ -549,7 +549,7 @@ def get_imageset_gids(ibs, imgsetid_list): """.format( GSG_RELATION_TABLE=const.GSG_RELATION_TABLE ) - ).fetchall() + ) gids_list = ibs.db.get( const.GSG_RELATION_TABLE, ('image_rowid',), From 2d3fe891519d0690e3d11ef59afeae9d5c5fadb6 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sat, 10 Oct 2020 18:34:13 -0700 Subject: [PATCH 074/294] Use primary key when checking for rowid existence ``` Traceback (most recent call last): File "/usr/local/lib/python3.6/dist-packages/xdoctest/doctest_example.py", line 598, in run exec(code, test_globals) File "", line rel: 20, abs: 89, in >>> merge_databases(ibs_src, ibs_dst) File "/code/wbia/dbio/export_subset.py", line 104, in merge_databases ibs_src.fix_invalid_annotmatches() File "/code/wbia/other/ibsfuncs.py", line 1441, in fix_invalid_annotmatches invalid_annotmatch_rowids = ibs.check_annotmatch_consistency() File "/code/wbia/other/ibsfuncs.py", line 1043, in check_annotmatch_consistency exists1_list = ibs.db.check_rowid_exists(const.ANNOTATION_TABLE, aid1_list) File "/code/wbia/dtool/sql_control.py", line 786, in check_rowid_exists rowid_list1 = self.get(tablename, ('rowid',), rowid_iter) File "/code/wbia/dtool/sql_control.py", line 1226, in get tblname, colnames, params_iter, where_clause, eager=eager, **kwargs File "/code/wbia/dtool/sql_control.py", line 1069, in get_where stmt = sqlalchemy.select([table.c[c] for c in colnames]) File "/code/wbia/dtool/sql_control.py", line 1069, in stmt = sqlalchemy.select([table.c[c] for c in colnames]) File "/usr/local/lib/python3.6/dist-packages/sqlalchemy/util/_collections.py", line 194, in __getitem__ return self._data[key] KeyError: 'rowid' ``` A KeyError is being raised because we are inquiring about the 'rowid' column on a sqlalchemy table object that does not have a column named 'rowid'. This column is an alias for the primary key when the primary key is an integer. The SQLAlchemy table object is being used to get the column types for parameter binding, but it only works with real defined columns. The change here is to use the actual primary key instead of rowid, which should be one and the same. --- wbia/dtool/sql_control.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 6f0635904a..4f38b7857d 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -775,7 +775,20 @@ def get_all_rowids_where(self, tblname, where_clause, params, **kwargs): return self._executeone_operation_fmt(operation_fmt, fmtdict, params, **kwargs) def check_rowid_exists(self, tablename, rowid_iter, eager=True, **kwargs): - rowid_list1 = self.get(tablename, ('rowid',), rowid_iter) + """Check for the existence of rows (``rowid_iter``) in a table (``tablename``). + Returns as sequence of rowids that exist in the given sequence. + + The 'rowid' term is an alias for the primary key. When calling this method, + you should know that the primary key may be more than one column. + + """ + # BBB (10-Oct-12020) 'rowid' only exists in SQLite and auto-magically gets mapped + # to an integer primary key. However, SQLAlchemy doesn't abide by this magic. + # The aliased column is not part of a reflected table. + # So we find and use the primary key instead. + table = self._reflect_table(tablename) + columns = tuple(c.name for c in table.primary_key.columns) + rowid_list1 = self.get(tablename, columns, rowid_iter) exists_list = [rowid is not None for rowid in rowid_list1] return exists_list From b05da67c9a5837c1a2a0d11391951e36c1cd3f58 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sat, 10 Oct 2020 20:48:26 -0700 Subject: [PATCH 075/294] Fix get_rowid_from_superkey to use primary key Change to non-? based query and reuse logic in get_where_eq, which is essentially was doing. --- wbia/dtool/sql_control.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 4f38b7857d..188cbd8c93 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1127,8 +1127,12 @@ def get_rowid_from_superkey( self, tblname, params_iter=None, superkey_colnames=None, **kwargs ): """ getter which uses the constrained superkeys instead of rowids """ - where_clause = ' AND '.join([colname + '=?' for colname in superkey_colnames]) - return self.get_where(tblname, ('rowid',), params_iter, where_clause, **kwargs) + # ??? Why can this be called with params_iter=None & superkey_colnames=None? + table = self._reflect_table(tblname) + columns = tuple(c.name for c in table.primary_key.columns) + return self.get_where_eq( + tblname, columns, params_iter, superkey_colnames, op='AND', **kwargs + ) def get( self, From 54388e88a02a52f1d3ee6ed318c9040e8861060b Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sat, 10 Oct 2020 22:00:05 -0700 Subject: [PATCH 076/294] Correct the output of get_column This is a straight text query without knowledge of how to cast the results to python types. This change uses the table definition to ensure the correct python typed value is returned. An example of this failing would be around a selection of UUID values. Because we store them as bytes, they don't even come across as inferable. --- wbia/dtool/sql_control.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 188cbd8c93..b7af93c014 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -2346,9 +2346,12 @@ def get_column_names(self, tablename): return column_names def get_column(self, tablename, name): - """ Conveinience: """ - table, (column,) = sanitize_sql(self, tablename, (name,)) - return self.executeone(text(f'SELECT {column} FROM {table} ORDER BY rowid ASC')) + """Get all the values for the specified column (``name``) of the table (``tablename``)""" + table = self._reflect_table(tablename) + stmt = sqlalchemy.select([table.c[name]]).order_by( + *[c.asc() for c in table.primary_key.columns] + ) + return self.executeone(stmt) def get_table_as_pandas( self, tablename, rowids=None, columns=None, exclude_columns=[] From f4651c9a3d9553f680950992a52ee24bbd2bba7f Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sat, 10 Oct 2020 22:06:53 -0700 Subject: [PATCH 077/294] Remove utool.embed_on_exception_context usage Hopefully this won't be necessary by the time I'm done with it. --- wbia/dtool/sql_control.py | 225 ++++++++++++++++++-------------------- 1 file changed, 108 insertions(+), 117 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index b7af93c014..6924b1746a 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -2558,134 +2558,125 @@ def get_table_new_transferdata(self, tablename, exclude_columns=[]): >>> print('dependsmap = %s' % (ut.repr2(dependsmap, nl=True),)) >>> print('L___') """ - import utool + all_column_names = self.get_column_names(tablename) + isvalid_list = [name not in exclude_columns for name in all_column_names] + column_names = ut.compress(all_column_names, isvalid_list) + column_list = [ + self.get_column(tablename, name) + for name in column_names + if name not in exclude_columns + ] - with utool.embed_on_exception_context: - all_column_names = self.get_column_names(tablename) - isvalid_list = [name not in exclude_columns for name in all_column_names] - column_names = ut.compress(all_column_names, isvalid_list) - column_list = [ - self.get_column(tablename, name) - for name in column_names - if name not in exclude_columns - ] + extern_colx_list = [] + extern_tablename_list = [] + extern_superkey_colname_list = [] + extern_superkey_colval_list = [] + extern_primarycolnames_list = [] + dependsmap = self.metadata[tablename].dependsmap + if dependsmap is not None: + for colname, dependtup in six.iteritems(dependsmap): + assert len(dependtup) == 3, 'must be 3 for now' + ( + extern_tablename, + extern_primary_colnames, + extern_superkey_colnames, + ) = dependtup + if extern_primary_colnames is None: + # INFER PRIMARY COLNAMES + extern_primary_colnames = self.get_table_primarykey_colnames( + extern_tablename + ) + if extern_superkey_colnames is None: - extern_colx_list = [] - extern_tablename_list = [] - extern_superkey_colname_list = [] - extern_superkey_colval_list = [] - extern_primarycolnames_list = [] - dependsmap = self.metadata[tablename].dependsmap - if dependsmap is not None: - for colname, dependtup in six.iteritems(dependsmap): - assert len(dependtup) == 3, 'must be 3 for now' - ( - extern_tablename, - extern_primary_colnames, - extern_superkey_colnames, - ) = dependtup - if extern_primary_colnames is None: - # INFER PRIMARY COLNAMES - extern_primary_colnames = self.get_table_primarykey_colnames( - extern_tablename - ) - if extern_superkey_colnames is None: - - def get_standard_superkey_colnames(tablename_): - try: - # FIXME: Rectify duplicate code - superkeys = self.get_table_superkey_colnames(tablename_) - if len(superkeys) > 1: - primary_superkey = self.metadata[ - tablename_ - ].primary_superkey - self.get_table_superkey_colnames('contributors') - if primary_superkey is None: - raise AssertionError( - ( - 'tablename_=%r has multiple superkeys=%r, ' - 'but no primary superkey.' - ' A primary superkey is required' - ) - % (tablename_, superkeys) + def get_standard_superkey_colnames(tablename_): + try: + # FIXME: Rectify duplicate code + superkeys = self.get_table_superkey_colnames(tablename_) + if len(superkeys) > 1: + primary_superkey = self.metadata[ + tablename_ + ].primary_superkey + self.get_table_superkey_colnames('contributors') + if primary_superkey is None: + raise AssertionError( + ( + 'tablename_=%r has multiple superkeys=%r, ' + 'but no primary superkey.' + ' A primary superkey is required' ) - else: - index = superkeys.index(primary_superkey) - superkey_colnames = superkeys[index] - elif len(superkeys) == 1: - superkey_colnames = superkeys[0] - else: - logger.info(self.get_table_csv_header(tablename_)) - self.print_table_csv( - 'metadata', exclude_columns=['metadata_value'] + % (tablename_, superkeys) ) - # Execute hack to fix contributor tables - if tablename_ == 'contributors': - # hack to fix contributors table - constraint_str = self.metadata[ - tablename_ - ].constraint - parse_result = parse.parse( - 'CONSTRAINT superkey UNIQUE ({superkey})', - constraint_str, - ) - superkey = parse_result['superkey'] - assert ( - superkey == 'contributor_tag' - ), 'hack failed1' - assert ( - self.metadata['contributors'].superkey is None - ), 'hack failed2' - self.metadata['contributors'].superkey = [ - (superkey,) - ] - return (superkey,) - else: - raise NotImplementedError( - 'Cannot Handle: len(superkeys) == 0. ' - 'Probably a degenerate case' - ) - except Exception as ex: - ut.printex( - ex, - 'Error Getting superkey colnames', - keys=['tablename_', 'superkeys'], + else: + index = superkeys.index(primary_superkey) + superkey_colnames = superkeys[index] + elif len(superkeys) == 1: + superkey_colnames = superkeys[0] + else: + logger.info(self.get_table_csv_header(tablename_)) + self.print_table_csv( + 'metadata', exclude_columns=['metadata_value'] ) - raise - return superkey_colnames - - try: - extern_superkey_colnames = get_standard_superkey_colnames( - extern_tablename - ) + # Execute hack to fix contributor tables + if tablename_ == 'contributors': + # hack to fix contributors table + constraint_str = self.metadata[tablename_].constraint + parse_result = parse.parse( + 'CONSTRAINT superkey UNIQUE ({superkey})', + constraint_str, + ) + superkey = parse_result['superkey'] + assert superkey == 'contributor_tag', 'hack failed1' + assert ( + self.metadata['contributors'].superkey is None + ), 'hack failed2' + self.metadata['contributors'].superkey = [(superkey,)] + return (superkey,) + else: + raise NotImplementedError( + 'Cannot Handle: len(superkeys) == 0. ' + 'Probably a degenerate case' + ) except Exception as ex: ut.printex( ex, - 'Error Building Transferdata', - keys=['tablename_', 'dependtup'], + 'Error Getting superkey colnames', + keys=['tablename_', 'superkeys'], ) raise - # INFER SUPERKEY COLNAMES - colx = ut.listfind(column_names, colname) - extern_rowids = column_list[colx] - superkey_column = self.get( - extern_tablename, extern_superkey_colnames, extern_rowids - ) - extern_colx_list.append(colx) - extern_superkey_colname_list.append(extern_superkey_colnames) - extern_superkey_colval_list.append(superkey_column) - extern_tablename_list.append(extern_tablename) - extern_primarycolnames_list.append(extern_primary_colnames) + return superkey_colnames - new_transferdata = ( - column_list, - column_names, - extern_colx_list, - extern_superkey_colname_list, - extern_superkey_colval_list, - extern_tablename_list, - extern_primarycolnames_list, - ) + try: + extern_superkey_colnames = get_standard_superkey_colnames( + extern_tablename + ) + except Exception as ex: + ut.printex( + ex, + 'Error Building Transferdata', + keys=['tablename_', 'dependtup'], + ) + raise + # INFER SUPERKEY COLNAMES + colx = ut.listfind(column_names, colname) + extern_rowids = column_list[colx] + superkey_column = self.get( + extern_tablename, extern_superkey_colnames, extern_rowids + ) + extern_colx_list.append(colx) + extern_superkey_colname_list.append(extern_superkey_colnames) + extern_superkey_colval_list.append(superkey_column) + extern_tablename_list.append(extern_tablename) + extern_primarycolnames_list.append(extern_primary_colnames) + + new_transferdata = ( + column_list, + column_names, + extern_colx_list, + extern_superkey_colname_list, + extern_superkey_colval_list, + extern_tablename_list, + extern_primarycolnames_list, + ) return new_transferdata # def import_table_new_transferdata(tablename, new_transferdata): From f115e14398e49ff135f244c1bca7c535b47b7ac2 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sat, 10 Oct 2020 22:11:14 -0700 Subject: [PATCH 078/294] Refactored to use the reflected table object I'm not exactly sure why there was a check for validity of a column name in there. Was get_column_names somehow returning faux columns? --- wbia/dtool/sql_control.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 6924b1746a..cb9a96c0b6 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -2558,14 +2558,9 @@ def get_table_new_transferdata(self, tablename, exclude_columns=[]): >>> print('dependsmap = %s' % (ut.repr2(dependsmap, nl=True),)) >>> print('L___') """ - all_column_names = self.get_column_names(tablename) - isvalid_list = [name not in exclude_columns for name in all_column_names] - column_names = ut.compress(all_column_names, isvalid_list) - column_list = [ - self.get_column(tablename, name) - for name in column_names - if name not in exclude_columns - ] + table = self._reflect_table(tablename) + column_names = [c.name for c in table.columns if c.name not in exclude_columns] + column_list = [self.get_column(tablename, name) for name in column_names] extern_colx_list = [] extern_tablename_list = [] From 8f1d6bd54625e5a2a7e109da9599f8c31b3d98b0 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sat, 10 Oct 2020 22:57:04 -0700 Subject: [PATCH 079/294] Adjust get_imageset_gsgrids to use get_where_eq This removes dialect specific SQL --- wbia/control/manual_imageset_funcs.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/wbia/control/manual_imageset_funcs.py b/wbia/control/manual_imageset_funcs.py index cb870002bb..cf15ae3c82 100644 --- a/wbia/control/manual_imageset_funcs.py +++ b/wbia/control/manual_imageset_funcs.py @@ -594,37 +594,34 @@ def get_imageset_gsgrids(ibs, imgsetid_list=None, gid_list=None): if imgsetid_list is not None and gid_list is None: # TODO: Group type params_iter = ((imgsetid,) for imgsetid in imgsetid_list) - where_clause = 'imageset_rowid=?' # list of relationships for each imageset - gsgrids_list = ibs.db.get_where( + gsgrids_list = ibs.db.get_where_eq( const.GSG_RELATION_TABLE, ('gsgr_rowid',), params_iter, - where_clause, + ('imageset_rowid',), unpack_scalars=False, ) elif gid_list is not None and imgsetid_list is None: # TODO: Group type params_iter = ((gid,) for gid in gid_list) - where_clause = 'image_rowid=?' # list of relationships for each imageset - gsgrids_list = ibs.db.get_where( + gsgrids_list = ibs.db.get_where_eq( const.GSG_RELATION_TABLE, ('gsgr_rowid',), params_iter, - where_clause, + ('image_rowid',), unpack_scalars=False, ) else: # TODO: Group type params_iter = ((imgsetid, gid) for imgsetid, gid in zip(imgsetid_list, gid_list)) - where_clause = 'imageset_rowid=? AND image_rowid=?' # list of relationships for each imageset - gsgrids_list = ibs.db.get_where( + gsgrids_list = ibs.db.get_where_eq( const.GSG_RELATION_TABLE, ('gsgr_rowid',), params_iter, - where_clause, + ('imageset_rowid', 'image_rowid'), unpack_scalars=False, ) return gsgrids_list From 7770120ce3ae8de435ea94b40ac006a91dd52c02 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sun, 11 Oct 2020 00:12:01 -0700 Subject: [PATCH 080/294] Cast numpy.ndarray integer values to int Upon the SQLController's `get` method being given a `numpy.ndarray` or `list` of `numpy.int*` values, cast all the values to `int`. This is only needed for `rowid`, because it's not a real column. Otherwise the sqlalchemy types would convert the values to the correct sql type. This resolves about 70+ failing tests. --- wbia/dtool/sql_control.py | 31 ++++++++++++++++++++++------ wbia/tests/dtool/test_sql_control.py | 25 ++++++++++++++++++++++ 2 files changed, 50 insertions(+), 6 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index cb9a96c0b6..81bbcea37a 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1207,7 +1207,11 @@ def get( else: id_param_name = '_identifier' where_clause = text(id_colname + f' = :{id_param_name}') - if id_colname != 'rowid': # b/c rowid doesn't really exist as a column + if id_colname == 'rowid': + # Cast all item values to in, in case values are numpy.integer* + # Strangely allow for None values + id_iter = [id_ is not None and int(id_) or id_ for id_ in id_iter] + else: # b/c rowid doesn't really exist as a column column = self._reflect_table(tblname).c[id_colname] where_clause = where_clause.bindparams( bindparam(id_param_name, type_=column.type) @@ -1350,15 +1354,17 @@ def set( **{col: bindparam(f'e{i}') for i, col in enumerate(colnames)} ) where_clause = text(id_colname + f' = :{id_param_name}') - if id_colname != 'rowid': # b/c rowid doesn't really exist as a column + if id_colname == 'rowid': + # Cast all item values to in, in case values are numpy.integer* + # Strangely allow for None values + id_list = [id_ is not None and int(id_) or id_ for id_ in id_list] + else: # b/c rowid doesn't really exist as a column id_column = table.c[id_colname] where_clause = where_clause.bindparams( bindparam(id_param_name, type_=id_column.type) ) stmt = stmt.where(where_clause) for i, id in enumerate(id_list): - if id_colname == 'rowid': - id = int(id) params = {id_param_name: id} params.update({f'e{e}': p for e, p in enumerate(val_list[i])}) self.connection.execute(stmt, **params) @@ -1369,9 +1375,22 @@ def delete(self, tblname, id_list, id_colname='rowid', **kwargs): Optionally a different ID column can be specified via ``id_colname``. """ - stmt = text(f'DELETE FROM {tblname} WHERE {id_colname} = :id') + id_param_name = '_identifier' + table = self._reflect_table(tblname) + stmt = table.delete() + where_clause = text(id_colname + f' = :{id_param_name}') + if id_colname == 'rowid': + # Cast all item values to in, in case values are numpy.integer* + # Strangely allow for None values + id_list = [id_ is not None and int(id_) or id_ for id_ in id_list] + else: # b/c rowid doesn't really exist as a column + id_column = table.c[id_colname] + where_clause = where_clause.bindparams( + bindparam(id_param_name, type_=id_column.type) + ) + stmt = stmt.where(where_clause) for id in id_list: - self.connection.execute(stmt, id=id) + self.connection.execute(stmt, {id_param_name: id}) def delete_rowids(self, tblname, rowid_list, **kwargs): """ deletes the the rows in rowid_list """ diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index ca051cd6b2..d492b72552 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -2,6 +2,7 @@ import uuid from functools import partial +import numpy as np import pytest import sqlalchemy.exc from sqlalchemy import MetaData, Table @@ -800,6 +801,30 @@ def test_get_by_id(self): # Verify getting assert data == expected + def test_get_by_numpy_array_of_ids(self): + # Make a table for records + table_name = 'test_getting' + self.make_table(table_name) + + # Create some dummy records + insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') + for i in range(0, 10): + x, y, z = (str(i), i, i * 2.01) + self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + + # Call the testing target + requested_ids = np.array([2, 4, 6]) + data = self.ctrlr.get(table_name, ['x', 'z'], requested_ids) + + # Build the expect results of the testing target + sql_array = ', '.join([str(id) for id in requested_ids]) + results = self.ctrlr.connection.execute( + f'SELECT x, z FROM {table_name} WHERE id in ({sql_array})' + ) + expected = results.fetchall() + # Verify getting + assert data == expected + def test_get_as_unique(self): # This test could be inaccurate, because this logical path appears # to be bolted on the side. Usage of this path's feature is unknown. From 83be46a30099148786d482dbb096cb43e574b5f1 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sun, 11 Oct 2020 16:43:14 -0700 Subject: [PATCH 081/294] Fix JSON encoding and decoding There is a fair bit of stuff going on in the original sqlite conversion that used utool.{to|from}_json. The primary benefit of just using the existing implementation is that it'll just work. Though, the implementation it calls isn't as self-explanitory as I'd like. --- wbia/dtool/types.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/wbia/dtool/types.py b/wbia/dtool/types.py index e373710a5f..651f965b0b 100644 --- a/wbia/dtool/types.py +++ b/wbia/dtool/types.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- """Mapping of Python types to SQL types""" import io -import json import uuid import numpy as np +from utool.util_cache import from_json, to_json from sqlalchemy.types import Integer as SAInteger from sqlalchemy.types import TypeDecorator, UserDefinedType @@ -46,25 +46,13 @@ def get_col_spec(self, **kw): def bind_processor(self, dialect): def process(value): - if value is None: - return value - else: - if isinstance(value, self.base_py_type): - return json.dumps(value) - else: - return value + return to_json(value) return process def result_processor(self, dialect, coltype): def process(value): - if value is None: - return value - else: - if not isinstance(value, self.base_py_type): - return json.loads(value) - else: - return value + return from_json(value) return process From 34060b68331c3163f132f2318436428593b799a2 Mon Sep 17 00:00:00 2001 From: karen chan Date: Wed, 14 Oct 2020 22:27:24 +0100 Subject: [PATCH 082/294] Fix DependencyCache initialize to not create ":memory:" file After running test "wbia/dtool/example_depcache.py::dummy_example_depcacahe:0", a file called ":memory" is created in the current directory. This is because the DependencyCache initialize code prepends `sqlite:///` to `:memory` causing sqlite to create the file. --- wbia/dtool/depcache_control.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wbia/dtool/depcache_control.py b/wbia/dtool/depcache_control.py index bad112b692..1d4ccb097d 100644 --- a/wbia/dtool/depcache_control.py +++ b/wbia/dtool/depcache_control.py @@ -218,7 +218,7 @@ def initialize(depc, _debug=None): for fname in depc.fname_to_db.keys(): if fname == ':memory:': - fpath = fname + db_uri = 'sqlite:///:memory:' else: fname_ = ut.ensure_ext(fname, '.sqlite') from os.path import dirname @@ -227,9 +227,9 @@ def initialize(depc, _debug=None): if prefix_dpath: ut.ensuredir(ut.unixjoin(depc.cache_dpath, prefix_dpath)) fpath = ut.unixjoin(depc.cache_dpath, fname_) + db_uri = 'sqlite:///{}'.format(os.path.realpath(fpath)) # if ut.get_argflag('--clear-all-depcache'): # ut.delete(fpath) - db_uri = 'sqlite:///{}'.format(os.path.realpath(fpath)) db = sql_control.SQLDatabaseController.from_uri(db_uri) depcache_table.ensure_config_table(db) depc.fname_to_db[fname] = db From f3c7b13d5bf34f6b5a64b3ff24d87d47781811bc Mon Sep 17 00:00:00 2001 From: karen chan Date: Mon, 19 Oct 2020 17:29:21 +0100 Subject: [PATCH 083/294] Change connection.commit to sqlalchemy transaction.commit sqlalchemy database connection doesn't have the method `.commit()`, we need to create a transaction and commit it instead. ``` DOCTEST TRACEBACK Traceback (most recent call last): File "/virtualenv/env3/lib/python3.7/site-packages/xdoctest/doctest_example.py", line 598, in run exec(code, test_globals) File "", line rel: 5, abs: 1251, in >>> result = check_cache_purge(ibs) File "/wbia/wildbook-ia/wbia/other/ibsfuncs.py", line 1477, in check_cache_purge db.squeeze() File "/wbia/wildbook-ia/wbia/dtool/sql_control.py", line 726, in squeeze self.shrink_memory() File "/wbia/wildbook-ia/wbia/dtool/sql_control.py", line 708, in shrink_memory self.connection.commit() AttributeError: 'Connection' object has no attribute 'commit' DOCTEST REPRODUCTION CommandLine: pytest /wbia/wildbook-ia/wbia/other/ibsfuncs.py::check_cache_purge:0 ``` --- wbia/dtool/sql_control.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 81bbcea37a..e237133c14 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -702,9 +702,9 @@ def optimize(self): def shrink_memory(self): logger.info('[sql] shrink_memory') - self.connection.commit() + transaction = self.connection.begin() self.connection.execute('PRAGMA shrink_memory;') - self.connection.commit() + transaction.commit() def vacuum(self): logger.info('[sql] vaccum') @@ -714,9 +714,9 @@ def vacuum(self): def integrity(self): logger.info('[sql] vaccum') - self.connection.commit() + transaction = self.connection.begin() self.connection.execute('PRAGMA integrity_check;') - self.connection.commit() + transaction.commit() def squeeze(self): logger.info('[sql] squeeze') From 08ad98e141f304fad5593d7e5d04117416e049c0 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Tue, 13 Oct 2020 10:12:10 -0700 Subject: [PATCH 084/294] Use 'self' in controller class methods --- wbia/control/IBEISControl.py | 614 +++++++++++++++++------------------ 1 file changed, 307 insertions(+), 307 deletions(-) diff --git a/wbia/control/IBEISControl.py b/wbia/control/IBEISControl.py index 2f5651dc6d..e1aa683594 100644 --- a/wbia/control/IBEISControl.py +++ b/wbia/control/IBEISControl.py @@ -325,7 +325,7 @@ class IBEISController(BASE_CLASS): @profile def __init__( - ibs, + self, dbdir=None, ensure=True, wbaddr=None, @@ -338,15 +338,15 @@ def __init__( # if verbose and ut.VERBOSE: logger.info('\n[ibs.__init__] new IBEISController') - ibs.dbname = None + self.dbname = None # an dict to hack in temporary state - ibs.const = const - ibs.readonly = None - ibs.depc_image = None - ibs.depc_annot = None - ibs.depc_part = None - # ibs.allow_override = 'override+warn' - ibs.allow_override = True + self.const = const + self.readonly = None + self.depc_image = None + self.depc_annot = None + self.depc_part = None + # self.allow_override = 'override+warn' + self.allow_override = True if force_serial is None: if ut.get_argflag(('--utool-force-serial', '--force-serial', '--serial')): force_serial = True @@ -354,60 +354,60 @@ def __init__( force_serial = not ut.in_main_process() # if const.CONTAINERIZED: # force_serial = True - ibs.force_serial = force_serial + self.force_serial = force_serial # observer_weakref_list keeps track of the guibacks connected to this # controller - ibs.observer_weakref_list = [] + self.observer_weakref_list = [] # not completely working decorator cache - ibs.table_cache = None - ibs._initialize_self() - ibs._init_dirs(dbdir=dbdir, ensure=ensure) + self.table_cache = None + self._initialize_self() + self._init_dirs(dbdir=dbdir, ensure=ensure) # _send_wildbook_request will do nothing if no wildbook address is # specified - ibs._send_wildbook_request(wbaddr) - ibs._init_sql( + self._send_wildbook_request(wbaddr) + self._init_sql( request_dbversion=request_dbversion, request_stagingversion=request_stagingversion, ) - ibs._init_config() - if not ut.get_argflag('--noclean') and not ibs.readonly: - # ibs._init_burned_in_species() - ibs._clean_species() - ibs.job_manager = None + self._init_config() + if not ut.get_argflag('--noclean') and not self.readonly: + # self._init_burned_in_species() + self._clean_species() + self.job_manager = None # Hack for changing the way chips compute # by default use serial because warpAffine is weird with multiproc is_mac = 'macosx' in ut.get_plat_specifier().lower() - ibs._parallel_chips = not ibs.force_serial and not is_mac + self._parallel_chips = not self.force_serial and not is_mac - ibs.containerized = const.CONTAINERIZED - ibs.production = const.PRODUCTION + self.containerized = const.CONTAINERIZED + self.production = const.PRODUCTION - logger.info('[ibs.__init__] CONTAINERIZED: %s\n' % (ibs.containerized,)) - logger.info('[ibs.__init__] PRODUCTION: %s\n' % (ibs.production,)) + logger.info('[ibs.__init__] CONTAINERIZED: %s\n' % (self.containerized,)) + logger.info('[ibs.__init__] PRODUCTION: %s\n' % (self.production,)) # Hack to store HTTPS flag (deliver secure content in web) - ibs.https = const.HTTPS + self.https = const.HTTPS logger.info('[ibs.__init__] END new IBEISController\n') - def reset_table_cache(ibs): - ibs.table_cache = accessor_decors.init_tablecache() + def reset_table_cache(self): + self.table_cache = accessor_decors.init_tablecache() - def clear_table_cache(ibs, tablename=None): + def clear_table_cache(self, tablename=None): logger.info('[ibs] clearing table_cache[%r]' % (tablename,)) if tablename is None: - ibs.reset_table_cache() + self.reset_table_cache() else: try: - del ibs.table_cache[tablename] + del self.table_cache[tablename] except KeyError: pass - def show_depc_graph(ibs, depc, reduced=False): + def show_depc_graph(self, depc, reduced=False): depc.show_graph(reduced=reduced) - def show_depc_image_graph(ibs, **kwargs): + def show_depc_image_graph(self, **kwargs): """ CommandLine: python -m wbia.control.IBEISControl --test-show_depc_image_graph --show @@ -422,9 +422,9 @@ def show_depc_image_graph(ibs, **kwargs): >>> ibs.show_depc_image_graph(reduced=reduced) >>> ut.show_if_requested() """ - ibs.show_depc_graph(ibs.depc_image, **kwargs) + self.show_depc_graph(self.depc_image, **kwargs) - def show_depc_annot_graph(ibs, *args, **kwargs): + def show_depc_annot_graph(self, *args, **kwargs): """ CommandLine: python -m wbia.control.IBEISControl --test-show_depc_annot_graph --show @@ -439,9 +439,9 @@ def show_depc_annot_graph(ibs, *args, **kwargs): >>> ibs.show_depc_annot_graph(reduced=reduced) >>> ut.show_if_requested() """ - ibs.show_depc_graph(ibs.depc_annot, *args, **kwargs) + self.show_depc_graph(self.depc_annot, *args, **kwargs) - def show_depc_annot_table_input(ibs, tablename, *args, **kwargs): + def show_depc_annot_table_input(self, tablename, *args, **kwargs): """ CommandLine: python -m wbia.control.IBEISControl --test-show_depc_annot_table_input --show --tablename=vsone @@ -457,30 +457,30 @@ def show_depc_annot_table_input(ibs, tablename, *args, **kwargs): >>> ibs.show_depc_annot_table_input(tablename) >>> ut.show_if_requested() """ - ibs.depc_annot[tablename].show_input_graph() + self.depc_annot[tablename].show_input_graph() - def get_cachestats_str(ibs): + def get_cachestats_str(self): """ Returns info about the underlying SQL cache memory """ total_size_str = ut.get_object_size_str( - ibs.table_cache, lbl='size(table_cache): ' + self.table_cache, lbl='size(table_cache): ' ) - total_size_str = '\nlen(table_cache) = %r' % (len(ibs.table_cache)) + total_size_str = '\nlen(table_cache) = %r' % (len(self.table_cache)) table_size_str_list = [ ut.get_object_size_str(val, lbl='size(table_cache[%s]): ' % (key,)) - for key, val in six.iteritems(ibs.table_cache) + for key, val in six.iteritems(self.table_cache) ] cachestats_str = total_size_str + ut.indentjoin(table_size_str_list, '\n * ') return cachestats_str - def print_cachestats_str(ibs): - cachestats_str = ibs.get_cachestats_str() + def print_cachestats_str(self): + cachestats_str = self.get_cachestats_str() logger.info('IBEIS Controller Cache Stats:') logger.info(cachestats_str) return cachestats_str - def _initialize_self(ibs): + def _initialize_self(self): """ Injects code from plugin modules into the controller @@ -488,17 +488,17 @@ def _initialize_self(ibs): """ if ut.VERBOSE: logger.info('[ibs] _initialize_self()') - ibs.reset_table_cache() + self.reset_table_cache() ut.util_class.inject_all_external_modules( - ibs, + self, controller_inject.CONTROLLER_CLASSNAME, - allow_override=ibs.allow_override, + allow_override=self.allow_override, ) - assert hasattr(ibs, 'get_database_species'), 'issue with ibsfuncs' - assert hasattr(ibs, 'get_annot_pair_timedelta'), 'issue with annotmatch_funcs' - ibs.register_controller() + assert hasattr(self, 'get_database_species'), 'issue with ibsfuncs' + assert hasattr(self, 'get_annot_pair_timedelta'), 'issue with annotmatch_funcs' + self.register_controller() - def _on_reload(ibs): + def _on_reload(self): """ For utools auto reload (rrr). Called before reload @@ -506,36 +506,36 @@ def _on_reload(ibs): # Reloading breaks flask, turn it off controller_inject.GLOBAL_APP_ENABLED = False # Only warn on first load. Overrideing while reloading is ok - ibs.allow_override = True - ibs.unregister_controller() + self.allow_override = True + self.unregister_controller() # Reload dependent modules ut.reload_injected_modules(controller_inject.CONTROLLER_CLASSNAME) - def load_plugin_module(ibs, module): + def load_plugin_module(self, module): ut.inject_instance( - ibs, + self, classkey=module.CLASS_INJECT_KEY, - allow_override=ibs.allow_override, + allow_override=self.allow_override, strict=False, verbose=False, ) # We should probably not implement __del__ # see: https://docs.python.org/2/reference/datamodel.html#object.__del__ - # def __del__(ibs): - # ibs.cleanup() + # def __del__(self): + # self.cleanup() # ------------ # SELF REGISTRATION # ------------ - def register_controller(ibs): + def register_controller(self): """ registers controller with global list """ - ibs_weakref = weakref.ref(ibs) + ibs_weakref = weakref.ref(self) __ALL_CONTROLLERS__.append(ibs_weakref) - def unregister_controller(ibs): - ibs_weakref = weakref.ref(ibs) + def unregister_controller(self): + ibs_weakref = weakref.ref(self) try: __ALL_CONTROLLERS__.remove(ibs_weakref) pass @@ -546,48 +546,48 @@ def unregister_controller(ibs): # OBSERVER REGISTRATION # ------------ - def cleanup(ibs): + def cleanup(self): """ call on del? """ - logger.info('[ibs.cleanup] Observers (if any) notified [controller killed]') - for observer_weakref in ibs.observer_weakref_list: + logger.info('[self.cleanup] Observers (if any) notified [controller killed]') + for observer_weakref in self.observer_weakref_list: observer_weakref().notify_controller_killed() - def register_observer(ibs, observer): + def register_observer(self, observer): logger.info('[register_observer] Observer registered: %r' % observer) observer_weakref = weakref.ref(observer) - ibs.observer_weakref_list.append(observer_weakref) + self.observer_weakref_list.append(observer_weakref) - def remove_observer(ibs, observer): + def remove_observer(self, observer): logger.info('[remove_observer] Observer removed: %r' % observer) - ibs.observer_weakref_list.remove(observer) + self.observer_weakref_list.remove(observer) - def notify_observers(ibs): + def notify_observers(self): logger.info('[notify_observers] Observers (if any) notified') - for observer_weakref in ibs.observer_weakref_list: + for observer_weakref in self.observer_weakref_list: observer_weakref().notify() # ------------ - def _init_rowid_constants(ibs): + def _init_rowid_constants(self): # ADD TO CONSTANTS # THIS IS EXPLICIT IN CONST, USE THAT VERSION INSTEAD - # ibs.UNKNOWN_LBLANNOT_ROWID = const.UNKNOWN_LBLANNOT_ROWID - # ibs.UNKNOWN_NAME_ROWID = ibs.UNKNOWN_LBLANNOT_ROWID - # ibs.UNKNOWN_SPECIES_ROWID = ibs.UNKNOWN_LBLANNOT_ROWID + # self.UNKNOWN_LBLANNOT_ROWID = const.UNKNOWN_LBLANNOT_ROWID + # self.UNKNOWN_NAME_ROWID = self.UNKNOWN_LBLANNOT_ROWID + # self.UNKNOWN_SPECIES_ROWID = self.UNKNOWN_LBLANNOT_ROWID - # ibs.MANUAL_CONFIG_SUFFIX = 'MANUAL_CONFIG' - # ibs.MANUAL_CONFIGID = ibs.add_config(ibs.MANUAL_CONFIG_SUFFIX) + # self.MANUAL_CONFIG_SUFFIX = 'MANUAL_CONFIG' + # self.MANUAL_CONFIGID = self.add_config(self.MANUAL_CONFIG_SUFFIX) # duct_tape.fix_compname_configs(ibs) # duct_tape.remove_database_slag(ibs) # duct_tape.fix_nulled_yaws(ibs) lbltype_names = const.KEY_DEFAULTS.keys() lbltype_defaults = const.KEY_DEFAULTS.values() - lbltype_ids = ibs.add_lbltype(lbltype_names, lbltype_defaults) - ibs.lbltype_ids = dict(zip(lbltype_names, lbltype_ids)) + lbltype_ids = self.add_lbltype(lbltype_names, lbltype_defaults) + self.lbltype_ids = dict(zip(lbltype_names, lbltype_ids)) @profile - def _init_sql(ibs, request_dbversion=None, request_stagingversion=None): + def _init_sql(self, request_dbversion=None, request_stagingversion=None): """ Load or create sql database """ from wbia.other import duct_tape # NOQA @@ -598,22 +598,22 @@ def _init_sql(ibs, request_dbversion=None, request_stagingversion=None): # DATABASE DURING A POST UPDATE FUNCTION ROUTINE, WHICH HAS TO BE LOADED # FIRST AND DEFINED IN ORDER TO MAKE THE SUBSEQUENT WRITE CALLS TO THE # RELEVANT CACHE DATABASE - ibs._init_depcache() - ibs._init_sqldbcore(request_dbversion=request_dbversion) - ibs._init_sqldbstaging(request_stagingversion=request_stagingversion) - # ibs.db.dump_schema() - ibs._init_rowid_constants() + self._init_depcache() + self._init_sqldbcore(request_dbversion=request_dbversion) + self._init_sqldbstaging(request_stagingversion=request_stagingversion) + # self.db.dump_schema() + self._init_rowid_constants() - def _needs_backup(ibs): + def _needs_backup(self): needs_backup = not ut.get_argflag('--nobackup') - if ibs.get_dbname() == 'PZ_MTEST': + if self.get_dbname() == 'PZ_MTEST': needs_backup = False if dtool.sql_control.READ_ONLY: needs_backup = False return needs_backup @profile - def _init_sqldbcore(ibs, request_dbversion=None): + def _init_sqldbcore(self, request_dbversion=None): """ Example: >>> # DISABLE_DOCTEST @@ -643,14 +643,14 @@ def _init_sqldbcore(ibs, request_dbversion=None): backup_idx = ut.get_argval('--loadbackup', type_=int, default=None) sqldb_fpath = None if backup_idx is not None: - backups = _sql_helpers.get_backup_fpaths(ibs) + backups = _sql_helpers.get_backup_fpaths(self) logger.info('backups = %r' % (backups,)) sqldb_fpath = backups[backup_idx] logger.info('CHOSE BACKUP sqldb_fpath = %r' % (sqldb_fpath,)) - if backup_idx is None and ibs._needs_backup(): + if backup_idx is None and self._needs_backup(): try: _sql_helpers.ensure_daily_database_backup( - ibs.get_ibsdir(), ibs.sqldb_fname, ibs.backupdir + self.get_ibsdir(), self.sqldb_fname, self.backupdir ) except IOError as ex: ut.printex( @@ -662,45 +662,45 @@ def _init_sqldbcore(ibs, request_dbversion=None): raise # IBEIS SQL State Database if request_dbversion is None: - ibs.db_version_expected = '2.0.0' + self.db_version_expected = '2.0.0' else: - ibs.db_version_expected = request_dbversion + self.db_version_expected = request_dbversion # TODO: add this functionality to SQLController if backup_idx is None: new_version, new_fname = dtool.sql_control.dev_test_new_schema_version( - ibs.get_dbname(), - ibs.get_ibsdir(), - ibs.sqldb_fname, - ibs.db_version_expected, + self.get_dbname(), + self.get_ibsdir(), + self.sqldb_fname, + self.db_version_expected, version_next='2.0.0', ) - ibs.db_version_expected = new_version - ibs.sqldb_fname = new_fname + self.db_version_expected = new_version + self.sqldb_fname = new_fname if sqldb_fpath is None: assert backup_idx is None - sqldb_fpath = join(ibs.get_ibsdir(), ibs.sqldb_fname) + sqldb_fpath = join(self.get_ibsdir(), self.sqldb_fname) readonly = None else: readonly = True db_uri = 'sqlite:///{}'.format(realpath(sqldb_fpath)) - ibs.db = dtool.SQLDatabaseController.from_uri(db_uri, readonly=readonly) - ibs.readonly = ibs.db.readonly + self.db = dtool.SQLDatabaseController.from_uri(db_uri, readonly=readonly) + self.readonly = self.db.readonly if backup_idx is None: # Ensure correct schema versions _sql_helpers.ensure_correct_version( - ibs, - ibs.db, - ibs.db_version_expected, + self, + self.db, + self.db_version_expected, DB_SCHEMA, verbose=ut.VERBOSE, - dobackup=not ibs.readonly, + dobackup=not self.readonly, ) # import sys # sys.exit(1) @profile - def _init_sqldbstaging(ibs, request_stagingversion=None): + def _init_sqldbstaging(self, request_stagingversion=None): """ Example: >>> # DISABLE_DOCTEST @@ -730,15 +730,15 @@ def _init_sqldbstaging(ibs, request_stagingversion=None): backup_idx = ut.get_argval('--loadbackup-staging', type_=int, default=None) sqlstaging_fpath = None if backup_idx is not None: - backups = _sql_helpers.get_backup_fpaths(ibs) + backups = _sql_helpers.get_backup_fpaths(self) logger.info('backups = %r' % (backups,)) sqlstaging_fpath = backups[backup_idx] logger.info('CHOSE BACKUP sqlstaging_fpath = %r' % (sqlstaging_fpath,)) # HACK - if backup_idx is None and ibs._needs_backup(): + if backup_idx is None and self._needs_backup(): try: _sql_helpers.ensure_daily_database_backup( - ibs.get_ibsdir(), ibs.sqlstaging_fname, ibs.backupdir + self.get_ibsdir(), self.sqlstaging_fname, self.backupdir ) except IOError as ex: ut.printex( @@ -748,39 +748,39 @@ def _init_sqldbstaging(ibs, request_stagingversion=None): raise # IBEIS SQL State Database if request_stagingversion is None: - ibs.staging_version_expected = '1.2.0' + self.staging_version_expected = '1.2.0' else: - ibs.staging_version_expected = request_stagingversion + self.staging_version_expected = request_stagingversion # TODO: add this functionality to SQLController if backup_idx is None: new_version, new_fname = dtool.sql_control.dev_test_new_schema_version( - ibs.get_dbname(), - ibs.get_ibsdir(), - ibs.sqlstaging_fname, - ibs.staging_version_expected, + self.get_dbname(), + self.get_ibsdir(), + self.sqlstaging_fname, + self.staging_version_expected, version_next='1.2.0', ) - ibs.staging_version_expected = new_version - ibs.sqlstaging_fname = new_fname + self.staging_version_expected = new_version + self.sqlstaging_fname = new_fname if sqlstaging_fpath is None: assert backup_idx is None - sqlstaging_fpath = join(ibs.get_ibsdir(), ibs.sqlstaging_fname) + sqlstaging_fpath = join(self.get_ibsdir(), self.sqlstaging_fname) readonly = None else: readonly = True db_uri = 'sqlite:///{}'.format(realpath(sqlstaging_fpath)) - ibs.staging = dtool.SQLDatabaseController.from_uri( + self.staging = dtool.SQLDatabaseController.from_uri( db_uri, readonly=readonly, ) - ibs.readonly = ibs.staging.readonly + self.readonly = self.staging.readonly if backup_idx is None: # Ensure correct schema versions _sql_helpers.ensure_correct_version( - ibs, - ibs.staging, - ibs.staging_version_expected, + self, + self.staging, + self.staging_version_expected, STAGING_SCHEMA, verbose=ut.VERBOSE, ) @@ -788,96 +788,96 @@ def _init_sqldbstaging(ibs, request_stagingversion=None): # sys.exit(1) @profile - def _init_depcache(ibs): + def _init_depcache(self): # Initialize dependency cache for images image_root_getters = {} - ibs.depc_image = dtool.DependencyCache( + self.depc_image = dtool.DependencyCache( root_tablename=const.IMAGE_TABLE, default_fname=const.IMAGE_TABLE + '_depcache', - cache_dpath=ibs.get_cachedir(), - controller=ibs, - get_root_uuid=ibs.get_image_uuids, + cache_dpath=self.get_cachedir(), + controller=self, + get_root_uuid=self.get_image_uuids, root_getters=image_root_getters, ) - ibs.depc_image.initialize() + self.depc_image.initialize() """ Need to reinit this sometimes if cache is ever deleted """ # Initialize dependency cache for annotations annot_root_getters = { - 'name': ibs.get_annot_names, - 'species': ibs.get_annot_species, - 'yaw': ibs.get_annot_yaws, - 'viewpoint_int': ibs.get_annot_viewpoint_int, - 'viewpoint': ibs.get_annot_viewpoints, - 'bbox': ibs.get_annot_bboxes, - 'verts': ibs.get_annot_verts, - 'image_uuid': lambda aids: ibs.get_image_uuids( - ibs.get_annot_image_rowids(aids) + 'name': self.get_annot_names, + 'species': self.get_annot_species, + 'yaw': self.get_annot_yaws, + 'viewpoint_int': self.get_annot_viewpoint_int, + 'viewpoint': self.get_annot_viewpoints, + 'bbox': self.get_annot_bboxes, + 'verts': self.get_annot_verts, + 'image_uuid': lambda aids: self.get_image_uuids( + self.get_annot_image_rowids(aids) ), - 'theta': ibs.get_annot_thetas, - 'occurrence_text': ibs.get_annot_occurrence_text, + 'theta': self.get_annot_thetas, + 'occurrence_text': self.get_annot_occurrence_text, } - ibs.depc_annot = dtool.DependencyCache( + self.depc_annot = dtool.DependencyCache( # root_tablename='annot', # const.ANNOTATION_TABLE root_tablename=const.ANNOTATION_TABLE, default_fname=const.ANNOTATION_TABLE + '_depcache', - cache_dpath=ibs.get_cachedir(), - controller=ibs, - get_root_uuid=ibs.get_annot_visual_uuids, + cache_dpath=self.get_cachedir(), + controller=self, + get_root_uuid=self.get_annot_visual_uuids, root_getters=annot_root_getters, ) # backwards compatibility - ibs.depc = ibs.depc_annot + self.depc = self.depc_annot # TODO: root_uuids should be specified as the # base_root_uuid plus a hash of the attributes that matter for the # requested computation. - ibs.depc_annot.initialize() + self.depc_annot.initialize() # Initialize dependency cache for parts part_root_getters = {} - ibs.depc_part = dtool.DependencyCache( + self.depc_part = dtool.DependencyCache( root_tablename=const.PART_TABLE, default_fname=const.PART_TABLE + '_depcache', - cache_dpath=ibs.get_cachedir(), - controller=ibs, - get_root_uuid=ibs.get_part_uuids, + cache_dpath=self.get_cachedir(), + controller=self, + get_root_uuid=self.get_part_uuids, root_getters=part_root_getters, ) - ibs.depc_part.initialize() + self.depc_part.initialize() - def _close_depcache(ibs): - ibs.depc_image.close() - ibs.depc_image = None - ibs.depc_annot.close() - ibs.depc_annot = None - ibs.depc_part.close() - ibs.depc_part = None + def _close_depcache(self): + self.depc_image.close() + self.depc_image = None + self.depc_annot.close() + self.depc_annot = None + self.depc_part.close() + self.depc_part = None - def disconnect_sqldatabase(ibs): + def disconnect_sqldatabase(self): logger.info('disconnecting from sql database') - ibs._close_depcache() - ibs.db.close() - ibs.db = None - ibs.staging.close() - ibs.staging = None - - def clone_handle(ibs, **kwargs): - ibs2 = IBEISController(dbdir=ibs.get_dbdir(), ensure=False) + self._close_depcache() + self.db.close() + self.db = None + self.staging.close() + self.staging = None + + def clone_handle(self, **kwargs): + ibs2 = IBEISController(dbdir=self.get_dbdir(), ensure=False) if len(kwargs) > 0: ibs2.update_query_cfg(**kwargs) - # if ibs.qreq is not None: - # ibs2._prep_qreq(ibs.qreq.qaids, ibs.qreq.daids) + # if self.qreq is not None: + # ibs2._prep_qreq(self.qreq.qaids, self.qreq.daids) return ibs2 - def backup_database(ibs): + def backup_database(self): from wbia.control import _sql_helpers - _sql_helpers.database_backup(ibs.get_ibsdir(), ibs.sqldb_fname, ibs.backupdir) + _sql_helpers.database_backup(self.get_ibsdir(), self.sqldb_fname, self.backupdir) _sql_helpers.database_backup( - ibs.get_ibsdir(), ibs.sqlstaging_fname, ibs.backupdir + self.get_ibsdir(), self.sqlstaging_fname, self.backupdir ) - def _send_wildbook_request(ibs, wbaddr, payload=None): + def _send_wildbook_request(self, wbaddr, payload=None): import requests if wbaddr is None: @@ -898,7 +898,7 @@ def _send_wildbook_request(ibs, wbaddr, payload=None): return response def _init_dirs( - ibs, dbdir=None, dbname='testdb_1', workdir='~/wbia_workdir', ensure=True + self, dbdir=None, dbname='testdb_1', workdir='~/wbia_workdir', ensure=True ): """ Define ibs directories @@ -907,67 +907,67 @@ def _init_dirs( REL_PATHS = const.REL_PATHS if not ut.QUIET: - logger.info('[ibs._init_dirs] ibs.dbdir = %r' % dbdir) + logger.info('[self._init_dirs] self.dbdir = %r' % dbdir) if dbdir is not None: workdir, dbname = split(dbdir) - ibs.workdir = ut.truepath(workdir) - ibs.dbname = dbname - ibs.sqldb_fname = PATH_NAMES.sqldb - ibs.sqlstaging_fname = PATH_NAMES.sqlstaging + self.workdir = ut.truepath(workdir) + self.dbname = dbname + self.sqldb_fname = PATH_NAMES.sqldb + self.sqlstaging_fname = PATH_NAMES.sqlstaging # Make sure you are not nesting databases assert PATH_NAMES._ibsdb != ut.dirsplit( - ibs.workdir + self.workdir ), 'cannot work in _ibsdb internals' assert PATH_NAMES._ibsdb != dbname, 'cannot create db in _ibsdb internals' - ibs.dbdir = join(ibs.workdir, ibs.dbname) + self.dbdir = join(self.workdir, self.dbname) # All internal paths live in /_ibsdb # TODO: constantify these # so non controller objects (like in score normalization) have access # to these - ibs._ibsdb = join(ibs.dbdir, REL_PATHS._ibsdb) - ibs.trashdir = join(ibs.dbdir, REL_PATHS.trashdir) - ibs.cachedir = join(ibs.dbdir, REL_PATHS.cache) - ibs.backupdir = join(ibs.dbdir, REL_PATHS.backups) - ibs.logsdir = join(ibs.dbdir, REL_PATHS.logs) - ibs.chipdir = join(ibs.dbdir, REL_PATHS.chips) - ibs.imgdir = join(ibs.dbdir, REL_PATHS.images) - ibs.uploadsdir = join(ibs.dbdir, REL_PATHS.uploads) + self._ibsdb = join(self.dbdir, REL_PATHS._ibsdb) + self.trashdir = join(self.dbdir, REL_PATHS.trashdir) + self.cachedir = join(self.dbdir, REL_PATHS.cache) + self.backupdir = join(self.dbdir, REL_PATHS.backups) + self.logsdir = join(self.dbdir, REL_PATHS.logs) + self.chipdir = join(self.dbdir, REL_PATHS.chips) + self.imgdir = join(self.dbdir, REL_PATHS.images) + self.uploadsdir = join(self.dbdir, REL_PATHS.uploads) # All computed dirs live in /_ibsdb/_wbia_cache - ibs.thumb_dpath = join(ibs.dbdir, REL_PATHS.thumbs) - ibs.flanndir = join(ibs.dbdir, REL_PATHS.flann) - ibs.qresdir = join(ibs.dbdir, REL_PATHS.qres) - ibs.bigcachedir = join(ibs.dbdir, REL_PATHS.bigcache) - ibs.distinctdir = join(ibs.dbdir, REL_PATHS.distinctdir) + self.thumb_dpath = join(self.dbdir, REL_PATHS.thumbs) + self.flanndir = join(self.dbdir, REL_PATHS.flann) + self.qresdir = join(self.dbdir, REL_PATHS.qres) + self.bigcachedir = join(self.dbdir, REL_PATHS.bigcache) + self.distinctdir = join(self.dbdir, REL_PATHS.distinctdir) if ensure: - ibs.ensure_directories() + self.ensure_directories() assert dbdir is not None, 'must specify database directory' - def ensure_directories(ibs): + def ensure_directories(self): """ Makes sure the core directores for the controller exist """ _verbose = ut.VERBOSE - ut.ensuredir(ibs._ibsdb) - ut.ensuredir(ibs.cachedir, verbose=_verbose) - ut.ensuredir(ibs.backupdir, verbose=_verbose) - ut.ensuredir(ibs.logsdir, verbose=_verbose) - ut.ensuredir(ibs.workdir, verbose=_verbose) - ut.ensuredir(ibs.imgdir, verbose=_verbose) - ut.ensuredir(ibs.chipdir, verbose=_verbose) - ut.ensuredir(ibs.flanndir, verbose=_verbose) - ut.ensuredir(ibs.qresdir, verbose=_verbose) - ut.ensuredir(ibs.bigcachedir, verbose=_verbose) - ut.ensuredir(ibs.thumb_dpath, verbose=_verbose) - ut.ensuredir(ibs.distinctdir, verbose=_verbose) - ibs.get_smart_patrol_dir() + ut.ensuredir(self._ibsdb) + ut.ensuredir(self.cachedir, verbose=_verbose) + ut.ensuredir(self.backupdir, verbose=_verbose) + ut.ensuredir(self.logsdir, verbose=_verbose) + ut.ensuredir(self.workdir, verbose=_verbose) + ut.ensuredir(self.imgdir, verbose=_verbose) + ut.ensuredir(self.chipdir, verbose=_verbose) + ut.ensuredir(self.flanndir, verbose=_verbose) + ut.ensuredir(self.qresdir, verbose=_verbose) + ut.ensuredir(self.bigcachedir, verbose=_verbose) + ut.ensuredir(self.thumb_dpath, verbose=_verbose) + ut.ensuredir(self.distinctdir, verbose=_verbose) + self.get_smart_patrol_dir() # -------------- # --- DIRS ---- # -------------- @register_api('/api/core/db/name/', methods=['GET']) - def get_dbname(ibs): + def get_dbname(self): """ Returns: list_ (list): database name @@ -976,14 +976,14 @@ def get_dbname(ibs): Method: GET URL: /api/core/db/name/ """ - return ibs.dbname + return self.dbname - def get_db_name(ibs): - """ Alias for ibs.get_dbname(). """ - return ibs.get_dbname() + def get_db_name(self): + """ Alias for self.get_dbname(). """ + return self.get_dbname() @register_api(CORE_DB_UUID_INIT_API_RULE, methods=['GET']) - def get_db_init_uuid(ibs): + def get_db_init_uuid(self): """ Returns: UUID: The SQLDatabaseController's initialization UUID @@ -992,125 +992,125 @@ def get_db_init_uuid(ibs): Method: GET URL: /api/core/db/uuid/init/ """ - return ibs.db.get_db_init_uuid() + return self.db.get_db_init_uuid() - def get_logdir_local(ibs): - return ibs.logsdir + def get_logdir_local(self): + return self.logsdir - def get_logdir_global(ibs, local=False): + def get_logdir_global(self, local=False): if const.CONTAINERIZED: - return ibs.get_logdir_local() + return self.get_logdir_local() else: return ut.get_logging_dir(appname='wbia') - def get_dbdir(ibs): + def get_dbdir(self): """ database dir with ibs internal directory """ - return ibs.dbdir + return self.dbdir - def get_db_core_path(ibs): - return ibs.db.uri + def get_db_core_path(self): + return self.db.uri - def get_db_staging_path(ibs): - return ibs.staging.uri + def get_db_staging_path(self): + return self.staging.uri - def get_db_cache_path(ibs): - return ibs.dbcache.uri + def get_db_cache_path(self): + return self.dbcache.uri - def get_shelves_path(ibs): + def get_shelves_path(self): engine_slot = const.ENGINE_SLOT engine_slot = str(engine_slot).lower() if engine_slot in ['none', 'null', '1', 'default']: engine_shelve_dir = 'engine_shelves' else: engine_shelve_dir = 'engine_shelves_%s' % (engine_slot,) - return join(ibs.get_cachedir(), engine_shelve_dir) + return join(self.get_cachedir(), engine_shelve_dir) - def get_trashdir(ibs): - return ibs.trashdir + def get_trashdir(self): + return self.trashdir - def get_ibsdir(ibs): + def get_ibsdir(self): """ ibs internal directory """ - return ibs._ibsdb + return self._ibsdb - def get_chipdir(ibs): - return ibs.chipdir + def get_chipdir(self): + return self.chipdir - def get_probchip_dir(ibs): - return join(ibs.get_cachedir(), 'prob_chips') + def get_probchip_dir(self): + return join(self.get_cachedir(), 'prob_chips') - def get_fig_dir(ibs): + def get_fig_dir(self): """ ibs internal directory """ - return join(ibs._ibsdb, 'figures') + return join(self._ibsdb, 'figures') - def get_imgdir(ibs): + def get_imgdir(self): """ ibs internal directory """ - return ibs.imgdir + return self.imgdir - def get_uploadsdir(ibs): + def get_uploadsdir(self): """ ibs internal directory """ - return ibs.uploadsdir + return self.uploadsdir - def get_thumbdir(ibs): + def get_thumbdir(self): """ database directory where thumbnails are cached """ - return ibs.thumb_dpath + return self.thumb_dpath - def get_workdir(ibs): + def get_workdir(self): """ directory where databases are saved to """ - return ibs.workdir + return self.workdir - def get_cachedir(ibs): + def get_cachedir(self): """ database directory of all cached files """ - return ibs.cachedir + return self.cachedir - def get_match_thumbdir(ibs): - match_thumb_dir = ut.unixjoin(ibs.get_cachedir(), 'match_thumbs') + def get_match_thumbdir(self): + match_thumb_dir = ut.unixjoin(self.get_cachedir(), 'match_thumbs') ut.ensuredir(match_thumb_dir) return match_thumb_dir - def get_wbia_resource_dir(ibs): + def get_wbia_resource_dir(self): """ returns the global resource dir in .config or AppData or whatever """ resource_dir = sysres.get_wbia_resource_dir() return resource_dir - def get_detect_modeldir(ibs): + def get_detect_modeldir(self): return join(sysres.get_wbia_resource_dir(), 'detectmodels') - def get_detectimg_cachedir(ibs): + def get_detectimg_cachedir(self): """ Returns: detectimgdir (str): database directory of image resized for detections """ - return join(ibs.cachedir, const.PATH_NAMES.detectimg) + return join(self.cachedir, const.PATH_NAMES.detectimg) - def get_flann_cachedir(ibs): + def get_flann_cachedir(self): """ Returns: flanndir (str): database directory where the FLANN KD-Tree is stored """ - return ibs.flanndir + return self.flanndir - def get_qres_cachedir(ibs): + def get_qres_cachedir(self): """ Returns: qresdir (str): database directory where query results are stored """ - return ibs.qresdir + return self.qresdir - def get_neighbor_cachedir(ibs): - neighbor_cachedir = ut.unixjoin(ibs.get_cachedir(), 'neighborcache2') + def get_neighbor_cachedir(self): + neighbor_cachedir = ut.unixjoin(self.get_cachedir(), 'neighborcache2') return neighbor_cachedir - def get_big_cachedir(ibs): + def get_big_cachedir(self): """ Returns: bigcachedir (str): database directory where aggregate results are stored """ - return ibs.bigcachedir + return self.bigcachedir - def get_smart_patrol_dir(ibs, ensure=True): + def get_smart_patrol_dir(self, ensure=True): """ Args: ensure (bool): @@ -1133,7 +1133,7 @@ def get_smart_patrol_dir(ibs, ensure=True): >>> # verify results >>> ut.assertpath(smart_patrol_dpath, verbose=True) """ - smart_patrol_dpath = join(ibs.dbdir, const.PATH_NAMES.smartpatrol) + smart_patrol_dpath = join(self.dbdir, const.PATH_NAMES.smartpatrol) if ensure: ut.ensuredir(smart_patrol_dpath) return smart_patrol_dpath @@ -1143,7 +1143,7 @@ def get_smart_patrol_dir(ibs, ensure=True): # ------------------ @register_api('/log/current/', methods=['GET']) - def get_current_log_text(ibs): + def get_current_log_text(self): r""" CommandLine: python -m wbia.control.IBEISControl --exec-get_current_log_text @@ -1164,30 +1164,30 @@ def get_current_log_text(ibs): return text @register_api('/api/core/db/info/', methods=['GET']) - def get_dbinfo(ibs): + def get_dbinfo(self): from wbia.other import dbinfo - locals_ = dbinfo.get_dbinfo(ibs) + locals_ = dbinfo.get_dbinfo(self) return locals_['info_str'] - # return ut.repr2(dbinfo.get_dbinfo(ibs), nl=1)['infostr'] + # return ut.repr2(dbinfo.get_dbinfo(self), nl=1)['infostr'] # -------------- # --- MISC ---- # -------------- - def copy_database(ibs, dest_dbdir): + def copy_database(self, dest_dbdir): # TODO: rectify with rsync, script, and merge script. from wbia.init import sysres - sysres.copy_wbiadb(ibs.get_dbdir(), dest_dbdir) + sysres.copy_wbiadb(self.get_dbdir(), dest_dbdir) - def dump_database_csv(ibs): - dump_dir = join(ibs.get_dbdir(), 'CSV_DUMP') - ibs.db.dump_tables_to_csv(dump_dir=dump_dir) + def dump_database_csv(self): + dump_dir = join(self.get_dbdir(), 'CSV_DUMP') + self.db.dump_tables_to_csv(dump_dir=dump_dir) with open(join(dump_dir, '_ibsdb.dump'), 'w') as fp: - dump(ibs.db.connection, fp) + dump(self.db.connection, fp) - def get_database_icon(ibs, max_dsize=(None, 192), aid=None): + def get_database_icon(self, max_dsize=(None, 192), aid=None): r""" Args: max_dsize (tuple): (default = (None, 192)) @@ -1204,58 +1204,58 @@ def get_database_icon(ibs, max_dsize=(None, 192), aid=None): >>> from wbia.control.IBEISControl import * # NOQA >>> import wbia >>> ibs = wbia.opendb(defaultdb='testdb1') - >>> icon = ibs.get_database_icon() + >>> icon = self.get_database_icon() >>> ut.quit_if_noshow() >>> import wbia.plottool as pt >>> pt.imshow(icon) >>> ut.show_if_requested() """ - # if ibs.get_dbname() == 'Oxford': + # if self.get_dbname() == 'Oxford': # pass # else: import vtool as vt - if hasattr(ibs, 'force_icon_aid'): - aid = ibs.force_icon_aid + if hasattr(self, 'force_icon_aid'): + aid = self.force_icon_aid if aid is None: - species = ibs.get_primary_database_species() + species = self.get_primary_database_species() # Use a url to get the icon url = { - ibs.const.TEST_SPECIES.GIR_MASAI: 'http://i.imgur.com/tGDVaKC.png', - ibs.const.TEST_SPECIES.ZEB_PLAIN: 'http://i.imgur.com/2Ge1PRg.png', - ibs.const.TEST_SPECIES.ZEB_GREVY: 'http://i.imgur.com/PaUT45f.png', + self.const.TEST_SPECIES.GIR_MASAI: 'http://i.imgur.com/tGDVaKC.png', + self.const.TEST_SPECIES.ZEB_PLAIN: 'http://i.imgur.com/2Ge1PRg.png', + self.const.TEST_SPECIES.ZEB_GREVY: 'http://i.imgur.com/PaUT45f.png', }.get(species, None) if url is not None: icon = vt.imread(ut.grab_file_url(url), orient='auto') else: # HACK: (this should probably be a db setting) # use an specific aid to get the icon - aid = {'Oxford': 73, 'seaturtles': 37}.get(ibs.get_dbname(), None) + aid = {'Oxford': 73, 'seaturtles': 37}.get(self.get_dbname(), None) if aid is None: # otherwise just grab a random aid - aid = ibs.get_valid_aids()[0] + aid = self.get_valid_aids()[0] if aid is not None: - icon = ibs.get_annot_chips(aid) + icon = self.get_annot_chips(aid) icon = vt.resize_to_maxdims(icon, max_dsize) return icon - def _custom_ibsstr(ibs): + def _custom_ibsstr(self): # typestr = ut.type_str(type(ibs)).split('.')[-1] - typestr = ibs.__class__.__name__ - dbname = ibs.get_dbname() + typestr = self.__class__.__name__ + dbname = self.get_dbname() # hash_str = hex(id(ibs)) # ibsstr = '<%s(%s) at %s>' % (typestr, dbname, hash_str, ) - hash_str = ibs.get_db_init_uuid() + hash_str = self.get_db_init_uuid() ibsstr = '<%s(%s) with UUID %s>' % (typestr, dbname, hash_str) return ibsstr - def __str__(ibs): - return ibs._custom_ibsstr() + def __str__(self): + return self._custom_ibsstr() - def __repr__(ibs): - return ibs._custom_ibsstr() + def __repr__(self): + return self._custom_ibsstr() - def __getstate__(ibs): + def __getstate__(self): """ Example: >>> # ENABLE_DOCTEST @@ -1267,12 +1267,12 @@ def __getstate__(ibs): """ # Hack to allow for wbia objects to be pickled state = { - 'dbdir': ibs.get_dbdir(), + 'dbdir': self.get_dbdir(), 'machine_name': ut.get_computer_name(), } return state - def __setstate__(ibs, state): + def __setstate__(self, state): # Hack to allow for wbia objects to be pickled import wbia @@ -1288,20 +1288,20 @@ def __setstate__(ibs, state): if not iswarning: raise ibs2 = wbia.opendb(dbdir=dbdir, web=False) - ibs.__dict__.update(**ibs2.__dict__) + self.__dict__.update(**ibs2.__dict__) - def predict_ws_injury_interim_svm(ibs, aids): + def predict_ws_injury_interim_svm(self, aids): from wbia.scripts import classify_shark - return classify_shark.predict_ws_injury_interim_svm(ibs, aids) + return classify_shark.predict_ws_injury_interim_svm(self, aids) def get_web_port_via_scan( - ibs, url_base='127.0.0.1', port_base=5000, scan_limit=100, verbose=True + self, url_base='127.0.0.1', port_base=5000, scan_limit=100, verbose=True ): import requests api_rule = CORE_DB_UUID_INIT_API_RULE - target_uuid = ibs.get_db_init_uuid() + target_uuid = self.get_db_init_uuid() for candidate_port in range(port_base, port_base + scan_limit + 1): candidate_url = 'http://%s:%s%s' % (url_base, candidate_port, api_rule) try: From 3aacc2912de9e52fa29664aa7196471ceaa5c906 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Tue, 20 Oct 2020 10:41:24 -0700 Subject: [PATCH 085/294] Fix class to use the 'self' argument convention --- wbia/dtool/depcache_table.py | 904 +++++++++++++++++------------------ 1 file changed, 449 insertions(+), 455 deletions(-) diff --git a/wbia/dtool/depcache_table.py b/wbia/dtool/depcache_table.py index a48fb6868e..ae083d0aea 100644 --- a/wbia/dtool/depcache_table.py +++ b/wbia/dtool/depcache_table.py @@ -171,7 +171,7 @@ def ensure_config_table(db): class _TableConfigHelper(object): """ helper for configuration table """ - def get_parent_rowids(table, rowid_list): + def get_parent_rowids(self, rowid_list): """ Args: rowid_list (list): native table rowids @@ -185,12 +185,12 @@ def get_parent_rowids(table, rowid_list): >>> # Then add two items to this table, and for each item >>> # Find their parent inputs """ - parent_rowids = table.get_internal_columns( - rowid_list, table.parent_id_colnames, unpack_scalars=True, keepwrap=True + parent_rowids = self.get_internal_columns( + rowid_list, self.parent_id_colnames, unpack_scalars=True, keepwrap=True ) return parent_rowids - def get_parent_rowargs(table, rowid_list): + def get_parent_rowargs(self, rowid_list): """ Args: rowid_list (list): native table rowids @@ -204,18 +204,18 @@ def get_parent_rowargs(table, rowid_list): >>> # Then add two items to this table, and for each item >>> # Find their parent inputs """ - parent_rowids = table.get_parent_rowids(rowid_list) - parent_ismulti = table.get_parent_col_attr('ismulti') + parent_rowids = self.get_parent_rowids(rowid_list) + parent_ismulti = self.get_parent_col_attr('ismulti') if any(parent_ismulti): # If any of the parent columns are multi-indexes, then lookup the # mapping from the aggregated uuid to the expanded rowid set. parent_args = [] - model_uuids = table.get_model_uuid(rowid_list) + model_uuids = self.get_model_uuid(rowid_list) for rowid, uuid, p_id_list in zip(rowid_list, model_uuids, parent_rowids): - input_info = table.get_model_inputs(uuid) + input_info = self.get_model_inputs(uuid) fixed_args = [] for p_name, p_id, flag in zip( - table.parent_id_colnames, p_id_list, parent_ismulti + self.parent_id_colnames, p_id_list, parent_ismulti ): if flag: new_p_id = input_info[p_name + '_model_input'] @@ -231,7 +231,7 @@ def get_parent_rowargs(table, rowid_list): parent_args = parent_rowids return parent_args - def get_row_parent_rowid_map(table, rowid_list): + def get_row_parent_rowid_map(self, rowid_list): """ >>> from wbia.dtool.depcache_table import * # NOQA @@ -239,13 +239,13 @@ def get_row_parent_rowid_map(table, rowid_list): key = parent_rowid_dict.keys()[0] val = parent_rowid_dict.values()[0] """ - parent_rowids = table.get_parent_rowids(rowid_list) + parent_rowids = self.get_parent_rowids(rowid_list) parent_rowid_dict = dict( - zip(table.parent_id_tablenames, ut.list_transpose(parent_rowids)) + zip(self.parent_id_tablenames, ut.list_transpose(parent_rowids)) ) return parent_rowid_dict - def get_config_history(table, rowid_list, assume_unique=True): + def get_config_history(self, rowid_list, assume_unique=True): """ Returns the list of config objects for all properties in the dependency history of this object. Multi-edges are handled. Set assume_unique to @@ -260,23 +260,23 @@ def get_config_history(table, rowid_list, assume_unique=True): """ if assume_unique: rowid_list = rowid_list[0:1] - tbl_cfgids = table.get_row_cfgid(rowid_list) + tbl_cfgids = self.get_row_cfgid(rowid_list) cfgid2_rowids = ut.group_items(rowid_list, tbl_cfgids) unique_cfgids = cfgid2_rowids.keys() unique_cfgids = ut.filter_Nones(unique_cfgids) if len(unique_cfgids) == 0: return None - unique_configs = table.get_config_from_rowid(unique_cfgids) + unique_configs = self.get_config_from_rowid(unique_cfgids) - # parent_rowids = table.get_parent_rowids(rowid_list) - parent_rowargs = table.get_parent_rowargs(rowid_list) + # parent_rowids = self.get_parent_rowids(rowid_list) + parent_rowargs = self.get_parent_rowargs(rowid_list) ret_list = [unique_configs] - depc = table.depc + depc = self.depc rowargsT = ut.listT(parent_rowargs) - parent_ismulti = table.get_parent_col_attr('ismulti') + parent_ismulti = self.get_parent_col_attr('ismulti') for tblname, ismulti, ids in zip( - table.parent_id_tablenames, parent_ismulti, rowargsT + self.parent_id_tablenames, parent_ismulti, rowargsT ): if tblname == depc.root: continue @@ -288,16 +288,16 @@ def get_config_history(table, rowid_list, assume_unique=True): ret_list.extend(ancestor_configs) return ret_list - def __remove_old_configs(table): + def __remove_old_configs(self): """ table = ibs.depc['pairwise_match'] """ # developing - # c = table.db.get_table_as_pandas('config') - # t = table.db.get_table_as_pandas(table.tablename) + # c = self.db.get_table_as_pandas('config') + # t = self.db.get_table_as_pandas(self.tablename) - # config_rowids = table.db.get_all_rowids(CONFIG_TABLE) - # cfgdict_list = table.db.get( + # config_rowids = self.db.get_all_rowids(CONFIG_TABLE) + # cfgdict_list = self.db.get( # CONFIG_TABLE, colnames=(CONFIG_DICT,), id_iter=config_rowids, # id_colname=CONFIG_ROWID) # bad_rowids = [] @@ -310,10 +310,10 @@ def __remove_old_configs(table): SELECT rowid, {} from {} """ ).format(CONFIG_DICT, CONFIG_TABLE) - table.db.cur.execute(command) + self.db.cur.execute(command) bad_rowids = [] - for rowid, cfgdict in table.db.cur.fetchall(): + for rowid, cfgdict in self.db.cur.fetchall(): # MAKE GENERAL CONDITION if cfgdict['version'] < 7: bad_rowids.append(rowid) @@ -324,17 +324,17 @@ def __remove_old_configs(table): SELECT rowid from {tablename} WHERE config_rowid IN {bad_rowids} """ - ).format(tablename=table.tablename, bad_rowids=in_str) + ).format(tablename=self.tablename, bad_rowids=in_str) # logger.info(command) - table.db.cur.execute(command) - rowids = ut.flatten(table.db.cur.fetchall()) - table.delete_rows(rowids, dry=True, verbose=True, delete_extern=True) + self.db.cur.execute(command) + rowids = ut.flatten(self.db.cur.fetchall()) + self.delete_rows(rowids, dry=True, verbose=True, delete_extern=True) - def get_ancestor_rowids(table, rowid_list, target_table): - parent_rowids = table.get_parent_rowids(rowid_list) - depc = table.depc + def get_ancestor_rowids(self, rowid_list, target_table): + parent_rowids = self.get_parent_rowids(rowid_list) + depc = self.depc for tblname, ids in zip( - table.parent_id_tablenames, ut.list_transpose(parent_rowids) + self.parent_id_tablenames, ut.list_transpose(parent_rowids) ): if tblname == target_table: return ids @@ -344,14 +344,14 @@ def get_ancestor_rowids(table, rowid_list, target_table): return ancestor_ids return None # Base case - def get_row_cfgid(table, rowid_list): + def get_row_cfgid(self, rowid_list): """ >>> from wbia.dtool.depcache_table import * # NOQA """ - config_rowids = table.get_internal_columns(rowid_list, (CONFIG_ROWID,)) + config_rowids = self.get_internal_columns(rowid_list, (CONFIG_ROWID,)) return config_rowids - def get_row_configs(table, rowid_list): + def get_row_configs(self, rowid_list): """ Example: >>> # ENABLE_DOCTEST @@ -363,21 +363,21 @@ def get_row_configs(table, rowid_list): >>> rowid_list = depc.get_rowids('chip', [1, 2], config={}) >>> configs = table.get_row_configs(rowid_list) """ - config_rowids = table.get_row_cfgid(rowid_list) + config_rowids = self.get_row_cfgid(rowid_list) # Only look up the configs that are needed unique_config_rowids, groupxs = ut.group_indices(config_rowids) - unique_configs = table.get_config_from_rowid(unique_config_rowids) + unique_configs = self.get_config_from_rowid(unique_config_rowids) configs = ut.ungroup_unique(unique_configs, groupxs, maxval=len(rowid_list) - 1) return configs - def get_row_cfghashid(table, rowid_list): - config_rowids = table.get_row_cfgid(rowid_list) - config_hashids = table.get_config_hashid(config_rowids) + def get_row_cfghashid(self, rowid_list): + config_rowids = self.get_row_cfgid(rowid_list) + config_hashids = self.get_config_hashid(config_rowids) return config_hashids - def get_row_cfgstr(table, rowid_list): - config_rowids = table.get_row_cfgid(rowid_list) - cfgstr_list = table.db.get( + def get_row_cfgstr(self, rowid_list): + config_rowids = self.get_row_cfgid(rowid_list) + cfgstr_list = self.db.get( CONFIG_TABLE, colnames=(CONFIG_STRID,), id_iter=config_rowids, @@ -385,15 +385,15 @@ def get_row_cfgstr(table, rowid_list): ) return cfgstr_list - def get_config_rowid(table, config=None, _debug=None): + def get_config_rowid(self, config=None, _debug=None): if isinstance(config, int): config_rowid = config else: - config_rowid = table.add_config(config) + config_rowid = self.add_config(config) return config_rowid - def get_config_hashid(table, config_rowid_list): - hashid_list = table.db.get( + def get_config_hashid(self, config_rowid_list): + hashid_list = self.db.get( CONFIG_TABLE, colnames=(CONFIG_HASHID,), id_iter=config_rowid_list, @@ -401,8 +401,8 @@ def get_config_hashid(table, config_rowid_list): ) return hashid_list - def get_config_rowid_from_hashid(table, config_hashid_list): - config_rowid_list = table.db.get( + def get_config_rowid_from_hashid(self, config_hashid_list): + config_rowid_list = self.db.get( CONFIG_TABLE, colnames=(CONFIG_ROWID,), id_iter=config_hashid_list, @@ -410,20 +410,19 @@ def get_config_rowid_from_hashid(table, config_hashid_list): ) return config_rowid_list - def get_config_from_rowid(table, config_rowids): - cfgdict_list = table.db.get( + def get_config_from_rowid(self, config_rowids): + cfgdict_list = self.db.get( CONFIG_TABLE, colnames=(CONFIG_DICT,), id_iter=config_rowids, id_colname=CONFIG_ROWID, ) return [ - None if dict_ is None else table.configclass(**dict_) - for dict_ in cfgdict_list + None if dict_ is None else self.configclass(**dict_) for dict_ in cfgdict_list ] # @profile - def add_config(table, config, _debug=None): + def add_config(self, config, _debug=None): try: # assume config is AlgoRequest or TableConfig config_strid = config.get_cfgstr() @@ -432,14 +431,14 @@ def add_config(table, config, _debug=None): config_hashid = ut.hashstr27(config_strid) logger.debug('config_strid = %r' % (config_strid,)) logger.debug('config_hashid = %r' % (config_hashid,)) - get_rowid_from_superkey = table.get_config_rowid_from_hashid + get_rowid_from_superkey = self.get_config_rowid_from_hashid colnames = (CONFIG_HASHID, CONFIG_TABLENAME, CONFIG_STRID, CONFIG_DICT) if hasattr(config, 'config'): # Hack for requests config = config.config cfgdict = config.__getstate__() - param_list = [(config_hashid, table.tablename, config_strid, cfgdict)] - config_rowid_list = table.db.add_cleanly( + param_list = [(config_hashid, self.tablename, config_strid, cfgdict)] + config_rowid_list = self.db.add_cleanly( CONFIG_TABLE, colnames, param_list, get_rowid_from_superkey ) config_rowid = config_rowid_list[0] @@ -453,11 +452,11 @@ class _TableDebugHelper(object): Contains printing and debug things """ - def print_sql_info(table): - add_op = table.db._make_add_table_sqlstr(sep='\n ', **table._get_addtable_kw()) + def print_sql_info(self): + add_op = self.db._make_add_table_sqlstr(sep='\n ', **self._get_addtable_kw()) ut.cprint(add_op, 'sql') - def print_internal_info(table, all_attrs=False): + def print_internal_info(self, all_attrs=False): """ CommandLine: python -m dtool.depcache_table --exec-print_internal_info @@ -472,77 +471,75 @@ def print_internal_info(table, all_attrs=False): >>> table.print_internal_info() """ logger.info('----') - logger.info(table) + logger.info(self) # Print the other inferred attrs logger.info( - 'table.parent_col_attrs = %s' % (ut.repr3(table.parent_col_attrs, nl=2),) + 'self.parent_col_attrs = %s' % (ut.repr3(self.parent_col_attrs, nl=2),) ) - logger.info('table.data_col_attrs = %s' % (ut.repr3(table.data_col_attrs, nl=2),)) + logger.info('self.data_col_attrs = %s' % (ut.repr3(self.data_col_attrs, nl=2),)) # Print the inferred allcol attrs ut.cprint( - 'table.internal_col_attrs = %s' - % (ut.repr3(table.internal_col_attrs, nl=1, sorted_=False)), + 'self.internal_col_attrs = %s' + % (ut.repr3(self.internal_col_attrs, nl=1, sorted_=False)), 'python', ) - add_table_kw = table._get_addtable_kw() - logger.info('table.add_table_kw = %s' % (ut.repr2(add_table_kw, nl=2),)) - table.print_sql_info() + add_table_kw = self._get_addtable_kw() + logger.info('self.add_table_kw = %s' % (ut.repr2(add_table_kw, nl=2),)) + self.print_sql_info() if all_attrs: # Print all attributes - for a in ut.get_instance_attrnames( - table, with_properties=True, default=False - ): - logger.info(' table.%s = %r' % (a, getattr(table, a))) + for a in ut.get_instance_attrnames(self, with_properties=True, default=False): + logger.info(' self.%s = %r' % (a, getattr(self, a))) - def print_table(table): - table.db.print_table_csv(table.tablename) - # if table.ismulti: - # table.print_model_manifests() + def print_table(self): + self.db.print_table_csv(self.tablename) + # if self.ismulti: + # self.print_model_manifests() - def print_info(table, with_colattrs=True, with_graphattrs=True): + def print_info(self, with_colattrs=True, with_graphattrs=True): """ debug function """ logger.info('TABLE ATTRIBUTES') - logger.info('table.tablename = %r' % (table.tablename,)) - logger.info('table.isinteractive = %r' % (table.isinteractive,)) - logger.info('table.default_onthefly = %r' % (table.default_onthefly,)) - logger.info('table.rm_extern_on_delete = %r' % (table.rm_extern_on_delete,)) - logger.info('table.chunksize = %r' % (table.chunksize,)) - logger.info('table.fname = %r' % (table.fname,)) - logger.info('table.docstr = %r' % (table.docstr,)) - logger.info('table.data_colnames = %r' % (table.data_colnames,)) - logger.info('table.data_coltypes = %r' % (table.data_coltypes,)) + logger.info('self.tablename = %r' % (self.tablename,)) + logger.info('self.isinteractive = %r' % (self.isinteractive,)) + logger.info('self.default_onthefly = %r' % (self.default_onthefly,)) + logger.info('self.rm_extern_on_delete = %r' % (self.rm_extern_on_delete,)) + logger.info('self.chunksize = %r' % (self.chunksize,)) + logger.info('self.fname = %r' % (self.fname,)) + logger.info('self.docstr = %r' % (self.docstr,)) + logger.info('self.data_colnames = %r' % (self.data_colnames,)) + logger.info('self.data_coltypes = %r' % (self.data_coltypes,)) if with_graphattrs: logger.info('TABLE GRAPH ATTRIBUTES') - logger.info('table.children = %r' % (table.children,)) - logger.info('table.parent = %r' % (table.parent,)) - logger.info('table.configclass = %r' % (table.configclass,)) - logger.info('table.requestclass = %r' % (table.requestclass,)) + logger.info('self.children = %r' % (self.children,)) + logger.info('self.parent = %r' % (self.parent,)) + logger.info('self.configclass = %r' % (self.configclass,)) + logger.info('self.requestclass = %r' % (self.requestclass,)) if with_colattrs: nl = 1 logger.info('TABEL COLUMN ATTRIBUTES') logger.info( - 'table.data_col_attrs = %s' % (ut.repr3(table.data_col_attrs, nl=nl),) + 'self.data_col_attrs = %s' % (ut.repr3(self.data_col_attrs, nl=nl),) ) logger.info( - 'table.parent_col_attrs = %s' % (ut.repr3(table.parent_col_attrs, nl=nl),) + 'self.parent_col_attrs = %s' % (ut.repr3(self.parent_col_attrs, nl=nl),) ) logger.info( - 'table.internal_data_col_attrs = %s' - % (ut.repr3(table.internal_data_col_attrs, nl=nl),) + 'self.internal_data_col_attrs = %s' + % (ut.repr3(self.internal_data_col_attrs, nl=nl),) ) logger.info( - 'table.internal_parent_col_attrs = %s' - % (ut.repr3(table.internal_parent_col_attrs, nl=nl),) + 'self.internal_parent_col_attrs = %s' + % (ut.repr3(self.internal_parent_col_attrs, nl=nl),) ) logger.info( - 'table.internal_col_attrs = %s' - % (ut.repr3(table.internal_col_attrs, nl=nl),) + 'self.internal_col_attrs = %s' + % (ut.repr3(self.internal_col_attrs, nl=nl),) ) - def print_schemadef(table): - logger.info('\n'.join(table.db.get_table_autogen_str(table.tablename))) + def print_schemadef(self): + logger.info('\n'.join(self.db.get_table_autogen_str(self.tablename))) - def print_configs(table): + def print_configs(self): """ CommandLine: python -m dtool.depcache_table --exec-print_configs @@ -562,28 +559,28 @@ def print_configs(table): >>> rowids = depc.get_rowids('spam', [1, 2]) >>> table.print_configs() """ - text = table.db.get_table_csv(CONFIG_TABLE) + text = self.db.get_table_csv(CONFIG_TABLE) logger.info(text) - def print_csv(table, truncate=True): - logger.info(table.db.get_table_csv(table.tablename, truncate=truncate)) + def print_csv(self, truncate=True): + logger.info(self.db.get_table_csv(self.tablename, truncate=truncate)) - def print_model_manifests(table): + def print_model_manifests(self): logger.info('manifests') - rowids = table._get_all_rowids() - uuids = table.get_model_uuid(rowids) + rowids = self._get_all_rowids() + uuids = self.get_model_uuid(rowids) for rowid, uuid in zip(rowids, uuids): logger.info('rowid = %r' % (rowid,)) - logger.info(ut.repr3(table.get_model_inputs(uuid), nl=1)) + logger.info(ut.repr3(self.get_model_inputs(uuid), nl=1)) - def _assert_self(table): - assert len(table.data_colnames) == len( - table.data_coltypes + def _assert_self(self): + assert len(self.data_colnames) == len( + self.data_coltypes ), 'specify same number of colnames and coltypes' - if table.preproc_func is not None: + if self.preproc_func is not None: # Check that preproc_func has a valid signature # ie (depc, parent_ids, config) - argspec = ut.get_func_argspec(table.preproc_func) + argspec = ut.get_func_argspec(self.preproc_func) args = argspec.args if argspec.varargs and argspec.keywords: assert len(args) == 1, 'varargs and kwargs must have one arg for depcache' @@ -594,22 +591,20 @@ def _assert_self(table): 'preproc_func=%r for table=%s must have a ' 'depcache arg, at least one parent rowid arg, ' 'and a config arg' - ) % (table.preproc_func, table.tablename) + ) % (self.preproc_func, self.tablename) raise AssertionError(msg) rowid_args = args[1:-1] - if len(rowid_args) != len(table.parents()): - logger.info('table.preproc_func = %r' % (table.preproc_func,)) + if len(rowid_args) != len(self.parents()): + logger.info('self.preproc_func = %r' % (self.preproc_func,)) logger.info('args = %r' % (args,)) logger.info('rowid_args = %r' % (rowid_args,)) msg = ( 'preproc function for table=%s must have as many ' 'rowids %d args as parent %d' - ) % (table.tablename, len(rowid_args), len(table.parents())) + ) % (self.tablename, len(rowid_args), len(self.parents())) raise AssertionError(msg) extern_class_colattrs = [ - colattr - for colattr in table.data_col_attrs - if colattr.get('is_external_class') + colattr for colattr in self.data_col_attrs if colattr.get('is_external_class') ] for colattr in extern_class_colattrs: cls = colattr['coltype'] @@ -634,7 +629,7 @@ class _TableInternalSetup(ub.NiceRepr): """ helper that sets up column information """ @profile - def _infer_datacol(table): + def _infer_datacol(self): """ Constructs the columns needed to represent relationship to data @@ -662,7 +657,7 @@ def _infer_datacol(table): data_col_attrs = [] # Parse column datatypes - _iter = enumerate(zip(table.data_colnames, table.data_coltypes)) + _iter = enumerate(zip(self.data_colnames, self.data_coltypes)) for data_colx, (colname, coltype) in _iter: colattr = ut.odict() # Check column input subtypes @@ -729,7 +724,7 @@ def _infer_datacol(table): assert hasattr(coltype, '__getstate__') and hasattr( coltype, '__setstate__' ), ('External classes must have __getstate__ and ' '__setstate__ methods') - read_func, write_func = make_extern_io_funcs(table, coltype) + read_func, write_func = make_extern_io_funcs(self, coltype) sqltype = TYPE_TO_SQLTYPE[str] intern_colname = colname + EXTERN_SUFFIX # raise AssertionError('external class columns') @@ -744,7 +739,7 @@ def _infer_datacol(table): return data_col_attrs @profile - def _infer_parentcol(table): + def _infer_parentcol(self): """ construct columns to represent relationship to parent @@ -776,7 +771,7 @@ def _infer_parentcol(table): >>> depc.d.get_indexer_data([ >>> uuid.UUID('a01eda32-e4e0-b139-3274-e91d1b3e9ecf')]) """ - parent_tablenames = table.parent_tablenames + parent_tablenames = self.parent_tablenames parent_col_attrs = [] # Handle dependencies when a parent are pairwise between tables @@ -863,7 +858,7 @@ def _infer_parentcol(table): return parent_col_attrs @profile - def _infer_allcol(table): + def _infer_allcol(self): r""" Combine information from parentcol and datacol Build column definitions that will directly define SQL columns @@ -873,7 +868,7 @@ def _infer_allcol(table): # Append primary column colattr = ut.odict( [ - ('intern_colname', table.rowid_colname), + ('intern_colname', self.rowid_colname), ('sqltype', 'INTEGER PRIMARY KEY'), ('isprimary', True), ] @@ -883,7 +878,7 @@ def _infer_allcol(table): # Append parent columns ismulti = False - for parent_colattr in table.parent_col_attrs: + for parent_colattr in self.parent_col_attrs: colattr = ut.odict() colattr['intern_colname'] = parent_colattr['intern_colname'] colattr['parent_table'] = parent_colattr['parent_table'] @@ -913,26 +908,26 @@ def _infer_allcol(table): internal_col_attrs.append(colattr) # Append quick access column - # return any(table.get_parent_col_attr('ismulti')) - # if table.ismulti: + # return any(self.get_parent_col_attr('ismulti')) + # if self.ismulti: if ismulti: # Append model uuid column colattr = ut.odict() - colattr['intern_colname'] = table.model_uuid_colname + colattr['intern_colname'] = self.model_uuid_colname colattr['sqltype'] = 'UUID NOT NULL' colattr['intern_colx'] = len(internal_col_attrs) internal_col_attrs.append(colattr) # Append model uuid column colattr = ut.odict() - colattr['intern_colname'] = table.is_augmented_colname + colattr['intern_colname'] = self.is_augmented_colname colattr['sqltype'] = 'INTEGER DEFAULT 0' colattr['intern_colx'] = len(internal_col_attrs) internal_col_attrs.append(colattr) if False: # TODO: eventually enable - if table.taggable: + if self.taggable: colattr = ut.odict() colattr['intern_colname'] = 'model_tag' colattr['sqltype'] = 'TEXT' @@ -944,7 +939,7 @@ def _infer_allcol(table): pass # Append data columns - for data_colattr in table.data_col_attrs: + for data_colattr in self.data_col_attrs: colname = data_colattr['colname'] if data_colattr.get('isnested', False): for nestcol in data_colattr['nestattrs']: @@ -971,7 +966,7 @@ def _infer_allcol(table): internal_col_attrs.append(colattr) # Append extra columns - for parent_colattr in table.parent_col_attrs: + for parent_colattr in self.parent_col_attrs: for extra_colattr in parent_colattr.get('extra_cols', []): colattr = ut.odict() colattr['intern_colname'] = extra_colattr['intern_colname'] @@ -986,176 +981,175 @@ def _infer_allcol(table): class _TableGeneralHelper(ub.NiceRepr): """ helper """ - def __nice__(table): - num_parents = len(table.parent_tablenames) - num_cols = len(table.data_colnames) + def __nice__(self): + num_parents = len(self.parent_tablenames) + num_cols = len(self.data_colnames) return '(%s) nP=%d%s nC=%d' % ( - table.tablename, + self.tablename, num_parents, - '*' if False and table.ismulti else '', + '*' if False and self.ismulti else '', num_cols, ) # @property - # def _table_colnames(table): + # def _table_colnames(self): # return @property - def extern_dpath(table): - cache_dpath = table.depc.cache_dpath - extern_dname = 'extern_' + table.tablename + def extern_dpath(self): + cache_dpath = self.depc.cache_dpath + extern_dname = 'extern_' + self.tablename extern_dpath = join(cache_dpath, extern_dname) return extern_dpath @property - def dpath(table): + def dpath(self): # assert table.ismulti, 'only valid for models' - dname = table.tablename + '_storage' - dpath = join(table.depc.cache_dpath, dname) + dname = self.tablename + '_storage' + dpath = join(self.depc.cache_dpath, dname) # ut.ensuredir(dpath) return dpath - # def dpath(table): + # def dpath(self): # from os.path import dirname - # dpath = dirname(table.db.fpath) + # dpath = dirname(self.db.fpath) # return dpath @property @ut.memoize - def ismulti(table): + def ismulti(self): # TODO: or has multi parent - return any(table.get_parent_col_attr('ismulti')) + return any(self.get_parent_col_attr('ismulti')) @property - def configclass(table): - return table.depc.configclass_dict[table.tablename] + def configclass(self): + return self.depc.configclass_dict[self.tablename] @property - def requestclass(table): - return table.depc.requestclass_dict.get(table.tablename, None) + def requestclass(self): + return self.depc.requestclass_dict.get(self.tablename, None) - def new_request(table, qaids, daids, cfgdict=None): - request = table.depc.new_request(table.tablename, qaids, daids, cfgdict=cfgdict) + def new_request(self, qaids, daids, cfgdict=None): + request = self.depc.new_request(self.tablename, qaids, daids, cfgdict=cfgdict) return request # --- Standard Properties @property - def internal_data_col_attrs(table): - flags = table.get_intern_col_attr('isdata') - return ut.compress(table.internal_col_attrs, flags) + def internal_data_col_attrs(self): + flags = self.get_intern_col_attr('isdata') + return ut.compress(self.internal_col_attrs, flags) @property - def internal_parent_col_attrs(table): - flags = table.get_intern_col_attr('isparent') - return ut.compress(table.internal_col_attrs, flags) + def internal_parent_col_attrs(self): + flags = self.get_intern_col_attr('isparent') + return ut.compress(self.internal_col_attrs, flags) # --- / Standard Properties @ut.memoize - def get_parent_col_attr(table, key): - return ut.dict_take_column(table.parent_col_attrs, key) + def get_parent_col_attr(self, key): + return ut.dict_take_column(self.parent_col_attrs, key) @ut.memoize - def get_intern_data_col_attr(table, key): - return ut.dict_take_column(table.internal_data_col_attrs, key) + def get_intern_data_col_attr(self, key): + return ut.dict_take_column(self.internal_data_col_attrs, key) @ut.memoize - def get_intern_parent_col_attr(table, key): - return ut.dict_take_column(table.internal_parent_col_attrs, key) + def get_intern_parent_col_attr(self, key): + return ut.dict_take_column(self.internal_parent_col_attrs, key) @ut.memoize - def get_intern_col_attr(table, key): - return ut.dict_take_column(table.internal_col_attrs, key) + def get_intern_col_attr(self, key): + return ut.dict_take_column(self.internal_col_attrs, key) @ut.memoize - def get_data_col_attr(table, key): - return ut.dict_take_column(table.data_col_attrs, key) + def get_data_col_attr(self, key): + return ut.dict_take_column(self.data_col_attrs, key) @property @ut.memoize - def parent_id_tablenames(table): + def parent_id_tablenames(self): tablenames = tuple( - [parent_colattr['parent_table'] for parent_colattr in table.parent_col_attrs] + [parent_colattr['parent_table'] for parent_colattr in self.parent_col_attrs] ) return tablenames @property @ut.memoize - def parent_id_prefix(table): + def parent_id_prefix(self): prefixes = tuple( - [parent_colattr['prefix'] for parent_colattr in table.parent_col_attrs] + [parent_colattr['prefix'] for parent_colattr in self.parent_col_attrs] ) return prefixes @property - def extern_columns(table): - colnames = table.get_data_col_attr('colname') - flags = table.get_data_col_attr('is_extern') + def extern_columns(self): + colnames = self.get_data_col_attr('colname') + flags = self.get_data_col_attr('is_extern') return ut.compress(colnames, flags) @property - def rowid_colname(table): + def rowid_colname(self): """ rowid of this table used by other dependant tables """ - return table.tablename + '_rowid' + return self.tablename + '_rowid' @property - def superkey_colnames(table): - return table.parent_id_colnames + (CONFIG_ROWID,) + def superkey_colnames(self): + return self.parent_id_colnames + (CONFIG_ROWID,) @property - def model_uuid_colname(table): + def model_uuid_colname(self): return 'model_uuid' @property - def is_augmented_colname(table): + def is_augmented_colname(self): return 'augment_bit' @property - def parent_id_colnames(table): - return tuple([colattr['intern_colname'] for colattr in table.parent_col_attrs]) + def parent_id_colnames(self): + return tuple([colattr['intern_colname'] for colattr in self.parent_col_attrs]) - def get_rowids_from_root(table, root_rowids, config=None): - return table.depc.get_rowids(table.tablename, root_rowids, config=config) + def get_rowids_from_root(self, root_rowids, config=None): + return self.depc.get_rowids(self.tablename, root_rowids, config=config) @property @ut.memoize - def parent(table): + def parent(self): return ut.odict( [ (parent_colattr['parent_table'], parent_colattr) - for parent_colattr in table.parent_col_attrs + for parent_colattr in self.parent_col_attrs ] ) # return tuple([parent_colattr['parent_table'] - # for parent_colattr in table.parent_col_attrs]) + # for parent_colattr in self.parent_col_attrs]) @ut.memoize - def parents(table, data=None): + def parents(self, data=None): if data: return [ (parent_colattr['parent_table'], parent_colattr) - for parent_colattr in table.parent_col_attrs + for parent_colattr in self.parent_col_attrs ] else: return [ - parent_colattr['parent_table'] - for parent_colattr in table.parent_col_attrs + parent_colattr['parent_table'] for parent_colattr in self.parent_col_attrs ] @property - def children(table): - graph = table.depc.explicit_graph - children_tablenames = list(nx.neighbors(graph, table.tablename)) + def children(self): + graph = self.depc.explicit_graph + children_tablenames = list(nx.neighbors(graph, self.tablename)) return children_tablenames @property - def ancestors(table): - graph = table.depc.explicit_graph - children_tablenames = list(nx.ancestors(graph, table.tablename)) + def ancestors(self): + graph = self.depc.explicit_graph + children_tablenames = list(nx.ancestors(graph, self.tablename)) return children_tablenames - def show_dep_subgraph(table, inter=None): + def show_dep_subgraph(self, inter=None): from wbia.plottool.interactions import ExpandableInteraction autostart = inter is None @@ -1163,8 +1157,8 @@ def show_dep_subgraph(table, inter=None): inter = ExpandableInteraction(nCols=2) import wbia.plottool as pt - graph = table.depc.explicit_graph - nodes = ut.nx_all_nodes_between(graph, None, table.tablename) + graph = self.depc.explicit_graph + nodes = ut.nx_all_nodes_between(graph, None, self.tablename) G = graph.subgraph(nodes) plot_kw = {'fontname': 'Ubuntu'} @@ -1172,14 +1166,14 @@ def show_dep_subgraph(table, inter=None): ut.partial( pt.show_nx, G, - title='Dependency Subgraph (%s)' % (table.tablename), + title='Dependency Subgraph (%s)' % (self.tablename), **plot_kw, ) ) if autostart: inter.start() - def show_input_graph(table, inter=None): + def show_input_graph(self, inter=None): """ CommandLine: python -m dtool.depcache_table show_input_graph --show @@ -1201,8 +1195,8 @@ def show_input_graph(table, inter=None): autostart = inter is None if inter is None: inter = ExpandableInteraction(nCols=2) - table.show_dep_subgraph(inter) - inputs = table.rootmost_inputs + self.show_dep_subgraph(inter) + inputs = self.rootmost_inputs inter = inputs.show_exi_graph(inter) if autostart: inter.start() @@ -1210,7 +1204,7 @@ def show_input_graph(table, inter=None): @property @ut.memoize - def expanded_input_graph(table): + def expanded_input_graph(self): """ CommandLine: python -m dtool.depcache_table --exec-expanded_input_graph --show --table=neighbs @@ -1238,13 +1232,13 @@ def expanded_input_graph(table): """ from wbia.dtool import input_helpers - graph = table.depc.explicit_graph.copy() - target = table.tablename + graph = self.depc.explicit_graph.copy() + target = self.tablename exi_graph = input_helpers.make_expanded_input_graph(graph, target) return exi_graph @property - def rootmost_inputs(table): + def rootmost_inputs(self): """ CommandLine: python -m dtool.depcache_table rootmost_inputs --show @@ -1266,24 +1260,24 @@ def rootmost_inputs(table): """ from wbia.dtool import input_helpers - exi_graph = table.expanded_input_graph - rootmost_inputs = input_helpers.get_rootmost_inputs(exi_graph, table) + exi_graph = self.expanded_input_graph + rootmost_inputs = input_helpers.get_rootmost_inputs(exi_graph, self) return rootmost_inputs @ut.memoize - def requestable_col_attrs(table): + def requestable_col_attrs(self): """ Maps names of requestable columns to indicies of internal columns """ requestable_col_attrs = {} - for colattr in table.internal_data_col_attrs: + for colattr in self.internal_data_col_attrs: rattr = {} colname = colattr['intern_colname'] rattr['intern_colx'] = colattr['intern_colx'] rattr['intern_colname'] = colattr['intern_colname'] requestable_col_attrs[colname] = rattr - for colattr in table.data_col_attrs: + for colattr in self.data_col_attrs: rattr = {} if colattr.get('isnested'): nest_internal_names = ut.take_column(colattr['nestattrs'], 'flat_colname') @@ -1306,11 +1300,11 @@ def requestable_col_attrs(table): return requestable_col_attrs @ut.memoize - def computable_colnames(table): + def computable_colnames(self): # These are the colnames that we expect to be computed - intern_colnames = ut.take_column(table.internal_col_attrs, 'intern_colname') + intern_colnames = ut.take_column(self.internal_col_attrs, 'intern_colname') insertable_flags = [ - not colattr.get('isprimary') for colattr in table.internal_col_attrs + not colattr.get('isprimary') for colattr in self.internal_col_attrs ] colnames = tuple(ut.compress(intern_colnames, insertable_flags)) return colnames @@ -1322,7 +1316,7 @@ class _TableComputeHelper(object): # @profile def prepare_storage( - table, dirty_parent_ids, proptup_gen, dirty_preproc_args, config_rowid, config + self, dirty_parent_ids, proptup_gen, dirty_preproc_args, config_rowid, config ): """ Converts output from ``preproc_func`` to data that can be stored in SQL @@ -1355,19 +1349,19 @@ def prepare_storage( >>> table.print_model_manifests() >>> #ut.vd(depc.cache_dpath) """ - if table.default_to_unpack: + if self.default_to_unpack: # Hack for tables explicilty specified with a single column proptup_gen = (None if data is None else (data,) for data in proptup_gen) # Flatten nested columns - if any(table.get_data_col_attr('isnested')): - proptup_gen = table._prepare_storage_nested(proptup_gen) + if any(self.get_data_col_attr('isnested')): + proptup_gen = self._prepare_storage_nested(proptup_gen) # Write external columns - if any(table.get_data_col_attr('write_func')): - proptup_gen = table._prepare_storage_extern( + if any(self.get_data_col_attr('write_func')): + proptup_gen = self._prepare_storage_extern( dirty_parent_ids, config_rowid, config, proptup_gen ) - if table.ismulti: - manifest_dpath = table.dpath + if self.ismulti: + manifest_dpath = self.dpath ut.ensuredir(manifest_dpath) # Concatenate data with internal rowids / config-id for ids_, data_cols, args_ in zip( @@ -1377,19 +1371,19 @@ def prepare_storage( if data_cols is None: yield None else: - multi_parent_flags = table.get_parent_col_attr('ismulti') - parent_colnames = table.get_parent_col_attr('intern_colname') + multi_parent_flags = self.get_parent_col_attr('ismulti') + parent_colnames = self.get_parent_col_attr('intern_colname') multi_id_names = ut.compress(parent_colnames, multi_parent_flags) multi_ids = ut.compress(ids_, multi_parent_flags) multi_args = ut.compress(args_, multi_parent_flags) - if table.ismulti: + if self.ismulti: multi_setsizes = [] manifest_data = {} for multi_id, arg_, name in zip( multi_ids, multi_args, multi_id_names ): - assert table.ismulti, 'only valid for models' + assert self.ismulti, 'only valid for models' # TODO: need to get back to root ids manifest_data.update( **{ @@ -1407,7 +1401,7 @@ def prepare_storage( manifest_data['model_uuid'] = model_uuid manifest_data['augmented'] = False - manifest_fpath = table.get_model_manifest_fpath(model_uuid) + manifest_fpath = self.get_model_manifest_fpath(model_uuid) ut.save_json(manifest_fpath, manifest_data, pretty=1) # TODO: hash all input UUIDs and the full config together @@ -1436,37 +1430,37 @@ def prepare_storage( ) raise - def get_model_manifest_fname(table, model_uuid): + def get_model_manifest_fname(self, model_uuid): manifest_fname = 'input_manifest_%s.json' % (model_uuid,) return manifest_fname - def get_model_manifest_fpath(table, model_uuid): - manifest_fname = table.get_model_manifest_fname(model_uuid) - manifest_fpath = join(table.dpath, manifest_fname) + def get_model_manifest_fpath(self, model_uuid): + manifest_fname = self.get_model_manifest_fname(model_uuid) + manifest_fpath = join(self.dpath, manifest_fname) return manifest_fpath - def get_model_inputs(table, model_uuid): + def get_model_inputs(self, model_uuid): """ Ignore: >>> table.get_model_uuid([2]) [UUID('5b66772c-e654-dd9a-c9de-0ccc1bb6861c')] """ - assert table.ismulti, 'must be a model' - manifest_fpath = table.get_model_manifest_fpath(model_uuid) + assert self.ismulti, 'must be a model' + manifest_fpath = self.get_model_manifest_fpath(model_uuid) manifest_data = ut.load_json(manifest_fpath) return manifest_data - def get_model_uuid(table, rowids): + def get_model_uuid(self, rowids): """ Ignore: >>> table.get_model_uuid([2]) [UUID('5b66772c-e654-dd9a-c9de-0ccc1bb6861c')] """ - assert table.ismulti, 'must be a model' - model_uuid_list = table.get_internal_columns(rowids, ('model_uuid',)) + assert self.ismulti, 'must be a model' + model_uuid_list = self.get_internal_columns(rowids, ('model_uuid',)) return model_uuid_list - def get_model_rowids(table, model_uuid_list): + def get_model_rowids(self, model_uuid_list): """ Get the rowid of a model given its uuid @@ -1475,12 +1469,12 @@ def get_model_rowids(table, model_uuid_list): >>> table.get_model_rowids([uuid.UUID('5b66772c-e654-dd9a-c9de-0ccc1bb6861c')]) [2] """ - assert table.ismulti, 'must be a model' - colnames = (table.rowid_colname,) - andwhere_colnames = (table.model_uuid_colname,) + assert self.ismulti, 'must be a model' + colnames = (self.rowid_colname,) + andwhere_colnames = (self.model_uuid_colname,) params_iter = list(zip(model_uuid_list)) - rowid_list = table.db.get_where_eq( - table.tablename, + rowid_list = self.db.get_where_eq( + self.tablename, colnames, params_iter, andwhere_colnames, @@ -1490,13 +1484,13 @@ def get_model_rowids(table, model_uuid_list): return rowid_list @profile - def _prepare_storage_nested(table, proptup_gen): + def _prepare_storage_nested(self, proptup_gen): """ Hack for when a sql schema has tuples defined in it. Accepts nested tuples and flattens them to fit into the sql tables """ - nCols = len(table.data_colnames) - idxs1 = ut.where(table.get_data_col_attr('isnested')) + nCols = len(self.data_colnames) + idxs1 = ut.where(self.get_data_col_attr('isnested')) idxs2 = ut.index_complement(idxs1, nCols) for data in proptup_gen: if data is None: @@ -1515,12 +1509,12 @@ def _prepare_storage_nested(table, proptup_gen): # @profile def _prepare_storage_extern( - table, dirty_parent_ids, config_rowid, config, proptup_gen + self, dirty_parent_ids, config_rowid, config, proptup_gen ): """ Writes external data to disk if write function is specified. """ - internal_data_col_attrs = table.internal_data_col_attrs + internal_data_col_attrs = self.internal_data_col_attrs writable_flags = ut.dict_take_column(internal_data_col_attrs, 'write_func', False) extern_colattrs = ut.compress(internal_data_col_attrs, writable_flags) # extern_colnames = ut.dict_take_column(extern_colattrs, 'colname') @@ -1532,7 +1526,7 @@ def _prepare_storage_extern( extern_fnames_list = list( zip( *[ - table._get_extern_fnames( + self._get_extern_fnames( dirty_parent_ids, config_rowid, config, extern_colattr ) for extern_colattr in extern_colattrs @@ -1540,7 +1534,7 @@ def _prepare_storage_extern( ) ) # get extern cache directory and fpaths - extern_dpath = table.extern_dpath + extern_dpath = self.extern_dpath ut.ensuredir(extern_dpath) # extern_fpaths_list = [ # [join(extern_dpath, fname) for fname in fnames] @@ -1574,7 +1568,7 @@ def _prepare_storage_extern( data_new = tuple(ut.ungroup(grouped_items, groupxs, nCols - 1)) yield data_new - def get_extern_fnames(table, parent_rowids, config, extern_col_index=0): + def get_extern_fnames(self, parent_rowids, config, extern_col_index=0): """ convinience function around get_extern_fnames @@ -1594,25 +1588,25 @@ def get_extern_fnames(table, parent_rowids, config, extern_col_index=0): >>> fname_list = table.get_extern_fnames(parent_rowids, config) >>> print('fname_list = %r' % (fname_list,)) """ - config_rowid = table.get_config_rowid(config) + config_rowid = self.get_config_rowid(config) # depc.get_rowids(tablename, root_rowids, config) - internal_data_col_attrs = table.internal_data_col_attrs + internal_data_col_attrs = self.internal_data_col_attrs writable_flags = ut.dict_take_column(internal_data_col_attrs, 'write_func', False) extern_colattrs = ut.compress(internal_data_col_attrs, writable_flags) extern_colattr = extern_colattrs[extern_col_index] - fname_list = table._get_extern_fnames( + fname_list = self._get_extern_fnames( parent_rowids, config_rowid, config, extern_colattr ) # if False: - # root_rowids = table.depc.get_root_rowids(table.tablename, rowid_list) + # root_rowids = self.depc.get_root_rowids(self.tablename, rowid_list) # info_props = ['image_uuid', 'verts', 'theta'] - # table.depc.make_root_info_uuid(root_rowids, info_props) + # self.depc.make_root_info_uuid(root_rowids, info_props) return fname_list def _get_extern_fnames( - table, parent_rowids, config_rowid, config, extern_colattr=None + self, parent_rowids, config_rowid, config, extern_colattr=None ): """ TODO: @@ -1624,17 +1618,17 @@ def _get_extern_fnames( Args: parent_rowids (list of tuples) - list of tuples of rowids """ - config_hashid = table.get_config_hashid([config_rowid])[0] - prefix = table.tablename + config_hashid = self.get_config_hashid([config_rowid])[0] + prefix = self.tablename prefix += '_' + extern_colattr['colname'] - colattrs = table.data_col_attrs[extern_colattr['data_colx']] + colattrs = self.data_col_attrs[extern_colattr['data_colx']] # if colname is not None: # prefix += '_' + colname # TODO: Put relevant root properties into the hash of the filename # (like bbox, parent image. basically the general vuuid and suuid. fmtstr = '{prefix}_id={rowids}_{config_hashid}{ext}' # HACK: check if the config specifies the extension type - # extkey = table.extern_ext_config_keys.get(colname, 'ext') + # extkey = self.extern_ext_config_keys.get(colname, 'ext') if 'extern_ext' in colattrs: ext = colattrs['extern_ext'] else: @@ -1652,7 +1646,7 @@ def _get_extern_fnames( return fname_list def _compute_dirty_rows( - table, dirty_parent_ids, dirty_preproc_args, config_rowid, config, verbose=True + self, dirty_parent_ids, dirty_preproc_args, config_rowid, config, verbose=True ): """ dirty_preproc_args = preproc_args @@ -1671,15 +1665,15 @@ def _compute_dirty_rows( config_ = config.config if hasattr(config, 'config') else config # call registered worker function - if table.vectorized: + if self.vectorized: # Function is written in a way that only accepts multiple inputs at # once and generates output - proptup_gen = table.preproc_func(table.depc, *argsT, config=config_) + proptup_gen = self.preproc_func(self.depc, *argsT, config=config_) else: # Function is written in a way that only accepts a single row of # input at a time proptup_gen = ( - table.preproc_func(table.depc, *argrow, config=config_) + self.preproc_func(self.depc, *argrow, config=config_) for argrow in zip(*argsT) ) @@ -1694,7 +1688,7 @@ def _compute_dirty_rows( nInput, ) # Append rowids and rectify nested and external columns - dirty_params_iter = table.prepare_storage( + dirty_params_iter = self.prepare_storage( dirty_parent_ids, proptup_gen, dirty_preproc_args, config_rowid, config_ ) if DEBUG_LIST_MODE: @@ -1704,7 +1698,7 @@ def _compute_dirty_rows( return dirty_params_iter def _chunk_compute_dirty_rows( - table, dirty_parent_ids, dirty_preproc_args, config_rowid, config, verbose=True + self, dirty_parent_ids, dirty_preproc_args, config_rowid, config, verbose=True ): """ Executes registered functions, does external storage and yeilds results @@ -1724,11 +1718,11 @@ def _chunk_compute_dirty_rows( >>> depc.print_all_tables() """ nInput = len(dirty_parent_ids) - chunksize = nInput if table.chunksize is None else table.chunksize + chunksize = nInput if self.chunksize is None else self.chunksize logger.info( '[deptbl.compute] nInput={}, chunksize={}, tbl={}'.format( - nInput, table.chunksize, table.tablename + nInput, self.chunksize, self.tablename ) ) @@ -1738,20 +1732,20 @@ def _chunk_compute_dirty_rows( dirty_iter, chunksize, nInput, - lbl='[deptbl.compute] add %s chunk' % (table.tablename), + lbl='[deptbl.compute] add %s chunk' % (self.tablename), ) # These are the colnames that we expect to be computed - colnames = table.computable_colnames() + colnames = self.computable_colnames() # def unfinished_features(): - # if table._asobject: + # if self._asobject: # # Convinience - # argsT = [table.depc.get_obj(parent, rowids) - # for parent, rowids in zip(table.parents(), + # argsT = [self.depc.get_obj(parent, rowids) + # for parent, rowids in zip(self.parents(), # dirty_parent_ids_chunk)] # onthefly = None - # if table.default_onthefly or onthefly: - # assert not table.ismulti, ('cannot onthefly multi tables') - # proptup_gen = [tuple([None] * len(table.data_col_attrs)) + # if self.default_onthefly or onthefly: + # assert not self.ismulti, ('cannot onthefly multi tables') + # proptup_gen = [tuple([None] * len(self.data_col_attrs)) # for _ in range(len(dirty_parent_ids_chunk))] # pass # CALL EXTERNAL PREPROCESSING / GENERATION FUNCTION @@ -1763,7 +1757,7 @@ def _chunk_compute_dirty_rows( return dirty_parent_ids_chunk, dirty_preproc_args_chunk = zip(*dirty_chunk) - dirty_params_iter = table._compute_dirty_rows( + dirty_params_iter = self._compute_dirty_rows( dirty_parent_ids_chunk, dirty_preproc_args_chunk, config_rowid, @@ -1785,12 +1779,12 @@ def _chunk_compute_dirty_rows( 'error in add_rowids', keys=[ 'table', - 'table.parents()', + 'self.parents()', 'config', 'argsT', 'config_rowid', 'dirty_parent_ids', - 'table.preproc_func', + 'self.preproc_func', ], tb=True, ) @@ -1845,7 +1839,7 @@ class DependencyCacheTable( @profile def __init__( - table, + self, depc=None, parent_tablenames=None, tablename=None, @@ -1867,92 +1861,92 @@ def __init__( recieves kwargs from depc._register_prop """ try: - table.db = None + self.db = None except Exception: # HACK: jedi type hinting. Need to have non-obvious condition - table.db = SQLDatabaseController() - table.fpath_to_db = {} + self.db = SQLDatabaseController() + self.fpath_to_db = {} assert ( re.search('[0-9]', tablename) is None ), 'tablename=%r cannot contain numbers' % (tablename,) # parent depcache - table.depc = depc + self.depc = depc # Definitions - table.tablename = tablename - table.docstr = docstr - table.parent_tablenames = parent_tablenames - table.data_colnames = tuple(data_colnames) - table.data_coltypes = data_coltypes - table.preproc_func = preproc_func - table.fname = fname + self.tablename = tablename + self.docstr = docstr + self.parent_tablenames = parent_tablenames + self.data_colnames = tuple(data_colnames) + self.data_coltypes = data_coltypes + self.preproc_func = preproc_func + self.fname = fname # Behavior - table.on_delete = None - table.default_to_unpack = default_to_unpack - table.vectorized = vectorized - table.taggable = taggable + self.on_delete = None + self.default_to_unpack = default_to_unpack + self.vectorized = vectorized + self.taggable = taggable - # table.store_modification_time = True + # self.store_modification_time = True # Use the filesystem to accomplish this - # table.store_access_time = True - # table.store_create_time = True - # table.store_delete_time = True + # self.store_access_time = True + # self.store_create_time = True + # self.store_delete_time = True - table.chunksize = chunksize + self.chunksize = chunksize # Developmental properties - table.subproperties = {} - table.isinteractive = isinteractive - table._asobject = asobject - table.default_onthefly = default_onthefly + self.subproperties = {} + self.isinteractive = isinteractive + self._asobject = asobject + self.default_onthefly = default_onthefly # SQL Internals - table.sqldb_fpath = None - table.rm_extern_on_delete = rm_extern_on_delete + self.sqldb_fpath = None + self.rm_extern_on_delete = rm_extern_on_delete # Update internals - table.parent_col_attrs = table._infer_parentcol() - table.data_col_attrs = table._infer_datacol() - table.internal_col_attrs = table._infer_allcol() + self.parent_col_attrs = self._infer_parentcol() + self.data_col_attrs = self._infer_datacol() + self.internal_col_attrs = self._infer_allcol() # Check for errors if ut.SUPER_STRICT: - table._assert_self() + self._assert_self() - table._hack_chunk_cache = None + self._hack_chunk_cache = None # @profile - def initialize(table, _debug=None): + def initialize(self, _debug=None): """ Ensures the SQL schema for this cache table """ - table.db = table.depc.fname_to_db[table.fname] - # logger.info('Checking sql for table=%r' % (table.tablename,)) - if not table.db.has_table(table.tablename): - logger.debug('Initializing table=%r' % (table.tablename,)) - new_state = table._get_addtable_kw() - table.db.add_table(**new_state) + self.db = self.depc.fname_to_db[self.fname] + # logger.info('Checking sql for table=%r' % (self.tablename,)) + if not self.db.has_table(self.tablename): + logger.debug('Initializing table=%r' % (self.tablename,)) + new_state = self._get_addtable_kw() + self.db.add_table(**new_state) else: # TODO: Check for table modifications - new_state = table._get_addtable_kw() + new_state = self._get_addtable_kw() try: - current_state = table.db.get_table_autogen_dict(table.tablename) + current_state = self.db.get_table_autogen_dict(self.tablename) except Exception as ex: strict = True ut.printex( ex, - 'TABLE %s IS CORRUPTED' % (table.tablename,), + 'TABLE %s IS CORRUPTED' % (self.tablename,), iswarning=not strict, ) if strict: raise - table.clear_table() - current_state = table.db.get_table_autogen_dict(table.tablename) + self.clear_table() + current_state = self.db.get_table_autogen_dict(self.tablename) if current_state['coldef_list'] != new_state['coldef_list']: logger.info('WARNING TABLE IS MODIFIED') - if predrop_grace_period(table.tablename): - table.clear_table() + if predrop_grace_period(self.tablename): + self.clear_table() else: raise NotImplementedError('Need to be able to modify tables') - def _get_addtable_kw(table): + def _get_addtable_kw(self): """ Information that defines the SQL table @@ -1975,16 +1969,16 @@ def _get_addtable_kw(table): """ coldef_list = [ (colattr['intern_colname'], colattr['sqltype']) - for colattr in table.internal_col_attrs + for colattr in self.internal_col_attrs ] - superkeys = [table.superkey_colnames] + superkeys = [self.superkey_colnames] add_table_kw = ut.odict( [ - ('tablename', table.tablename), + ('tablename', self.tablename), ('coldef_list', coldef_list), - ('docstr', table.docstr), + ('docstr', self.docstr), ('superkeys', superkeys), - ('dependson', table.parents()), + ('dependson', self.parents()), ] ) return add_table_kw @@ -1993,16 +1987,16 @@ def _get_addtable_kw(table): # --- GETTERS NATIVE --- # ---------------------- - def _get_all_rowids(table): - return table.db.get_all_rowids(table.tablename) + def _get_all_rowids(self): + return self.db.get_all_rowids(self.tablename) @property - def number_of_rows(table): - return table.db.get_row_count(table.tablename) + def number_of_rows(self): + return self.db.get_row_count(self.tablename) # @profile def ensure_rows( - table, + self, parent_ids_, preproc_args, config=None, @@ -2034,10 +2028,10 @@ def ensure_rows( """ try: # Get requested configuration id - config_rowid = table.get_config_rowid(config) + config_rowid = self.get_config_rowid(config) # Check which rows are already computed - initial_rowid_list = table._get_rowid(parent_ids_, config=config) + initial_rowid_list = self._get_rowid(parent_ids_, config=config) initial_rowid_list = list(initial_rowid_list) logger.debug( @@ -2053,7 +2047,7 @@ def ensure_rows( if num_dirty > 0: logger.debug( - 'Add %d / %d new rows to %r' % (num_dirty, num_total, table.tablename) + 'Add %d / %d new rows to %r' % (num_dirty, num_total, self.tablename) ) logger.debug( '[deptbl.add] * config_rowid = {}, config={}'.format( @@ -2072,10 +2066,10 @@ def ensure_rows( # Break iterator into chunks if False and verbose: # check parent configs we are working with - for x, parname in enumerate(table.parents()): - if parname == table.depc.root: + for x, parname in enumerate(self.parents()): + if parname == self.depc.root: continue - parent_table = table.depc[parname] + parent_table = self.depc[parname] ut.take_column(parent_ids_, x) rowid_list = ut.take_column(parent_ids_, x) try: @@ -2087,28 +2081,28 @@ def ensure_rows( ) # Gives the function a hacky cache to use between chunks - table._hack_chunk_cache = {} - gen = table._chunk_compute_dirty_rows( + self._hack_chunk_cache = {} + gen = self._chunk_compute_dirty_rows( dirty_parent_ids, dirty_preproc_args, config_rowid, config ) """ colnames, dirty_params_iter, nChunkInput = next(gen) """ for colnames, dirty_params_iter, nChunkInput in gen: - table.db._add( - table.tablename, + self.db._add( + self.tablename, colnames, dirty_params_iter, nInput=nChunkInput, ) # Remove cache when main add is done - table._hack_chunk_cache = None + self._hack_chunk_cache = None logger.debug('[deptbl.add] finished add') # # The requested data is clean and must now exist in the parent # database, do a lookup to ensure the correct order. - rowid_list = table._get_rowid(parent_ids_, config=config) + rowid_list = self._get_rowid(parent_ids_, config=config) else: rowid_list = initial_rowid_list logger.debug('[deptbl.add] rowid_list = %s' % ut.trunc_repr(rowid_list)) @@ -2118,14 +2112,14 @@ def ensure_rows( logger.error( 'DEPC ENSURE_ROWS FOR TABLE %r FAILED DUE TO INTEGRITY ERROR (RETRY %d)!' - % (table, retry) + % (self, retry) ) retry_delay = random.uniform(retry_delay_min, retry_delay_max) logger.error('\t WAITING %0.02f SECONDS THEN RETRYING' % (retry_delay,)) time.sleep(retry_delay) retry_ = retry - 1 - rowid_list = table.ensure_rows( + rowid_list = self.ensure_rows( parent_ids_, preproc_args, config=config, @@ -2137,7 +2131,7 @@ def ensure_rows( return rowid_list - def _rectify_ids(table, parent_rowids): + def _rectify_ids(self, parent_rowids): r""" Filters any rows containing None ids and transforms many-to-one sets of rowids into hashable UUIDS. @@ -2183,9 +2177,9 @@ def _rectify_ids(table, parent_rowids): valid_parent_ids_ = ut.take(parent_rowids, idxs1) preproc_args = valid_parent_ids_ - if table.ismulti: + if self.ismulti: # Convert any parent-id containing multiple values into a hash of uuids - multi_parent_flags = table.get_parent_col_attr('ismulti') + multi_parent_flags = self.get_parent_col_attr('ismulti') num_parents = len(multi_parent_flags) multi_parent_colxs = ut.where(multi_parent_flags) normal_colxs = ut.index_complement(multi_parent_colxs, num_parents) @@ -2197,9 +2191,9 @@ def _rectify_ids(table, parent_rowids): ] # TODO: give each table a uuid getter function that derives from # get_root_uuids - multicol_tables = ut.take(table.parents(), multi_parent_colxs) + multicol_tables = ut.take(self.parents(), multi_parent_colxs) parent_uuid_getters = [ - table.depc.get_root_uuid if col == table.depc.root else ut.identity + self.depc.get_root_uuid if col == self.depc.root else ut.identity for col in multicol_tables ] @@ -2231,7 +2225,7 @@ def _rectify_ids(table, parent_rowids): rectify_tup = parent_ids_, preproc_args, idxs1, idxs2 return rectify_tup - def _unrectify_ids(table, rowid_list_, parent_rowids, idxs1, idxs2): + def _unrectify_ids(self, rowid_list_, parent_rowids, idxs1, idxs2): """ Ensures that output is the same length as input. Inserts necessary Nones where the original input was also None. @@ -2241,7 +2235,7 @@ def _unrectify_ids(table, rowid_list_, parent_rowids, idxs1, idxs2): return rowid_list def get_rowid( - table, + self, parent_rowids, config=None, ensure=True, @@ -2287,20 +2281,20 @@ def get_rowid( """ logger.debug( '[deptbl.get_rowid] Get %s rowids via %d parent superkeys' - % (table.tablename, len(parent_rowids)) + % (self.tablename, len(parent_rowids)) ) logger.debug('[deptbl.get_rowid] config = %r' % (config,)) logger.debug('[deptbl.get_rowid] ensure = %r' % (ensure,)) # Ensure inputs are in the correct format / remove Nones # Collapse multi-inputs into a UUID hash - rectify_tup = table._rectify_ids(parent_rowids) + rectify_tup = self._rectify_ids(parent_rowids) (parent_ids_, preproc_args, idxs1, idxs2) = rectify_tup # Do the getting / adding work if recompute: logger.info('REQUESTED RECOMPUTE') # get existing rowids, delete them, recompute the request - rowid_list_ = table._get_rowid( + rowid_list_ = self._get_rowid( parent_ids_, config=config, eager=True, @@ -2309,16 +2303,16 @@ def get_rowid( rowid_list_ = list(rowid_list_) needs_recompute_rowids = ut.filter_Nones(rowid_list_) try: - table._recompute_and_store(needs_recompute_rowids) + self._recompute_and_store(needs_recompute_rowids) except Exception: # If the config changes, there is nothing we can do. # We have to delete the rows. - table.delete_rows(rowid_list_) + self.delete_rows(rowid_list_) if ensure or recompute: # Compute properties if they do not exist for try_num in range(num_retries): try: - rowid_list_ = table.ensure_rows( + rowid_list_ = self.ensure_rows( parent_ids_, preproc_args, config=config, @@ -2327,38 +2321,38 @@ def get_rowid( if try_num == num_retries - 1: raise else: - rowid_list_ = table._get_rowid( + rowid_list_ = self._get_rowid( parent_ids_, config=config, eager=eager, nInput=nInput, ) # Map outputs to correspond with inputs - rowid_list = table._unrectify_ids(rowid_list_, parent_rowids, idxs1, idxs2) + rowid_list = self._unrectify_ids(rowid_list_, parent_rowids, idxs1, idxs2) return rowid_list # @profile - def _get_rowid(table, parent_ids_, config=None, eager=True, nInput=None): + def _get_rowid(self, parent_ids_, config=None, eager=True, nInput=None): """ Returns rowids using parent superkeys. Does not add non-existing properties. """ - colnames = (table.rowid_colname,) - config_rowid = table.get_config_rowid(config=config) + colnames = (self.rowid_colname,) + config_rowid = self.get_config_rowid(config=config) logger.debug('_get_rowid') - logger.debug('_get_rowid table.tablename = %r ' % (table.tablename,)) + logger.debug('_get_rowid self.tablename = %r ' % (self.tablename,)) logger.debug('_get_rowid parent_ids_ = %s' % (ut.trunc_repr(parent_ids_))) logger.debug('_get_rowid config = %s' % (config)) - logger.debug('_get_rowid table.rowid_colname = %s' % (table.rowid_colname)) + logger.debug('_get_rowid self.rowid_colname = %s' % (self.rowid_colname)) logger.debug('_get_rowid config_rowid = %s' % (config_rowid)) - andwhere_colnames = table.superkey_colnames + andwhere_colnames = self.superkey_colnames params_iter = (ids_ + (config_rowid,) for ids_ in parent_ids_) # TODO: make sure things that call this can accept a generator # Then remove this next line params_iter = list(params_iter) # logger.info('**params_iter = %r' % (params_iter,)) - rowid_list = table.db.get_where_eq( - table.tablename, + rowid_list = self.db.get_where_eq( + self.tablename, colnames, params_iter, andwhere_colnames, @@ -2368,17 +2362,17 @@ def _get_rowid(table, parent_ids_, config=None, eager=True, nInput=None): logger.debug('_get_rowid rowid_list = %s' % (ut.trunc_repr(rowid_list))) return rowid_list - def clear_table(table): + def clear_table(self): """ Deletes all data in this table """ # TODO: need to clear one-to-one dependencies as well - logger.info('Clearing data in %r' % (table,)) - table.db.drop_table(table.tablename) - table.db.add_table(**table._get_addtable_kw()) + logger.info('Clearing data in %r' % (self,)) + self.db.drop_table(self.tablename) + self.db.add_table(**self._get_addtable_kw()) # @profile - def delete_rows(table, rowid_list, delete_extern=None, dry=False, verbose=None): + def delete_rows(self, rowid_list, delete_extern=None, dry=False, verbose=None): """ CommandLine: python -m dtool.depcache_table --exec-delete_rows @@ -2414,30 +2408,30 @@ def delete_rows(table, rowid_list, delete_extern=None, dry=False, verbose=None): """ # import networkx as nx # from wbia.dtool.algo.preproc import preproc_feat - if table.on_delete is not None and not dry: - table.on_delete() + if self.on_delete is not None and not dry: + self.on_delete() if delete_extern is None: - delete_extern = table.rm_extern_on_delete + delete_extern = self.rm_extern_on_delete if verbose is None: verbose = False if ut.NOT_QUIET: if ut.VERBOSE: logger.info( 'Requested delete of %d rows from %s' - % (len(rowid_list), table.tablename) + % (len(rowid_list), self.tablename) ) if dry: logger.info('Dry run') # logger.info('delete_extern = %r' % (delete_extern,)) - depc = table.depc + depc = self.depc # TODO: # REMOVE EXTERNAL FILES - internal_colnames = table.get_intern_data_col_attr('intern_colname') - is_extern = table.get_intern_data_col_attr('is_external_pointer') + internal_colnames = self.get_intern_data_col_attr('intern_colname') + is_extern = self.get_intern_data_col_attr('is_external_pointer') extern_colnames = tuple(ut.compress(internal_colnames, is_extern)) if len(extern_colnames) > 0: - uris = table.get_internal_columns( + uris = self.get_internal_columns( rowid_list, extern_colnames, unpack_scalars=False, @@ -2449,7 +2443,7 @@ def delete_rows(table, rowid_list, delete_extern=None, dry=False, verbose=None): if not isinstance(uri, tuple): uri = [uri] for uri_ in uri: - absuris.append(join(table.extern_dpath, uri_)) + absuris.append(join(self.extern_dpath, uri_)) fpaths = [fpath for fpath in absuris if exists(fpath)] if delete_extern: if ut.VERBOSE or len(fpaths) > 0: @@ -2482,17 +2476,17 @@ def get_child_partial_rowids(child_table, rowid_list, parent_colnames): return child_rowids if ut.VERBOSE: - if table.children: - logger.info('Deleting from %r children' % (len(table.children),)) + if self.children: + logger.info('Deleting from %r children' % (len(self.children),)) else: logger.info('Table is a leaf node') - for child in table.children: - child_table = table.depc[child] + for child in self.children: + child_table = self.depc[child] if not child_table.ismulti: # Hack, wont work for vsone / multisets parent_colnames = ( - child_table.parent[table.tablename]['intern_colname'], + child_table.parent[self.tablename]['intern_colname'], ) child_rowids = get_child_partial_rowids( child_table, rowid_list, parent_colnames @@ -2504,24 +2498,24 @@ def get_child_partial_rowids(child_table, rowid_list, parent_colnames): if ut.VERBOSE or len(non_none_rowids) > 0: logger.info( 'Deleting %d non-None rows from %s' - % (len(non_none_rowids), table.tablename) + % (len(non_none_rowids), self.tablename) ) logger.info('...done!') # Finalize: Delete rows from this table if not dry: - table.db.delete_rowids(table.tablename, rowid_list) + self.db.delete_rowids(self.tablename, rowid_list) num_deleted = len(ut.filter_Nones(rowid_list)) else: num_deleted = 0 return num_deleted - def _resolve_requested_columns(table, requested_colnames): + def _resolve_requested_columns(self, requested_colnames): ######## # Map requested colnames flat to internal colnames ######## # Get requested column information - requestable_col_attrs = table.requestable_col_attrs() + requestable_col_attrs = self.requestable_col_attrs() requested_colattrs = ut.take(requestable_col_attrs, requested_colnames) # Make column indicies iterable for grouping intern_colxs = [ @@ -2535,7 +2529,7 @@ def _resolve_requested_columns(table, requested_colnames): extern_colattrs = ut.compress(requested_colattrs, isextern_flags) extern_resolve_colxs = ut.compress(nested_offsets_start, isextern_flags) extern_read_funcs = ut.take_column(extern_colattrs, 'read_func') - intern_colnames_ = ut.take_column(table.internal_col_attrs, 'intern_colname') + intern_colnames_ = ut.take_column(self.internal_col_attrs, 'intern_colname') intern_colnames = ut.unflat_take(intern_colnames_, intern_colxs) # TODO: this can be cleaned up @@ -2549,7 +2543,7 @@ def _resolve_requested_columns(table, requested_colnames): # @profile def get_row_data( - table, + self, tbl_rowids, colnames=None, _debug=None, @@ -2620,14 +2614,14 @@ def get_row_data( """ logger.debug( ('Get col of tablename=%r, colnames=%r with ' 'tbl_rowids=%s') - % (table.tablename, colnames, ut.trunc_repr(tbl_rowids)) + % (self.tablename, colnames, ut.trunc_repr(tbl_rowids)) ) #### # Resolve requested column names if unpack_columns is None: - unpack_columns = table.default_to_unpack + unpack_columns = self.default_to_unpack if colnames is None: - requested_colnames = table.data_colnames + requested_colnames = self.data_colnames elif isinstance(colnames, six.string_types): # Unpack columns if only a single column is requested. requested_colnames = (colnames,) @@ -2636,7 +2630,7 @@ def get_row_data( requested_colnames = colnames logger.debug('requested_colnames = %r' % (requested_colnames,)) - tup = table._resolve_requested_columns(requested_colnames) + tup = self._resolve_requested_columns(requested_colnames) nesting_xs, extern_resolve_tups, flat_intern_colnames = tup logger.debug( @@ -2652,15 +2646,15 @@ def get_row_data( #### # Read data stored in SQL # FIXME: understand unpack_scalars and keepwrap - # if table.default_onthefly: - # table._onthefly_dataget + # if self.default_onthefly: + # self._onthefly_dataget # else: if nInput is None and ut.is_listlike(nonNone_tbl_rowids): nInput = len(nonNone_tbl_rowids) generator_version = not eager - raw_prop_list = table.get_internal_columns( + raw_prop_list = self.get_internal_columns( nonNone_tbl_rowids, flat_intern_colnames, eager=eager, @@ -2695,7 +2689,7 @@ def tuptake(list_, index_list): if generator_version: def _generator_resolve_all(): - extern_dpath = table.extern_dpath + extern_dpath = self.extern_dpath for rawprop in raw_prop_list: if rawprop is None: raise Exception( @@ -2731,7 +2725,7 @@ def _generator_resolve_all(): for try_num in range(num_retries + 1): tries_left = num_retries - try_num try: - prop_listT = table._resolve_any_external_data( + prop_listT = self._resolve_any_external_data( nonNone_tbl_rowids, raw_prop_list, extern_resolve_tups, @@ -2769,7 +2763,7 @@ def _generator_resolve_all(): return prop_list def _resolve_any_external_data( - table, + self, nonNone_tbl_rowids, raw_prop_list, extern_resolve_tups, @@ -2780,7 +2774,7 @@ def _resolve_any_external_data( ): #### # Read data specified by any external columns - extern_dpath = table.extern_dpath + extern_dpath = self.extern_dpath try: prop_listT = list(zip(*raw_prop_list)) except TypeError as ex: @@ -2830,8 +2824,8 @@ def _resolve_any_external_data( ) failed_rowids = ut.compress(nonNone_tbl_rowids, failed_list) if delete_on_fail: - table._recompute_external_storage(failed_rowids) - # table.delete_rows(failed_rowids, delete_extern=None) + self._recompute_external_storage(failed_rowids) + # self.delete_rows(failed_rowids, delete_extern=None) raise ExternalStorageException( 'Some cached filenames failed to read. ' 'Need to recompute %d/%d rows' % (sum(failed_list), len(failed_list)) @@ -2840,7 +2834,7 @@ def _resolve_any_external_data( prop_listT[extern_colx] = data_list return prop_listT - def _recompute_external_storage(table, tbl_rowids): + def _recompute_external_storage(self, tbl_rowids): """ Recomputes the external file stored for this row. This DOES NOT modify the depcache internals. @@ -2848,26 +2842,26 @@ def _recompute_external_storage(table, tbl_rowids): logger.info('Recomputing external data (_recompute_external_storage)') # TODO: need to rectify parent ids? - parent_rowids = table.get_parent_rowids(tbl_rowids) - parent_rowargs = table.get_parent_rowargs(tbl_rowids) + parent_rowids = self.get_parent_rowids(tbl_rowids) + parent_rowargs = self.get_parent_rowargs(tbl_rowids) - # configs = table.get_row_configs(tbl_rowids) + # configs = self.get_row_configs(tbl_rowids) # assert ut.allsame(list(map(id, configs))), 'more than one config not yet supported' # TODO; groupby config - config_rowids = table.get_row_cfgid(tbl_rowids) + config_rowids = self.get_row_cfgid(tbl_rowids) unique_cfgids, groupxs = ut.group_indices(config_rowids) for xs, cfgid in zip(groupxs, unique_cfgids): parent_ids = ut.take(parent_rowids, xs) parent_args = ut.take(parent_rowargs, xs) - config = table.get_config_from_rowid([cfgid])[0] - dirty_params_iter = table._compute_dirty_rows( + config = self.get_config_from_rowid([cfgid])[0] + dirty_params_iter = self._compute_dirty_rows( parent_ids, parent_args, config_rowid=cfgid, config=config ) # Evaulate just to ensure storage ut.evaluate_generator(dirty_params_iter) - def _recompute_and_store(table, tbl_rowids, config=None): + def _recompute_and_store(self, tbl_rowids, config=None): """ Recomputes all data stored for this row. This DOES modify the depcache internals. @@ -2875,31 +2869,31 @@ def _recompute_and_store(table, tbl_rowids, config=None): logger.info('Recomputing external data (_recompute_and_store)') if len(tbl_rowids) == 0: return - parent_rowids = table.get_parent_rowids(tbl_rowids) - parent_rowargs = table.get_parent_rowargs(tbl_rowids) - # configs = table.get_row_configs(tbl_rowids) + parent_rowids = self.get_parent_rowids(tbl_rowids) + parent_rowargs = self.get_parent_rowargs(tbl_rowids) + # configs = self.get_row_configs(tbl_rowids) # assert ut.allsame(list(map(id, configs))), 'more than one config not yet supported' # TODO; groupby config if config is None: - config_rowids = table.get_row_cfgid(tbl_rowids) + config_rowids = self.get_row_cfgid(tbl_rowids) unique_cfgids, groupxs = ut.group_indices(config_rowids) else: # This is incredibly hacky. pass - colnames = table.computable_colnames() + colnames = self.computable_colnames() for xs, cfgid in zip(groupxs, unique_cfgids): parent_ids = ut.take(parent_rowids, xs) parent_args = ut.take(parent_rowargs, xs) rowids = ut.take(tbl_rowids, xs) - config = table.get_config_from_rowid([cfgid])[0] - dirty_params_iter = table._compute_dirty_rows( + config = self.get_config_from_rowid([cfgid])[0] + dirty_params_iter = self._compute_dirty_rows( parent_ids, parent_args, config_rowid=cfgid, config=config ) # Evaulate to external and internal storage - table.db.set(table.tablename, colnames, dirty_params_iter, rowids) + self.db.set(self.tablename, colnames, dirty_params_iter, rowids) # _onthefly_dataget # togroup_args = [parent_rowids] @@ -2907,11 +2901,11 @@ def _recompute_and_store(table, tbl_rowids, config=None): # unique_args_list = [unique_configs] # raw_prop_lists = [] - # # func = ut.partial(table.preproc_func, table.depc) + # # func = ut.partial(self.preproc_func, self.depc) # def groupmap_func(group_args, unique_args): # config_ = unique_args[0] # argsT = group_args - # propgen = table.preproc_func(table.depc, *argsT, config=config_) + # propgen = self.preproc_func(self.depc, *argsT, config=config_) # return list(propgen) # def grouped_map(groupmap_func, groupxs, togroup_args, unique_args_list): @@ -2931,7 +2925,7 @@ def _recompute_and_store(table, tbl_rowids, config=None): # @profile def get_internal_columns( - table, + self, tbl_rowids, colnames=None, eager=True, @@ -2944,11 +2938,11 @@ def get_internal_columns( Access data in this table using the table PRIMARY KEY rowids (not depc PRIMARY ids) """ - prop_list = table.db.get( - table.tablename, + prop_list = self.db.get( + self.tablename, colnames, tbl_rowids, - id_colname=table.rowid_colname, + id_colname=self.rowid_colname, eager=eager, nInput=nInput, unpack_scalars=unpack_scalars, @@ -2957,7 +2951,7 @@ def get_internal_columns( ) return prop_list - def export_rows(table, rowid, target): + def export_rows(self, rowid, target): """ The goal of this is to export taggable data that can be used independantly of its dependant features. @@ -2992,17 +2986,17 @@ def export_rows(table, rowid, target): rowid = 1 """ raise NotImplementedError('unfinished') - colnames = tuple(table.db.get_column_names(table.tablename)) - colvals = table.db.get(table.tablename, colnames, [rowid])[0] # NOQA + colnames = tuple(self.db.get_column_names(self.tablename)) + colvals = self.db.get(self.tablename, colnames, [rowid])[0] # NOQA - uuid = table.get_model_uuid([rowid])[0] - manifest_data = table.get_model_inputs(uuid) # NOQA + uuid = self.get_model_uuid([rowid])[0] + manifest_data = self.get_model_inputs(uuid) # NOQA - config_history = table.get_config_history([rowid]) # NOQA + config_history = self.get_config_history([rowid]) # NOQA - table.parent_col_attrs = table._infer_parentcol() - table.data_col_attrs - table.internal_col_attrs + self.parent_col_attrs = self._infer_parentcol() + self.data_col_attrs + self.internal_col_attrs - table.db.cur.execute('SELECT * FROM {tablename} WHERE rowid=?') + self.db.cur.execute('SELECT * FROM {tablename} WHERE rowid=?') pass From 9d14071eca7aed3b48e5542457eb949852dfe017 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Tue, 20 Oct 2020 10:58:00 -0700 Subject: [PATCH 086/294] Fix class to use the 'self' argument convention --- wbia/dtool/depcache_control.py | 380 ++++++++++++++++----------------- 1 file changed, 190 insertions(+), 190 deletions(-) diff --git a/wbia/dtool/depcache_control.py b/wbia/dtool/depcache_control.py index 1d4ccb097d..c9f37f032d 100644 --- a/wbia/dtool/depcache_control.py +++ b/wbia/dtool/depcache_control.py @@ -92,7 +92,7 @@ class _CoreDependencyCache(object): @profile def _register_prop( - depc, + self, tablename, parents=None, colnames=None, @@ -117,7 +117,7 @@ def _register_prop( if isinstance(tablename, six.string_types): tablename = six.text_type(tablename) if parents is None: - parents = [depc.root] + parents = [self.root] if colnames is None: colnames = 'data' if coltypes is None: @@ -135,7 +135,7 @@ def _register_prop( raise ValueError('must specify coltypes of %s' % (tablename,)) coltypes = [np.ndarray] * len(colnames) if fname is None: - fname = depc.default_fname + fname = self.default_fname if configclass is None: # Make a default config with no parameters configclass = {} @@ -146,10 +146,10 @@ def _register_prop( # ---------- # Register a new table and configuration if requestclass is not None: - depc.requestclass_dict[tablename] = requestclass - depc.fname_to_db[fname] = None + self.requestclass_dict[tablename] = requestclass + self.fname_to_db[fname] = None table = depcache_table.DependencyCacheTable( - depc=depc, + depc=self, parent_tablenames=parents, tablename=tablename, data_colnames=colnames, @@ -159,12 +159,12 @@ def _register_prop( default_to_unpack=default_to_unpack, **kwargs, ) - depc.cachetable_dict[tablename] = table - depc.configclass_dict[tablename] = configclass + self.cachetable_dict[tablename] = table + self.configclass_dict[tablename] = configclass return table @ut.apply_docstr(REG_PREPROC_DOC) - def register_preproc(depc, *args, **kwargs): + def register_preproc(self, *args, **kwargs): """ Decorator for registration of cachables """ @@ -172,51 +172,51 @@ def register_preproc(depc, *args, **kwargs): def register_preproc_wrapper(func): check_register(args, kwargs) kwargs['preproc_func'] = func - depc._register_prop(*args, **kwargs) + self._register_prop(*args, **kwargs) return func return register_preproc_wrapper - def _register_subprop(depc, tablename, propname=None, preproc_func=None): + def _register_subprop(self, tablename, propname=None, preproc_func=None): """ subproperties are always recomputeed on the fly """ - table = depc.cachetable_dict[tablename] + table = self.cachetable_dict[tablename] table.subproperties[propname] = preproc_func - def close(depc): + def close(self): """ Close all managed SQL databases """ - for fname, db in depc.fname_to_db.items(): + for fname, db in self.fname_to_db.items(): db.close() @profile - def initialize(depc, _debug=None): + def initialize(self, _debug=None): """ Creates all registered tables """ logger.info( - '[depc] Initialize %s depcache in %r' % (depc.root.upper(), depc.cache_dpath) + '[depc] Initialize %s depcache in %r' % (self.root.upper(), self.cache_dpath) ) - if depc._use_globals: - reg_preproc = PREPROC_REGISTER[depc.root] - reg_subprop = SUBPROP_REGISTER[depc.root] + if self._use_globals: + reg_preproc = PREPROC_REGISTER[self.root] + reg_subprop = SUBPROP_REGISTER[self.root] logger.info( '[depc.init] Registering %d global preproc funcs' % len(reg_preproc) ) for args_, _kwargs in reg_preproc: - depc._register_prop(*args_, **_kwargs) + self._register_prop(*args_, **_kwargs) logger.info('[depc.init] Registering %d global subprops ' % len(reg_subprop)) for args_, _kwargs in reg_subprop: - depc._register_subprop(*args_, **_kwargs) + self._register_subprop(*args_, **_kwargs) - ut.ensuredir(depc.cache_dpath) + ut.ensuredir(self.cache_dpath) # Memory filestore # if False: # # http://docs.pyfilesystem.org/en/latest/getting_started.html # pip install fs - for fname in depc.fname_to_db.keys(): + for fname in self.fname_to_db.keys(): if fname == ':memory:': db_uri = 'sqlite:///:memory:' else: @@ -225,17 +225,17 @@ def initialize(depc, _debug=None): prefix_dpath = dirname(fname_) if prefix_dpath: - ut.ensuredir(ut.unixjoin(depc.cache_dpath, prefix_dpath)) - fpath = ut.unixjoin(depc.cache_dpath, fname_) + ut.ensuredir(ut.unixjoin(self.cache_dpath, prefix_dpath)) + fpath = ut.unixjoin(self.cache_dpath, fname_) db_uri = 'sqlite:///{}'.format(os.path.realpath(fpath)) # if ut.get_argflag('--clear-all-depcache'): # ut.delete(fpath) db = sql_control.SQLDatabaseController.from_uri(db_uri) depcache_table.ensure_config_table(db) - depc.fname_to_db[fname] = db + self.fname_to_db[fname] = db logger.info('[depc] Finished initialization') - for table in depc.cachetable_dict.values(): + for table in self.cachetable_dict.values(): table.initialize() # HACKS: @@ -243,15 +243,15 @@ def initialize(depc, _debug=None): class InjectedDepc(object): pass - depc.d = InjectedDepc() - depc.w = InjectedDepc() - d = depc.d - w = depc.w + self.d = InjectedDepc() + self.w = InjectedDepc() + d = self.d + w = self.w inject_patterns = [ - ('get_{tablename}_rowids', depc.get_rowids), - ('get_{tablename}_config_history', depc.get_config_history), + ('get_{tablename}_rowids', self.get_rowids), + ('get_{tablename}_config_history', self.get_config_history), ] - for table in depc.cachetable_dict.values(): + for table in self.cachetable_dict.values(): wobj = InjectedDepc() # Set nested version setattr(w, table.tablename, wobj) @@ -264,7 +264,7 @@ class InjectedDepc(object): setattr(wobj, funcname, func) dfmtstr = 'get_{tablename}_{colname}' for colname in table.data_colnames: - get_prop = ut.partial(depc.get, table.tablename, colnames=colname) + get_prop = ut.partial(self.get, table.tablename, colnames=colname) attrname = dfmtstr.format(tablename=table.tablename, colname=colname) # Set flat version setattr(d, attrname, get_prop) @@ -273,7 +273,7 @@ class InjectedDepc(object): # ----------------------------- # GRAPH INSPECTION - def get_dependencies(depc, tablename): + def get_dependencies(self, tablename): """ gets level dependences from root to tablename @@ -296,11 +296,11 @@ def get_dependencies(depc, tablename): ] """ try: - assert tablename in depc.cachetable_dict, 'tablename=%r does not exist' % ( + assert tablename in self.cachetable_dict, 'tablename=%r does not exist' % ( tablename, ) - root = depc.root_tablename - children_, parents_ = list(zip(*depc.get_edges())) + root = self.root_tablename + children_, parents_ = list(zip(*self.get_edges())) child_to_parents = ut.group_items(children_, parents_) if ut.VERYVERBOSE: logger.info('root = %r' % (root,)) @@ -331,7 +331,7 @@ def get_dependencies(depc, tablename): return dependency_levels - def _ensure_config(depc, tablekey, config, _debug=False): + def _ensure_config(self, tablekey, config, _debug=False): """ Creates a full table configuration with all defaults using config @@ -339,8 +339,8 @@ def _ensure_config(depc, tablekey, config, _debug=False): tablekey (str): name of the table to grab config from config (dict): may be overspecified or underspecfied """ - configclass = depc.configclass_dict.get(tablekey, None) - # requestclass = depc.requestclass_dict.get(tablekey, None) + configclass = self.configclass_dict.get(tablekey, None) + # requestclass = self.requestclass_dict.get(tablekey, None) if configclass is None: config_ = config else: @@ -363,24 +363,24 @@ def _ensure_config(depc, tablekey, config, _debug=False): logger.debug(' config_ = %r' % (config_,)) return config_ - def get_config_trail(depc, tablename, config): - graph = depc.make_graph(implicit=True) - tablename_list = ut.nx_all_nodes_between(graph, depc.root, tablename) + def get_config_trail(self, tablename, config): + graph = self.make_graph(implicit=True) + tablename_list = ut.nx_all_nodes_between(graph, self.root, tablename) tablename_list = ut.nx_topsort_nodes(graph, tablename_list) config_trail = [] for tablekey in tablename_list: - if tablekey in depc.configclass_dict: - config_ = depc._ensure_config(tablekey, config) + if tablekey in self.configclass_dict: + config_ = self._ensure_config(tablekey, config) config_trail.append(config_) return config_trail - def get_config_trail_str(depc, tablename, config): - config_trail = depc.get_config_trail(tablename, config) + def get_config_trail_str(self, tablename, config): + config_trail = self.get_config_trail(tablename, config) trail_cfgstr = '_'.join([x.get_cfgstr() for x in config_trail]) return trail_cfgstr def _get_parent_input( - depc, + self, tablename, root_rowids, config, @@ -392,8 +392,8 @@ def _get_parent_input( nInput=None, ): # Get ancestor rowids that are descendants of root - table = depc[tablename] - rowid_dict = depc.get_all_descendant_rowids( + table = self[tablename] + rowid_dict = self.get_all_descendant_rowids( tablename, root_rowids, config=config, @@ -404,13 +404,13 @@ def _get_parent_input( recompute_all=recompute_all, levels_up=1, ) - parent_rowids = depc._get_parent_rowids(table, rowid_dict) + parent_rowids = self._get_parent_rowids(table, rowid_dict) return parent_rowids # ----------------------------- # STATE GETTERS - def rectify_input_tuple(depc, exi_inputs, input_tuple): + def rectify_input_tuple(self, exi_inputs, input_tuple): """ Standardizes inputs allowed for convinience into the expected input for get_parent_rowids. @@ -451,7 +451,7 @@ def rectify_input_tuple(depc, exi_inputs, input_tuple): rectified_input.append(x) return rectified_input - def get_parent_rowids(depc, target_tablename, input_tuple, config=None, **kwargs): + def get_parent_rowids(self, target_tablename, input_tuple, config=None, **kwargs): """ Returns the parent rowids needed to get / compute a property of tablename @@ -492,7 +492,7 @@ def get_parent_rowids(depc, target_tablename, input_tuple, config=None, **kwargs logger.debug(' * target_tablename = %r' % (target_tablename,)) logger.debug(' * input_tuple=%s' % (ut.trunc_repr(input_tuple),)) logger.debug(' * config = %r' % (config,)) - target_table = depc[target_tablename] + target_table = self[target_tablename] # TODO: Expand to the appropriate given inputs if _hack_rootmost: @@ -503,7 +503,7 @@ def get_parent_rowids(depc, target_tablename, input_tuple, config=None, **kwargs exi_inputs = target_table.rootmost_inputs.total_expand() logger.debug(' * exi_inputs=%s' % (exi_inputs,)) - rectified_input = depc.rectify_input_tuple(exi_inputs, input_tuple) + rectified_input = self.rectify_input_tuple(exi_inputs, input_tuple) rowid_dict = {} for rmi, rowids in zip(exi_inputs.rmi_list, rectified_input): @@ -519,7 +519,7 @@ def get_parent_rowids(depc, target_tablename, input_tuple, config=None, **kwargs % (count, len(compute_edges), input_nodes, output_node), ) tablekey = output_node.tablename - table = depc[tablekey] + table = self[tablekey] input_nodes_ = input_nodes logger.debug( 'table.parent_id_tablenames = %r' % (table.parent_id_tablenames,) @@ -592,7 +592,7 @@ def get_parent_rowids(depc, target_tablename, input_tuple, config=None, **kwargs if output_node.tablename != target_tablename: # Get table configuration - config_ = depc._ensure_config(tablekey, config) + config_ = self._ensure_config(tablekey, config) output_rowids = table.get_rowid( _parent_rowids, config=config_, recompute=_recompute, **_kwargs @@ -606,18 +606,18 @@ def get_parent_rowids(depc, target_tablename, input_tuple, config=None, **kwargs # rowids = rowid_dict[output_node] return parent_rowids - def check_rowids(depc, tablename, input_tuple, config={}): + def check_rowids(self, tablename, input_tuple, config={}): """ Returns a list of flags where True means the row has been computed and False means that it needs to be computed. """ - existing_rowids = depc.get_rowids( + existing_rowids = self.get_rowids( tablename, input_tuple, config=config, ensure=False ) flags = ut.flag_not_None_items(existing_rowids) return flags - def get_rowids(depc, tablename, input_tuple, **rowid_kw): + def get_rowids(self, tablename, input_tuple, **rowid_kw): """ Used to get tablename rowids. Ensures rows exist unless ensure=False. rowids uniquely specify parent inputs and a configuration. @@ -676,9 +676,9 @@ def get_rowids(depc, tablename, input_tuple, **rowid_kw): _hack_rootmost = _kwargs.pop('_hack_rootmost', False) _recompute_all = _kwargs.pop('recompute_all', False) recompute = _kwargs.pop('recompute', _recompute_all) - table = depc[target_tablename] + table = self[target_tablename] - parent_rowids = depc.get_parent_rowids( + parent_rowids = self.get_parent_rowids( target_tablename, input_tuple, config=config, @@ -686,7 +686,7 @@ def get_rowids(depc, tablename, input_tuple, **rowid_kw): **_kwargs, ) - config_ = depc._ensure_config(target_tablename, config) + config_ = self._ensure_config(target_tablename, config) rowids = table.get_rowid( parent_rowids, config=config_, recompute=recompute, **_kwargs ) @@ -694,7 +694,7 @@ def get_rowids(depc, tablename, input_tuple, **rowid_kw): @ut.accepts_scalar_input2(argx_list=[1]) def get( - depc, + self, tablename, root_rowids, colnames=None, @@ -787,8 +787,8 @@ def get( >>> prop_list3 = depc.get(tablename, root_rowids) >>> assert np.all(prop_list1[0][1] == prop_list3[0][1]), 'computed same info' """ - if tablename == depc.root_tablename: - return depc.root_getters[colnames](root_rowids) + if tablename == self.root_tablename: + return self.root_getters[colnames](root_rowids) # pass logger.debug(' * tablename=%s' % (tablename)) logger.debug(' * root_rowids=%s' % (ut.trunc_repr(root_rowids))) @@ -800,7 +800,7 @@ def get( from os.path import join # recompute_ = recompute or recompute_all - parent_rowids = depc.get_parent_rowids( + parent_rowids = self.get_parent_rowids( tablename, root_rowids, config=config, @@ -809,9 +809,9 @@ def get( eager=True, nInput=None, ) - config_ = depc._ensure_config(tablename, config) + config_ = self._ensure_config(tablename, config) logger.debug(' * (ensured) config_ = %r' % (config_,)) - table = depc[tablename] + table = self[tablename] extern_dpath = table.extern_dpath ut.ensuredir(extern_dpath) fname_list = table.get_extern_fnames( @@ -844,9 +844,9 @@ def get( for trynum in range(num_retries + 1): try: - table = depc[tablename] + table = self[tablename] # Vectorized get of properties - tbl_rowids = depc.get_rowids(tablename, input_tuple, **rowid_kw) + tbl_rowids = self.get_rowids(tablename, input_tuple, **rowid_kw) logger.debug('[depc.get] tbl_rowids = %s' % (ut.trunc_repr(tbl_rowids),)) prop_list = table.get_row_data(tbl_rowids, colnames, **rowdata_kw) except depcache_table.ExternalStorageException: @@ -859,7 +859,7 @@ def get( return prop_list def get_native( - depc, tablename, tbl_rowids, colnames=None, _debug=None, read_extern=True + self, tablename, tbl_rowids, colnames=None, _debug=None, read_extern=True ): """ Gets data using internal ids, which is faster if you have them. @@ -900,7 +900,7 @@ def get_native( logger.debug(' * tablename = %r' % (tablename,)) logger.debug(' * colnames = %r' % (colnames,)) logger.debug(' * tbl_rowids=%s' % (ut.trunc_repr(tbl_rowids))) - table = depc[tablename] + table = self[tablename] # import utool # with utool.embed_on_exception_context: # try: @@ -927,12 +927,12 @@ def get_native( # delete_on_fail=False) return prop_list - def get_config_history(depc, tablename, root_rowids, config=None): + def get_config_history(self, tablename, root_rowids, config=None): # Vectorized get of properties - tbl_rowids = depc.get_rowids(tablename, root_rowids, config=config) - return depc[tablename].get_config_history(tbl_rowids) + tbl_rowids = self.get_rowids(tablename, root_rowids, config=config) + return self[tablename].get_config_history(tbl_rowids) - def get_root_rowids(depc, tablename, native_rowids): + def get_root_rowids(self, tablename, native_rowids): r""" Args: tablename (str): @@ -959,58 +959,58 @@ def get_root_rowids(depc, tablename, native_rowids): >>> assert ancestor_rowids1 == root_rowids, 'should have same root' >>> assert ancestor_rowids2 == root_rowids, 'should have same root' """ - return depc.get_ancestor_rowids(tablename, native_rowids, depc.root) + return self.get_ancestor_rowids(tablename, native_rowids, self.root) - def get_ancestor_rowids(depc, tablename, native_rowids, ancestor_tablename=None): + def get_ancestor_rowids(self, tablename, native_rowids, ancestor_tablename=None): """ ancestor_tablename = depc.root; native_rowids = cid_list; tablename = const.CHIP_TABLE """ if ancestor_tablename is None: - ancestor_tablename = depc.root - table = depc[tablename] + ancestor_tablename = self.root + table = self[tablename] ancestor_rowids = table.get_ancestor_rowids(native_rowids, ancestor_tablename) return ancestor_rowids - def new_request(depc, tablename, qaids, daids, cfgdict=None): + def new_request(self, tablename, qaids, daids, cfgdict=None): """ creates a request for data that can be executed later """ logger.info('[depc] NEW %s request' % (tablename,)) - requestclass = depc.requestclass_dict[tablename] - request = requestclass.new(depc, qaids, daids, cfgdict, tablename=tablename) + requestclass = self.requestclass_dict[tablename] + request = requestclass.new(self, qaids, daids, cfgdict, tablename=tablename) return request # ----------------------------- # STATE MODIFIERS - def delete_property(depc, tablename, root_rowids, config=None, _debug=False): + def delete_property(self, tablename, root_rowids, config=None, _debug=False): """ Deletes the rowids of `tablename` that correspond to `root_rowids` using `config`. FIXME: make this work for all configs """ - rowid_list = depc.get_rowids( + rowid_list = self.get_rowids( tablename, root_rowids, config=config, ensure=False, ) - table = depc[tablename] + table = self[tablename] num_deleted = table.delete_rows(rowid_list) return num_deleted - def delete_property_all(depc, tablename, root_rowids, _debug=False): + def delete_property_all(self, tablename, root_rowids, _debug=False): """ Deletes the rowids of `tablename` that correspond to `root_rowids` using `config`. FIXME: make this work for all configs """ - table = depc[tablename] + table = self[tablename] all_rowid_list = table._get_all_rowids() if len(all_rowid_list) == 0: return 0 - ancestor_rowid_list = depc.get_ancestor_rowids(tablename, all_rowid_list) + ancestor_rowid_list = self.get_ancestor_rowids(tablename, all_rowid_list) rowid_list = [] root_rowids_set = set(root_rowids) @@ -1031,7 +1031,7 @@ class DependencyCache(_CoreDependencyCache, ut.NiceRepr): """ def __init__( - depc, + self, root_tablename=None, cache_dpath='./DEPCACHE', controller=None, @@ -1044,68 +1044,68 @@ def __init__( if default_fname is None: default_fname = root_tablename + '_primary_cache' # default_fname = ':memory:' - depc.root_getters = root_getters + self.root_getters = root_getters # Root of all dependencies - depc.root_tablename = root_tablename + self.root_tablename = root_tablename # Directory all cachefiles are stored in - depc.cache_dpath = ut.truepath(cache_dpath) + self.cache_dpath = ut.truepath(cache_dpath) # Parent (ibs) controller - depc.controller = controller + self.controller = controller # Internal dictionary of dependant tables - depc.cachetable_dict = {} - depc.configclass_dict = {} - depc.requestclass_dict = {} - depc.resultclass_dict = {} + self.cachetable_dict = {} + self.configclass_dict = {} + self.requestclass_dict = {} + self.resultclass_dict = {} # Mapping of different files properties are stored in - depc.fname_to_db = {} + self.fname_to_db = {} # Function to map a root rowid to an object - # depc._root_asobject = root_asobject - depc._use_globals = use_globals - depc.default_fname = default_fname + # self._root_asobject = root_asobject + self._use_globals = use_globals + self.default_fname = default_fname if get_root_uuid is None: logger.info('WARNING NEED UUID FUNCTION') # HACK get_root_uuid = ut.identity - depc.get_root_uuid = get_root_uuid - depc.delete_exclude_tables = {} + self.get_root_uuid = get_root_uuid + self.delete_exclude_tables = {} # BBB (25-Sept-12020) `_debug` remains around to be backwards compatible - depc._debug = False + self._debug = False - def get_tablenames(depc): - return list(depc.cachetable_dict.keys()) + def get_tablenames(self): + return list(self.cachetable_dict.keys()) @property - def tables(depc): - return list(depc.cachetable_dict.values()) + def tables(self): + return list(self.cachetable_dict.values()) @property - def tablenames(depc): - return depc.get_tablenames() + def tablenames(self): + return self.get_tablenames() - def print_schemas(depc): - for fname, db in depc.fname_to_db.items(): + def print_schemas(self): + for fname, db in self.fname_to_db.items(): logger.info('fname = %r' % (fname,)) db.print_schema() - # def print_table_csv(depc, tablename): - # depc[tablename] + # def print_table_csv(self, tablename): + # self[tablename] - def print_table(depc, tablename): - depc[tablename].print_table() + def print_table(self, tablename): + self[tablename].print_table() - def print_all_tables(depc): - for tablename, table in depc.cachetable_dict.items(): + def print_all_tables(self): + for tablename, table in self.cachetable_dict.items(): table.print_table() # db = table.db # db.print_table_csv(tablename) - def print_config_tables(depc): - for fname in depc.fname_to_db: + def print_config_tables(self): + for fname in self.fname_to_db: logger.info('---') logger.info('db_fname = %r' % (fname,)) - depc.fname_to_db[fname].print_table_csv('config') + self.fname_to_db[fname].print_table_csv('config') - def get_edges(depc, data=False): + def get_edges(self, data=False): """ edges for networkx structure """ @@ -1145,26 +1145,26 @@ def get_edgedata(tablekey, parentkey, parent_data): edges = [ (parentkey, tablekey, get_edgedata(tablekey, parentkey, parent_data)) - for tablekey, table in depc.cachetable_dict.items() + for tablekey, table in self.cachetable_dict.items() for parentkey, parent_data in table.parents(data=True) ] else: edges = [ (parentkey, tablekey) - for tablekey, table in depc.cachetable_dict.items() + for tablekey, table in self.cachetable_dict.items() for parentkey in table.parents(data=False) ] return edges - def get_implicit_edges(depc, data=False): + def get_implicit_edges(self, data=False): """ Edges defined by subconfigurations """ # add implicit edges implicit_edges = [] # Map config classes to tablenames - _inverted_ccdict = ut.invert_dict(depc.configclass_dict) - for tablename2, configclass in depc.configclass_dict.items(): + _inverted_ccdict = ut.invert_dict(self.configclass_dict) + for tablename2, configclass in self.configclass_dict.items(): cfg = configclass() subconfigs = cfg.get_sub_config_list() if subconfigs is not None and len(subconfigs) > 0: @@ -1176,7 +1176,7 @@ def get_implicit_edges(depc, data=False): return implicit_edges @ut.memoize - def make_graph(depc, **kwargs): + def make_graph(self, **kwargs): """ Constructs a networkx representation of the dependency graph @@ -1232,13 +1232,13 @@ def make_graph(depc, **kwargs): # graph = nx.DiGraph() graph = nx.MultiDiGraph() - nodes = list(depc.cachetable_dict.keys()) - edges = depc.get_edges(data=True) + nodes = list(self.cachetable_dict.keys()) + edges = self.get_edges(data=True) graph.add_nodes_from(nodes) graph.add_edges_from(edges) if kwargs.get('implicit', True): - implicit_edges = depc.get_implicit_edges(data=True) + implicit_edges = self.get_implicit_edges(data=True) graph.add_edges_from(implicit_edges) shape_dict = { @@ -1261,8 +1261,8 @@ def make_graph(depc, **kwargs): } def _node_attrs(dict_): - props = {k: dict_['node'] for k, v in depc.cachetable_dict.items()} - props[depc.root] = dict_['root'] + props = {k: dict_['node'] for k, v in self.cachetable_dict.items()} + props[self.root] = dict_['root'] return props nx.set_node_attributes(graph, name='color', values=_node_attrs(color_dict)) @@ -1387,40 +1387,40 @@ def _node_attrs(dict_): return graph @property - def graph(depc): - return depc.make_graph() + def graph(self): + return self.make_graph() @property - def explicit_graph(depc): - return depc.make_graph(implicit=False) + def explicit_graph(self): + return self.make_graph(implicit=False) @property - def reduced_graph(depc): - return depc.make_graph(reduced=True) + def reduced_graph(self): + return self.make_graph(reduced=True) - def show_graph(depc, reduced=False, **kwargs): + def show_graph(self, reduced=False, **kwargs): """ Helper "fluff" function """ import wbia.plottool as pt - graph = depc.make_graph(reduced=reduced) + graph = self.make_graph(reduced=reduced) if ut.is_developer(): ut.ensureqt() kwargs['layout'] = 'agraph' pt.show_nx(graph, **kwargs) - def __nice__(depc): - infostr_ = 'nTables=%d' % len(depc.cachetable_dict) - return '(%s) %s' % (depc.root_tablename, infostr_) + def __nice__(self): + infostr_ = 'nTables=%d' % len(self.cachetable_dict) + return '(%s) %s' % (self.root_tablename, infostr_) - def __getitem__(depc, tablekey): - return depc.cachetable_dict[tablekey] + def __getitem__(self, tablekey): + return self.cachetable_dict[tablekey] @property - def root(depc): - return depc.root_tablename + def root(self): + return self.root_tablename def delete_root( - depc, + self, root_rowids, delete_extern=None, _debug=False, @@ -1447,56 +1447,56 @@ def delete_root( >>> depc.get('fgweight', [1]) >>> depc.delete_root(root_rowids) """ - # graph = depc.make_graph(implicit=False) + # graph = self.make_graph(implicit=False) # hack # check to make sure child does not have another parent - rowid_dict = depc.get_allconfig_descendant_rowids( + rowid_dict = self.get_allconfig_descendant_rowids( root_rowids, table_config_filter ) - # children = [child for child in graph.succ[depc.root_tablename] + # children = [child for child in graph.succ[self.root_tablename] # if sum([len(e) for e in graph.pred[child].values()]) == 1] - # depc.delete_property(tablename, root_rowids) + # self.delete_property(tablename, root_rowids) num_deleted = 0 for tablename, table_rowids in rowid_dict.items(): - if tablename == depc.root: + if tablename == self.root: continue # Specific prop exclusion - delete_exclude_table_set_prop = depc.delete_exclude_tables.get(prop, []) - delete_exclude_table_set_all = depc.delete_exclude_tables.get(None, []) + delete_exclude_table_set_prop = self.delete_exclude_tables.get(prop, []) + delete_exclude_table_set_all = self.delete_exclude_tables.get(None, []) if ( tablename in delete_exclude_table_set_prop or tablename in delete_exclude_table_set_all ): continue - table = depc[tablename] + table = self[tablename] num_deleted += table.delete_rows(table_rowids, delete_extern=delete_extern) return num_deleted - def register_delete_table_exclusion(depc, tablename, prop): - if prop not in depc.delete_exclude_tables: - depc.delete_exclude_tables[prop] = set([]) - depc.delete_exclude_tables[prop].add(tablename) - args = (ut.repr3(depc.delete_exclude_tables),) + def register_delete_table_exclusion(self, tablename, prop): + if prop not in self.delete_exclude_tables: + self.delete_exclude_tables[prop] = set([]) + self.delete_exclude_tables[prop].add(tablename) + args = (ut.repr3(self.delete_exclude_tables),) logger.info('[depc] Updated delete tables: %s' % args) - def get_allconfig_descendant_rowids(depc, root_rowids, table_config_filter=None): + def get_allconfig_descendant_rowids(self, root_rowids, table_config_filter=None): import networkx as nx - # list(nx.topological_sort(nx.bfs_tree(graph, depc.root))) - # decendants = nx.descendants(graph, depc.root) + # list(nx.topological_sort(nx.bfs_tree(graph, self.root))) + # decendants = nx.descendants(graph, self.root) # raise NotImplementedError() - graph = depc.make_graph(implicit=True) - root = depc.root + graph = self.make_graph(implicit=True) + root = self.root rowid_dict = {} rowid_dict[root] = root_rowids # Find all rowids that inherit from the specific root rowids - sinks = list(ut.nx_sink_nodes(nx.bfs_tree(graph, depc.root))) + sinks = list(ut.nx_sink_nodes(nx.bfs_tree(graph, self.root))) for target_tablename in sinks: path = nx.shortest_path(graph, root, target_tablename) for parent, child in ut.itertwo(path): - child_table = depc[child] + child_table = self[child] relevant_col_attrs = [ attrs for attrs in child_table.parent_col_attrs @@ -1547,7 +1547,7 @@ def get_allconfig_descendant_rowids(depc, root_rowids, table_config_filter=None) ) return rowid_dict - def notify_root_changed(depc, root_rowids, prop, force_delete=False): + def notify_root_changed(self, root_rowids, prop, force_delete=False): """ this is where we are notified that a "registered" root property has changed. @@ -1557,18 +1557,18 @@ def notify_root_changed(depc, root_rowids, prop, force_delete=False): % (prop, len(root_rowids)) ) # for key in tables_depending_on(prop) - # depc.delete_property(key, root_rowids) + # self.delete_property(key, root_rowids) # TODO: check which properties were invalidated by this prop # TODO; remove invalidated properties if force_delete: - depc.delete_root(root_rowids, prop=prop) + self.delete_root(root_rowids, prop=prop) - def clear_all(depc): - logger.info('Clearning all cached data in %r' % (depc,)) - for table in depc.cachetable_dict.values(): + def clear_all(self): + logger.info('Clearning all cached data in %r' % (self,)) + for table in self.cachetable_dict.values(): table.clear_table() - def make_root_info_uuid(depc, root_rowids, info_props): + def make_root_info_uuid(self, root_rowids, info_props): """ Creates a uuid that depends on certain properties of the root object. This is used for implicit cache invalidation because, if those @@ -1583,24 +1583,24 @@ def make_root_info_uuid(depc, root_rowids, info_props): >>> info_props = ['image_uuid', 'verts', 'theta'] >>> info_props = ['image_uuid', 'verts', 'theta', 'name', 'species', 'yaw'] """ - getters = ut.dict_take(depc.root_getters, info_props) + getters = ut.dict_take(self.root_getters, info_props) infotup_list = zip(*[getter(root_rowids) for getter in getters]) info_uuid_list = [ut.augment_uuid(*tup) for tup in infotup_list] return info_uuid_list - def get_uuids(depc, tablename, root_rowids, config=None): + def get_uuids(self, tablename, root_rowids, config=None): """ # TODO: Make uuids for dependant object based on root uuid and path of # construction. """ - if tablename == depc.root: - uuid_list = depc.get_root_uuid(root_rowids) + if tablename == self.root: + uuid_list = self.get_root_uuid(root_rowids) return uuid_list get_native_property = _CoreDependencyCache.get_native get_property = _CoreDependencyCache.get - def stacked_config(depc, source, dest, config): + def stacked_config(self, source, dest, config): r""" CommandLine: python -m dtool.depcache_control stacked_config --show @@ -1621,15 +1621,15 @@ def stacked_config(depc, source, dest, config): if config is None: config = {} if source is None: - source = depc.root - graph = depc.make_graph(implicit=True) + source = self.root + graph = self.make_graph(implicit=True) requires_tables = ut.setdiff( ut.nx_all_nodes_between(graph, source, dest), [source] ) - # requires_tables = ut.setdiff(ut.nx_all_nodes_between(depc.graph, 'annotations', 'featweight'), ['annotations']) - requires_tables = ut.nx_topsort_nodes(depc.graph, requires_tables) + # requires_tables = ut.setdiff(ut.nx_all_nodes_between(self.graph, 'annotations', 'featweight'), ['annotations']) + requires_tables = ut.nx_topsort_nodes(self.graph, requires_tables) requires_configs = [ - depc.configclass_dict[tblname](**config) for tblname in requires_tables + self.configclass_dict[tblname](**config) for tblname in requires_tables ] # cfgstr_list = [cfg.get_cfgstr() for cfg in requires_configs] stacked_config = base.StackedConfig(requires_configs) From 0c536071b496a3bb7c509bd63cb310c612d68315 Mon Sep 17 00:00:00 2001 From: karen chan Date: Tue, 13 Oct 2020 21:48:58 +0100 Subject: [PATCH 087/294] Remove brambox from runtime requirements wbia-brambox is going to be included in the wbia-lightnet package so we don't need it in the runtime requirements anymore. --- requirements/runtime.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 25e197676f..b18841028d 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,6 +1,6 @@ blinker boto>=2.20.1 -brambox +# wbia-brambox # already included in wbia-lightnet cachetools>=1.1.6 click colorama>=0.3.2 @@ -17,7 +17,6 @@ gdm imgaug ipython>=5.0.0 -lightnet lockfile>=0.10.2 mako matplotlib>=3.3.0 @@ -65,6 +64,7 @@ tqdm ubelt >= 0.8.7 wbia-cnn>=3.0.2 +wbia-lightnet wbia-pydarknet >= 3.0.1 wbia-pyflann >= 3.1.0 wbia-pyhesaff >= 3.0.2 From 6f6623d29a3e819e191bbb036c0673e2896b59e8 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Mon, 19 Oct 2020 23:53:22 -0700 Subject: [PATCH 088/294] Create named Dependency Caches This works towards creating named dependency caches as opposed to the file based focus. With just a name we can derive a filename or a database name, whichever is more applicable. --- wbia/control/IBEISControl.py | 33 ++++++++------------ wbia/dtool/depcache_control.py | 55 ++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 20 deletions(-) diff --git a/wbia/control/IBEISControl.py b/wbia/control/IBEISControl.py index e1aa683594..e1b49b731b 100644 --- a/wbia/control/IBEISControl.py +++ b/wbia/control/IBEISControl.py @@ -791,17 +791,15 @@ def _init_sqldbstaging(self, request_stagingversion=None): def _init_depcache(self): # Initialize dependency cache for images image_root_getters = {} - self.depc_image = dtool.DependencyCache( - root_tablename=const.IMAGE_TABLE, - default_fname=const.IMAGE_TABLE + '_depcache', - cache_dpath=self.get_cachedir(), - controller=self, - get_root_uuid=self.get_image_uuids, + self.depc_image = dtool.DependencyCache.as_named( + self, + const.IMAGE_TABLE, + self.get_image_uuids, root_getters=image_root_getters, ) self.depc_image.initialize() - """ Need to reinit this sometimes if cache is ever deleted """ + # Need to reinit this sometimes if cache is ever deleted # Initialize dependency cache for annotations annot_root_getters = { 'name': self.get_annot_names, @@ -817,13 +815,10 @@ def _init_depcache(self): 'theta': self.get_annot_thetas, 'occurrence_text': self.get_annot_occurrence_text, } - self.depc_annot = dtool.DependencyCache( - # root_tablename='annot', # const.ANNOTATION_TABLE - root_tablename=const.ANNOTATION_TABLE, - default_fname=const.ANNOTATION_TABLE + '_depcache', - cache_dpath=self.get_cachedir(), - controller=self, - get_root_uuid=self.get_annot_visual_uuids, + self.depc_annot = dtool.DependencyCache.as_named( + self, + const.ANNOTATION_TABLE, + self.get_annot_visual_uuids, root_getters=annot_root_getters, ) # backwards compatibility @@ -835,12 +830,10 @@ def _init_depcache(self): # Initialize dependency cache for parts part_root_getters = {} - self.depc_part = dtool.DependencyCache( - root_tablename=const.PART_TABLE, - default_fname=const.PART_TABLE + '_depcache', - cache_dpath=self.get_cachedir(), - controller=self, - get_root_uuid=self.get_part_uuids, + self.depc_part = dtool.DependencyCache.as_named( + self, + const.PART_TABLE, + self.get_part_uuids, root_getters=part_root_getters, ) self.depc_part.initialize() diff --git a/wbia/dtool/depcache_control.py b/wbia/dtool/depcache_control.py index c9f37f032d..3a057494a7 100644 --- a/wbia/dtool/depcache_control.py +++ b/wbia/dtool/depcache_control.py @@ -1071,6 +1071,61 @@ def __init__( # BBB (25-Sept-12020) `_debug` remains around to be backwards compatible self._debug = False + @classmethod + def as_named( + cls, + controller, + name, + get_root_uuid, + table_name=None, + root_getters=None, + use_globals=True, + ): + """ + Args: + controller (IBEISController): main controller + name (str): name of this controller instance, which is used in naming the data storage + table_name (str): (optional) if not the same as the 'name' + get_root_uuid: ??? + root_getters: ??? + use_globals (bool): ??? (default: True) + + """ + if table_name is None: + table_name = name + + self = cls.__new__(cls) + + # Parent (ibs) controller + self.controller = controller + # Internal dictionary of dependant tables + self.cachetable_dict = {} + self.configclass_dict = {} + self.requestclass_dict = {} + self.resultclass_dict = {} + + self.root_getters = root_getters + # Root of all dependencies + self.root_tablename = table_name + # XXX Directory all cachefiles are stored in + self.cache_dpath = ut.truepath(self.controller.get_cachedir()) + + # XXX Mapping of different files properties are stored in + self.fname_to_db = {} + + # Function to map a root rowid to an object + self._use_globals = use_globals + + # XXX remove filesystem name + self.default_fname = f'{table_name}_cache' + + self.get_root_uuid = get_root_uuid + self.delete_exclude_tables = {} + # BBB (25-Sept-12020) `_debug` remains around to be backwards compatible + self._debug = False + + return self + def get_tablenames(self): return list(self.cachetable_dict.keys()) From 50a6e39b42ccea36dcdb4053a3412c1d3fb10377 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Tue, 20 Oct 2020 01:44:17 -0700 Subject: [PATCH 089/294] Add a base_uri property to the controller The intension is to allow the child controllers to use this base_uri to make their own connections to the database. --- wbia/control/IBEISControl.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/wbia/control/IBEISControl.py b/wbia/control/IBEISControl.py index e1b49b731b..84eb360049 100644 --- a/wbia/control/IBEISControl.py +++ b/wbia/control/IBEISControl.py @@ -53,6 +53,7 @@ (print, rrr, profile) = ut.inject2(__name__) logger = logging.getLogger('wbia') + # Import modules which define injectable functions # tuples represent conditional imports with the flags in the first part of the @@ -362,6 +363,12 @@ def __init__( self.table_cache = None self._initialize_self() self._init_dirs(dbdir=dbdir, ensure=ensure) + + # FIXME (20-Oct-12020) Set the base URI + # This is a temporary adjustment to obtain a URI where one would not + # be present as part of this filesystem focused contructor. + self._base_uri = f'sqlite:///{self.get_ibsdir()}' + # _send_wildbook_request will do nothing if no wildbook address is # specified self._send_wildbook_request(wbaddr) @@ -391,6 +398,11 @@ def __init__( logger.info('[ibs.__init__] END new IBEISController\n') + @property + def base_uri(self): + """Base database URI without a specific database name""" + return self._base_uri + def reset_table_cache(self): self.table_cache = accessor_decors.init_tablecache() From 351942ea34de7bc794334b7ac70daee8ac983102 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Tue, 20 Oct 2020 12:12:05 -0700 Subject: [PATCH 090/294] Remove unused property fpath_to_db This attribute doesn't appear to be used anywhere and the tests pass without it. --- wbia/dtool/depcache_table.py | 1 - 1 file changed, 1 deletion(-) diff --git a/wbia/dtool/depcache_table.py b/wbia/dtool/depcache_table.py index ae083d0aea..634b966e7d 100644 --- a/wbia/dtool/depcache_table.py +++ b/wbia/dtool/depcache_table.py @@ -1865,7 +1865,6 @@ def __init__( except Exception: # HACK: jedi type hinting. Need to have non-obvious condition self.db = SQLDatabaseController() - self.fpath_to_db = {} assert ( re.search('[0-9]', tablename) is None ), 'tablename=%r cannot contain numbers' % (tablename,) From d4e4a67281ed5d651187587bb385e7e603324b2a Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Tue, 20 Oct 2020 12:29:08 -0700 Subject: [PATCH 091/294] Abstract the database naming away from file names This pulls the database connection responsiblity into the controller where it can be abstracted for use. The intention is to move away from a filebased name, but also to remove the `fname_to_db` attribute from the scope and vision of the DepCTable class. --- wbia/dtool/depcache_control.py | 6 ++++++ wbia/dtool/depcache_table.py | 10 ++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/wbia/dtool/depcache_control.py b/wbia/dtool/depcache_control.py index 3a057494a7..66581e3876 100644 --- a/wbia/dtool/depcache_control.py +++ b/wbia/dtool/depcache_control.py @@ -189,6 +189,12 @@ def close(self): for fname, db in self.fname_to_db.items(): db.close() + def get_db_by_name(self, name): + """Get the database (i.e. SQLController) for the given database name""" + # FIXME (20-Oct-12020) Currently handled via a mapping of 'fname' + # to database controller objects. + return self.fname_to_db[name] + @profile def initialize(self, _debug=None): """ diff --git a/wbia/dtool/depcache_table.py b/wbia/dtool/depcache_table.py index 634b966e7d..bc6fb73706 100644 --- a/wbia/dtool/depcache_table.py +++ b/wbia/dtool/depcache_table.py @@ -1910,12 +1910,18 @@ def __init__( self._hack_chunk_cache = None - # @profile + @property + def db_name(self): + """Name of the database this Table belongs to""" + # FIXME (20-Oct-12020) 'fname' is the legacy name, but can't be changed just yet + # as its fairly intertwined in the registration decorator code. + return self.fname + def initialize(self, _debug=None): """ Ensures the SQL schema for this cache table """ - self.db = self.depc.fname_to_db[self.fname] + self.db = self.depc.get_db_by_name(self.db_name) # logger.info('Checking sql for table=%r' % (self.tablename,)) if not self.db.has_table(self.tablename): logger.debug('Initializing table=%r' % (self.tablename,)) From 706a28e456fefd017d20f219f10615c8c162ca78 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Tue, 20 Oct 2020 12:59:29 -0700 Subject: [PATCH 092/294] Fix AttributeError for keywords attribute This FullArgsSpec object doesn't have a keywords attribute, but looks like under some circumstances it could. So I'm changing it to be optionally checked. --- wbia/dtool/depcache_table.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wbia/dtool/depcache_table.py b/wbia/dtool/depcache_table.py index bc6fb73706..be28ddc71a 100644 --- a/wbia/dtool/depcache_table.py +++ b/wbia/dtool/depcache_table.py @@ -582,7 +582,7 @@ def _assert_self(self): # ie (depc, parent_ids, config) argspec = ut.get_func_argspec(self.preproc_func) args = argspec.args - if argspec.varargs and argspec.keywords: + if argspec.varargs and (hasattr(argspec, 'keywords') and argspec.keywords): assert len(args) == 1, 'varargs and kwargs must have one arg for depcache' else: if len(args) < 3: From fbca249034371c4f16049ce85d8539a06efc2536 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Tue, 20 Oct 2020 13:04:35 -0700 Subject: [PATCH 093/294] Add from_name constructor for DepCacheTable The intent of this new constructor method is to move away from the file focus. This is a step in that direction. Passing tests --- wbia/dtool/depcache_control.py | 8 ++-- wbia/dtool/depcache_table.py | 87 +++++++++++++++++++++++++++++++--- 2 files changed, 84 insertions(+), 11 deletions(-) diff --git a/wbia/dtool/depcache_control.py b/wbia/dtool/depcache_control.py index 66581e3876..8c37a0dc72 100644 --- a/wbia/dtool/depcache_control.py +++ b/wbia/dtool/depcache_control.py @@ -148,14 +148,14 @@ def _register_prop( if requestclass is not None: self.requestclass_dict[tablename] = requestclass self.fname_to_db[fname] = None - table = depcache_table.DependencyCacheTable( - depc=self, + table = depcache_table.DependencyCacheTable.from_name( + fname, + tablename, + self, parent_tablenames=parents, - tablename=tablename, data_colnames=colnames, data_coltypes=coltypes, preproc_func=preproc_func, - fname=fname, default_to_unpack=default_to_unpack, **kwargs, ) diff --git a/wbia/dtool/depcache_table.py b/wbia/dtool/depcache_table.py index be28ddc71a..35a6520dd3 100644 --- a/wbia/dtool/depcache_table.py +++ b/wbia/dtool/depcache_table.py @@ -1877,7 +1877,7 @@ def __init__( self.data_colnames = tuple(data_colnames) self.data_coltypes = data_coltypes self.preproc_func = preproc_func - self.fname = fname + self._db_name = fname # Behavior self.on_delete = None self.default_to_unpack = default_to_unpack @@ -1908,20 +1908,93 @@ def __init__( if ut.SUPER_STRICT: self._assert_self() + # ??? Clearly a hack, but to what end? self._hack_chunk_cache = None + @classmethod + def from_name( + cls, + db_name, + table_name, + depcache_controller, + parent_tablenames=None, + data_colnames=None, + data_coltypes=None, + preproc_func=None, + docstr='no docstr', + asobject=False, + chunksize=None, + isinteractive=False, + default_to_unpack=False, + default_onthefly=False, + rm_extern_on_delete=False, + vectorized=True, + taggable=False, + ): + """Build the instance based on a database and table name.""" + self = cls.__new__(cls) + + self._db_name = db_name + + # Set the table name + if re.search('[0-9]', table_name): + raise ValueError(f"tablename = '{table_name}' cannot contain numbers") + self.tablename = table_name + + # Set the parent depcache controller + self.depc = depcache_controller + + # Definitions + self.docstr = docstr + self.parent_tablenames = parent_tablenames + self.data_colnames = tuple(data_colnames) + self.data_coltypes = data_coltypes + self.preproc_func = preproc_func + + # Behavior + self.on_delete = None + self.default_to_unpack = default_to_unpack + self.vectorized = vectorized + self.taggable = taggable + + # XXX (20-Oct-12020) It's not clear if these attributes are absolutely necessary. + self.chunksize = chunksize + # Developmental properties + self.subproperties = {} + self.isinteractive = isinteractive + self._asobject = asobject + self.default_onthefly = default_onthefly + # SQL Internals + self.sqldb_fpath = None + self.rm_extern_on_delete = rm_extern_on_delete + # Update internals + self.parent_col_attrs = self._infer_parentcol() + self.data_col_attrs = self._infer_datacol() + self.internal_col_attrs = self._infer_allcol() + # /XXX + + # Check for errors + # FIXME (20-Oct-12020) This seems like a bad idea... + # Why would you sometimes want to check and not at other times? + if ut.SUPER_STRICT: + self._assert_self() + + # ??? Clearly a hack, but to what end? + self._hack_chunk_cache = None + + return self + @property - def db_name(self): - """Name of the database this Table belongs to""" - # FIXME (20-Oct-12020) 'fname' is the legacy name, but can't be changed just yet - # as its fairly intertwined in the registration decorator code. - return self.fname + def fname(self): + """Backwards compatible name of the database this Table belongs to""" + # BBB (20-Oct-12020) 'fname' is the legacy name for the database name + return self._db_name def initialize(self, _debug=None): """ Ensures the SQL schema for this cache table """ - self.db = self.depc.get_db_by_name(self.db_name) + self.db = self.depc.get_db_by_name(self._db_name) # logger.info('Checking sql for table=%r' % (self.tablename,)) if not self.db.has_table(self.tablename): logger.debug('Initializing table=%r' % (self.tablename,)) From 89112701758779745d792337b27c900d22810dd2 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Tue, 20 Oct 2020 13:20:22 -0700 Subject: [PATCH 094/294] Remove unused attribute sqldb_fpath It's set to None and tests pass without it. Gone. --- wbia/dtool/depcache_table.py | 1 - 1 file changed, 1 deletion(-) diff --git a/wbia/dtool/depcache_table.py b/wbia/dtool/depcache_table.py index 35a6520dd3..18c166960d 100644 --- a/wbia/dtool/depcache_table.py +++ b/wbia/dtool/depcache_table.py @@ -1965,7 +1965,6 @@ def from_name( self._asobject = asobject self.default_onthefly = default_onthefly # SQL Internals - self.sqldb_fpath = None self.rm_extern_on_delete = rm_extern_on_delete # Update internals self.parent_col_attrs = self._infer_parentcol() From bca59c557db8f717a74840552f9e51992051d4e0 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Tue, 20 Oct 2020 13:23:26 -0700 Subject: [PATCH 095/294] Document DepCacheTable.rm_extern_on_delete --- wbia/dtool/depcache_table.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wbia/dtool/depcache_table.py b/wbia/dtool/depcache_table.py index 18c166960d..eee1a386db 100644 --- a/wbia/dtool/depcache_table.py +++ b/wbia/dtool/depcache_table.py @@ -1956,6 +1956,8 @@ def from_name( self.default_to_unpack = default_to_unpack self.vectorized = vectorized self.taggable = taggable + #: Flag to enable the deletion of external files on associated SQL row deletion. + self.rm_extern_on_delete = rm_extern_on_delete # XXX (20-Oct-12020) It's not clear if these attributes are absolutely necessary. self.chunksize = chunksize @@ -1964,8 +1966,6 @@ def from_name( self.isinteractive = isinteractive self._asobject = asobject self.default_onthefly = default_onthefly - # SQL Internals - self.rm_extern_on_delete = rm_extern_on_delete # Update internals self.parent_col_attrs = self._infer_parentcol() self.data_col_attrs = self._infer_datacol() From 00b2376b47f85a60ad338aad0dcaccc56b7d39eb Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Tue, 20 Oct 2020 13:29:41 -0700 Subject: [PATCH 096/294] Document DepCacheTable.chunksize --- wbia/dtool/depcache_table.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/wbia/dtool/depcache_table.py b/wbia/dtool/depcache_table.py index eee1a386db..34a383876f 100644 --- a/wbia/dtool/depcache_table.py +++ b/wbia/dtool/depcache_table.py @@ -1950,6 +1950,8 @@ def from_name( self.data_colnames = tuple(data_colnames) self.data_coltypes = data_coltypes self.preproc_func = preproc_func + #: Optional specification of the amount of blobs to modify in one SQL operation + self.chunksize = chunksize # Behavior self.on_delete = None @@ -1960,7 +1962,6 @@ def from_name( self.rm_extern_on_delete = rm_extern_on_delete # XXX (20-Oct-12020) It's not clear if these attributes are absolutely necessary. - self.chunksize = chunksize # Developmental properties self.subproperties = {} self.isinteractive = isinteractive From 332c5f17edf13f02ef6494b9c9a780579ce5109f Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Tue, 20 Oct 2020 13:36:07 -0700 Subject: [PATCH 097/294] Remove initialization of "development properties" These "development properties" appear to be largely unused. Thus, they don't need to be hanging around. --- wbia/dtool/depcache_table.py | 42 ++++++------------------------------ 1 file changed, 7 insertions(+), 35 deletions(-) diff --git a/wbia/dtool/depcache_table.py b/wbia/dtool/depcache_table.py index 34a383876f..9a1defc607 100644 --- a/wbia/dtool/depcache_table.py +++ b/wbia/dtool/depcache_table.py @@ -500,8 +500,6 @@ def print_info(self, with_colattrs=True, with_graphattrs=True): """ debug function """ logger.info('TABLE ATTRIBUTES') logger.info('self.tablename = %r' % (self.tablename,)) - logger.info('self.isinteractive = %r' % (self.isinteractive,)) - logger.info('self.default_onthefly = %r' % (self.default_onthefly,)) logger.info('self.rm_extern_on_delete = %r' % (self.rm_extern_on_delete,)) logger.info('self.chunksize = %r' % (self.chunksize,)) logger.info('self.fname = %r' % (self.fname,)) @@ -1736,18 +1734,6 @@ def _chunk_compute_dirty_rows( ) # These are the colnames that we expect to be computed colnames = self.computable_colnames() - # def unfinished_features(): - # if self._asobject: - # # Convinience - # argsT = [self.depc.get_obj(parent, rowids) - # for parent, rowids in zip(self.parents(), - # dirty_parent_ids_chunk)] - # onthefly = None - # if self.default_onthefly or onthefly: - # assert not self.ismulti, ('cannot onthefly multi tables') - # proptup_gen = [tuple([None] * len(self.data_col_attrs)) - # for _ in range(len(dirty_parent_ids_chunk))] - # pass # CALL EXTERNAL PREPROCESSING / GENERATION FUNCTION try: # prog_iter = list(prog_iter) @@ -1850,9 +1836,9 @@ def __init__( fname=None, asobject=False, chunksize=None, - isinteractive=False, + isinteractive=False, # no-op default_to_unpack=False, - default_onthefly=False, + default_onthefly=False, # no-op rm_extern_on_delete=False, vectorized=True, taggable=False, @@ -1891,11 +1877,6 @@ def __init__( # self.store_delete_time = True self.chunksize = chunksize - # Developmental properties - self.subproperties = {} - self.isinteractive = isinteractive - self._asobject = asobject - self.default_onthefly = default_onthefly # SQL Internals self.sqldb_fpath = None self.rm_extern_on_delete = rm_extern_on_delete @@ -1924,9 +1905,7 @@ def from_name( docstr='no docstr', asobject=False, chunksize=None, - isinteractive=False, default_to_unpack=False, - default_onthefly=False, rm_extern_on_delete=False, vectorized=True, taggable=False, @@ -1953,6 +1932,11 @@ def from_name( #: Optional specification of the amount of blobs to modify in one SQL operation self.chunksize = chunksize + # FIXME (20-Oct-12020) This definition of behavior by external means is a scope issue + # Another object should not be directly manipulating this object. + #: functions defined and populated through DependencyCache._register_subprop + self.subproperties = {} + # Behavior self.on_delete = None self.default_to_unpack = default_to_unpack @@ -1962,11 +1946,6 @@ def from_name( self.rm_extern_on_delete = rm_extern_on_delete # XXX (20-Oct-12020) It's not clear if these attributes are absolutely necessary. - # Developmental properties - self.subproperties = {} - self.isinteractive = isinteractive - self._asobject = asobject - self.default_onthefly = default_onthefly # Update internals self.parent_col_attrs = self._infer_parentcol() self.data_col_attrs = self._infer_datacol() @@ -2721,12 +2700,6 @@ def get_row_data( idxs2 = ut.index_complement(idxs1, len(tbl_rowids)) - #### - # Read data stored in SQL - # FIXME: understand unpack_scalars and keepwrap - # if self.default_onthefly: - # self._onthefly_dataget - # else: if nInput is None and ut.is_listlike(nonNone_tbl_rowids): nInput = len(nonNone_tbl_rowids) @@ -2973,7 +2946,6 @@ def _recompute_and_store(self, tbl_rowids, config=None): # Evaulate to external and internal storage self.db.set(self.tablename, colnames, dirty_params_iter, rowids) - # _onthefly_dataget # togroup_args = [parent_rowids] # grouped_parent_ids = ut.apply_grouping(parent_rowids, groupxs) # unique_args_list = [unique_configs] From 45ae22ecc0e47df88f7725253d7ca3c7a1ebf638 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Tue, 20 Oct 2020 15:30:49 -0700 Subject: [PATCH 098/294] Change fname_to_db to a private attribute _db_by_name This moves away from the filename based db structure. --- wbia/dtool/depcache_control.py | 45 +++++++++++++++++++--------------- wbia/dtool/example_depcache.py | 4 +-- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/wbia/dtool/depcache_control.py b/wbia/dtool/depcache_control.py index 8c37a0dc72..cf094e9318 100644 --- a/wbia/dtool/depcache_control.py +++ b/wbia/dtool/depcache_control.py @@ -135,6 +135,8 @@ def _register_prop( raise ValueError('must specify coltypes of %s' % (tablename,)) coltypes = [np.ndarray] * len(colnames) if fname is None: + # FIXME (20-Oct-12020) Base class doesn't define this property + # and thus expects the subclass to know it should be assigned. fname = self.default_fname if configclass is None: # Make a default config with no parameters @@ -147,7 +149,7 @@ def _register_prop( # Register a new table and configuration if requestclass is not None: self.requestclass_dict[tablename] = requestclass - self.fname_to_db[fname] = None + self._db_by_name.setdefault(fname, None) table = depcache_table.DependencyCacheTable.from_name( fname, tablename, @@ -182,18 +184,16 @@ def _register_subprop(self, tablename, propname=None, preproc_func=None): table = self.cachetable_dict[tablename] table.subproperties[propname] = preproc_func - def close(self): - """ - Close all managed SQL databases - """ - for fname, db in self.fname_to_db.items(): - db.close() - def get_db_by_name(self, name): """Get the database (i.e. SQLController) for the given database name""" # FIXME (20-Oct-12020) Currently handled via a mapping of 'fname' # to database controller objects. - return self.fname_to_db[name] + return self._db_by_name[name] + + def close(self): + """Close all managed SQL databases""" + for db_inst in self._db_by_name.values(): + db_inst.close() @profile def initialize(self, _debug=None): @@ -222,7 +222,7 @@ def initialize(self, _debug=None): # # http://docs.pyfilesystem.org/en/latest/getting_started.html # pip install fs - for fname in self.fname_to_db.keys(): + for fname in self._db_by_name.keys(): if fname == ':memory:': db_uri = 'sqlite:///:memory:' else: @@ -238,7 +238,7 @@ def initialize(self, _debug=None): # ut.delete(fpath) db = sql_control.SQLDatabaseController.from_uri(db_uri) depcache_table.ensure_config_table(db) - self.fname_to_db[fname] = db + self._db_by_name[fname] = db logger.info('[depc] Finished initialization') for table in self.cachetable_dict.values(): @@ -1048,6 +1048,7 @@ def __init__( use_globals=True, ): if default_fname is None: + # ??? So 'None_primary_cache' is a good name? default_fname = root_tablename + '_primary_cache' # default_fname = ':memory:' self.root_getters = root_getters @@ -1062,8 +1063,10 @@ def __init__( self.configclass_dict = {} self.requestclass_dict = {} self.resultclass_dict = {} - # Mapping of different files properties are stored in - self.fname_to_db = {} + # Mapping of database connections by name + # - names populated by _CoreDependencyCache._register_prop + # - values populated by _CoreDependencyCache.initialize + self._db_by_name = {} # Function to map a root rowid to an object # self._root_asobject = root_asobject self._use_globals = use_globals @@ -1116,8 +1119,10 @@ def as_named( # XXX Directory all cachefiles are stored in self.cache_dpath = ut.truepath(self.controller.get_cachedir()) - # XXX Mapping of different files properties are stored in - self.fname_to_db = {} + # Mapping of database connections by name + # - names populated by _CoreDependencyCache._register_prop + # - values populated by _CoreDependencyCache.initialize + self._db_by_name = {} # Function to map a root rowid to an object self._use_globals = use_globals @@ -1144,8 +1149,8 @@ def tablenames(self): return self.get_tablenames() def print_schemas(self): - for fname, db in self.fname_to_db.items(): - logger.info('fname = %r' % (fname,)) + for name, db in self._db_by_name.items(): + logger.info('name = %r' % (name,)) db.print_schema() # def print_table_csv(self, tablename): @@ -1161,10 +1166,10 @@ def print_all_tables(self): # db.print_table_csv(tablename) def print_config_tables(self): - for fname in self.fname_to_db: + for name in self._db_by_name: logger.info('---') - logger.info('db_fname = %r' % (fname,)) - self.fname_to_db[fname].print_table_csv('config') + logger.info('db_name = %r' % (name,)) + self._db_by_name[name].print_table_csv('config') def get_edges(self, data=False): """ diff --git a/wbia/dtool/example_depcache.py b/wbia/dtool/example_depcache.py index 3fe744e5c6..46a3ca28b2 100644 --- a/wbia/dtool/example_depcache.py +++ b/wbia/dtool/example_depcache.py @@ -725,8 +725,8 @@ def dummy_example_depcacahe(): req.execute() # ut.InstanceList( - db = list(depc.fname_to_db.values())[0] - # db_list = ut.InstanceList(depc.fname_to_db.values()) + db = list(depc._db_by_name.values())[0] + # db_list = ut.InstanceList(depc._db_by_name.values()) # db_list.print_table_csv('config', exclude_columns='config_strid') print('config table') From 9f445a0f5e821358cb8b0f23966ee8aabfb54740 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Tue, 20 Oct 2020 16:13:29 -0700 Subject: [PATCH 099/294] Remove mention of sqlite :memory: --- wbia/dtool/depcache_control.py | 20 ++++++++------------ wbia/dtool/sql_control.py | 1 - 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/wbia/dtool/depcache_control.py b/wbia/dtool/depcache_control.py index cf094e9318..2849e6a39c 100644 --- a/wbia/dtool/depcache_control.py +++ b/wbia/dtool/depcache_control.py @@ -223,17 +223,14 @@ def initialize(self, _debug=None): # pip install fs for fname in self._db_by_name.keys(): - if fname == ':memory:': - db_uri = 'sqlite:///:memory:' - else: - fname_ = ut.ensure_ext(fname, '.sqlite') - from os.path import dirname - - prefix_dpath = dirname(fname_) - if prefix_dpath: - ut.ensuredir(ut.unixjoin(self.cache_dpath, prefix_dpath)) - fpath = ut.unixjoin(self.cache_dpath, fname_) - db_uri = 'sqlite:///{}'.format(os.path.realpath(fpath)) + fname_ = ut.ensure_ext(fname, '.sqlite') + from os.path import dirname + + prefix_dpath = dirname(fname_) + if prefix_dpath: + ut.ensuredir(ut.unixjoin(self.cache_dpath, prefix_dpath)) + fpath = ut.unixjoin(self.cache_dpath, fname_) + db_uri = 'sqlite:///{}'.format(os.path.realpath(fpath)) # if ut.get_argflag('--clear-all-depcache'): # ut.delete(fpath) db = sql_control.SQLDatabaseController.from_uri(db_uri) @@ -1050,7 +1047,6 @@ def __init__( if default_fname is None: # ??? So 'None_primary_cache' is a good name? default_fname = root_tablename + '_primary_cache' - # default_fname = ':memory:' self.root_getters = root_getters # Root of all dependencies self.root_tablename = root_tablename diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index e237133c14..0d0d20124f 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -514,7 +514,6 @@ def _create_connection(self): raise AssertionError('Cannot open a new database in readonly mode') # Open the SQL database connection with support for custom types # lite.enable_callback_tracebacks(True) - # self.fpath = ':memory:' # References: # http://stackoverflow.com/questions/10205744/opening-sqlite3-database-from-python-in-read-only-mode From 976d19156409b03ade30ce699fd457451856c6d6 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Tue, 20 Oct 2020 20:19:12 -0700 Subject: [PATCH 100/294] Add a method to create a cache db URI The `make_cache_db_uri` method helps pull the responsiblity of uri creation into the main controller. Now the caches can be blissfully ignorant of the details of what database it's connecting to. --- wbia/control/IBEISControl.py | 4 ++++ wbia/dtool/depcache_control.py | 30 ++++++++---------------------- 2 files changed, 12 insertions(+), 22 deletions(-) diff --git a/wbia/control/IBEISControl.py b/wbia/control/IBEISControl.py index 84eb360049..15dd5a71e1 100644 --- a/wbia/control/IBEISControl.py +++ b/wbia/control/IBEISControl.py @@ -403,6 +403,10 @@ def base_uri(self): """Base database URI without a specific database name""" return self._base_uri + def make_cache_db_uri(self, name): + """Given a name of the cache produce a database connection URI""" + return f'sqlite:///{self.get_cachedir()}/{name}.sqlite' + def reset_table_cache(self): self.table_cache = accessor_decors.init_tablecache() diff --git a/wbia/dtool/depcache_control.py b/wbia/dtool/depcache_control.py index 2849e6a39c..4acbd54c8f 100644 --- a/wbia/dtool/depcache_control.py +++ b/wbia/dtool/depcache_control.py @@ -3,7 +3,6 @@ implicit version of dependency cache from wbia/templates/template_generator """ import logging -import os.path import utool as ut import numpy as np @@ -215,28 +214,15 @@ def initialize(self, _debug=None): for args_, _kwargs in reg_subprop: self._register_subprop(*args_, **_kwargs) - ut.ensuredir(self.cache_dpath) - - # Memory filestore - # if False: - # # http://docs.pyfilesystem.org/en/latest/getting_started.html - # pip install fs - - for fname in self._db_by_name.keys(): - fname_ = ut.ensure_ext(fname, '.sqlite') - from os.path import dirname - - prefix_dpath = dirname(fname_) - if prefix_dpath: - ut.ensuredir(ut.unixjoin(self.cache_dpath, prefix_dpath)) - fpath = ut.unixjoin(self.cache_dpath, fname_) - db_uri = 'sqlite:///{}'.format(os.path.realpath(fpath)) - # if ut.get_argflag('--clear-all-depcache'): - # ut.delete(fpath) - db = sql_control.SQLDatabaseController.from_uri(db_uri) + for name in self._db_by_name.keys(): + # FIXME (20-Oct-12020) 'smk/smk_agg_rvecs' is known to have issues. + # Either fix the name or find a better normalizer/slugifier. + normalized_name = name.replace('/', '__') + uri = self.controller.make_cache_db_uri(normalized_name) + db = sql_control.SQLDatabaseController.from_uri(uri) + # ??? This seems out of place. Shouldn't this be within the depcachetable instance? depcache_table.ensure_config_table(db) - self._db_by_name[fname] = db - logger.info('[depc] Finished initialization') + self._db_by_name[name] = db for table in self.cachetable_dict.values(): table.initialize() From aa44ff7d08274ce821104e742c60f561ec6f6482 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Tue, 20 Oct 2020 20:21:28 -0700 Subject: [PATCH 101/294] Refactor DependencyCache initialization The 'd' and 'w' attributes are a mystery that I've yet to crack. --- wbia/dtool/depcache_control.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/wbia/dtool/depcache_control.py b/wbia/dtool/depcache_control.py index 4acbd54c8f..003bd45c52 100644 --- a/wbia/dtool/depcache_control.py +++ b/wbia/dtool/depcache_control.py @@ -199,9 +199,6 @@ def initialize(self, _debug=None): """ Creates all registered tables """ - logger.info( - '[depc] Initialize %s depcache in %r' % (self.root.upper(), self.cache_dpath) - ) if self._use_globals: reg_preproc = PREPROC_REGISTER[self.root] reg_subprop = SUBPROP_REGISTER[self.root] @@ -232,10 +229,9 @@ def initialize(self, _debug=None): class InjectedDepc(object): pass + # ??? What's the significance of 'd' and 'w'? Do the mean anything? self.d = InjectedDepc() self.w = InjectedDepc() - d = self.d - w = self.w inject_patterns = [ ('get_{tablename}_rowids', self.get_rowids), ('get_{tablename}_config_history', self.get_config_history), @@ -243,20 +239,20 @@ class InjectedDepc(object): for table in self.cachetable_dict.values(): wobj = InjectedDepc() # Set nested version - setattr(w, table.tablename, wobj) + setattr(self.w, table.tablename, wobj) for dfmtstr, func in inject_patterns: funcname = ut.get_funcname(func) attrname = dfmtstr.format(tablename=table.tablename) - get_rowids = ut.partial(func, table.tablename) + partial_func = ut.partial(func, table.tablename) # Set flat version - setattr(d, attrname, get_rowids) + setattr(self.d, attrname, partial_func) setattr(wobj, funcname, func) dfmtstr = 'get_{tablename}_{colname}' for colname in table.data_colnames: get_prop = ut.partial(self.get, table.tablename, colnames=colname) attrname = dfmtstr.format(tablename=table.tablename, colname=colname) # Set flat version - setattr(d, attrname, get_prop) + setattr(self.d, attrname, get_prop) setattr(wobj, 'get_' + colname, get_prop) # ----------------------------- From 1f93e67b7ca91b1866511990ead9880265c9403a Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Tue, 20 Oct 2020 22:42:10 -0700 Subject: [PATCH 102/294] Collapse _CoreDependencyCache into DependencyCache There was no benefit to these classes being separate. The lack of constructor on `_CoreDependencyCache` with the dependence on attributes defined by `DependencyCache` was bothering me. Rather than dance around the issue, I'm just making them one class. I'm not sure it'd be easy to subclass `_CoreDependencyCache` in that state anyhow. Because the new `as_named` constructor depends on the controller, I've added a `DummyController` to all the example_depcache* cases. This implements enough of the controller to instantiate a depcache. My hope is that this dependence won't last, but is a bridge to a different type of solution moving forward. --- wbia/dtool/depcache_control.py | 223 +++++++++++++++----------------- wbia/dtool/example_depcache.py | 29 ++++- wbia/dtool/example_depcache2.py | 49 +++---- 3 files changed, 156 insertions(+), 145 deletions(-) diff --git a/wbia/dtool/depcache_control.py b/wbia/dtool/depcache_control.py index 003bd45c52..d70ae2468d 100644 --- a/wbia/dtool/depcache_control.py +++ b/wbia/dtool/depcache_control.py @@ -83,11 +83,111 @@ def _wrapper(func): return _depcdecors -class _CoreDependencyCache(object): - """ - Core worker functions for the depcache - Inherited by a calss with some "nice extras - """ +class DependencyCache: + def __init__( + self, + root_tablename=None, + cache_dpath='./DEPCACHE', + controller=None, + default_fname=None, + # root_asobject=None, + get_root_uuid=None, + root_getters=None, + use_globals=True, + ): + if default_fname is None: + # ??? So 'None_primary_cache' is a good name? + default_fname = root_tablename + '_primary_cache' + self.root_getters = root_getters + # Root of all dependencies + self.root_tablename = root_tablename + self.name = root_tablename + # Directory all cachefiles are stored in + self.cache_dpath = ut.truepath(cache_dpath) + # Parent (ibs) controller + self.controller = controller + # Internal dictionary of dependant tables + self.cachetable_dict = {} + self.configclass_dict = {} + self.requestclass_dict = {} + self.resultclass_dict = {} + # Mapping of database connections by name + # - names populated by _CoreDependencyCache._register_prop + # - values populated by _CoreDependencyCache.initialize + self._db_by_name = {} + # Function to map a root rowid to an object + # self._root_asobject = root_asobject + self._use_globals = use_globals + self.default_fname = default_fname + if get_root_uuid is None: + logger.info('WARNING NEED UUID FUNCTION') + # HACK + get_root_uuid = ut.identity + self.get_root_uuid = get_root_uuid + self.delete_exclude_tables = {} + # BBB (25-Sept-12020) `_debug` remains around to be backwards compatible + self._debug = False + + @classmethod + def as_named( + cls, + controller, + name, + get_root_uuid, + table_name=None, + root_getters=None, + use_globals=True, + ): + """ + Args: + controller (IBEISController): main controller + name (str): name of this controller instance, which is used in naming the data storage + table_name (str): (optional) if not the same as the 'name' + get_root_uuid: ??? + root_getters: ??? + use_globals (bool): ??? (default: True) + + """ + if table_name is None: + table_name = name + + self = cls.__new__(cls) + + self.name = name + # Parent (ibs) controller + self.controller = controller + # Internal dictionary of dependant tables + self.cachetable_dict = {} + self.configclass_dict = {} + self.requestclass_dict = {} + self.resultclass_dict = {} + + self.root_getters = root_getters + # Root of all dependencies + self.root_tablename = table_name + # XXX Directory all cachefiles are stored in + self.cache_dpath = self.controller.get_cachedir() + + # Mapping of database connections by name + # - names populated by _register_prop + # - values populated by initialize + self._db_by_name = {} + + # Function to map a root rowid to an object + self._use_globals = use_globals + + # FIXME (20-Oct-12020) remove filesystem name + self.default_fname = f'{table_name}_cache' + + self.get_root_uuid = get_root_uuid + self.delete_exclude_tables = {} + # BBB (25-Sept-12020) `_debug` remains around to be backwards compatible + self._debug = False + + return self + + def __repr__(self): + return f'' @profile def _register_prop( @@ -1006,115 +1106,6 @@ def delete_property_all(self, tablename, root_rowids, _debug=False): num_deleted = table.delete_rows(rowid_list) return num_deleted - -@six.add_metaclass(ut.ReloadingMetaclass) -class DependencyCache(_CoreDependencyCache, ut.NiceRepr): - """ - Currently, to use this class a user must: - * on root modification, call depc.on_root_modified - * use decorators to register relevant functions - """ - - def __init__( - self, - root_tablename=None, - cache_dpath='./DEPCACHE', - controller=None, - default_fname=None, - # root_asobject=None, - get_root_uuid=None, - root_getters=None, - use_globals=True, - ): - if default_fname is None: - # ??? So 'None_primary_cache' is a good name? - default_fname = root_tablename + '_primary_cache' - self.root_getters = root_getters - # Root of all dependencies - self.root_tablename = root_tablename - # Directory all cachefiles are stored in - self.cache_dpath = ut.truepath(cache_dpath) - # Parent (ibs) controller - self.controller = controller - # Internal dictionary of dependant tables - self.cachetable_dict = {} - self.configclass_dict = {} - self.requestclass_dict = {} - self.resultclass_dict = {} - # Mapping of database connections by name - # - names populated by _CoreDependencyCache._register_prop - # - values populated by _CoreDependencyCache.initialize - self._db_by_name = {} - # Function to map a root rowid to an object - # self._root_asobject = root_asobject - self._use_globals = use_globals - self.default_fname = default_fname - if get_root_uuid is None: - logger.info('WARNING NEED UUID FUNCTION') - # HACK - get_root_uuid = ut.identity - self.get_root_uuid = get_root_uuid - self.delete_exclude_tables = {} - # BBB (25-Sept-12020) `_debug` remains around to be backwards compatible - self._debug = False - - @classmethod - def as_named( - cls, - controller, - name, - get_root_uuid, - table_name=None, - root_getters=None, - use_globals=True, - ): - """ - Args: - controller (IBEISController): main controller - name (str): name of this controller instance, which is used in naming the data storage - table_name (str): (optional) if not the same as the 'name' - get_root_uuid: ??? - root_getters: ??? - use_globals (bool): ??? (default: True) - - """ - if table_name is None: - table_name = name - - self = cls.__new__(cls) - - # Parent (ibs) controller - self.controller = controller - # Internal dictionary of dependant tables - self.cachetable_dict = {} - self.configclass_dict = {} - self.requestclass_dict = {} - self.resultclass_dict = {} - - self.root_getters = root_getters - # Root of all dependencies - self.root_tablename = table_name - # XXX Directory all cachefiles are stored in - self.cache_dpath = ut.truepath(self.controller.get_cachedir()) - - # Mapping of database connections by name - # - names populated by _CoreDependencyCache._register_prop - # - values populated by _CoreDependencyCache.initialize - self._db_by_name = {} - - # Function to map a root rowid to an object - self._use_globals = use_globals - - # XXX remove filesystem name - self.default_fname = f'{table_name}_cache' - - self.get_root_uuid = get_root_uuid - self.delete_exclude_tables = {} - # BBB (25-Sept-12020) `_debug` remains around to be backwards compatible - self._debug = False - - return self - def get_tablenames(self): return list(self.cachetable_dict.keys()) @@ -1641,8 +1632,8 @@ def get_uuids(self, tablename, root_rowids, config=None): uuid_list = self.get_root_uuid(root_rowids) return uuid_list - get_native_property = _CoreDependencyCache.get_native - get_property = _CoreDependencyCache.get + get_native_property = get_native + get_property = get def stacked_config(self, source, dest, config): r""" diff --git a/wbia/dtool/example_depcache.py b/wbia/dtool/example_depcache.py index 46a3ca28b2..931b5ac108 100644 --- a/wbia/dtool/example_depcache.py +++ b/wbia/dtool/example_depcache.py @@ -34,6 +34,19 @@ def dummy_global_preproc_func(depc, parent_rowids, config=None): yield 'dummy' +class DummyController: + """Just enough (IBEIS) controller to make the dependency cache examples work""" + + def __init__(self, cache_dpath): + self.cache_dpath = cache_dpath + + def make_cache_db_uri(self, name): + return f"sqlite:///{self.cache_dpath}/{name}.sqlite" + + def get_cachedir(self): + return self.cache_dpath + + class DummyKptsConfig(dtool.Config): def get_param_info_list(self): return [ @@ -224,12 +237,16 @@ def get_root_uuid(aid_list): dtool_repo = dirname(ut.get_module_dir(dtool)) cache_dpath = join(dtool_repo, 'DEPCACHE') - depc = dtool.DependencyCache( - root_tablename=dummy_root, - default_fname=fname, - cache_dpath=cache_dpath, - get_root_uuid=get_root_uuid, - # root_asobject=root_asobject, + if not fname: + fname = dummy_root + + controller = DummyController(cache_dpath) + depc = dtool.DependencyCache.as_named( + controller, + fname, + get_root_uuid, + table_name=dummy_root, + root_getters=None, use_globals=False, ) diff --git a/wbia/dtool/example_depcache2.py b/wbia/dtool/example_depcache2.py index 27aa5af07f..6b146d8154 100644 --- a/wbia/dtool/example_depcache2.py +++ b/wbia/dtool/example_depcache2.py @@ -86,20 +86,20 @@ def testdata_depc3(in_memory=True): >>> ut.show_if_requested() """ from wbia import dtool + from wbia.dtool.example_depcache import DummyController # put the test cache in the dtool repo dtool_repo = dirname(ut.get_module_dir(dtool)) cache_dpath = join(dtool_repo, 'DEPCACHE3') - # FIXME: this only puts the sql files in memory - default_fname = ':memory:' if in_memory else None - + controller = DummyController(cache_dpath) root = 'annot' - depc = dtool.DependencyCache( - root_tablename=root, - get_root_uuid=ut.identity, - default_fname=default_fname, - cache_dpath=cache_dpath, + depc = dtool.DependencyCache.as_named( + controller, + root, + ut.identity, + table_name=None, + root_getters=None, use_globals=False, ) @@ -156,20 +156,20 @@ def testdata_depc4(in_memory=True): >>> ut.show_if_requested() """ from wbia import dtool + from wbia.dtool.example_depcache import DummyController # put the test cache in the dtool repo dtool_repo = dirname(ut.get_module_dir(dtool)) cache_dpath = join(dtool_repo, 'DEPCACHE3') - # FIXME: this only puts the sql files in memory - default_fname = ':memory:' if in_memory else None - + controller = DummyController(cache_dpath) root = 'annot' - depc = dtool.DependencyCache( - root_tablename=root, - get_root_uuid=ut.identity, - default_fname=default_fname, - cache_dpath=cache_dpath, + depc = dtool.DependencyCache.as_named( + controller, + root, + ut.identity, + table_name=None, + root_getters=None, use_globals=False, ) @@ -202,20 +202,23 @@ def testdata_depc4(in_memory=True): def testdata_custom_annot_depc(dummy_dependencies, in_memory=True): from wbia import dtool + from wbia.dtool.example_depcache import DummyController # put the test cache in the dtool repo dtool_repo = dirname(ut.get_module_dir(dtool)) cache_dpath = join(dtool_repo, 'DEPCACHE5') - # FIXME: this only puts the sql files in memory - default_fname = ':memory:' if in_memory else None + + controller = DummyController(cache_dpath) root = 'annot' - depc = dtool.DependencyCache( - root_tablename=root, - get_root_uuid=ut.identity, - default_fname=default_fname, - cache_dpath=cache_dpath, + depc = dtool.DependencyCache.as_named( + controller, + root, + ut.identity, + table_name=None, + root_getters=None, use_globals=False, ) + # ---------- register_dummy_config = depc_34_helper(depc) From a61e2653c8512e888ffa9ce804af45c0583b7932 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Tue, 20 Oct 2020 23:00:57 -0700 Subject: [PATCH 103/294] Refactor depcache example's DependencyCache creation The last commit's copy&paste make this an obvious necessity. --- wbia/dtool/example_depcache2.py | 86 ++++++++++++--------------------- 1 file changed, 32 insertions(+), 54 deletions(-) diff --git a/wbia/dtool/example_depcache2.py b/wbia/dtool/example_depcache2.py index 6b146d8154..ceb7448d2e 100644 --- a/wbia/dtool/example_depcache2.py +++ b/wbia/dtool/example_depcache2.py @@ -1,10 +1,36 @@ # -*- coding: utf-8 -*- -import utool as ut +from pathlib import Path -# import numpy as np -from os.path import join, dirname +import utool as ut from six.moves import zip +from wbia.dtool.depcache_control import DependencyCache +from wbia.dtool.example_depcache import DummyController + + +HERE = Path(__file__).parent.resolve() + + +def _depc_factory(name, cache_dir): + """DependencyCache factory for the examples + + Args: + name (str): name of the cache (e.g. 'annot') + cache_dir (str): name of the cache directory + + """ + cache_dpath = HERE / cache_dir + controller = DummyController(cache_dpath) + depc = DependencyCache.as_named( + controller, + name, + ut.identity, + table_name=None, + root_getters=None, + use_globals=False, + ) + return depc + def depc_34_helper(depc): def register_dummy_config(tablename, parents, **kwargs): @@ -85,23 +111,7 @@ def testdata_depc3(in_memory=True): >>> #depc['viewpoint_classification'].show_input_graph() >>> ut.show_if_requested() """ - from wbia import dtool - from wbia.dtool.example_depcache import DummyController - - # put the test cache in the dtool repo - dtool_repo = dirname(ut.get_module_dir(dtool)) - cache_dpath = join(dtool_repo, 'DEPCACHE3') - - controller = DummyController(cache_dpath) - root = 'annot' - depc = dtool.DependencyCache.as_named( - controller, - root, - ut.identity, - table_name=None, - root_getters=None, - use_globals=False, - ) + depc = _depc_factory('annot', 'DEPCACHE3') # ---------- # dummy_cols = dict(colnames=['data'], coltypes=[np.ndarray]) @@ -155,23 +165,7 @@ def testdata_depc4(in_memory=True): >>> #depc['viewpoint_classification'].show_input_graph() >>> ut.show_if_requested() """ - from wbia import dtool - from wbia.dtool.example_depcache import DummyController - - # put the test cache in the dtool repo - dtool_repo = dirname(ut.get_module_dir(dtool)) - cache_dpath = join(dtool_repo, 'DEPCACHE3') - - controller = DummyController(cache_dpath) - root = 'annot' - depc = dtool.DependencyCache.as_named( - controller, - root, - ut.identity, - table_name=None, - root_getters=None, - use_globals=False, - ) + depc = _depc_factory('annot', 'DEPCACHE4') # ---------- # dummy_cols = dict(colnames=['data'], coltypes=[np.ndarray]) @@ -201,23 +195,7 @@ def testdata_depc4(in_memory=True): def testdata_custom_annot_depc(dummy_dependencies, in_memory=True): - from wbia import dtool - from wbia.dtool.example_depcache import DummyController - - # put the test cache in the dtool repo - dtool_repo = dirname(ut.get_module_dir(dtool)) - cache_dpath = join(dtool_repo, 'DEPCACHE5') - - controller = DummyController(cache_dpath) - root = 'annot' - depc = dtool.DependencyCache.as_named( - controller, - root, - ut.identity, - table_name=None, - root_getters=None, - use_globals=False, - ) + depc = _depc_factory('annot', 'DEPCACHE5') # ---------- register_dummy_config = depc_34_helper(depc) From fea729aed583ff70a5e2b5f25b8deb9a92151c76 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Tue, 20 Oct 2020 23:41:00 -0700 Subject: [PATCH 104/294] Ensure the example's cache directory exists --- .gitignore | 2 +- wbia/dtool/example_depcache.py | 18 +++++++++++------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index efe3950d95..c3fa622017 100644 --- a/.gitignore +++ b/.gitignore @@ -36,7 +36,7 @@ testdb*/ tmp.txt vsone.*.cPkl vsone.*.json -wbia/DEPCACHE*/ +wbia/dtool/DEPCACHE*/ # Translations *.mo diff --git a/wbia/dtool/example_depcache.py b/wbia/dtool/example_depcache.py index 931b5ac108..99a1056fe7 100644 --- a/wbia/dtool/example_depcache.py +++ b/wbia/dtool/example_depcache.py @@ -4,15 +4,21 @@ python -m dtool.example_depcache --exec-dummy_example_depcacahe --show python -m dtool.depcache_control --exec-make_graph --show """ +from pathlib import Path +from os.path import join + import utool as ut import numpy as np import uuid -from os.path import join, dirname from six.moves import zip + from wbia.dtool import depcache_control from wbia import dtool +HERE = Path(__file__).parent.resolve() + + if False: # Example of global registration DUMMY_ROOT_TABLENAME = 'dummy_annot' @@ -38,10 +44,11 @@ class DummyController: """Just enough (IBEIS) controller to make the dependency cache examples work""" def __init__(self, cache_dpath): - self.cache_dpath = cache_dpath + self.cache_dpath = Path(cache_dpath) + self.cache_dpath.mkdir(exist_ok=True) def make_cache_db_uri(self, name): - return f"sqlite:///{self.cache_dpath}/{name}.sqlite" + return f'sqlite:///{self.cache_dpath}/{name}.sqlite' def get_cachedir(self): return self.cache_dpath @@ -233,13 +240,10 @@ def testdata_depc(fname=None): def get_root_uuid(aid_list): return ut.lmap(ut.hashable_to_uuid, aid_list) - # put the test cache in the dtool repo - dtool_repo = dirname(ut.get_module_dir(dtool)) - cache_dpath = join(dtool_repo, 'DEPCACHE') - if not fname: fname = dummy_root + cache_dpath = HERE / 'DEPCACHE' controller = DummyController(cache_dpath) depc = dtool.DependencyCache.as_named( controller, From 719c9b16e9ecdd78b3c4bdecc1e9b9567fd06a48 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Wed, 21 Oct 2020 17:07:47 -0700 Subject: [PATCH 105/294] Replace the default constructor for DependencyCache Replace DependencyCache's default constructor with `as_named` classmethod. --- wbia/control/IBEISControl.py | 6 ++-- wbia/dtool/depcache_control.py | 49 --------------------------------- wbia/dtool/example_depcache.py | 2 +- wbia/dtool/example_depcache2.py | 2 +- 4 files changed, 5 insertions(+), 54 deletions(-) diff --git a/wbia/control/IBEISControl.py b/wbia/control/IBEISControl.py index 15dd5a71e1..4c1f58bc46 100644 --- a/wbia/control/IBEISControl.py +++ b/wbia/control/IBEISControl.py @@ -807,7 +807,7 @@ def _init_sqldbstaging(self, request_stagingversion=None): def _init_depcache(self): # Initialize dependency cache for images image_root_getters = {} - self.depc_image = dtool.DependencyCache.as_named( + self.depc_image = dtool.DependencyCache( self, const.IMAGE_TABLE, self.get_image_uuids, @@ -831,7 +831,7 @@ def _init_depcache(self): 'theta': self.get_annot_thetas, 'occurrence_text': self.get_annot_occurrence_text, } - self.depc_annot = dtool.DependencyCache.as_named( + self.depc_annot = dtool.DependencyCache( self, const.ANNOTATION_TABLE, self.get_annot_visual_uuids, @@ -846,7 +846,7 @@ def _init_depcache(self): # Initialize dependency cache for parts part_root_getters = {} - self.depc_part = dtool.DependencyCache.as_named( + self.depc_part = dtool.DependencyCache( self, const.PART_TABLE, self.get_part_uuids, diff --git a/wbia/dtool/depcache_control.py b/wbia/dtool/depcache_control.py index d70ae2468d..3a3a8e25e5 100644 --- a/wbia/dtool/depcache_control.py +++ b/wbia/dtool/depcache_control.py @@ -86,51 +86,6 @@ def _wrapper(func): class DependencyCache: def __init__( self, - root_tablename=None, - cache_dpath='./DEPCACHE', - controller=None, - default_fname=None, - # root_asobject=None, - get_root_uuid=None, - root_getters=None, - use_globals=True, - ): - if default_fname is None: - # ??? So 'None_primary_cache' is a good name? - default_fname = root_tablename + '_primary_cache' - self.root_getters = root_getters - # Root of all dependencies - self.root_tablename = root_tablename - self.name = root_tablename - # Directory all cachefiles are stored in - self.cache_dpath = ut.truepath(cache_dpath) - # Parent (ibs) controller - self.controller = controller - # Internal dictionary of dependant tables - self.cachetable_dict = {} - self.configclass_dict = {} - self.requestclass_dict = {} - self.resultclass_dict = {} - # Mapping of database connections by name - # - names populated by _CoreDependencyCache._register_prop - # - values populated by _CoreDependencyCache.initialize - self._db_by_name = {} - # Function to map a root rowid to an object - # self._root_asobject = root_asobject - self._use_globals = use_globals - self.default_fname = default_fname - if get_root_uuid is None: - logger.info('WARNING NEED UUID FUNCTION') - # HACK - get_root_uuid = ut.identity - self.get_root_uuid = get_root_uuid - self.delete_exclude_tables = {} - # BBB (25-Sept-12020) `_debug` remains around to be backwards compatible - self._debug = False - - @classmethod - def as_named( - cls, controller, name, get_root_uuid, @@ -151,8 +106,6 @@ def as_named( if table_name is None: table_name = name - self = cls.__new__(cls) - self.name = name # Parent (ibs) controller self.controller = controller @@ -184,8 +137,6 @@ def as_named( # BBB (25-Sept-12020) `_debug` remains around to be backwards compatible self._debug = False - return self - def __repr__(self): return f'' diff --git a/wbia/dtool/example_depcache.py b/wbia/dtool/example_depcache.py index 99a1056fe7..4da19e3af1 100644 --- a/wbia/dtool/example_depcache.py +++ b/wbia/dtool/example_depcache.py @@ -245,7 +245,7 @@ def get_root_uuid(aid_list): cache_dpath = HERE / 'DEPCACHE' controller = DummyController(cache_dpath) - depc = dtool.DependencyCache.as_named( + depc = dtool.DependencyCache( controller, fname, get_root_uuid, diff --git a/wbia/dtool/example_depcache2.py b/wbia/dtool/example_depcache2.py index ceb7448d2e..1fb243b036 100644 --- a/wbia/dtool/example_depcache2.py +++ b/wbia/dtool/example_depcache2.py @@ -21,7 +21,7 @@ def _depc_factory(name, cache_dir): """ cache_dpath = HERE / cache_dir controller = DummyController(cache_dpath) - depc = DependencyCache.as_named( + depc = DependencyCache( controller, name, ut.identity, From e2bbe43c9442f43a061cf3de8f6dd8d8d6b31100 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Mon, 30 Nov 2020 16:43:07 -0800 Subject: [PATCH 106/294] Move hsdb conversion code to script Moving this logic to a separate script. It's one less thing I'll need to ignore/stumble-over when looking through the code. --- setup.py | 1 + wbia/cli/convert_hsdb.py | 41 ++++++++++++++++++++++++++++++++++++ wbia/control/IBEISControl.py | 25 ++++++++-------------- 3 files changed, 51 insertions(+), 16 deletions(-) create mode 100644 wbia/cli/convert_hsdb.py diff --git a/setup.py b/setup.py index 38e7fb3874..d6cac9a1a5 100755 --- a/setup.py +++ b/setup.py @@ -279,6 +279,7 @@ def gen_packages_items(): entry_points="""\ [console_scripts] wbia-init-testdbs = wbia.cli.testdbs:main + wbia-convert-hsdb = wbia.cli.convert_hsdb:main """, ) diff --git a/wbia/cli/convert_hsdb.py b/wbia/cli/convert_hsdb.py new file mode 100644 index 0000000000..a0f8bf7594 --- /dev/null +++ b/wbia/cli/convert_hsdb.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +"""Script to convert hotspotter database (HSDB) to a WBIA compatible database""" +import sys + +import click + +from wbia.dbio.ingest_hsdb import is_hsdb, is_succesful_convert, convert_hsdb_to_wbia + + +@click.command() +@click.option( + '--db-dir', required=True, type=click.Path(exists=True), help='database location' +) +def main(db_dir): + """Convert hotspotter database (HSDB) to a WBIA compatible database""" + click.echo(f'⏳ working on {db_dir}') + if is_hsdb(db_dir): + click.echo('✅ confirmed hotspotter database') + else: + click.echo('❌ not a hotspotter database') + sys.exit(1) + if is_succesful_convert(db_dir): + click.echo('✅ already converted hotspotter database') + sys.exit(0) + + convert_hsdb_to_wbia( + db_dir, + ensure=True, + verbose=True, + ) + + if is_succesful_convert(db_dir): + click.echo('✅ successfully converted database') + else: + click.echo('❌ unsuccessfully converted... further investigation necessary') + sys.exit(1) + sys.exit(0) + + +if __name__ == '__main__': + main() diff --git a/wbia/control/IBEISControl.py b/wbia/control/IBEISControl.py index 4c1f58bc46..544efdf3eb 100644 --- a/wbia/control/IBEISControl.py +++ b/wbia/control/IBEISControl.py @@ -43,7 +43,6 @@ from six.moves import zip from os.path import join, split, realpath from wbia.init import sysres -from wbia.dbio import ingest_hsdb from wbia import constants as const from wbia.control import accessor_decors, controller_inject from wbia.dtool.dump import dump @@ -265,21 +264,15 @@ def request_IBEISController( if force_serial: assert ibs.force_serial, 'set use_cache=False in wbia.opendb' else: - # Convert hold hotspotter dirs if necessary - if check_hsdb and ingest_hsdb.check_unconverted_hsdb(dbdir): - ibs = ingest_hsdb.convert_hsdb_to_wbia( - dbdir, ensure=ensure, wbaddr=wbaddr, verbose=verbose - ) - else: - ibs = IBEISController( - dbdir=dbdir, - ensure=ensure, - wbaddr=wbaddr, - verbose=verbose, - force_serial=force_serial, - request_dbversion=request_dbversion, - request_stagingversion=request_stagingversion, - ) + ibs = IBEISController( + dbdir=dbdir, + ensure=ensure, + wbaddr=wbaddr, + verbose=verbose, + force_serial=force_serial, + request_dbversion=request_dbversion, + request_stagingversion=request_stagingversion, + ) __IBEIS_CONTROLLER_CACHE__[dbdir] = ibs return ibs From b3470f6de9a193792b8d2d1e47057c98c5e87127 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Mon, 30 Nov 2020 22:24:27 -0800 Subject: [PATCH 107/294] Stream all wbia logging to stderr This simply configures the wbia logger to output. We lost the logging when we switched to using the logging module. This brings it back in a temporary fashion. Better configuration support through ini based configuration or some other means to be added later. --- wbia/entry_points.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/wbia/entry_points.py b/wbia/entry_points.py index 25e5c92577..a047c1828c 100644 --- a/wbia/entry_points.py +++ b/wbia/entry_points.py @@ -73,6 +73,11 @@ def _init_wbia(dbdir=None, verbose=None, use_cache=True, web=None, **kwargs): params.parse_args() from wbia.control import IBEISControl + # Set up logging + # TODO (30-Nov-12020) This is intended to be a temporary fix to logging. + logger.setLevel(logging.DEBUG) + logger.addHandler(logging.StreamHandler()) + if verbose is None: verbose = ut.VERBOSE if verbose and NOT_QUIET: From 780b4dff74848aa6217e21ba847d1e77b8fe34d1 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Mon, 30 Nov 2020 22:36:38 -0800 Subject: [PATCH 108/294] Fix logging in controller method Just touching up the logging so that it's using logging read of utool. --- wbia/control/manual_meta_funcs.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/wbia/control/manual_meta_funcs.py b/wbia/control/manual_meta_funcs.py index 5d3bb4ba33..0b84737442 100644 --- a/wbia/control/manual_meta_funcs.py +++ b/wbia/control/manual_meta_funcs.py @@ -1011,14 +1011,10 @@ def _init_config(ibs): try: general_config = ut.load_cPkl(config_fpath, verbose=ut.VERBOSE) except IOError as ex: - if ut.VERBOSE: - ut.printex(ex, 'failed to genral load config', iswarning=True) + logger.error('*** failed to load general config', exc_info=ex) general_config = {} current_species = general_config.get('current_species', None) - if ut.VERBOSE and ut.NOT_QUIET: - logger.info( - '[_init_config] general_config.current_species = %r' % (current_species,) - ) + logger.info('[_init_config] general_config.current_species = %r' % (current_species,)) # ##### # species_list = ibs.get_database_species() From 0655be8ce709f8c4624015b75807a144d8dda1f3 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Tue, 1 Dec 2020 22:51:04 -0800 Subject: [PATCH 109/294] Refactor primary species lookup to a single query There are two things wrong with how this method is written: 1) does a shit load of database queries 2) orders and aggregates outside the database. This does 1 + (N*1) + (N*1) queries against the database, where N is the number of annotations (aids). So for the manta slim dataset this is about 97k queries for this one operation (i.e. get_primary_species). It's kinda surprising it did so well. The ordering and aggregation was done outside of the database. This effectively ignores one of the great powers of sql databases. This change reduces the operation to one query, letting the database power through joining the data, aggregating it and ordering it. --- wbia/other/ibsfuncs.py | 41 ++++++++++++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/wbia/other/ibsfuncs.py b/wbia/other/ibsfuncs.py index 54a566bd9b..69e45df7ca 100644 --- a/wbia/other/ibsfuncs.py +++ b/wbia/other/ibsfuncs.py @@ -17,6 +17,7 @@ import types import functools import re +from collections import OrderedDict from six.moves import zip, range, map, reduce from os.path import split, join, exists import numpy as np @@ -3611,15 +3612,41 @@ def get_primary_database_species(ibs, aid_list=None, speedhack=True): return 'zebra_grevys' if aid_list is None: aid_list = ibs.get_valid_aids(is_staged=None) - species_list = ibs.get_annot_species_texts(aid_list) - species_hist = ut.dict_hist(species_list) - if len(species_hist) == 0: + + annotations = ibs.db._reflect_table('annotations') + species = ibs.db._reflect_table('species') + + from sqlalchemy.sql import select, func, desc + + stmt = ( + select( + [ + species.c.species_text, + func.count(annotations.c.annot_rowid).label('num_annots'), + ] + ) + .select_from( + annotations.outerjoin( + species, annotations.c.species_rowid == species.c.species_rowid + ) + ) + .where(annotations.c.annot_rowid.in_(aid_list)) + .group_by('species_text') + .order_by(desc('num_annots')) + ) + results = ibs.db.connection.execute(stmt) + + species_count = OrderedDict() + for row in results: + species_text = row.species_text + if species_text is None: + species_text = const.UNKNOWN + species_count[species_text] = row.num_annots + + if not species_count: primary_species = const.UNKNOWN else: - frequent_species = sorted( - species_hist.items(), key=lambda item: item[1], reverse=True - ) - primary_species = frequent_species[0][0] + primary_species = species_count.popitem(last=False)[0] # FIFO return primary_species From f506a0701641e98ce75d7e7c8034f6a7cde1f33e Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Wed, 2 Dec 2020 00:04:36 -0800 Subject: [PATCH 110/294] Move the counting query to the count function Move the query from get_database_primary_species to get_database_species_count. Then use this function within get_database_primary_species. --- wbia/other/ibsfuncs.py | 67 +++++++++++++++++++++--------------------- 1 file changed, 33 insertions(+), 34 deletions(-) diff --git a/wbia/other/ibsfuncs.py b/wbia/other/ibsfuncs.py index 69e45df7ca..a6385ca0a8 100644 --- a/wbia/other/ibsfuncs.py +++ b/wbia/other/ibsfuncs.py @@ -3613,36 +3613,7 @@ def get_primary_database_species(ibs, aid_list=None, speedhack=True): if aid_list is None: aid_list = ibs.get_valid_aids(is_staged=None) - annotations = ibs.db._reflect_table('annotations') - species = ibs.db._reflect_table('species') - - from sqlalchemy.sql import select, func, desc - - stmt = ( - select( - [ - species.c.species_text, - func.count(annotations.c.annot_rowid).label('num_annots'), - ] - ) - .select_from( - annotations.outerjoin( - species, annotations.c.species_rowid == species.c.species_rowid - ) - ) - .where(annotations.c.annot_rowid.in_(aid_list)) - .group_by('species_text') - .order_by(desc('num_annots')) - ) - results = ibs.db.connection.execute(stmt) - - species_count = OrderedDict() - for row in results: - species_text = row.species_text - if species_text is None: - species_text = const.UNKNOWN - species_count[species_text] = row.num_annots - + species_count = ibs.get_database_species_count(aid_list) if not species_count: primary_species = const.UNKNOWN else: @@ -3691,14 +3662,42 @@ def get_database_species_count(ibs, aid_list=None): >>> ibs = wbia.opendb('testdb1') >>> result = ut.repr2(ibs.get_database_species_count(), nl=False) >>> print(result) - {'____': 3, 'bear_polar': 2, 'zebra_grevys': 2, 'zebra_plains': 6} + {'zebra_plains': 6, '____': 3, 'bear_polar': 2, 'zebra_grevys': 2} """ if aid_list is None: aid_list = ibs.get_valid_aids() - species_list = ibs.get_annot_species_texts(aid_list) - species_count_dict = ut.item_hist(species_list) - return species_count_dict + + annotations = ibs.db._reflect_table('annotations') + species = ibs.db._reflect_table('species') + + from sqlalchemy.sql import select, func, desc + + stmt = ( + select( + [ + species.c.species_text, + func.count(annotations.c.annot_rowid).label('num_annots'), + ] + ) + .select_from( + annotations.outerjoin( + species, annotations.c.species_rowid == species.c.species_rowid + ) + ) + .where(annotations.c.annot_rowid.in_(aid_list)) + .group_by('species_text') + .order_by(desc('num_annots')) + ) + results = ibs.db.connection.execute(stmt) + + species_count = OrderedDict() + for row in results: + species_text = row.species_text + if species_text is None: + species_text = const.UNKNOWN + species_count[species_text] = row.num_annots + return species_count @register_ibs_method From 137cb3b57ab89d269c42c465e1f098b3005bd153 Mon Sep 17 00:00:00 2001 From: karen chan Date: Wed, 2 Dec 2020 21:48:39 +0000 Subject: [PATCH 111/294] Execute sql in batches in get_database_species_count When we have a large database, the aid list passed to `get_database_species_count` can be so big that causes this sqlite error: ``` [main()] IBEIS LOAD encountered exception: (sqlite3.OperationalError) too many SQL variables [SQL: SELECT species.species_text, count(annotations.annot_rowid) AS num_annots FROM annotations LEFT OUTER JOIN species ON annotations.species_rowid = species.species_rowid WHERE annotations.annot_rowid IN (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ... ``` We can execute the sql in batches instead. It seems the limit is defined in `SQLITE_MAX_VARIABLE_NUMBER` which is by default 999. (See https://sqlite.org/limits.html) --- wbia/other/ibsfuncs.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/wbia/other/ibsfuncs.py b/wbia/other/ibsfuncs.py index a6385ca0a8..bfc9bbdfbb 100644 --- a/wbia/other/ibsfuncs.py +++ b/wbia/other/ibsfuncs.py @@ -3648,7 +3648,7 @@ def get_dominant_species(ibs, aid_list): @register_ibs_method -def get_database_species_count(ibs, aid_list=None): +def get_database_species_count(ibs, aid_list=None, BATCH_SIZE=25000): """ CommandLine: @@ -3660,9 +3660,9 @@ def get_database_species_count(ibs, aid_list=None): >>> import wbia # NOQA >>> #print(ut.repr2(wbia.opendb('PZ_Master0').get_database_species_count())) >>> ibs = wbia.opendb('testdb1') - >>> result = ut.repr2(ibs.get_database_species_count(), nl=False) + >>> result = ut.repr2(ibs.get_database_species_count(BATCH_SIZE=2), nl=False) >>> print(result) - {'zebra_plains': 6, '____': 3, 'bear_polar': 2, 'zebra_grevys': 2} + {'zebra_plains': 6, '____': 3, 'zebra_grevys': 2, 'bear_polar': 2} """ if aid_list is None: @@ -3671,8 +3671,9 @@ def get_database_species_count(ibs, aid_list=None): annotations = ibs.db._reflect_table('annotations') species = ibs.db._reflect_table('species') - from sqlalchemy.sql import select, func, desc + from sqlalchemy.sql import select, func, desc, bindparam + species_count = OrderedDict() stmt = ( select( [ @@ -3685,18 +3686,21 @@ def get_database_species_count(ibs, aid_list=None): species, annotations.c.species_rowid == species.c.species_rowid ) ) - .where(annotations.c.annot_rowid.in_(aid_list)) + .where(annotations.c.annot_rowid.in_(bindparam('aids', expanding=True))) .group_by('species_text') .order_by(desc('num_annots')) ) - results = ibs.db.connection.execute(stmt) - - species_count = OrderedDict() - for row in results: - species_text = row.species_text - if species_text is None: - species_text = const.UNKNOWN - species_count[species_text] = row.num_annots + for batch in range(int(len(aid_list) / BATCH_SIZE) + 1): + aids = aid_list[batch * BATCH_SIZE : (batch + 1) * BATCH_SIZE] + results = ibs.db.connection.execute(stmt, {'aids': aids}) + + for row in results: + species_text = row.species_text + if species_text is None: + species_text = const.UNKNOWN + species_count[species_text] = ( + species_count.get(species_text, 0) + row.num_annots + ) return species_count From 2b982ec8470f758119f4bc7d76bd43a778fbc69f Mon Sep 17 00:00:00 2001 From: karen chan Date: Thu, 10 Dec 2020 00:03:05 +0000 Subject: [PATCH 112/294] Rewrite SQLDatabaseController.get to use less queries `SQLDatabaseController.get` was making one query per "id", for example, ``` select annot_staged_flag from annotations where rowid = 1; select annot_staged_flag from annotations where rowid = 2; select annot_staged_flag from annotations where rowid = 3; ``` etc. This causes `/api/annot/` to be very slow. This commit changes it to: ``` select rowid, annot_staged_flag from annotations where rowid in (1, 2, 3); ``` so that we can get all the information in one query. This speeds up `/api/annot/`. --- wbia/dtool/sql_control.py | 61 +++++++++++++++++++++++++++++---------- 1 file changed, 46 insertions(+), 15 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 0d0d20124f..a6babab3a8 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -24,6 +24,7 @@ from wbia.dtool import lite from wbia.dtool.dump import dumps +from wbia.dtool.types import Integer print, rrr, profile = ut.inject2(__name__) @@ -1141,6 +1142,7 @@ def get( id_colname='rowid', eager=True, assume_unique=False, + BATCH_SIZE=250000, **kwargs, ): """Get rows of data by ID @@ -1203,23 +1205,52 @@ def get( if id_iter is None: where_clause = None params_iter = [] + + return self.get_where( + tblname, colnames, params_iter, where_clause, eager=eager, **kwargs + ) + + id_iter = list(id_iter) # id_iter could be a set + table = self._reflect_table(tblname) + result_map = {} + if id_colname == 'rowid': # rowid isn't an actual column in sqlite + id_column = sqlalchemy.sql.column('rowid', Integer) else: - id_param_name = '_identifier' - where_clause = text(id_colname + f' = :{id_param_name}') - if id_colname == 'rowid': - # Cast all item values to in, in case values are numpy.integer* - # Strangely allow for None values - id_iter = [id_ is not None and int(id_) or id_ for id_ in id_iter] - else: # b/c rowid doesn't really exist as a column - column = self._reflect_table(tblname).c[id_colname] - where_clause = where_clause.bindparams( - bindparam(id_param_name, type_=column.type) - ) - params_iter = [{id_param_name: id} for id in id_iter] + id_column = table.c[id_colname] + stmt = sqlalchemy.select([id_column] + [table.c[c] for c in colnames]) + stmt = stmt.where(id_column.in_(bindparam('value', expanding=True))) + for batch in range(int(len(id_iter) / BATCH_SIZE) + 1): + val_list = self.executeone( + stmt, + {'value': id_iter[batch * BATCH_SIZE : (batch + 1) * BATCH_SIZE]}, + ) - return self.get_where( - tblname, colnames, params_iter, where_clause, eager=eager, **kwargs - ) + for val in val_list: + if not kwargs.get('keepwrap', False) and len(val[1:]) == 1: + values = val[1] + else: + values = val[1:] + existing = result_map.setdefault(val[0], set()) + if isinstance(existing, set): + try: + existing.add(values) + except TypeError: + # unhashable type + result_map[val[0]] = list(result_map[val[0]]) + if values not in result_map[val[0]]: + result_map[val[0]].append(values) + elif values not in existing: + existing.append(values) + + results = [] + for id_ in id_iter: + result = sorted(list(result_map.get(id_, set()))) + if kwargs.get('unpack_scalars', True) and isinstance(result, list): + results.append(_unpacker(result)) + else: + results.append(result) + + return results def set( self, From e710dcb8b5ce3c7abf4aeb5eb1c20d8dc33c2429 Mon Sep 17 00:00:00 2001 From: karen chan Date: Thu, 10 Dec 2020 23:16:31 +0000 Subject: [PATCH 113/294] Reduce queries in SQLDatabaseController.get_where_eq `/view/` was taking a really long time to respond because of `update_all_image_special_imageset`. The code does a large number of queries like this using `SQLDatabaseController.get_where_eq`: ``` SELECT imageset_image_relationship.image_rowid FROM imageset_image_relationship WHERE imageset_image_relationship.image_rowid = 72591 AND imageset_image_relationship.imageset_rowid = 1 SELECT imageset_image_relationship.image_rowid FROM imageset_image_relationship WHERE imageset_image_relationship.image_rowid = 72592 AND imageset_image_relationship.imageset_rowid = 1 SELECT imageset_image_relationship.image_rowid FROM imageset_image_relationship WHERE imageset_image_relationship.image_rowid = 72593 AND imageset_image_relationship.imageset_rowid = 1 ``` Get the same information but reduce the number of queries by grouping the queries together, for example: ``` SELECT image_rowid, imageset_rowid, imageset_image_relationship.image_rowid FROM imageset_image_relationship WHERE imageset_image_relationship.image_rowid = 72591 AND imageset_image_relationship.imageset_rowid = 1 OR imageset_image_relationship.image_rowid = 72592 AND imageset_image_relationship.imageset_rowid = 1 OR imageset_image_relationship.image_rowid = 72593 AND imageset_image_relationship.imageset_rowid = 1 ``` --- wbia/dtool/sql_control.py | 90 ++++++++++++++++++++++++++++++++------- 1 file changed, 75 insertions(+), 15 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index a6babab3a8..40ff22f170 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -955,6 +955,7 @@ def get_where_eq( where_colnames, unpack_scalars=True, op='AND', + BATCH_SIZE=250000, **kwargs, ): """Executes a SQL select where the given parameters match/equal @@ -973,23 +974,82 @@ def get_where_eq( (default: True) """ + if len(where_colnames) == 1: + return self.get( + tblname, + colnames, + id_iter=(p[0] for p in params_iter), + id_colname=where_colnames[0], + unpack_scalars=unpack_scalars, + BATCH_SIZE=BATCH_SIZE, + **kwargs, + ) + params_iter = list(params_iter) table = self._reflect_table(tblname) - # Build the equality conditions using column type information. - # This allows us to bind the parameter with the correct type. - equal_conditions = [ - (table.c[c] == bindparam(c, type_=table.c[c].type)) for c in where_colnames - ] - gate_func = {'and': sqlalchemy.and_, 'or': sqlalchemy.or_}[op.lower()] - where_clause = gate_func(*equal_conditions) - params = [dict(zip(where_colnames, p)) for p in params_iter] - return self.get_where( - tblname, - colnames, - params, - where_clause, - unpack_scalars=unpack_scalars, - **kwargs, + if op.lower() != 'and' or not params_iter: + # Build the equality conditions using column type information. + # This allows us to bind the parameter with the correct type. + equal_conditions = [ + (table.c[c] == bindparam(c, type_=table.c[c].type)) + for c in where_colnames + ] + gate_func = {'and': sqlalchemy.and_, 'or': sqlalchemy.or_}[op.lower()] + where_clause = gate_func(*equal_conditions) + params = [dict(zip(where_colnames, p)) for p in params_iter] + return self.get_where( + tblname, + colnames, + params, + where_clause, + unpack_scalars=unpack_scalars, + **kwargs, + ) + + params_per_batch = int(BATCH_SIZE / len(params_iter[0])) + result_map = {} + stmt = sqlalchemy.select( + [table.c[c] for c in tuple(where_colnames) + tuple(colnames)] + ) + stmt = stmt.where( + sqlalchemy.tuple_(*[table.c[c] for c in where_colnames]).in_( + sqlalchemy.sql.bindparam('params', expanding=True) + ) ) + for batch in range(int(len(params_iter) / params_per_batch) + 1): + val_list = self.executeone( + stmt, + { + 'params': params_iter[ + batch * params_per_batch : (batch + 1) * params_per_batch + ] + }, + ) + for val in val_list: + key = val[: len(params_iter[0])] + values = val[len(params_iter[0]) :] + if not kwargs.get('keepwrap', False) and len(values) == 1: + values = values[0] + existing = result_map.setdefault(key, set()) + if isinstance(existing, set): + try: + existing.add(values) + except TypeError: + # unhashable type + result_map[key] = list(result_map[key]) + if values not in result_map[key]: + result_map[key].append(values) + elif values not in existing: + existing.append(values) + + results = [] + for id_ in params_iter: + result = sorted(list(result_map.get(tuple(id_), set()))) + if unpack_scalars and isinstance(result, list): + results.append(_unpacker(result)) + else: + results.append(result) + + return results def get_where_eq_set( self, From a48857b417dce8239e08badcc2e2b53da90ab7e7 Mon Sep 17 00:00:00 2001 From: karen chan Date: Thu, 10 Dec 2020 23:50:30 +0000 Subject: [PATCH 114/294] Change all "AND" get_where to get_where_eq `get_where_eq` has been optimized to reduce the number of queries to the database so changing `get_where` to `get_where_eq`. --- wbia/control/manual_gsgrelate_funcs.py | 5 ++--- wbia/control/manual_lblannot_funcs.py | 15 ++++++--------- wbia/control/manual_lblimage_funcs.py | 20 ++++++++------------ wbia/control/manual_meta_funcs.py | 10 ++++------ wbia/control/manual_review_funcs.py | 10 ++++------ wbia/dtool/sql_control.py | 5 +++-- 6 files changed, 27 insertions(+), 38 deletions(-) diff --git a/wbia/control/manual_gsgrelate_funcs.py b/wbia/control/manual_gsgrelate_funcs.py index 09c8384175..daeb66ec18 100644 --- a/wbia/control/manual_gsgrelate_funcs.py +++ b/wbia/control/manual_gsgrelate_funcs.py @@ -55,13 +55,12 @@ def get_image_gsgrids(ibs, gid_list): list_ (list): a list of imageset-image-relationship rowids for each imageid""" # TODO: Group type params_iter = ((gid,) for gid in gid_list) - where_clause = 'image_rowid=?' # list of relationships for each image - gsgrids_list = ibs.db.get_where( + gsgrids_list = ibs.db.get_where_eq( const.GSG_RELATION_TABLE, ('gsgr_rowid',), params_iter, - where_clause, + ('image_rowid',), unpack_scalars=False, ) return gsgrids_list diff --git a/wbia/control/manual_lblannot_funcs.py b/wbia/control/manual_lblannot_funcs.py index b44a7b195b..d93b24da9e 100644 --- a/wbia/control/manual_lblannot_funcs.py +++ b/wbia/control/manual_lblannot_funcs.py @@ -100,9 +100,8 @@ def get_lblannot_rowid_from_superkey(ibs, lbltype_rowid_list, value_list): """ colnames = ('lblannot_rowid',) params_iter = zip(lbltype_rowid_list, value_list) - where_clause = 'lbltype_rowid=? AND lblannot_value=?' - lblannot_rowid_list = ibs.db.get_where( - const.LBLANNOT_TABLE, colnames, params_iter, where_clause + lblannot_rowid_list = ibs.db.get_where_eq( + const.LBLANNOT_TABLE, colnames, params_iter, ('lbltype_rowid', 'lblannot_value') ) # BIG HACK FOR ENFORCING UNKNOWN LBLANNOTS HAVE ROWID 0 lblannot_rowid_list = [ @@ -169,13 +168,12 @@ def get_alr_annot_rowids_from_lblannot_rowid(ibs, lblannot_rowid_list): # FIXME: SLOW # if verbose: # logger.info(ut.get_caller_name(N=list(range(0, 20)))) - where_clause = 'lblannot_rowid=?' params_iter = [(lblannot_rowid,) for lblannot_rowid in lblannot_rowid_list] - aids_list = ibs.db.get_where( + aids_list = ibs.db.get_where_eq( const.AL_RELATION_TABLE, ('annot_rowid',), params_iter, - where_clause, + ('lblannot_rowid',), unpack_scalars=False, ) return aids_list @@ -223,9 +221,8 @@ def get_alrid_from_superkey(ibs, aid_list, lblannot_rowid_list): """ colnames = ('annot_rowid',) params_iter = zip(aid_list, lblannot_rowid_list) - where_clause = 'annot_rowid=? AND lblannot_rowid=?' - alrid_list = ibs.db.get_where( - const.AL_RELATION_TABLE, colnames, params_iter, where_clause + alrid_list = ibs.db.get_where_eq( + const.AL_RELATION_TABLE, colnames, params_iter, ('annot_rowid', 'lblannot_rowid') ) return alrid_list diff --git a/wbia/control/manual_lblimage_funcs.py b/wbia/control/manual_lblimage_funcs.py index 2fb40c89a1..fc5847e770 100644 --- a/wbia/control/manual_lblimage_funcs.py +++ b/wbia/control/manual_lblimage_funcs.py @@ -101,9 +101,8 @@ def get_lblimage_rowid_from_superkey(ibs, lbltype_rowid_list, value_list): """ colnames = ('lblimage_rowid',) params_iter = zip(lbltype_rowid_list, value_list) - where_clause = 'lbltype_rowid=? AND lblimage_value=?' - lblimage_rowid_list = ibs.db.get_where( - const.LBLIMAGE_TABLE, colnames, params_iter, where_clause + lblimage_rowid_list = ibs.db.get_where_eq( + const.LBLIMAGE_TABLE, colnames, params_iter, ('lbltype_rowid', 'lblimage_value') ) return lblimage_rowid_list @@ -172,13 +171,12 @@ def get_lblimage_gids(ibs, lblimage_rowid_list): # FIXME: SLOW # if verbose: # logger.info(ut.get_caller_name(N=list(range(0, 20)))) - where_clause = 'lblimage_rowid=?' params_iter = [(lblimage_rowid,) for lblimage_rowid in lblimage_rowid_list] - gids_list = ibs.db.get_where( + gids_list = ibs.db.get_where_eq( const.GL_RELATION_TABLE, ('image_rowid',), params_iter, - where_clause, + ('lblimage_rowid',), unpack_scalars=False, ) return gids_list @@ -226,9 +224,8 @@ def get_glrid_from_superkey(ibs, gid_list, lblimage_rowid_list): """ colnames = ('image_rowid',) params_iter = zip(gid_list, lblimage_rowid_list) - where_clause = 'image_rowid=? AND lblimage_rowid=?' - glrid_list = ibs.db.get_where( - const.GL_RELATION_TABLE, colnames, params_iter, where_clause + glrid_list = ibs.db.get_where_eq( + const.GL_RELATION_TABLE, colnames, params_iter, ('image_rowid', 'lblimage_rowid') ) return glrid_list @@ -242,12 +239,11 @@ def get_image_glrids(ibs, gid_list): be only of a specific lbltype/category/type """ params_iter = ((gid,) for gid in gid_list) - where_clause = 'image_rowid=?' - glrids_list = ibs.db.get_where( + glrids_list = ibs.db.get_where_eq( const.GL_RELATION_TABLE, ('glr_rowid',), params_iter, - where_clause=where_clause, + ('image_rowid',), unpack_scalars=False, ) return glrids_list diff --git a/wbia/control/manual_meta_funcs.py b/wbia/control/manual_meta_funcs.py index 0b84737442..81cc4f1e05 100644 --- a/wbia/control/manual_meta_funcs.py +++ b/wbia/control/manual_meta_funcs.py @@ -900,13 +900,12 @@ def get_metadata_value(ibs, metadata_key_list, db): URL: /api/metadata/value/ """ params_iter = ((metadata_key,) for metadata_key in metadata_key_list) - where_clause = 'metadata_key=?' # list of relationships for each image - metadata_value_list = db.get_where( + metadata_value_list = db.get_where_eq( const.METADATA_TABLE, ('metadata_value',), params_iter, - where_clause, + ('metadata_key',), unpack_scalars=True, ) return metadata_value_list @@ -924,13 +923,12 @@ def get_metadata_rowid_from_metadata_key(ibs, metadata_key_list, db): """ db = db[0] # Unwrap tuple, required by @accessor_decors.getter_1to1 decorator params_iter = ((metadata_key,) for metadata_key in metadata_key_list) - where_clause = 'metadata_key=?' # list of relationships for each image - metadata_rowid_list = db.get_where( + metadata_rowid_list = db.get_where_eq( const.METADATA_TABLE, ('metadata_rowid',), params_iter, - where_clause, + ('metadata_key',), unpack_scalars=True, ) return metadata_rowid_list diff --git a/wbia/control/manual_review_funcs.py b/wbia/control/manual_review_funcs.py index e585ddfe65..ab37ae5552 100644 --- a/wbia/control/manual_review_funcs.py +++ b/wbia/control/manual_review_funcs.py @@ -566,9 +566,8 @@ def get_review_decisions_from_only(ibs, aid_list, eager=True, nInput=None): REVIEW_EVIDENCE_DECISION, ) params_iter = [(aid,) for aid in aid_list] - where_clause = '%s=?' % (REVIEW_AID1) - review_tuple_decisions_list = ibs.staging.get_where( - const.REVIEW_TABLE, colnames, params_iter, where_clause, unpack_scalars=False + review_tuple_decisions_list = ibs.staging.get_where_eq( + const.REVIEW_TABLE, colnames, params_iter, (REVIEW_AID1,), unpack_scalars=False ) return review_tuple_decisions_list @@ -586,9 +585,8 @@ def get_review_rowids_from_only(ibs, aid_list, eager=True, nInput=None): """ colnames = (REVIEW_ROWID,) params_iter = [(aid,) for aid in aid_list] - where_clause = '%s=?' % (REVIEW_AID1) - review_rowids = ibs.staging.get_where( - const.REVIEW_TABLE, colnames, params_iter, where_clause, unpack_scalars=False + review_rowids = ibs.staging.get_where_eq( + const.REVIEW_TABLE, colnames, params_iter, (REVIEW_AID1,), unpack_scalars=False ) return review_rowids diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 40ff22f170..2bb0037f3e 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1675,10 +1675,11 @@ def get_metadata_val(self, key, eval_=False, default=None): """ val is the repr string unless eval_ is true """ - where_clause = 'metadata_key=?' colnames = ('metadata_value',) params_iter = [(key,)] - vals = self.get_where(METADATA_TABLE_NAME, colnames, params_iter, where_clause) + vals = self.get_where_eq( + METADATA_TABLE_NAME, colnames, params_iter, ('metadata_key',) + ) assert len(vals) == 1, 'duplicate keys in metadata table' val = vals[0] if val is None: From 2d949c9dc8953d0377b6c0654c5b5a5030a8a742 Mon Sep 17 00:00:00 2001 From: karen chan Date: Mon, 14 Dec 2020 19:03:39 +0000 Subject: [PATCH 115/294] Upgrade sqlalchemy to >=1.4.0b1 We were using sqlalchemy 1.3.20 but it has a bug where if we do: ``` stmt = select([table.c['column_one']).where( tuple_(table.c['uuid_column'], table.c['integer_column']).in_( bindparam('params', expanding=True) ) ) conn.execute(stmt, [(uuid_value, int_value)]) ``` `uuid_value` is processed correctly as uuid but `int_value` is processed also as uuid for some reason. This is fixed in sqlalchemy 1.4.0b1. --- requirements/runtime.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index b18841028d..d36a4b2b6a 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -54,7 +54,7 @@ simplejson>=3.6.5 sip six>=1.10.0 -sqlalchemy +sqlalchemy>=1.4.0b1 statsmodels>=0.6.1 torch From 885cf9c9aab956c4912abfb0c981fff5f4749b3f Mon Sep 17 00:00:00 2001 From: karen chan Date: Tue, 15 Dec 2020 18:18:34 +0000 Subject: [PATCH 116/294] Add imgsetid_list to get_valid_gids `get_valid_gids` was only taking one `imgsetid` causing some code to call it many many times. This created lots of small queries making web requests very slow. Add `imgsetid_list` so `get_valid_gids` can be called once. --- wbia/control/manual_image_funcs.py | 12 ++++++++++-- wbia/gui/guiback.py | 4 +--- wbia/web/routes.py | 14 ++++---------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/wbia/control/manual_image_funcs.py b/wbia/control/manual_image_funcs.py index bb13020355..75c6b4a019 100644 --- a/wbia/control/manual_image_funcs.py +++ b/wbia/control/manual_image_funcs.py @@ -112,7 +112,13 @@ def _get_all_image_rowids(ibs): @accessor_decors.ider @register_api('/api/image/', methods=['GET']) def get_valid_gids( - ibs, imgsetid=None, require_unixtime=False, require_gps=None, reviewed=None, **kwargs + ibs, + imgsetid=None, + imgsetid_list=(), + require_unixtime=False, + require_gps=None, + reviewed=None, + **kwargs ): r""" Args: @@ -147,8 +153,10 @@ def get_valid_gids( >>> print(result) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] """ - if imgsetid is None: + if imgsetid is None and not imgsetid_list: gid_list = ibs._get_all_gids() + elif imgsetid_list: + gid_list = ibs.get_imageset_gids(imgsetid_list) else: assert not ut.isiterable(imgsetid) gid_list = ibs.get_imageset_gids(imgsetid) diff --git a/wbia/gui/guiback.py b/wbia/gui/guiback.py index 339d978483..b628d6e25b 100644 --- a/wbia/gui/guiback.py +++ b/wbia/gui/guiback.py @@ -1686,9 +1686,7 @@ def merge_imagesets(back, imgsetid_list, destination_imgsetid): destination_imgsetid = imgsetid_list[destination_index] deprecated_imgsetids = list(imgsetid_list) deprecated_imgsetids.pop(destination_index) - gid_list = ut.flatten( - [ibs.get_valid_gids(imgsetid=imgsetid) for imgsetid in imgsetid_list] - ) + gid_list = ut.flatten(ibs.get_valid_gids(imgsetid_list=imgsetid_list)) imgsetid_list = [destination_imgsetid] * len(gid_list) ibs.set_image_imgsetids(gid_list, imgsetid_list) ibs.delete_imagesets(deprecated_imgsetids) diff --git a/wbia/web/routes.py b/wbia/web/routes.py index 84070ad708..4b69cc03e7 100644 --- a/wbia/web/routes.py +++ b/wbia/web/routes.py @@ -1484,7 +1484,7 @@ def view_imagesets(**kwargs): all_gid_list = ibs.get_valid_gids() all_aid_list = ibs.get_valid_aids() - gids_list = [ibs.get_valid_gids(imgsetid=imgsetid_) for imgsetid_ in imgsetid_list] + gids_list = ibs.get_valid_gids(imgsetid_list=imgsetid_list) num_gids = list(map(len, gids_list)) ###################################################################################### @@ -1777,9 +1777,7 @@ def view_images(**kwargs): None if imgsetid_ == 'None' or imgsetid_ == '' else int(imgsetid_) for imgsetid_ in imgsetid_list ] - gid_list = ut.flatten( - [ibs.get_valid_gids(imgsetid=imgsetid) for imgsetid_ in imgsetid_list] - ) + gid_list = ut.flatten(ibs.get_valid_gids(imgsetid_list=imgsetid_list)) else: gid_list = ibs.get_valid_gids() filtered = False @@ -1868,9 +1866,7 @@ def view_annotations(**kwargs): None if imgsetid_ == 'None' or imgsetid_ == '' else int(imgsetid_) for imgsetid_ in imgsetid_list ] - gid_list = ut.flatten( - [ibs.get_valid_gids(imgsetid=imgsetid_) for imgsetid_ in imgsetid_list] - ) + gid_list = ut.flatten(ibs.get_valid_gids(imgsetid_list=imgsetid_list)) aid_list = ut.flatten(ibs.get_image_aids(gid_list)) else: aid_list = ibs.get_valid_aids() @@ -2028,9 +2024,7 @@ def view_names(**kwargs): None if imgsetid_ == 'None' or imgsetid_ == '' else int(imgsetid_) for imgsetid_ in imgsetid_list ] - gid_list = ut.flatten( - [ibs.get_valid_gids(imgsetid=imgsetid_) for imgsetid_ in imgsetid_list] - ) + gid_list = ut.flatten(ibs.get_valid_gids(imgsetid_list=imgsetid_list)) aid_list = ut.flatten(ibs.get_image_aids(gid_list)) nid_list = ibs.get_annot_name_rowids(aid_list) else: From f69bffe552f7e84a7412c89a2a6f2c18082c4a77 Mon Sep 17 00:00:00 2001 From: karen chan Date: Wed, 16 Dec 2020 21:14:19 +0000 Subject: [PATCH 117/294] Return early for filter_* functions if input is empty A lot of filter_* functions do some database queries even if the input is empty. Given that we're filtering the input, if the input is empty, there's nothing we need to do so we just need to return. --- wbia/control/manual_annot_funcs.py | 2 ++ wbia/control/manual_part_funcs.py | 2 ++ wbia/init/filter_annots.py | 6 ++++++ wbia/other/ibsfuncs.py | 14 ++++++++++++++ wbia/web/routes.py | 20 ++++++++++++++++++++ 5 files changed, 44 insertions(+) diff --git a/wbia/control/manual_annot_funcs.py b/wbia/control/manual_annot_funcs.py index c238ca8f16..6ad08986fb 100644 --- a/wbia/control/manual_annot_funcs.py +++ b/wbia/control/manual_annot_funcs.py @@ -838,6 +838,8 @@ def filter_annotation_set( is_canonical=None, min_timedelta=None, ): + if not aid_list: # no need to filter if empty + return aid_list # -- valid aid filtering -- # filter by is_exemplar if is_exemplar is True: diff --git a/wbia/control/manual_part_funcs.py b/wbia/control/manual_part_funcs.py index ae75761133..c7ffcd8c4e 100644 --- a/wbia/control/manual_part_funcs.py +++ b/wbia/control/manual_part_funcs.py @@ -107,6 +107,8 @@ def filter_part_set( viewpoint='no-filter', minqual=None, ): + if not part_rowid_list: # no need to filter if empty + return part_rowid_list # -- valid part_rowid filtering -- # filter by is_staged diff --git a/wbia/init/filter_annots.py b/wbia/init/filter_annots.py index 87458a27a5..448a806b5a 100644 --- a/wbia/init/filter_annots.py +++ b/wbia/init/filter_annots.py @@ -1322,6 +1322,9 @@ def filter_annots_independent( logger.info('No annot filter returning') return avail_aids + if not avail_aids: # no need to filter if empty + return avail_aids + VerbosityContext = verb_context('FILTER_INDEPENDENT', aidcfg, verbose) VerbosityContext.startfilter(withpre=withpre) @@ -1599,6 +1602,9 @@ def filter_annots_intragroup( logger.info('No annot filter returning') return avail_aids + if not avail_aids: # no need to filter if empty + return avail_aids + VerbosityContext = verb_context('FILTER_INTRAGROUP', aidcfg, verbose) VerbosityContext.startfilter(withpre=withpre) diff --git a/wbia/other/ibsfuncs.py b/wbia/other/ibsfuncs.py index bfc9bbdfbb..70d5676179 100644 --- a/wbia/other/ibsfuncs.py +++ b/wbia/other/ibsfuncs.py @@ -150,6 +150,8 @@ def filter_junk_annotations(ibs, aid_list): >>> result = str(filtered_aid_list) >>> print(result) """ + if not aid_list: # no need to filter if empty + return aid_list isjunk_list = ibs.get_annot_isjunk(aid_list) filtered_aid_list = ut.filterfalse_items(aid_list, isjunk_list) return filtered_aid_list @@ -4897,6 +4899,8 @@ def filter_aids_to_quality(ibs, aid_list, minqual, unknown_ok=True, speedhack=Tr >>> x1 = filter_aids_to_quality(ibs, aid_list, 'good', True, speedhack=True) >>> x2 = filter_aids_to_quality(ibs, aid_list, 'good', True, speedhack=False) """ + if not aid_list: # no need to filter if empty + return aid_list if speedhack: list_repr = ','.join(map(str, aid_list)) minqual_int = const.QUALITY_TEXT_TO_INT[minqual] @@ -4923,6 +4927,8 @@ def filter_aids_to_viewpoint(ibs, aid_list, valid_yaws, unknown_ok=True): valid_yaws = ['primary', 'primary1', 'primary-1'] """ + if not aid_list: # no need to filter if empty + return aid_list def rectify_view_category(view): @ut.memoize @@ -4983,6 +4989,8 @@ def filter_aids_without_name(ibs, aid_list, invert=False, speedhack=True): >>> assert np.all(np.array(annots2_.nids) < 0) >>> assert len(annots2_) == 4 """ + if not aid_list: # no need to filter if empty + return aid_list if speedhack: list_repr = ','.join(map(str, aid_list)) if invert: @@ -5037,6 +5045,8 @@ def filter_annots_using_minimum_timedelta(ibs, aid_list, min_timedelta): >>> wbia.other.dbinfo.hackshow_names(ibs, filtered_aids) >>> ut.show_if_requested() """ + if not aid_list: # no need to filter if empty + return aid_list import vtool as vt # min_timedelta = 60 * 60 * 24 @@ -5092,6 +5102,8 @@ def filter_aids_without_timestamps(ibs, aid_list, invert=False): Removes aids without timestamps aid_list = ibs.get_valid_aids() """ + if not aid_list: # no need to filter if empty + return aid_list unixtime_list = ibs.get_annot_image_unixtimes(aid_list) flag_list = [unixtime != -1 for unixtime in unixtime_list] if invert: @@ -5126,6 +5138,8 @@ def filter_aids_to_species(ibs, aid_list, species, speedhack=True): >>> print(result) aid_list_ = [9, 10] """ + if not aid_list: # no need to filter if empty + return aid_list species_rowid = ibs.get_species_rowids_from_text(species) if speedhack: list_repr = ','.join(map(str, aid_list)) diff --git a/wbia/web/routes.py b/wbia/web/routes.py index 4b69cc03e7..1fec360737 100644 --- a/wbia/web/routes.py +++ b/wbia/web/routes.py @@ -279,6 +279,8 @@ def _date_list(gid_list): return date_list def filter_annots_imageset(aid_list): + if not aid_list: # no need to filter if empty + return aid_list try: imgsetid = request.args.get('imgsetid', '') imgsetid = int(imgsetid) @@ -296,6 +298,8 @@ def filter_annots_imageset(aid_list): return aid_list def filter_images_imageset(gid_list): + if not gid_list: # no need to filter if empty + return gid_list try: imgsetid = request.args.get('imgsetid', '') imgsetid = int(imgsetid) @@ -313,6 +317,8 @@ def filter_images_imageset(gid_list): return gid_list def filter_names_imageset(nid_list): + if not nid_list: # no need to filter if empty + return nid_list try: imgsetid = request.args.get('imgsetid', '') imgsetid = int(imgsetid) @@ -333,6 +339,8 @@ def filter_names_imageset(nid_list): return nid_list def filter_annots_general(ibs, aid_list): + if not aid_list: # no need to filter if empty + return aid_list if ibs.dbname == 'GGR-IBEIS': # Grevy's filter_kw = { @@ -1013,6 +1021,8 @@ def _date_list(gid_list): return date_list def filter_annots_imageset(aid_list): + if not aid_list: # no need to filter if empty + return aid_list try: imgsetid = request.args.get('imgsetid', '') imgsetid = int(imgsetid) @@ -1030,6 +1040,8 @@ def filter_annots_imageset(aid_list): return aid_list def filter_annots_general(ibs, aid_list): + if not aid_list: # no need to filter if empty + return aid_list if ibs.dbname == 'GGR-IBEIS': # Grevy's filter_kw = { @@ -1233,6 +1245,8 @@ def _date_list(gid_list): return date_list def filter_species_of_interest(gid_list): + if not gid_list: # no need to filter if empty + return gid_list wanted_set = set(['zebra_plains', 'zebra_grevys', 'giraffe_masai']) aids_list = ibs.get_image_aids(gid_list) speciess_list = ut.unflat_map(ibs.get_annot_species_texts, aids_list) @@ -1245,6 +1259,8 @@ def filter_species_of_interest(gid_list): return gid_list_filtered def filter_viewpoints_of_interest(gid_list, allowed_viewpoint_list): + if not gid_list: # no need to filter if empty + return gid_list aids_list = ibs.get_image_aids(gid_list) wanted_set = set(allowed_viewpoint_list) viewpoints_list = ut.unflat_map(ibs.get_annot_viewpoints, aids_list) @@ -1257,6 +1273,8 @@ def filter_viewpoints_of_interest(gid_list, allowed_viewpoint_list): return gid_list_filtered def filter_bad_metadata(gid_list): + if not gid_list: # no need to filter if empty + return gid_list wanted_set = set(['2015/03/01', '2015/03/02', '2016/01/30', '2016/01/31']) date_list = _date_list(gid_list) gps_list = ibs.get_image_gps(gid_list) @@ -1267,6 +1285,8 @@ def filter_bad_metadata(gid_list): return gid_list_filtered def filter_bad_quality(gid_list, allowed_quality_list): + if not gid_list: # no need to filter if empty + return gid_list aids_list = ibs.get_image_aids(gid_list) wanted_set = set(allowed_quality_list) qualities_list = ut.unflat_map(ibs.get_annot_quality_texts, aids_list) From 1e70fc1bc1072a07f969e4b18d4f475bbc99625d Mon Sep 17 00:00:00 2001 From: karen chan Date: Wed, 23 Dec 2020 19:21:35 +0000 Subject: [PATCH 118/294] Downgrade scikit-learn to 0.23.2 Tests are failing after the upgrade to scikit-learn 0.24.0: ``` Traceback (most recent call last): File "/wbia/wbia-utool/utool/util_io.py", line 357, in load_cPkl data = FixRenamedUnpickler(file_).load() UnicodeDecodeError: 'ascii' codec can't decode byte 0x98 in position 0: ordinal not in range(128) During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/wbia/wildbook-ia/wbia/dtool/depcache_table.py", line 1750, in _chunk_compute_dirty_rows config, File "/wbia/wildbook-ia/wbia/dtool/depcache_table.py", line 1680, in _compute_dirty_rows proptup_gen = list(proptup_gen) File "/wbia/wildbook-ia/wbia/core_annots.py", line 805, in compute_probchip grouped_probchips.append(list(gen)) File "/wbia/wildbook-ia/wbia/core_annots.py", line 831, in cnn_probchips for chunk in ut.ichunks(mask_gen, 256): File "/wbia/wbia-utool/utool/util_iter.py", line 439, in ichunks_noborder for chunk in chunks_with_sentinals: File "/wbia/wbia-plugin-cnn/wbia_cnn/_plugin.py", line 1002, in generate_species_background model_state = ut.load_cPkl(model_state_fpath) File "/wbia/wbia-utool/utool/util_io.py", line 362, in load_cPkl data = FixRenamedUnpickler(file_, encoding='latin1').load() File "/wbia/wbia-utool/utool/util_io.py", line 276, in find_class return super(FixRenamedUnpickler, self).find_class(module, name) ModuleNotFoundError: No module named 'sklearn.preprocessing.label' ``` The module is `sklearn.preprocessing.label` and the name is `LabelEncoder`. It has been moved to `sklearn.preprocessing`. The file that failed to unpickle is `/root/.cache/wbia_cnn/background.candidacy.zebra_plains.pkl`. We'll fix it later. --- requirements/runtime.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index d36a4b2b6a..6268f0ef20 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -41,7 +41,7 @@ pyzmq>=14.7.0 requests>=2.5.0 scikit-image>=0.12.3 -scikit-learn>=0.17.1 +scikit-learn>=0.17.1,<0.24.0 scipy>=0.18.0 sentry-sdk>=0.10.2 From 49f3fa77f9a0bd8f1e597de6b2c08ba674d04bcf Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Sun, 27 Dec 2020 23:44:09 -0800 Subject: [PATCH 119/294] Existing databases storer UUIDs with little-endian encoding --- wbia/dtool/types.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/wbia/dtool/types.py b/wbia/dtool/types.py index 651f965b0b..d615f3f10f 100644 --- a/wbia/dtool/types.py +++ b/wbia/dtool/types.py @@ -150,10 +150,10 @@ def process(value): return value else: if not isinstance(value, uuid.UUID): - return uuid.UUID(value).bytes + return uuid.UUID(value).bytes_le else: # hexstring - return value.bytes + return value.bytes_le return process @@ -163,7 +163,7 @@ def process(value): return value else: if not isinstance(value, uuid.UUID): - return uuid.UUID(bytes=value) + return uuid.UUID(bytes_le=value) else: return value From 26c5c605fb92061a554f843276397c2b9406158c Mon Sep 17 00:00:00 2001 From: karen chan Date: Tue, 29 Dec 2020 16:01:41 +0000 Subject: [PATCH 120/294] Add test to make sure uuid is stored as little endian Before sqlalchemy, we were storing uuid as bytes in little endian order. We have no tests to make sure that those existing uuids are being read back out correctly so just adding one now. I checked that this test failed without commit "49f3fa77f" (Existing databases storer UUIDs with little-endian encoding), so it is testing the right thing. --- wbia/tests/dtool/test_types.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/wbia/tests/dtool/test_types.py b/wbia/tests/dtool/test_types.py index d342a3ca03..8b7c26c09b 100644 --- a/wbia/tests/dtool/test_types.py +++ b/wbia/tests/dtool/test_types.py @@ -193,3 +193,20 @@ def test_uuid(db): results = db.execute(stmt) selected_value = results.fetchone()[0] assert selected_value == insert_value + + +def test_le_uuid(db): + db.execute(text('CREATE TABLE test(x UUID)')) + + # Insert a uuid value but explicitly stored as little endian + # (the way uuids were stored before sqlalchemy) + insert_value = uuid.uuid4() + stmt = text('INSERT INTO test(x) VALUES (:x)') + db.execute(stmt, x=insert_value.bytes_le) + + # Query for the value + stmt = text('SELECT x FROM test') + stmt = stmt.columns(x=UUID) + results = db.execute(stmt) + selected_value = results.fetchone()[0] + assert selected_value == insert_value From bfd0e64a5fc808cf97543c2550bffd8f36212286 Mon Sep 17 00:00:00 2001 From: karen chan Date: Tue, 22 Dec 2020 15:44:58 +0000 Subject: [PATCH 121/294] Catch StopIteration in /turk/identification/lnbnn When there's no more to review, `/turk/identification/lnbnn/` returns some json with traceback: ``` {"status": {"success": false, "code": 400, "message": "Route error, Python Exception thrown: 'no more to review!'", "cache": -1}, "response": "Traceback (most recent call last): File \"/wbia/wbia-utool/utool/util_dev.py\", line 3652, in pop val, key = self._heappop(_heap) IndexError: index out of range During handling of the above exception, another exception occurred: Traceback (most recent call last): File \"/wbia/wildbook-ia/wbia/algo/graph/mixin_priority.py\", line 290, in pop edge, priority = infr._pop() File \"/wbia/wildbook-ia/wbia/algo/graph/mixin_priority.py\", line 37, in _pop (e, (p, _)) = infr.queue.pop(*args) File \"/wbia/wbia-utool/utool/util_dev.py\", line 3658, in pop raise IndexError('queue is empty') IndexError: queue is empty During handling of the above exception, another exception occurred: Traceback (most recent call last): File \"/wbia/wildbook-ia/wbia/control/controller_inject.py\", line 1217, in translated_call result = func(**kwargs) File \"/wbia/wildbook-ia/wbia/web/routes.py\", line 4043, in turk_identification values = query_object.pop() File \"/wbia/wildbook-ia/wbia/algo/graph/mixin_priority.py\", line 292, in pop raise StopIteration('no more to review!') StopIteration: no more to review! ``` Instead, return a message saying "no more to review!" --- wbia/tests/web/__init__.py | 0 wbia/tests/web/test_routes.py | 10 ++++++++++ wbia/web/routes.py | 5 ++++- wbia/web/templates/simple.html | 7 +++++++ 4 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 wbia/tests/web/__init__.py create mode 100644 wbia/tests/web/test_routes.py create mode 100644 wbia/web/templates/simple.html diff --git a/wbia/tests/web/__init__.py b/wbia/tests/web/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wbia/tests/web/test_routes.py b/wbia/tests/web/test_routes.py new file mode 100644 index 0000000000..4c5f04f4db --- /dev/null +++ b/wbia/tests/web/test_routes.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +import wbia + + +def test_turk_identification_no_more_to_review(): + with wbia.opendb_bg_web('testdb2', managed=True) as web_ibs: + resp = web_ibs.get('/turk/identification/lnbnn/') + assert resp.status_code == 200 + assert b'Traceback' not in resp.content, resp.content + assert b'

No more to review!

' in resp.content, resp.content diff --git a/wbia/web/routes.py b/wbia/web/routes.py index 1fec360737..b91bf29611 100644 --- a/wbia/web/routes.py +++ b/wbia/web/routes.py @@ -4054,7 +4054,10 @@ def turk_identification( review_cfg[ 'max_num' ] = global_feedback_limit # Controls the top X to be randomly sampled and displayed to all concurrent users - values = query_object.pop() + try: + values = query_object.pop() + except StopIteration as e: + return appf.template(None, 'simple', title=str(e).capitalize()) (review_aid1_list, review_aid2_list), review_confidence = values review_aid1_list = [review_aid1_list] review_aid2_list = [review_aid2_list] diff --git a/wbia/web/templates/simple.html b/wbia/web/templates/simple.html new file mode 100644 index 0000000000..b25b89a2dc --- /dev/null +++ b/wbia/web/templates/simple.html @@ -0,0 +1,7 @@ +{% extends "layout.html" %} +{% block content %} +
+

{{ title }}

+

{{ content }}

+
+{% endblock %} From b90aba5c73b27a5fa35f7e20c874b587dcf65f25 Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Tue, 29 Dec 2020 15:57:18 -0800 Subject: [PATCH 122/294] Add None fallback to JSON column types --- wbia/dtool/types.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/wbia/dtool/types.py b/wbia/dtool/types.py index d615f3f10f..e5f893226c 100644 --- a/wbia/dtool/types.py +++ b/wbia/dtool/types.py @@ -46,13 +46,19 @@ def get_col_spec(self, **kw): def bind_processor(self, dialect): def process(value): - return to_json(value) + if value is None: + return None + else: + return to_json(value) return process def result_processor(self, dialect, coltype): def process(value): - return from_json(value) + if value is None: + return value + else: + return from_json(value) return process From dbab2f979bc94db98ea03ceb1d92ff608e8932b1 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Wed, 30 Dec 2020 00:04:29 -0800 Subject: [PATCH 123/294] Fix call to IBEISController methods that use 'self' Fixes several routes on the IBEISController from the change in 08ad98e141f304fad5593d7e5d04117416e049c0 --- wbia/control/controller_inject.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wbia/control/controller_inject.py b/wbia/control/controller_inject.py index 068b9bd7ae..12c7ae2ca6 100644 --- a/wbia/control/controller_inject.py +++ b/wbia/control/controller_inject.py @@ -590,7 +590,7 @@ def translate_wbia_webcall(func, *args, **kwargs): output = func(**kwargs) except TypeError: try: - output = func(ibs=ibs, **kwargs) + output = func(ibs, **kwargs) except WebException: raise except Exception as ex2: # NOQA From 42b3221d6c2c93c309d939891e72bd53a42484b0 Mon Sep 17 00:00:00 2001 From: karen chan Date: Thu, 24 Dec 2020 13:59:26 +0000 Subject: [PATCH 124/294] Upgrade scikit-learn to 0.24.0 and wbia-utool to 3.3.3 wbia-utool 3.3.3 includes changes to use scikit-learn 0.24.0. scikit-learn 0.24.0 has changed the way StratifiedKFold split works and we can no longer specify `shuffle=True` and `random_state=rng`: ``` DOCTEST TRACEBACK Traceback (most recent call last): File "/virtualenv/env3/lib/python3.7/site-packages/xdoctest/doctest_example.py", line 598, in run exec(code, test_globals) File "", line rel: 5, abs: 49, in >>> wbia, smk, qreq_ = testdata_smk() File "/wbia/wildbook-ia/wbia/algo/smk/smk_pipeline.py", line 577, in testdata_smk skf = sklearn.model_selection.StratifiedKFold(**xvalkw) File "/virtualenv/env3/lib/python3.7/site-packages/sklearn/utils/validation.py", line 63, in inner_f return f(*args, **kwargs) File "/virtualenv/env3/lib/python3.7/site-packages/sklearn/model_selection/_split.py", line 637, in __init__ random_state=random_state) File "/virtualenv/env3/lib/python3.7/site-packages/sklearn/utils/validation.py", line 63, in inner_f return f(*args, **kwargs) File "/virtualenv/env3/lib/python3.7/site-packages/sklearn/model_selection/_split.py", line 291, in __init__ 'Setting a random_state has no effect since shuffle is ' ValueError: Setting a random_state has no effect since shuffle is False. You should leave random_state to its default (None), or set shuffle=True. DOCTEST REPRODUCTION CommandLine: pytest /wbia/wildbook-ia/wbia/algo/smk/match_chips5.py::EstimatorRequest.shallowcopy:0 ``` So I removed `random_state`. Also update imports after scikit-learn moved classes and modules around. --- requirements/runtime.txt | 4 ++-- wbia/algo/smk/smk_pipeline.py | 3 +-- wbia/algo/verif/sklearn_utils.py | 8 +++----- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 6268f0ef20..d34b7b2175 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -41,7 +41,7 @@ pyzmq>=14.7.0 requests>=2.5.0 scikit-image>=0.12.3 -scikit-learn>=0.17.1,<0.24.0 +scikit-learn>=0.24.0 scipy>=0.18.0 sentry-sdk>=0.10.2 @@ -70,5 +70,5 @@ wbia-pyflann >= 3.1.0 wbia-pyhesaff >= 3.0.2 wbia-pyrf >= 3.0.0 -wbia-utool >= 3.3.1 +wbia-utool >= 3.3.3 wbia-vtool >= 3.2.1 diff --git a/wbia/algo/smk/smk_pipeline.py b/wbia/algo/smk/smk_pipeline.py index 3058ea7f68..2ab6d0a8f3 100644 --- a/wbia/algo/smk/smk_pipeline.py +++ b/wbia/algo/smk/smk_pipeline.py @@ -571,8 +571,7 @@ def testdata_smk(*args, **kwargs): # import sklearn.model_selection ibs, aid_list = wbia.testdata_aids(defaultdb='PZ_MTEST') nid_list = np.array(ibs.annots(aid_list).nids) - rng = ut.ensure_rng(0) - xvalkw = dict(n_splits=4, shuffle=False, random_state=rng) + xvalkw = dict(n_splits=4, shuffle=False) skf = sklearn.model_selection.StratifiedKFold(**xvalkw) train_idx, test_idx = six.next(skf.split(aid_list, nid_list)) diff --git a/wbia/algo/verif/sklearn_utils.py b/wbia/algo/verif/sklearn_utils.py index b637bf7748..94185e752e 100644 --- a/wbia/algo/verif/sklearn_utils.py +++ b/wbia/algo/verif/sklearn_utils.py @@ -512,11 +512,9 @@ def amean(x, w=None): # and BM * MK MCC? def matthews_corrcoef(y_true, y_pred, sample_weight=None): - from sklearn.metrics.classification import ( - _check_targets, - LabelEncoder, - confusion_matrix, - ) + from sklearn.preprocessing import LabelEncoder + from sklearn.metrics import confusion_matrix + from sklearn.metrics._classification import _check_targets y_type, y_true, y_pred = _check_targets(y_true, y_pred) if y_type not in {'binary', 'multiclass'}: From bc2b1a6546c482f01364c1c3e4862e157baa81ec Mon Sep 17 00:00:00 2001 From: Drew Blount Date: Mon, 30 Nov 2020 17:31:37 -0800 Subject: [PATCH 125/294] updates apis_sync to use new .jpg method in prev. commit e030fce1, also https by default --- wbia/web/apis_sync.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wbia/web/apis_sync.py b/wbia/web/apis_sync.py index 594289db53..7b97ca55f7 100644 --- a/wbia/web/apis_sync.py +++ b/wbia/web/apis_sync.py @@ -47,7 +47,7 @@ REMOTE_UUID = None -REMOTE_URL = 'http://%s:%s' % (REMOTE_DOMAIN, REMOTE_PORT) +REMOTE_URL = 'https://%s:%s' % (REMOTE_DOMAIN, REMOTE_PORT) REMOTE_UUID = None if REMOTE_UUID is None else uuid.UUID(REMOTE_UUID) @@ -366,7 +366,7 @@ def sync_get_training_data(ibs, species_name, force_update=False, **kwargs): name_texts = ibs._sync_get_annot_endpoint('/api/annot/name/text/', aid_list) name_uuids = ibs._sync_get_annot_endpoint('/api/annot/name/uuid/', aid_list) images = ibs._sync_get_annot_endpoint('/api/annot/image/rowid/', aid_list) - gpaths = [ibs._construct_route_url_ibs('/api/image/src/%s/' % gid) for gid in images] + gpaths = [ibs._construct_route_url_ibs('/api/image/src/%s.jpg' % gid) for gid in images] specieses = [species_name] * len(aid_list) gid_list = ibs.add_images(gpaths) From 0b097bd993446b921bb548b9aefa59965ff88895 Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Thu, 26 Nov 2020 01:20:59 -0800 Subject: [PATCH 126/294] Add repository for the orientation plugin --- devops/Dockerfile | 19 +++++++++++-------- devops/provision/Dockerfile | 8 ++++++++ wbia/control/IBEISControl.py | 6 ++++++ 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/devops/Dockerfile b/devops/Dockerfile index 7e87f845de..428ed40c31 100644 --- a/devops/Dockerfile +++ b/devops/Dockerfile @@ -16,19 +16,22 @@ RUN set -ex \ && git stash && git pull && git stash pop || git reset --hard origin/develop \ && cd /wbia/wbia-plugin-kaggle7/wbia_kaggle7 \ && git stash && git pull && git stash pop || git reset --hard origin/develop \ + && cd /wbia/wbia-plugin-orientation/ \ + && git stash && git pull && git stash pop || git reset --hard origin/compatibility \ && find /wbia -name '.git' -type d -print0 | xargs -0 rm -rf \ && find /wbia -name '_skbuild' -type d -print0 | xargs -0 rm -rf # Run smoke tests RUN set -ex \ - && /virtualenv/env3/bin/python -c "import wbia; from wbia.__main__ import smoke_test; smoke_test()" \ - && /virtualenv/env3/bin/python -c "import wbia_cnn; from wbia_cnn.__main__ import main; main()" \ - && /virtualenv/env3/bin/python -c "import wbia_pie; from wbia_pie.__main__ import main; main()" \ - && /virtualenv/env3/bin/python -c "import wbia_flukematch; from wbia_flukematch.plugin import *" \ - && /virtualenv/env3/bin/python -c "import wbia_curvrank; from wbia_curvrank._plugin import *" \ - && /virtualenv/env3/bin/python -c "import wbia_finfindr; from wbia_finfindr._plugin import *" \ - && /virtualenv/env3/bin/python -c "import wbia_kaggle7; from wbia_kaggle7._plugin import *" \ - && /virtualenv/env3/bin/python -c "import wbia_deepsense; from wbia_deepsense._plugin import *" \ + && /virtualenv/env3/bin/python -c "import wbia; from wbia.__main__ import smoke_test; smoke_test()" \ + && /virtualenv/env3/bin/python -c "import wbia_cnn; from wbia_cnn.__main__ import main; main()" \ + && /virtualenv/env3/bin/python -c "import wbia_pie; from wbia_pie.__main__ import main; main()" \ + && /virtualenv/env3/bin/python -c "import wbia_orientation; from wbia_orientation.__main__ import main; main()" \ + && /virtualenv/env3/bin/python -c "import wbia_flukematch; from wbia_flukematch.plugin import *" \ + && /virtualenv/env3/bin/python -c "import wbia_curvrank; from wbia_curvrank._plugin import *" \ + && /virtualenv/env3/bin/python -c "import wbia_finfindr; from wbia_finfindr._plugin import *" \ + && /virtualenv/env3/bin/python -c "import wbia_kaggle7; from wbia_kaggle7._plugin import *" \ + && /virtualenv/env3/bin/python -c "import wbia_deepsense; from wbia_deepsense._plugin import *" \ && find /wbia/wbia* -name '*.a' -print0 | xargs -0 -i /bin/bash -c 'echo {} && ld -d {}' \ && find /wbia/wbia* -name '*.so' -print0 | xargs -0 -i /bin/bash -c 'echo {} && ld -d {}' \ && find /wbia/wildbook* -name '*.a' -print0 | xargs -0 -i /bin/bash -c 'echo {} && ld -d {}' \ diff --git a/devops/provision/Dockerfile b/devops/provision/Dockerfile index 47fe001e53..97091bc4c1 100644 --- a/devops/provision/Dockerfile +++ b/devops/provision/Dockerfile @@ -50,6 +50,10 @@ RUN set -ex \ && git clone --branch develop https://github.com/WildbookOrg/wbia-plugin-pie.git \ && git clone --branch develop https://github.com/WildbookOrg/wbia-plugin-lca.git +RUN set -ex \ + && cd /wbia \ + && git clone --branch compatibility https://github.com/WildbookOrg/wbia-plugin-orientation.git + # git clone --recursive --branch develop https://github.com/WildbookOrg/wbia-plugin-2d-orientation.git # Clone third-party WBIA plug-in repositories @@ -138,6 +142,10 @@ RUN /bin/bash -xc '. /virtualenv/env3/bin/activate \ && cd /wbia/wbia-plugin-lca \ && pip install -e .' +RUN /bin/bash -xc '. /virtualenv/env3/bin/activate \ + && cd /wbia/wbia-plugin-orientation \ + && pip install -e .' + RUN /bin/bash -xc '. /virtualenv/env3/bin/activate \ && cd /wbia/wbia-plugin-flukematch \ && ./unix_build.sh \ diff --git a/wbia/control/IBEISControl.py b/wbia/control/IBEISControl.py index 544efdf3eb..5443e327a4 100644 --- a/wbia/control/IBEISControl.py +++ b/wbia/control/IBEISControl.py @@ -140,6 +140,12 @@ (('--no-2d-orient', '--no2dorient'), 'wbia_2d_orientation._plugin'), ] + +if ut.get_argflag('--orient'): + AUTOLOAD_PLUGIN_MODNAMES += [ + (('--no-orient', '--noorient'), 'wbia_orientation._plugin'), + ] + if ut.get_argflag('--pie'): AUTOLOAD_PLUGIN_MODNAMES += [ (('--no-pie', '--nopie'), 'wbia_pie._plugin'), From 37047fc00813e9bf22b4cd0e38f1ba66b4d99544 Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Thu, 26 Nov 2020 02:09:55 -0800 Subject: [PATCH 127/294] Added new plugin config for the orienter --- devops/Dockerfile | 2 ++ wbia/core_annots.py | 43 ++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/devops/Dockerfile b/devops/Dockerfile index 428ed40c31..6effeec172 100644 --- a/devops/Dockerfile +++ b/devops/Dockerfile @@ -18,6 +18,8 @@ RUN set -ex \ && git stash && git pull && git stash pop || git reset --hard origin/develop \ && cd /wbia/wbia-plugin-orientation/ \ && git stash && git pull && git stash pop || git reset --hard origin/compatibility \ + && cd /wbia/wildbook-ia/ \ + && git stash && git pull && git stash pop || git reset --hard origin/add-orientation-plugin \ && find /wbia -name '.git' -type d -print0 | xargs -0 rm -rf \ && find /wbia -name '_skbuild' -type d -print0 | xargs -0 rm -rf diff --git a/wbia/core_annots.py b/wbia/core_annots.py index f9497cb4ac..023fce8990 100644 --- a/wbia/core_annots.py +++ b/wbia/core_annots.py @@ -2173,7 +2173,11 @@ def compute_aoi2(depc, aid_list, config=None): class OrienterConfig(dtool.Config): _param_info_list = [ - ut.ParamInfo('orienter_algo', 'deepsense', valid_values=['deepsense']), + ut.ParamInfo( + 'orienter_algo', + 'plugin:orientation', + valid_values=['deepsense, plugin:orientation'], + ), ut.ParamInfo('orienter_weight_filepath', None), ] _sub_config_list = [ChipConfig] @@ -2201,9 +2205,10 @@ def compute_orients_annotations(depc, aid_list, config=None): (float, str): tup CommandLine: - python -m wbia.core_annots --exec-compute_orients_annotations --deepsense + pytest wbia/core_annots.py::compute_orients_annotations:0 + pytest wbia/core_annots.py::compute_orients_annotations:1 - Example: + Doctest: >>> # DISABLE_DOCTEST >>> from wbia.core_images import * # NOQA >>> import wbia @@ -2223,6 +2228,29 @@ def compute_orients_annotations(depc, aid_list, config=None): >>> ibs.set_annot_bboxes(aid_list, bbox_list, theta_list=theta_list) >>> result_list = depc.get_property('orienter', aid_list, None, config=config) >>> print(result_list) + + Doctest: + >>> # ENABLE_DOCTEST + >>> from wbia.core_annots import * # NOQA + >>> import wbia + >>> defaultdb = 'testdb_identification' + >>> ibs = wbia.opendb(defaultdb=defaultdb) + >>> import utool as ut + >>> ut.embed() + >>> depc = ibs.depc_annot + >>> aid_list = ibs.get_valid_aids()[-16:-8] + >>> config = {'orienter_algo': 'plugin:orientation'} + >>> # depc.delete_property('orienter', aid_list) + >>> result_list = depc.get_property('orienter', aid_list, None, config=config) + >>> xtl_list = list(map(int, map(np.around, ut.take_column(result_list, 0)))) + >>> ytl_list = list(map(int, map(np.around, ut.take_column(result_list, 1)))) + >>> w_list = list(map(int, map(np.around, ut.take_column(result_list, 2)))) + >>> h_list = list(map(int, map(np.around, ut.take_column(result_list, 3)))) + >>> theta_list = ut.take_column(result_list, 4) + >>> bbox_list = list(zip(xtl_list, ytl_list, w_list, h_list)) + >>> ibs.set_annot_bboxes(aid_list, bbox_list, theta_list=theta_list) + >>> result_list = depc.get_property('orienter', aid_list, None, config=config) + >>> print(result_list) """ logger.info('[ibs] Process Annotation Labels') logger.info('config = %r' % (config,)) @@ -2265,6 +2293,15 @@ def compute_orients_annotations(depc, aid_list, config=None): result_gen.append(result) except Exception: raise RuntimeError('Deepsense orienter not working!') + elif config['orienter_algo'] in ['plugin:orientation']: + logger.info('[ibs] orienting using Orientation Plug-in') + try: + import utool as ut + + ut.embed() + from wbia_orientation import _plugin # NOQA + except Exception: + raise RuntimeError('Orientation plug-in not working!') else: raise ValueError( 'specified orienter algo is not supported in config = %r' % (config,) From e7c9069d1567e2cf5d4d5cbe9c069aa66a61e301 Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Tue, 1 Dec 2020 19:49:10 -0800 Subject: [PATCH 128/294] Added non-ajax call --- wbia/web/apis.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/wbia/web/apis.py b/wbia/web/apis.py index 03ee516329..5c6617f581 100644 --- a/wbia/web/apis.py +++ b/wbia/web/apis.py @@ -54,6 +54,16 @@ def web_embed(*args, **kwargs): ut.embed() +@register_route( + '/api/image/src/.jpg', + methods=['GET'], + __route_prefix_check__=False, + __route_authenticate__=False, +) +def image_src_api_ext(*args, **kwargs): + return image_src_api(*args, **kwargs) + + # Special function that is a route only to ignore the JSON response, but is # actually (and should be) an API call @register_route( From 0a4d24807871a5e3dbc3253b53e4b960288dcc9d Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Tue, 1 Dec 2020 19:50:39 -0800 Subject: [PATCH 129/294] Added non-ajax callc --- wbia/web/apis.py | 1 + 1 file changed, 1 insertion(+) diff --git a/wbia/web/apis.py b/wbia/web/apis.py index 5c6617f581..f9c5e8408f 100644 --- a/wbia/web/apis.py +++ b/wbia/web/apis.py @@ -58,6 +58,7 @@ def web_embed(*args, **kwargs): '/api/image/src/.jpg', methods=['GET'], __route_prefix_check__=False, + __route_postfix_check__=False, __route_authenticate__=False, ) def image_src_api_ext(*args, **kwargs): From 4f63292a3017fd886af4a8946d0cd6a5625eab45 Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Wed, 2 Dec 2020 13:28:53 -0800 Subject: [PATCH 130/294] Added better logging levels --- wbia/other/ibsfuncs.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/wbia/other/ibsfuncs.py b/wbia/other/ibsfuncs.py index 70d5676179..76057649dc 100644 --- a/wbia/other/ibsfuncs.py +++ b/wbia/other/ibsfuncs.py @@ -49,6 +49,10 @@ logger = logging.getLogger('wbia') +logging.getLogger().setLevel(logging.INFO) +logger.setLevel(logging.INFO) + + # Must import class before injection CLASS_INJECT_KEY, register_ibs_method = controller_inject.make_ibs_register_decorator( __name__ From aa311beb4c354a2cbb39d74309375352a8d13876 Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Wed, 9 Dec 2020 12:15:19 -0800 Subject: [PATCH 131/294] WIP --- wbia/core_annots.py | 61 +++++++++++++++++++++++++++++++++++++----- wbia/other/ibsfuncs.py | 5 ++-- 2 files changed, 56 insertions(+), 10 deletions(-) diff --git a/wbia/core_annots.py b/wbia/core_annots.py index 023fce8990..0c6381ed4c 100644 --- a/wbia/core_annots.py +++ b/wbia/core_annots.py @@ -2206,7 +2206,7 @@ def compute_orients_annotations(depc, aid_list, config=None): CommandLine: pytest wbia/core_annots.py::compute_orients_annotations:0 - pytest wbia/core_annots.py::compute_orients_annotations:1 + python -m wbia.core_annots --exec-compute_orients_annotations:1 --orient Doctest: >>> # DISABLE_DOCTEST @@ -2231,14 +2231,22 @@ def compute_orients_annotations(depc, aid_list, config=None): Doctest: >>> # ENABLE_DOCTEST - >>> from wbia.core_annots import * # NOQA >>> import wbia - >>> defaultdb = 'testdb_identification' - >>> ibs = wbia.opendb(defaultdb=defaultdb) + >>> import random >>> import utool as ut - >>> ut.embed() + >>> from wbia.init import sysres + >>> import numpy as np + >>> dbdir = sysres.ensure_testdb_orientation() + >>> ibs = wbia.opendb(dbdir=dbdir) + >>> aid_list = ibs.get_valid_aids() + >>> note_list = ibs.get_annot_notes(aid_list) + >>> species_list = ibs.get_annot_species(aid_list) + >>> flag_list = [ + >>> note == 'random-01' and species == 'right_whale_head' + >>> for note, species in zip(note_list, species_list) + >>> ] + >>> aid_list = ut.compress(aid_list, flag_list) >>> depc = ibs.depc_annot - >>> aid_list = ibs.get_valid_aids()[-16:-8] >>> config = {'orienter_algo': 'plugin:orientation'} >>> # depc.delete_property('orienter', aid_list) >>> result_list = depc.get_property('orienter', aid_list, None, config=config) @@ -2296,10 +2304,43 @@ def compute_orients_annotations(depc, aid_list, config=None): elif config['orienter_algo'] in ['plugin:orientation']: logger.info('[ibs] orienting using Orientation Plug-in') try: + from wbia_orientation import _plugin # NOQA + + species_tag_mapping = {'right_whale_head': 'rightwhale'} + + species_list = ibs.get_annot_species(aid_list) + + species_dict = {} + for aid, species in zip(aid_list, species_list): + if species not in species_dict: + species_dict[species] = [] + species_dict[species].append(aid) + + results_dict = {} + species_key_list = sorted(species_dict.keys()) + for species in species_key_list: + species_tag = species_tag_mapping.get(species, species) + message = 'Orientation plug-in does not support species_tag = %r' % ( + species_tag, + ) + assert species_tag in _plugin.MODEL_URLS, message + assert species_tag in _plugin.CONFIGS, message + aid_list_ = sorted(species_dict[species]) + print( + 'Computing %d orientations for species = %r' + % ( + len(aid_list_), + species, + ) + ) + + output_list, theta_list = _plugin.wbia_plugin_detect_oriented_box( + ibs, aid_list_, species_tag, plot_samples=False + ) + import utool as ut ut.embed() - from wbia_orientation import _plugin # NOQA except Exception: raise RuntimeError('Orientation plug-in not working!') else: @@ -2310,3 +2351,9 @@ def compute_orients_annotations(depc, aid_list, config=None): # yield detections for result in result_gen: yield result + + +if __name__ == '__main__': + import xdoctest as xdoc + + xdoc.doctest_module(__file__) diff --git a/wbia/other/ibsfuncs.py b/wbia/other/ibsfuncs.py index 76057649dc..1318dbf097 100644 --- a/wbia/other/ibsfuncs.py +++ b/wbia/other/ibsfuncs.py @@ -48,9 +48,8 @@ (print, rrr, profile) = ut.inject2(__name__, '[ibsfuncs]') logger = logging.getLogger('wbia') - -logging.getLogger().setLevel(logging.INFO) -logger.setLevel(logging.INFO) +logging.getLogger().setLevel(logging.DEBUG) +logger.setLevel(logging.DEBUG) # Must import class before injection From 2dc28f2f5c737ffb7183a26a45bf78de153c5c88 Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Wed, 30 Dec 2020 10:11:55 -0800 Subject: [PATCH 132/294] Update branches --- devops/Dockerfile | 2 +- devops/provision/Dockerfile | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/devops/Dockerfile b/devops/Dockerfile index 6effeec172..b190682e7b 100644 --- a/devops/Dockerfile +++ b/devops/Dockerfile @@ -17,7 +17,7 @@ RUN set -ex \ && cd /wbia/wbia-plugin-kaggle7/wbia_kaggle7 \ && git stash && git pull && git stash pop || git reset --hard origin/develop \ && cd /wbia/wbia-plugin-orientation/ \ - && git stash && git pull && git stash pop || git reset --hard origin/compatibility \ + && git stash && git pull && git stash pop || git reset --hard origin/develop \ && cd /wbia/wildbook-ia/ \ && git stash && git pull && git stash pop || git reset --hard origin/add-orientation-plugin \ && find /wbia -name '.git' -type d -print0 | xargs -0 rm -rf \ diff --git a/devops/provision/Dockerfile b/devops/provision/Dockerfile index 97091bc4c1..1e4afa1928 100644 --- a/devops/provision/Dockerfile +++ b/devops/provision/Dockerfile @@ -52,7 +52,7 @@ RUN set -ex \ RUN set -ex \ && cd /wbia \ - && git clone --branch compatibility https://github.com/WildbookOrg/wbia-plugin-orientation.git + && git clone --branch develop https://github.com/WildbookOrg/wbia-plugin-orientation.git # git clone --recursive --branch develop https://github.com/WildbookOrg/wbia-plugin-2d-orientation.git From 325d854ca481faafce9709140dbe8b88b35c6fe4 Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Mon, 4 Jan 2021 13:11:12 -0800 Subject: [PATCH 133/294] Added functionoality for the orientation plugin with depc --- wbia/core_annots.py | 142 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 131 insertions(+), 11 deletions(-) diff --git a/wbia/core_annots.py b/wbia/core_annots.py index 0c6381ed4c..9448d13b72 100644 --- a/wbia/core_annots.py +++ b/wbia/core_annots.py @@ -2176,7 +2176,7 @@ class OrienterConfig(dtool.Config): ut.ParamInfo( 'orienter_algo', 'plugin:orientation', - valid_values=['deepsense, plugin:orientation'], + valid_values=['deepsense', 'plugin:orientation'], ), ut.ParamInfo('orienter_weight_filepath', None), ] @@ -2206,7 +2206,7 @@ def compute_orients_annotations(depc, aid_list, config=None): CommandLine: pytest wbia/core_annots.py::compute_orients_annotations:0 - python -m wbia.core_annots --exec-compute_orients_annotations:1 --orient + python -m xdoctest /Users/jason.parham/code/wildbook-ia/wbia/core_annots.py compute_orients_annotations:1 --orient Doctest: >>> # DISABLE_DOCTEST @@ -2226,11 +2226,10 @@ def compute_orients_annotations(depc, aid_list, config=None): >>> theta_list = ut.take_column(result_list, 4) >>> bbox_list = list(zip(xtl_list, ytl_list, w_list, h_list)) >>> ibs.set_annot_bboxes(aid_list, bbox_list, theta_list=theta_list) - >>> result_list = depc.get_property('orienter', aid_list, None, config=config) >>> print(result_list) Doctest: - >>> # ENABLE_DOCTEST + >>> # DISABLE_DOCTEST >>> import wbia >>> import random >>> import utool as ut @@ -2246,6 +2245,7 @@ def compute_orients_annotations(depc, aid_list, config=None): >>> for note, species in zip(note_list, species_list) >>> ] >>> aid_list = ut.compress(aid_list, flag_list) + >>> aid_list = aid_list[:10] >>> depc = ibs.depc_annot >>> config = {'orienter_algo': 'plugin:orientation'} >>> # depc.delete_property('orienter', aid_list) @@ -2256,8 +2256,7 @@ def compute_orients_annotations(depc, aid_list, config=None): >>> h_list = list(map(int, map(np.around, ut.take_column(result_list, 3)))) >>> theta_list = ut.take_column(result_list, 4) >>> bbox_list = list(zip(xtl_list, ytl_list, w_list, h_list)) - >>> ibs.set_annot_bboxes(aid_list, bbox_list, theta_list=theta_list) - >>> result_list = depc.get_property('orienter', aid_list, None, config=config) + >>> # ibs.set_annot_bboxes(aid_list, bbox_list, theta_list=theta_list) >>> print(result_list) """ logger.info('[ibs] Process Annotation Labels') @@ -2305,8 +2304,8 @@ def compute_orients_annotations(depc, aid_list, config=None): logger.info('[ibs] orienting using Orientation Plug-in') try: from wbia_orientation import _plugin # NOQA - - species_tag_mapping = {'right_whale_head': 'rightwhale'} + from wbia_orientation.utils.data_manipulation import get_object_aligned_box + import vtool as vt species_list = ibs.get_annot_species(aid_list) @@ -2319,7 +2318,7 @@ def compute_orients_annotations(depc, aid_list, config=None): results_dict = {} species_key_list = sorted(species_dict.keys()) for species in species_key_list: - species_tag = species_tag_mapping.get(species, species) + species_tag = _plugin.SPECIES_MODEL_TAG_MAPPING.get(species, species) message = 'Orientation plug-in does not support species_tag = %r' % ( species_tag, ) @@ -2338,9 +2337,130 @@ def compute_orients_annotations(depc, aid_list, config=None): ibs, aid_list_, species_tag, plot_samples=False ) - import utool as ut + for aid_, predicted_output, predicted_theta in zip( + aid_list_, output_list, theta_list + ): + xc, yc, xt, yt, w = predicted_output + predicted_verts = get_object_aligned_box(xc, yc, xt, yt, w) + predicted_verts = np.around(np.array(predicted_verts)).astype( + np.int64 + ) + predicted_verts = tuple(map(tuple, predicted_verts.tolist())) + + calculated_theta = np.arctan2(yt - yc, xt - xc) + np.deg2rad(90) + predicted_rot = vt.rotation_around_mat3x3( + calculated_theta * -1.0, xc, yc + ) + predicted_aligned_verts = vt.transform_points_with_homography( + predicted_rot, np.array(predicted_verts).T + ).T + predicted_aligned_verts = np.around(predicted_aligned_verts).astype( + np.int64 + ) + predicted_aligned_verts = tuple( + map(tuple, predicted_aligned_verts.tolist()) + ) + + predicted_bbox = vt.bboxes_from_vert_list([predicted_aligned_verts])[ + 0 + ] + ( + predicted_xtl, + predicted_ytl, + predicted_w, + predicted_h, + ) = predicted_bbox + + result = ( + predicted_xtl, + predicted_ytl, + predicted_w, + predicted_h, + calculated_theta, + # predicted_theta, + ) + results_dict[aid_] = result + + if False: + from itertools import combinations + + predicted_bbox_verts = vt.verts_list_from_bboxes_list( + [predicted_bbox] + )[0] + predicted_bbox_rot = vt.rotation_around_bbox_mat3x3( + calculated_theta, predicted_bbox + ) + predicted_bbox_rotated_verts = ( + vt.transform_points_with_homography( + predicted_bbox_rot, np.array(predicted_bbox_verts).T + ).T + ) + predicted_bbox_rotated_verts = np.around( + predicted_bbox_rotated_verts + ).astype(np.int64) + predicted_bbox_rotated_verts = tuple( + map(tuple, predicted_bbox_rotated_verts.tolist()) + ) + + gid_ = ibs.get_annot_gids(aid_) + image = ibs.get_images(gid_) + + original_bbox = ibs.get_annot_bboxes(aid_) + original_theta = ibs.get_annot_thetas(aid_) + original_verts = vt.verts_list_from_bboxes_list([original_bbox])[ + 0 + ] + original_rot = vt.rotation_around_bbox_mat3x3( + original_theta, original_bbox + ) + rotated_verts = vt.transform_points_with_homography( + original_rot, np.array(original_verts).T + ).T + rotated_verts = np.around(rotated_verts).astype(np.int64) + rotated_verts = tuple(map(tuple, rotated_verts.tolist())) + + color = (255, 0, 0) + for vert in original_verts: + cv2.circle(image, vert, 20, color, -1) + + for vert1, vert2 in combinations(original_verts, 2): + cv2.line(image, vert1, vert2, color, 5) + + color = (0, 0, 255) + for vert in rotated_verts: + cv2.circle(image, vert, 20, color, -1) + + for vert1, vert2 in combinations(rotated_verts, 2): + cv2.line(image, vert1, vert2, color, 5) + + color = (0, 255, 0) + for vert in predicted_verts: + cv2.circle(image, vert, 20, color, -1) + + for vert1, vert2 in combinations(predicted_verts, 2): + cv2.line(image, vert1, vert2, color, 5) + + color = (255, 255, 0) + for vert in predicted_aligned_verts: + cv2.circle(image, vert, 20, color, -1) + + for vert1, vert2 in combinations(predicted_aligned_verts, 2): + cv2.line(image, vert1, vert2, color, 5) + + color = (0, 255, 255) + for vert in predicted_bbox_rotated_verts: + cv2.circle(image, vert, 10, color, -1) + + for vert1, vert2 in combinations(predicted_bbox_rotated_verts, 2): + cv2.line(image, vert1, vert2, color, 1) + + cv2.imwrite('/tmp/image.%d.png' % (aid_), image) + + result_gen = [] + for aid in aid_list: + result = results_dict[aid] + result_gen.append(result) - ut.embed() except Exception: raise RuntimeError('Orientation plug-in not working!') else: From 01a96d08e28675301eb4915141c16aca6648f45c Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Tue, 5 Jan 2021 12:37:36 -0800 Subject: [PATCH 134/294] Update logging --- wbia/entry_points.py | 2 +- wbia/other/ibsfuncs.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/wbia/entry_points.py b/wbia/entry_points.py index a047c1828c..d3f5da515a 100644 --- a/wbia/entry_points.py +++ b/wbia/entry_points.py @@ -75,7 +75,7 @@ def _init_wbia(dbdir=None, verbose=None, use_cache=True, web=None, **kwargs): # Set up logging # TODO (30-Nov-12020) This is intended to be a temporary fix to logging. - logger.setLevel(logging.DEBUG) + # logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler()) if verbose is None: diff --git a/wbia/other/ibsfuncs.py b/wbia/other/ibsfuncs.py index 1318dbf097..3bbf2c8293 100644 --- a/wbia/other/ibsfuncs.py +++ b/wbia/other/ibsfuncs.py @@ -48,8 +48,8 @@ (print, rrr, profile) = ut.inject2(__name__, '[ibsfuncs]') logger = logging.getLogger('wbia') -logging.getLogger().setLevel(logging.DEBUG) -logger.setLevel(logging.DEBUG) +# logging.getLogger().setLevel(logging.DEBUG) +# logger.setLevel(logging.DEBUG) # Must import class before injection From 9692adb8c3fa8201223da240adbd09f5a0211041 Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Tue, 5 Jan 2021 15:34:37 -0800 Subject: [PATCH 135/294] Turn off temporary logging --- wbia/entry_points.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wbia/entry_points.py b/wbia/entry_points.py index d3f5da515a..3a2fd648a0 100644 --- a/wbia/entry_points.py +++ b/wbia/entry_points.py @@ -76,7 +76,7 @@ def _init_wbia(dbdir=None, verbose=None, use_cache=True, web=None, **kwargs): # Set up logging # TODO (30-Nov-12020) This is intended to be a temporary fix to logging. # logger.setLevel(logging.DEBUG) - logger.addHandler(logging.StreamHandler()) + # logger.addHandler(logging.StreamHandler()) if verbose is None: verbose = ut.VERBOSE From 9417d35cf14cbb18a1b1314ed0723710c1332a1c Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Tue, 5 Jan 2021 16:05:30 -0800 Subject: [PATCH 136/294] Change Docker wbia branch back to develop --- devops/Dockerfile | 2 -- 1 file changed, 2 deletions(-) diff --git a/devops/Dockerfile b/devops/Dockerfile index b190682e7b..bc213ae98a 100644 --- a/devops/Dockerfile +++ b/devops/Dockerfile @@ -18,8 +18,6 @@ RUN set -ex \ && git stash && git pull && git stash pop || git reset --hard origin/develop \ && cd /wbia/wbia-plugin-orientation/ \ && git stash && git pull && git stash pop || git reset --hard origin/develop \ - && cd /wbia/wildbook-ia/ \ - && git stash && git pull && git stash pop || git reset --hard origin/add-orientation-plugin \ && find /wbia -name '.git' -type d -print0 | xargs -0 rm -rf \ && find /wbia -name '_skbuild' -type d -print0 | xargs -0 rm -rf From f205df630691b1f32c7d0145f99d00fbd37837ee Mon Sep 17 00:00:00 2001 From: Drew Blount Date: Mon, 23 Nov 2020 11:45:14 -0800 Subject: [PATCH 137/294] WIP on feature caching --- wbia/core_parts.py | 109 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) diff --git a/wbia/core_parts.py b/wbia/core_parts.py index b5ec1183f0..05971627ed 100644 --- a/wbia/core_parts.py +++ b/wbia/core_parts.py @@ -6,6 +6,7 @@ import logging import utool as ut import numpy as np +from wbia import dtool from wbia.control.controller_inject import register_preprocs, register_subprops from wbia import core_annots @@ -83,3 +84,111 @@ def compute_part_chip(depc, part_rowid_list, config=None): for result in result_list: yield result logger.info('Done Preprocessing Part Chips') + + +class PartAssignmentFeatureConfig(dtool.Config): + _param_info_list = [] + + +@derived_attribute( + tablename='part_assignment_features', + parents=['parts', 'annotations'], + colnames=[ + 'p_xtl', 'p_ytl', 'p_w', 'p_h', + 'a_xtl', 'a_ytl', 'a_w', 'a_h', + 'int_xtl', 'int_ytl', 'int_w', 'int_h', + 'intersect_area_relative_part', + 'intersect_area_relative_annot', + 'part_area_relative_annot' + ], + coltypes=[ + float, float, float, float, + float, float, float, float, + float, float, float, float, + float, float, float + ], + configclass=PartAssignmentFeatureConfig, + fname='part_assignment_features', + rm_extern_on_delete=True, + chunksize=256, +) +def compute_assignment_features(depc, part_rowid_list, aid_list, config=None): + assert len(part_rowid_list) == len(aid_list) + ibs = depc.controller + + part_bboxes = ibs.get_part_bboxes(part_rowid_list) + annot_bboxes = ibs.get_annot_bboxes(aid_list) + + part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] + annot_areas = [bbox[2] * bbox[3] for bbox in annot_bboxes] + p_area_relative_annot = [part_area / annot_area + for (part_area, annot_area) in zip(part_areas, annot_areas)] + + intersect_bboxes = _bbox_intersections(part_bboxes, annot_bboxes) + intersect_areas = [w * h if w > 0 and h > 0 else 0 + for (_,_,w,h) in intersect_bboxes] + + int_area_relative_part = [int_area / part_area for int_area, part_area + in zip(intersect_areas, part_areas)] + int_area_relative_annot = [int_area / annot_area for int_area, annot_area + in zip(intersect_areas, annot_areas)] + + result_list = list(zip(part_bboxes, annot_bboxes, intersect_bboxes, + int_area_relative_part, int_area_relative_annot, p_area_relative_annot)) + + for (part_bbox, annot_bbox, intersect_bbox, int_area_relative_part, + int_area_relative_annot, p_area_relative_annot) in result_list: + yield (part_bbox[0], part_bbox[1], part_bbox[2], part_bbox[3], + annot_bbox[0], annot_bbox[1], annot_bbox[2], annot_bbox[3], + intersect_bbox[0], intersect_bbox[1], intersect_bbox[2], intersect_bbox[3], + intersect_area_relative_part, + intersect_area_relative_annot, + part_area_relative_annot) + + + +def _bbox_intersections(bboxes_a, bboxes_b): + corner_bboxes_a = _bbox_to_corner_format(bboxes_a) + corner_bboxes_b = _bbox_to_corner_format(bboxes_b) + + intersect_xtls = [max(xtl_a, xtl_b) + for ((xtl_a, _, _, _),(xtl_b, _, _, _)) + in zip(corner_bboxes_a, corner_bboxes_b)] + + intersect_ytls = [max(ytl_a, ytl_b) + for ((_, ytl_a, _, _),(_, ytl_b, _, _)) + in zip(corner_bboxes_a, corner_bboxes_b)] + + intersect_xbrs = [min(xbr_a, xbr_b) + for ((_, _, xbr_a, _),(_, _, xbr_b, _)) + in zip(corner_bboxes_a, corner_bboxes_b)] + + intersect_ybrs = [min(ybr_a, ybr_b) + for ((_, _, _, ybr_a),(_, _, _, ybr_b)) + in zip(corner_bboxes_a, corner_bboxes_b)] + + intersect_widths = [int_xbr - int_xtl for int_xbr, int_xtl + in zip(intersect_xbrs, intersect_xtls)] + + intersect_heights = [int_ybr - int_ytl for int_ybr, int_ytl + in zip(intersect_ybrs, intersect_ytls)] + + intersect_bboxes = list(zip(intersect_xtls, intersect_ytls, + intersect_widths, intersect_heights)) + + return intersect_bboxes + + + +# converts bboxes from (xtl, ytl, w, h) to (xtl, ytl, xbr, ybr) +def _bbox_to_corner_format(bboxes): + corner_bboxes = [(bbox[0], bbox[1], bbox[0]+bbox[2], bbox[1]+bbox[3]) + for bbox in bboxes] + return corner_bboxes + + + + + + + From 5fc64d7bf5a039fb0580b0dc61bb8140d03a8043 Mon Sep 17 00:00:00 2001 From: Drew Blount Date: Tue, 15 Dec 2020 13:11:29 -0800 Subject: [PATCH 138/294] WIP on assigner --- wbia/control/manual_part_funcs.py | 70 +++++++++--- wbia/core_annots.py | 175 +++++++++++++++++++++++++++++- wbia/core_parts.py | 107 +----------------- wbia/dtool/depcache_control.py | 2 + 4 files changed, 231 insertions(+), 123 deletions(-) diff --git a/wbia/control/manual_part_funcs.py b/wbia/control/manual_part_funcs.py index c7ffcd8c4e..03866f74f9 100644 --- a/wbia/control/manual_part_funcs.py +++ b/wbia/control/manual_part_funcs.py @@ -29,7 +29,7 @@ PART_NOTE = 'part_note' PART_NUM_VERTS = 'part_num_verts' PART_ROWID = 'part_rowid' -# PART_TAG_TEXT = 'part_tag_text' +PART_TAG_TEXT = 'part_tag_text' PART_THETA = 'part_theta' PART_VERTS = 'part_verts' PART_UUID = 'part_uuid' @@ -1088,21 +1088,21 @@ def set_part_viewpoints(ibs, part_rowid_list, viewpoint_list): ibs.db.set(const.PART_TABLE, ('part_viewpoint',), val_iter, id_iter) -# @register_ibs_method -# @accessor_decors.setter -# def set_part_tag_text(ibs, part_rowid_list, part_tags_list, duplicate_behavior='error'): -# r""" part_tags_list -> part.part_tags[part_rowid_list] +@register_ibs_method +@accessor_decors.setter +def set_part_tag_text(ibs, part_rowid_list, part_tags_list, duplicate_behavior='error'): + r""" part_tags_list -> part.part_tags[part_rowid_list] -# Args: -# part_rowid_list -# part_tags_list + Args: + part_rowid_list + part_tags_list -# """ -# #logger.info('[ibs] set_part_tag_text of part_rowid_list=%r to tags=%r' % (part_rowid_list, part_tags_list)) -# id_iter = part_rowid_list -# colnames = (PART_TAG_TEXT,) -# ibs.db.set(const.PART_TABLE, colnames, part_tags_list, -# id_iter, duplicate_behavior=duplicate_behavior) + """ + #logger.info('[ibs] set_part_tag_text of part_rowid_list=%r to tags=%r' % (part_rowid_list, part_tags_list)) + id_iter = part_rowid_list + colnames = (PART_TAG_TEXT,) + ibs.db.set(const.PART_TABLE, colnames, part_tags_list, + id_iter, duplicate_behavior=duplicate_behavior) @register_ibs_method @@ -1485,6 +1485,48 @@ def set_part_contour(ibs, part_rowid_list, contour_dict_list): ibs.db.set(const.PART_TABLE, ('part_contour_json',), val_list, id_iter) +# setting up the Wild Dog data for assigner training +# def get_corresponding_aids(ibs, part_rowid_list, from_aids=None): +# if from_aids is None: +# from_aids = ibs.get_valid_aids() + + +def get_corresponding_aids_slow(ibs, part_rowid_list, from_aids): + part_bboxes = ibs.get_part_bboxes(part_rowid_list) + annot_bboxes = ibs.get_annot_bboxes(from_aids) + annot_gids = ibs.get_annot_gids(from_aids) + from collections import defaultdict + bbox_gid_to_aids = defaultdict(int) + for aid, gid, bbox in zip(from_aids, annot_gids, annot_bboxes): + bbox_gid_to_aids[(bbox[0], bbox[1], bbox[2], bbox[3], gid)] = aid + part_gids = ibs.get_part_image_rowids(parts) + part_rowid_to_aid = {part_id: bbox_gid_to_aids[(bbox[0], bbox[1], bbox[2], bbox[3], gid)] for part_id, gid, bbox in zip(part_rowid_list, part_gids, part_bboxes)} + + part_aids = [part_rowid_to_aid[partid] for partid in parts] + part_parent_aids = ibs.get_part_aids(part_rowid_list) + + # parents might be non-unique so we gotta make a unique name for each parent + parent_aid_to_part_rowids = defaultdict(list) + for part_rowid, parent_aid in zip(part_rowid_list, part_parent_aids): + parent_aid_to_part_rowids[parent_aid] += [part_rowid] + + part_annot_names = [','.join(str(p) for p in parent_aid_to_part_rowids[parent_aid]) for parent_aid in part_parent_aids] + + + # now assign names so we can associate the part annots with the non-part annots + new_part_names = ['part-%s' % part_rowid for part_rowid in part_rowid_list] + + +def sort_parts_by_tags(ibs, part_rowid_list): + tags = ibs.get_part_tag_text(part_rowid_list) + from collections import defaultdict + tag_to_rowids = new defaultdict(list) + for tag, part_rowid in zip(tags, part_rowid_list): + tag_to_rowids[tag] += [part_rowid] + parts_by_tags = [tag_to_rowdids[tag] for tag in tag_to_rowdids.keys()] + return parts_by_tags + + # ========== # Testdata # ========== diff --git a/wbia/core_annots.py b/wbia/core_annots.py index 9448d13b72..9f43edf4bd 100644 --- a/wbia/core_annots.py +++ b/wbia/core_annots.py @@ -62,6 +62,7 @@ (print, rrr, profile) = ut.inject2(__name__) logger = logging.getLogger('wbia') +CLASS_INJECT_KEY, register_ibs_method = make_ibs_register_decorator(__name__) derived_attribute = register_preprocs['annot'] register_subprop = register_subprops['annot'] @@ -2473,7 +2474,175 @@ def compute_orients_annotations(depc, aid_list, config=None): yield result -if __name__ == '__main__': - import xdoctest as xdoc +# for assigning part-annots to body-annots of the same individual: +class PartAssignmentFeatureConfig(dtool.Config): + _param_info_list = [] + + +@derived_attribute( + tablename='part_assignment_features', + parents=['annotations', 'annotations'], + colnames=[ + 'p_xtl', 'p_ytl', 'p_w', 'p_h', + 'b_xtl', 'b_ytl', 'b_w', 'b_h', + 'int_xtl', 'int_ytl', 'int_w', 'int_h', + 'intersect_area_relative_part', + 'intersect_area_relative_body', + 'part_area_relative_body' + ], + coltypes=[ + int, int, int, int, + int, int, int, int, + int, int, int, int, + float, float, float + ], + configclass=PartAssignmentFeatureConfig, + fname='part_assignment_features', + rm_extern_on_delete=True, + chunksize=256, +) +def compute_assignment_features(depc, part_aid_list, body_aid_list, config=None): + + ibs = depc.controller + + part_gids = ibs.get_annot_gids(part_aid_list) + body_gids = ibs.get_annot_gids(body_aid_list) + assert part_gids == body_gids, 'can only compute assignment features on aids in the same image' + parts_are_parts = _are_part_annots(part_aid_list) + assert all(are_parts_parts), 'all part_aids must be part annots.' + bodies_are_parts = _are_part_annots(body_aid_list) + assert not any(bodies_are_parts), 'body_aids cannot be part annots' + + part_bboxes = ibs.get_annot_bboxes(part_aid_list) + body_bboxes = ibs.get_annot_bboxes(body_aid_list) + + part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] + body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] + part_area_relative_body = [part_area / body_area + for (part_area, body_area) in zip(part_areas, body_areas)] + + intersect_bboxes = _bbox_intersections(part_bboxes, body_bboxes) + # note that intesect w and h could be negative if there is no intersection, in which case it is the x/y distance between the annots. + intersect_areas = [w * h if w > 0 and h > 0 else 0 + for (_, _, w, h) in intersect_bboxes] + + int_area_relative_part = [int_area / part_area for int_area, part_area + in zip(intersect_areas, part_areas)] + int_area_relative_body = [int_area / body_area for int_area, body_area + in zip(intersect_areas, body_areas)] + + result_list = list(zip( + part_bboxes, body_bboxes, intersect_bboxes, + int_area_relative_part, int_area_relative_body, part_area_relative_body + )) + + for (part_bbox, body_bbox, intersect_bbox, int_area_relative_part, + int_area_relative_body, part_area_relative_body) in result_list: + yield (part_bbox[0], part_bbox[1], part_bbox[2], part_bbox[3], + body_bbox[0], body_bbox[1], body_bbox[2], body_bbox[3], + intersect_bbox[0], intersect_bbox[1], intersect_bbox[2], intersect_bbox[3], + int_area_relative_part, + int_area_relative_body, + part_area_relative_body) + + +def _are_part_annots(ibs, aid_list): + species = ibs.get_annot_species(aid_list) + are_parts = ['+' in specie for specie in species] + return are_parts + + +def _bbox_intersections(bboxes_a, bboxes_b): + corner_bboxes_a = _bbox_to_corner_format(bboxes_a) + corner_bboxes_b = _bbox_to_corner_format(bboxes_b) + + intersect_xtls = [max(xtl_a, xtl_b) + for ((xtl_a, _, _, _), (xtl_b, _, _, _)) + in zip(corner_bboxes_a, corner_bboxes_b)] + + intersect_ytls = [max(ytl_a, ytl_b) + for ((_, ytl_a, _, _), (_, ytl_b, _, _)) + in zip(corner_bboxes_a, corner_bboxes_b)] + + intersect_xbrs = [min(xbr_a, xbr_b) + for ((_, _, xbr_a, _), (_, _, xbr_b, _)) + in zip(corner_bboxes_a, corner_bboxes_b)] + + intersect_ybrs = [min(ybr_a, ybr_b) + for ((_, _, _, ybr_a), (_, _, _, ybr_b)) + in zip(corner_bboxes_a, corner_bboxes_b)] + + intersect_widths = [int_xbr - int_xtl for int_xbr, int_xtl + in zip(intersect_xbrs, intersect_xtls)] + + intersect_heights = [int_ybr - int_ytl for int_ybr, int_ytl + in zip(intersect_ybrs, intersect_ytls)] + + intersect_bboxes = list(zip( + intersect_xtls, intersect_ytls, intersect_widths, intersect_heights)) + + return intersect_bboxes + + +# converts bboxes from (xtl, ytl, w, h) to (xtl, ytl, xbr, ybr) +def _bbox_to_corner_format(bboxes): + corner_bboxes = [(bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]) + for bbox in bboxes] + return corner_bboxes + + +def all_part_pairs(ibs, gid_list): + all_aids = ibs.get_image_aids(gid_list) + all_aids_are_parts = [_are_part_annots(ibs, aids) for aids in all_aids] + all_part_aids = [[aid for (aid, part) in zip(aids, are_parts) if part] for (aids, are_parts) in zip(all_aids, all_aids_are_parts)] + all_body_aids = [[aid for (aid, part) in zip(aids, are_parts) if not part] for (aids, are_parts) in zip(all_aids, all_aids_are_parts)] + part_body_parallel_lists = [_all_pairs_parallel(parts, bodies) for parts, bodies in zip(all_part_aids, all_body_aids)] + all_parts = [aid for part_body_parallel_list in part_body_parallel_lists + for aid in part_body_parallel_list[0]] + all_bodies = [aid for part_body_parallel_list in part_body_parallel_lists + for aid in part_body_parallel_list[1]] + return all_parts, all_bodies + + +def _all_pairs_parallel(list_a, list_b): + pairs = [(a, b) for a in list_a for b in list_b] + pairs_a = [pair[0] for pair in pairs] + pairs_b = [pair[1] for pair in pairs] + return pairs_a, pairs_b + + +# for wild dog dev +def wd_assigner_data(ibs): + all_aids = ibs.get_valid_aids() + ia_classes = ibs.get_annot_species(all_aids) + part_aids = [aid for aid, ia_class in zip(all_aids, ia_classes) if '+' in ia_class] + part_gids = list(set(ibs.get_annot_gids(part_aids))) + all_pairs = all_part_pairs(ibs, part_gids) + all_feats = ibs.depc_annot.get('part_assignment_features', all_pairs) + names = [ibs.get_annot_names(all_pairs[0]), ibs.get_annot_names(all_pairs[1])] + ground_truth = [n1 == n2 for (n1, n2) in zip(names[0],names[1])] + # we now have all features and the ground truths, time to to a train/test split + train_feats, test_feats = train_test_split(all_feats) + train_truth, test_truth = train_test_split(ground_truth) + assigner_data = {'data': train_feats, 'target': train_truth, + 'test': test_feats, 'test_truth': test_truth} + return assigner_data + + +def train_test_split(item_list, random_seed=777, test_size=0.1): + import random + import math + random.seed(random_seed) + sample_size = math.floor(len(item_list) * test_size) + all_indices = list(range(len(item_list))) + test_indices = random.sample(all_indices, sample_size) + test_items = [item_list[i] for i in test_indices] + train_indices = sorted(list( + set(all_indices) - set(test_indices) + )) + train_items = [item_list[i] for i in train_indices] + return train_items, test_items + + + - xdoc.doctest_module(__file__) diff --git a/wbia/core_parts.py b/wbia/core_parts.py index 05971627ed..744592f927 100644 --- a/wbia/core_parts.py +++ b/wbia/core_parts.py @@ -9,6 +9,7 @@ from wbia import dtool from wbia.control.controller_inject import register_preprocs, register_subprops from wbia import core_annots +from wbia.constants import ANNOTATION_TABLE, PART_TABLE (print, rrr, profile) = ut.inject2(__name__) logger = logging.getLogger('wbia') @@ -86,109 +87,3 @@ def compute_part_chip(depc, part_rowid_list, config=None): logger.info('Done Preprocessing Part Chips') -class PartAssignmentFeatureConfig(dtool.Config): - _param_info_list = [] - - -@derived_attribute( - tablename='part_assignment_features', - parents=['parts', 'annotations'], - colnames=[ - 'p_xtl', 'p_ytl', 'p_w', 'p_h', - 'a_xtl', 'a_ytl', 'a_w', 'a_h', - 'int_xtl', 'int_ytl', 'int_w', 'int_h', - 'intersect_area_relative_part', - 'intersect_area_relative_annot', - 'part_area_relative_annot' - ], - coltypes=[ - float, float, float, float, - float, float, float, float, - float, float, float, float, - float, float, float - ], - configclass=PartAssignmentFeatureConfig, - fname='part_assignment_features', - rm_extern_on_delete=True, - chunksize=256, -) -def compute_assignment_features(depc, part_rowid_list, aid_list, config=None): - assert len(part_rowid_list) == len(aid_list) - ibs = depc.controller - - part_bboxes = ibs.get_part_bboxes(part_rowid_list) - annot_bboxes = ibs.get_annot_bboxes(aid_list) - - part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] - annot_areas = [bbox[2] * bbox[3] for bbox in annot_bboxes] - p_area_relative_annot = [part_area / annot_area - for (part_area, annot_area) in zip(part_areas, annot_areas)] - - intersect_bboxes = _bbox_intersections(part_bboxes, annot_bboxes) - intersect_areas = [w * h if w > 0 and h > 0 else 0 - for (_,_,w,h) in intersect_bboxes] - - int_area_relative_part = [int_area / part_area for int_area, part_area - in zip(intersect_areas, part_areas)] - int_area_relative_annot = [int_area / annot_area for int_area, annot_area - in zip(intersect_areas, annot_areas)] - - result_list = list(zip(part_bboxes, annot_bboxes, intersect_bboxes, - int_area_relative_part, int_area_relative_annot, p_area_relative_annot)) - - for (part_bbox, annot_bbox, intersect_bbox, int_area_relative_part, - int_area_relative_annot, p_area_relative_annot) in result_list: - yield (part_bbox[0], part_bbox[1], part_bbox[2], part_bbox[3], - annot_bbox[0], annot_bbox[1], annot_bbox[2], annot_bbox[3], - intersect_bbox[0], intersect_bbox[1], intersect_bbox[2], intersect_bbox[3], - intersect_area_relative_part, - intersect_area_relative_annot, - part_area_relative_annot) - - - -def _bbox_intersections(bboxes_a, bboxes_b): - corner_bboxes_a = _bbox_to_corner_format(bboxes_a) - corner_bboxes_b = _bbox_to_corner_format(bboxes_b) - - intersect_xtls = [max(xtl_a, xtl_b) - for ((xtl_a, _, _, _),(xtl_b, _, _, _)) - in zip(corner_bboxes_a, corner_bboxes_b)] - - intersect_ytls = [max(ytl_a, ytl_b) - for ((_, ytl_a, _, _),(_, ytl_b, _, _)) - in zip(corner_bboxes_a, corner_bboxes_b)] - - intersect_xbrs = [min(xbr_a, xbr_b) - for ((_, _, xbr_a, _),(_, _, xbr_b, _)) - in zip(corner_bboxes_a, corner_bboxes_b)] - - intersect_ybrs = [min(ybr_a, ybr_b) - for ((_, _, _, ybr_a),(_, _, _, ybr_b)) - in zip(corner_bboxes_a, corner_bboxes_b)] - - intersect_widths = [int_xbr - int_xtl for int_xbr, int_xtl - in zip(intersect_xbrs, intersect_xtls)] - - intersect_heights = [int_ybr - int_ytl for int_ybr, int_ytl - in zip(intersect_ybrs, intersect_ytls)] - - intersect_bboxes = list(zip(intersect_xtls, intersect_ytls, - intersect_widths, intersect_heights)) - - return intersect_bboxes - - - -# converts bboxes from (xtl, ytl, w, h) to (xtl, ytl, xbr, ybr) -def _bbox_to_corner_format(bboxes): - corner_bboxes = [(bbox[0], bbox[1], bbox[0]+bbox[2], bbox[1]+bbox[3]) - for bbox in bboxes] - return corner_bboxes - - - - - - - diff --git a/wbia/dtool/depcache_control.py b/wbia/dtool/depcache_control.py index 3a3a8e25e5..f5f6e43be9 100644 --- a/wbia/dtool/depcache_control.py +++ b/wbia/dtool/depcache_control.py @@ -470,6 +470,8 @@ def rectify_input_tuple(self, exi_inputs, input_tuple): input_tuple_ = (input_tuple_,) if len(exi_inputs) != len(input_tuple_): msg = '#expected=%d, #got=%d' % (len(exi_inputs), len(input_tuple_)) + print(msg) + ut.embed() raise ValueError(msg) # rectify input depth From 4b296384ebb07c1388f5009a5478b29eebd38987 Mon Sep 17 00:00:00 2001 From: Drew Blount Date: Fri, 18 Dec 2020 15:20:02 -0800 Subject: [PATCH 139/294] WIP before breaking off assigner.py --- wbia/core_annots.py | 533 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 523 insertions(+), 10 deletions(-) diff --git a/wbia/core_annots.py b/wbia/core_annots.py index 9f43edf4bd..75bb3ae907 100644 --- a/wbia/core_annots.py +++ b/wbia/core_annots.py @@ -55,10 +55,12 @@ import numpy as np import cv2 import wbia.constants as const -from wbia.control.controller_inject import register_preprocs, register_subprops +from wbia.control.controller_inject import register_preprocs, register_subprops, make_ibs_register_decorator from wbia.algo.hots.chip_match import ChipMatch from wbia.algo.hots import neighbor_index +from sklearn import preprocessing + (print, rrr, profile) = ut.inject2(__name__) logger = logging.getLogger('wbia') @@ -2508,13 +2510,236 @@ def compute_assignment_features(depc, part_aid_list, body_aid_list, config=None) part_gids = ibs.get_annot_gids(part_aid_list) body_gids = ibs.get_annot_gids(body_aid_list) assert part_gids == body_gids, 'can only compute assignment features on aids in the same image' - parts_are_parts = _are_part_annots(part_aid_list) - assert all(are_parts_parts), 'all part_aids must be part annots.' - bodies_are_parts = _are_part_annots(body_aid_list) + parts_are_parts = _are_part_annots(ibs, part_aid_list) + assert all(parts_are_parts), 'all part_aids must be part annots.' + bodies_are_parts = _are_part_annots(ibs, body_aid_list) + assert not any(bodies_are_parts), 'body_aids cannot be part annots' + + part_bboxes = ibs.get_annot_bboxes(part_aid_list) + body_bboxes = ibs.get_annot_bboxes(body_aid_list) + + part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] + body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] + part_area_relative_body = [part_area / body_area + for (part_area, body_area) in zip(part_areas, body_areas)] + + intersect_bboxes = _bbox_intersections(part_bboxes, body_bboxes) + # note that intesect w and h could be negative if there is no intersection, in which case it is the x/y distance between the annots. + intersect_areas = [w * h if w > 0 and h > 0 else 0 + for (_, _, w, h) in intersect_bboxes] + + int_area_relative_part = [int_area / part_area for int_area, part_area + in zip(intersect_areas, part_areas)] + int_area_relative_body = [int_area / body_area for int_area, body_area + in zip(intersect_areas, body_areas)] + + result_list = list(zip( + part_bboxes, body_bboxes, intersect_bboxes, + int_area_relative_part, int_area_relative_body, part_area_relative_body + )) + + for (part_bbox, body_bbox, intersect_bbox, int_area_relative_part, + int_area_relative_body, part_area_relative_body) in result_list: + yield (part_bbox[0], part_bbox[1], part_bbox[2], part_bbox[3], + body_bbox[0], body_bbox[1], body_bbox[2], body_bbox[3], + intersect_bbox[0], intersect_bbox[1], intersect_bbox[2], intersect_bbox[3], + int_area_relative_part, + int_area_relative_body, + part_area_relative_body) + + +@derived_attribute( + tablename='normalized_assignment_features', + parents=['annotations', 'annotations'], + colnames=[ + 'p_xtl', 'p_ytl', 'p_w', 'p_h', + 'b_xtl', 'b_ytl', 'b_w', 'b_h', + 'int_xtl', 'int_ytl', 'int_w', 'int_h', + 'intersect_area_relative_part', + 'intersect_area_relative_body', + 'part_area_relative_body' + ], + coltypes=[ + float, float, float, float, + float, float, float, float, + float, float, float, float, + float, float, float + ], + configclass=PartAssignmentFeatureConfig, + fname='normalized_assignment_features', + rm_extern_on_delete=True, + chunksize=256, +) +def normalized_assignment_features(depc, part_aid_list, body_aid_list, config=None): + + ibs = depc.controller + + part_gids = ibs.get_annot_gids(part_aid_list) + body_gids = ibs.get_annot_gids(body_aid_list) + assert part_gids == body_gids, 'can only compute assignment features on aids in the same image' + parts_are_parts = _are_part_annots(ibs, part_aid_list) + assert all(parts_are_parts), 'all part_aids must be part annots.' + bodies_are_parts = _are_part_annots(ibs, body_aid_list) + assert not any(bodies_are_parts), 'body_aids cannot be part annots' + + part_bboxes = ibs.get_annot_bboxes(part_aid_list) + body_bboxes = ibs.get_annot_bboxes(body_aid_list) + im_widths = ibs.get_image_widths(part_gids) + im_heights = ibs.get_image_heights(part_gids) + part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) + body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) + + part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] + body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] + part_area_relative_body = [part_area / body_area + for (part_area, body_area) in zip(part_areas, body_areas)] + + intersect_bboxes = _bbox_intersections(part_bboxes, body_bboxes) + # note that intesect w and h could be negative if there is no intersection, in which case it is the x/y distance between the annots. + intersect_areas = [w * h if w > 0 and h > 0 else 0 + for (_, _, w, h) in intersect_bboxes] + + int_area_relative_part = [int_area / part_area for int_area, part_area + in zip(intersect_areas, part_areas)] + int_area_relative_body = [int_area / body_area for int_area, body_area + in zip(intersect_areas, body_areas)] + + result_list = list(zip( + part_bboxes, body_bboxes, intersect_bboxes, + int_area_relative_part, int_area_relative_body, part_area_relative_body + )) + + for (part_bbox, body_bbox, intersect_bbox, int_area_relative_part, + int_area_relative_body, part_area_relative_body) in result_list: + yield (part_bbox[0], part_bbox[1], part_bbox[2], part_bbox[3], + body_bbox[0], body_bbox[1], body_bbox[2], body_bbox[3], + intersect_bbox[0], intersect_bbox[1], intersect_bbox[2], intersect_bbox[3], + int_area_relative_part, + int_area_relative_body, + part_area_relative_body) + + +@derived_attribute( + tablename='standardized_assignment_features', + parents=['annotations', 'annotations'], + colnames=[ + 'p_xtl', 'p_ytl', 'p_w', 'p_h', + 'b_xtl', 'b_ytl', 'b_w', 'b_h', + 'int_xtl', 'int_ytl', 'int_w', 'int_h', + 'intersect_area_relative_part', + 'intersect_area_relative_body', + 'part_area_relative_body' + ], + coltypes=[ + float, float, float, float, + float, float, float, float, + float, float, float, float, + float, float, float + ], + configclass=PartAssignmentFeatureConfig, + fname='standardized_assignment_features', + rm_extern_on_delete=True, + chunksize=256000000, # chunk size is huge bc we need accurate means and stdevs of various traits +) +def standardized_assignment_features(depc, part_aid_list, body_aid_list, config=None): + + ibs = depc.controller + + part_gids = ibs.get_annot_gids(part_aid_list) + body_gids = ibs.get_annot_gids(body_aid_list) + assert part_gids == body_gids, 'can only compute assignment features on aids in the same image' + parts_are_parts = _are_part_annots(ibs, part_aid_list) + assert all(parts_are_parts), 'all part_aids must be part annots.' + bodies_are_parts = _are_part_annots(ibs, body_aid_list) + assert not any(bodies_are_parts), 'body_aids cannot be part annots' + + part_bboxes = ibs.get_annot_bboxes(part_aid_list) + body_bboxes = ibs.get_annot_bboxes(body_aid_list) + im_widths = ibs.get_image_widths(part_gids) + im_heights = ibs.get_image_heights(part_gids) + part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) + body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) + + part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] + body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] + part_area_relative_body = [part_area / body_area + for (part_area, body_area) in zip(part_areas, body_areas)] + + intersect_bboxes = _bbox_intersections(part_bboxes, body_bboxes) + # note that intesect w and h could be negative if there is no intersection, in which case it is the x/y distance between the annots. + intersect_areas = [w * h if w > 0 and h > 0 else 0 + for (_, _, w, h) in intersect_bboxes] + + int_area_relative_part = [int_area / part_area for int_area, part_area + in zip(intersect_areas, part_areas)] + int_area_relative_body = [int_area / body_area for int_area, body_area + in zip(intersect_areas, body_areas)] + + int_area_relative_part = preprocessing.scale(int_area_relative_part) + int_area_relative_body = preprocessing.scale(int_area_relative_body) + part_area_relative_body = preprocessing.scale(part_area_relative_body) + + result_list = list(zip( + part_bboxes, body_bboxes, intersect_bboxes, + int_area_relative_part, int_area_relative_body, part_area_relative_body + )) + + for (part_bbox, body_bbox, intersect_bbox, int_area_relative_part, + int_area_relative_body, part_area_relative_body) in result_list: + yield (part_bbox[0], part_bbox[1], part_bbox[2], part_bbox[3], + body_bbox[0], body_bbox[1], body_bbox[2], body_bbox[3], + intersect_bbox[0], intersect_bbox[1], intersect_bbox[2], intersect_bbox[3], + int_area_relative_part, + int_area_relative_body, + part_area_relative_body) + + +# like the above but bboxes are also standardized +@derived_attribute( + tablename='mega_standardized_assignment_features', + parents=['annotations', 'annotations'], + colnames=[ + 'p_xtl', 'p_ytl', 'p_w', 'p_h', + 'b_xtl', 'b_ytl', 'b_w', 'b_h', + 'int_xtl', 'int_ytl', 'int_w', 'int_h', + 'intersect_area_relative_part', + 'intersect_area_relative_body', + 'part_area_relative_body' + ], + coltypes=[ + float, float, float, float, + float, float, float, float, + float, float, float, float, + float, float, float + ], + configclass=PartAssignmentFeatureConfig, + fname='mega_standardized_assignment_features', + rm_extern_on_delete=True, + chunksize=256000000, # chunk size is huge bc we need accurate means and stdevs of various traits +) +def mega_standardized_assignment_features(depc, part_aid_list, body_aid_list, config=None): + + ibs = depc.controller + + part_gids = ibs.get_annot_gids(part_aid_list) + body_gids = ibs.get_annot_gids(body_aid_list) + assert part_gids == body_gids, 'can only compute assignment features on aids in the same image' + parts_are_parts = _are_part_annots(ibs, part_aid_list) + assert all(parts_are_parts), 'all part_aids must be part annots.' + bodies_are_parts = _are_part_annots(ibs, body_aid_list) assert not any(bodies_are_parts), 'body_aids cannot be part annots' part_bboxes = ibs.get_annot_bboxes(part_aid_list) body_bboxes = ibs.get_annot_bboxes(body_aid_list) + im_widths = ibs.get_image_widths(part_gids) + im_heights = ibs.get_image_heights(part_gids) + part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) + body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) + + part_bboxes = _standardized_bboxes(part_bboxes) + body_bboxes = _standardized_bboxes(body_bboxes) + + part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] @@ -2531,6 +2756,10 @@ def compute_assignment_features(depc, part_aid_list, body_aid_list, config=None) int_area_relative_body = [int_area / body_area for int_area, body_area in zip(intersect_areas, body_areas)] + int_area_relative_part = preprocessing.scale(int_area_relative_part) + int_area_relative_body = preprocessing.scale(int_area_relative_body) + part_area_relative_body = preprocessing.scale(part_area_relative_body) + result_list = list(zip( part_bboxes, body_bboxes, intersect_bboxes, int_area_relative_part, int_area_relative_body, part_area_relative_body @@ -2546,6 +2775,258 @@ def compute_assignment_features(depc, part_aid_list, body_aid_list, config=None) part_area_relative_body) +@derived_attribute( + tablename='theta_assignment_features', + parents=['annotations', 'annotations'], + colnames=[ + 'p_v1_x', 'p_v1_y', 'p_v2_x', 'p_v2_y', 'p_v3_x', 'p_v3_y', 'p_v4_x', 'p_v4_y', + 'p_center_x', 'p_center_y', + 'b_xtl', 'b_ytl', 'b_xbr', 'b_ybr', 'b_center_x', 'b_center_y', + 'int_area_scalar', 'part_body_distance', + 'part_body_centroid_dist', + 'int_over_union', + 'int_over_part', + 'int_over_body', + 'part_over_body' + ], + coltypes=[ + float, float, float, float, float, float, float, float, float, float, + float, float, float, float, float, float, + float, float, float, float, float, float, float + ], + configclass=PartAssignmentFeatureConfig, + fname='theta_assignment_features', + rm_extern_on_delete=True, + chunksize=256, # chunk size is huge bc we need accurate means and stdevs of various traits +) +def theta_assignment_features(depc, part_aid_list, body_aid_list, config=None): + + from shapely import geometry + import math + + ibs = depc.controller + + part_gids = ibs.get_annot_gids(part_aid_list) + body_gids = ibs.get_annot_gids(body_aid_list) + assert part_gids == body_gids, 'can only compute assignment features on aids in the same image' + parts_are_parts = _are_part_annots(ibs, part_aid_list) + assert all(parts_are_parts), 'all part_aids must be part annots.' + bodies_are_parts = _are_part_annots(ibs, body_aid_list) + assert not any(bodies_are_parts), 'body_aids cannot be part annots' + + im_widths = ibs.get_image_widths(part_gids) + im_heights = ibs.get_image_heights(part_gids) + + part_verts = ibs.get_annot_rotated_verts(part_aid_list) + body_verts = ibs.get_annot_rotated_verts(body_aid_list) + part_verts = _norm_vertices(part_verts, im_widths, im_heights) + body_verts = _norm_vertices(body_verts, im_widths, im_heights) + part_polys = [geometry.Polygon(vert) for vert in part_verts] + body_polys = [geometry.Polygon(vert) for vert in body_verts] + intersect_polys = [part.intersection(body) + for part, body in zip(part_polys, body_polys)] + intersect_areas = [poly.area for poly in intersect_polys] + # just to make int_areas more comparable via ML methods, and since all distances < 1 + int_area_scalars = [math.sqrt(area) for area in intersect_areas] + + + part_bboxes = ibs.get_annot_bboxes(part_aid_list) + body_bboxes = ibs.get_annot_bboxes(body_aid_list) + part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) + body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) + part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] + body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] + union_areas = [part + body - intersect for (part, body, intersect) + in zip(part_areas, body_areas, intersect_areas)] + int_over_unions = [intersect / union for (intersect, union) + in zip(intersect_areas, union_areas)] + + part_body_distances = [part.distance(body) + for part, body in zip(part_polys, body_polys)] + + part_centroids = [poly.centroid for poly in part_polys] + body_centroids = [poly.centroid for poly in body_polys] + + part_body_centroid_dists = [part.distance(body) for part, body + in zip(part_centroids, body_centroids)] + + int_over_parts = [int_area / part_area for part_area, int_area + in zip(part_areas, intersect_areas)] + + int_over_bodys = [int_area / body_area for body_area, int_area + in zip(body_areas, intersect_areas)] + + part_over_bodys = [part_area / body_area for part_area, body_area + in zip(part_areas, body_areas)] + + # note that here only parts have thetas, hence only returning body bboxes + result_list = list(zip( + part_verts, part_centroids, body_bboxes, body_centroids, + int_area_scalars, part_body_distances, part_body_centroid_dists, + int_over_unions, int_over_parts, int_over_bodys, part_over_bodys + )) + + for (part_vert, part_center, body_bbox, body_center, + int_area_scalar, part_body_distance, part_body_centroid_dist, + int_over_union, int_over_part, int_over_body, part_over_body) in result_list: + yield (part_vert[0][0], part_vert[0][1], part_vert[1][0], part_vert[1][1], + part_vert[2][0], part_vert[2][1], part_vert[3][0], part_vert[3][1], + part_center.x, part_center.y, + body_bbox[0], body_bbox[1], body_bbox[2], body_bbox[3], + body_center.x, body_center.y, + int_area_scalar, part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body, + ) + + +@derived_attribute( + tablename='theta_standardized_assignment_features', + parents=['annotations', 'annotations'], + colnames=[ + 'p_v1_x', 'p_v1_y', 'p_v2_x', 'p_v2_y', 'p_v3_x', 'p_v3_y', 'p_v4_x', 'p_v4_y', + 'p_center_x', 'p_center_y', + 'b_xtl', 'b_ytl', 'b_xbr', 'b_ybr', 'b_center_x', 'b_center_y', + 'int_area_scalar', 'part_body_distance', + 'part_body_centroid_dist', + 'int_over_union', + 'int_over_part', + 'int_over_body', + 'part_over_body' + ], + coltypes=[ + float, float, float, float, float, float, float, float, float, float, + float, float, float, float, float, float, + float, float, float, float, float, float, float + ], + configclass=PartAssignmentFeatureConfig, + fname='theta_standardized_assignment_features', + rm_extern_on_delete=True, + chunksize=2560000, # chunk size is huge bc we need accurate means and stdevs of various traits +) +def theta_standardized_assignment_features(depc, part_aid_list, body_aid_list, config=None): + + from shapely import geometry + import math + + ibs = depc.controller + + part_gids = ibs.get_annot_gids(part_aid_list) + body_gids = ibs.get_annot_gids(body_aid_list) + assert part_gids == body_gids, 'can only compute assignment features on aids in the same image' + parts_are_parts = _are_part_annots(ibs, part_aid_list) + assert all(parts_are_parts), 'all part_aids must be part annots.' + bodies_are_parts = _are_part_annots(ibs, body_aid_list) + assert not any(bodies_are_parts), 'body_aids cannot be part annots' + + im_widths = ibs.get_image_widths(part_gids) + im_heights = ibs.get_image_heights(part_gids) + + part_verts = ibs.get_annot_rotated_verts(part_aid_list) + body_verts = ibs.get_annot_rotated_verts(body_aid_list) + part_verts = _norm_vertices(part_verts, im_widths, im_heights) + body_verts = _norm_vertices(body_verts, im_widths, im_heights) + part_polys = [geometry.Polygon(vert) for vert in part_verts] + body_polys = [geometry.Polygon(vert) for vert in body_verts] + intersect_polys = [part.intersection(body) + for part, body in zip(part_polys, body_polys)] + intersect_areas = [poly.area for poly in intersect_polys] + # just to make int_areas more comparable via ML methods, and since all distances < 1 + int_area_scalars = [math.sqrt(area) for area in intersect_areas] + int_area_scalars = preprocessing.scale(int_area_scalars) + + + part_bboxes = ibs.get_annot_bboxes(part_aid_list) + body_bboxes = ibs.get_annot_bboxes(body_aid_list) + part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) + body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) + part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] + body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] + union_areas = [part + body - intersect for (part, body, intersect) + in zip(part_areas, body_areas, intersect_areas)] + int_over_unions = [intersect / union for (intersect, union) + in zip(intersect_areas, union_areas)] + int_over_unions = preprocessing.scale(int_over_unions) + + part_body_distances = [part.distance(body) + for part, body in zip(part_polys, body_polys)] + part_body_distances = preprocessing.scale(part_body_distances) + + part_centroids = [poly.centroid for poly in part_polys] + body_centroids = [poly.centroid for poly in body_polys] + + part_body_centroid_dists = [part.distance(body) for part, body + in zip(part_centroids, body_centroids)] + part_body_centroid_dists = preprocessing.scale(part_body_centroid_dists) + + int_over_parts = [int_area / part_area for part_area, int_area + in zip(part_areas, intersect_areas)] + int_over_parts = preprocessing.scale(int_over_parts) + + int_over_bodys = [int_area / body_area for body_area, int_area + in zip(body_areas, intersect_areas)] + int_over_bodys = preprocessing.scale(int_over_bodys) + + part_over_bodys = [part_area / body_area for part_area, body_area + in zip(part_areas, body_areas)] + part_over_bodys = preprocessing.scale(part_over_bodys) + + + + #standardization + + # note that here only parts have thetas, hence only returning body bboxes + result_list = list(zip( + part_verts, part_centroids, body_bboxes, body_centroids, + int_area_scalars, part_body_distances, part_body_centroid_dists, + int_over_unions, int_over_parts, int_over_bodys, part_over_bodys + )) + + for (part_vert, part_center, body_bbox, body_center, + int_area_scalar, part_body_distance, part_body_centroid_dist, + int_over_union, int_over_part, int_over_body, part_over_body) in result_list: + yield (part_vert[0][0], part_vert[0][1], part_vert[1][0], part_vert[1][1], + part_vert[2][0], part_vert[2][1], part_vert[3][0], part_vert[3][1], + part_center.x, part_center.y, + body_bbox[0], body_bbox[1], body_bbox[2], body_bbox[3], + body_center.x, body_center.y, + int_area_scalar, part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body, + ) + + +def _norm_bboxes(bbox_list, width_list, height_list): + normed_boxes = [(bbox[0]/w, bbox[1]/h, bbox[2]/w, bbox[3]/h) + for (bbox, w, h) + in zip(bbox_list, width_list, height_list)] + return normed_boxes + + +def _norm_vertices(verts_list, width_list, height_list): + normed_verts = [[[x / w , y / h] for x, y in vert] + for vert, w, h + in zip(verts_list, width_list, height_list) + ] + return normed_verts + + +# does this even make any sense? let's find out experimentally +def _standardized_bboxes(bbox_list): + xtls = preprocessing.scale([bbox[0] for bbox in bbox_list]) + ytls = preprocessing.scale([bbox[1] for bbox in bbox_list]) + wids = preprocessing.scale([bbox[2] for bbox in bbox_list]) + heis = preprocessing.scale([bbox[3] for bbox in bbox_list]) + standardized_bboxes = list(zip(xtls, ytls, wids, heis)) + return standardized_bboxes + + def _are_part_annots(ibs, aid_list): species = ibs.get_annot_species(aid_list) are_parts = ['+' in specie for specie in species] @@ -2583,6 +3064,30 @@ def _bbox_intersections(bboxes_a, bboxes_b): return intersect_bboxes +def _theta_aware_intersect_areas(verts_list_a, verts_list_b): + import shapely + polys_a = [shapely.geometry.Polygon(vert) for vert in verts_list_a] + polys_b = [shapely.geometry.Polygon(vert) for vert in verts_list_b] + intersect_areas = [poly1.intersection(poly2).area + for poly1, poly2 in zip(polys_a, polys_b)] + return intersect_areas + + +def _all_centroids(verts_list_a, verts_list_b): + import shapely + polys_a = [shapely.geometry.Polygon(vert) for vert in verts_list_a] + polys_b = [shapely.geometry.Polygon(vert) for vert in verts_list_b] + intersect_polys = [poly1.intersection(poly2) for poly1, poly2 in zip(polys_a, polys_b)] + + centroids_a = [poly.centroid for poly in polys_a] + centroids_b = [poly.centroid for poly in polys_b] + centroids_int = [poly.centroid for poly in intersect_polys] + + return centroids_a, centroids_b, centroids_int + +def _polygons_to_centroid_coords(polygon_list): + centroids = [poly.centroid for poly in polygon_list] + # converts bboxes from (xtl, ytl, w, h) to (xtl, ytl, xbr, ybr) def _bbox_to_corner_format(bboxes): @@ -2612,16 +3117,27 @@ def _all_pairs_parallel(list_a, list_b): # for wild dog dev +@register_ibs_method def wd_assigner_data(ibs): + return wd_training_data('part_assignment_features') + + +@register_ibs_method +def wd_normed_assigner_data(ibs): + return wd_training_data('normalized_assignment_features') + + +@register_ibs_method +def wd_training_data(ibs, depc_table_name='part_assignment_features'): all_aids = ibs.get_valid_aids() ia_classes = ibs.get_annot_species(all_aids) part_aids = [aid for aid, ia_class in zip(all_aids, ia_classes) if '+' in ia_class] part_gids = list(set(ibs.get_annot_gids(part_aids))) all_pairs = all_part_pairs(ibs, part_gids) - all_feats = ibs.depc_annot.get('part_assignment_features', all_pairs) + all_feats = ibs.depc_annot.get(depc_table_name, all_pairs) names = [ibs.get_annot_names(all_pairs[0]), ibs.get_annot_names(all_pairs[1])] ground_truth = [n1 == n2 for (n1, n2) in zip(names[0],names[1])] - # we now have all features and the ground truths, time to to a train/test split + train_feats, test_feats = train_test_split(all_feats) train_truth, test_truth = train_test_split(ground_truth) assigner_data = {'data': train_feats, 'target': train_truth, @@ -2629,6 +3145,7 @@ def wd_assigner_data(ibs): return assigner_data + def train_test_split(item_list, random_seed=777, test_size=0.1): import random import math @@ -2642,7 +3159,3 @@ def train_test_split(item_list, random_seed=777, test_size=0.1): )) train_items = [item_list[i] for i in train_indices] return train_items, test_items - - - - From 9fc477857427cf46c68866939bdb3c67cf85e754 Mon Sep 17 00:00:00 2001 From: Drew Blount Date: Tue, 22 Dec 2020 15:02:02 -0800 Subject: [PATCH 140/294] WIP with more assigner features and assigner.py --- wbia/algo/detect/__init__.py | 2 + wbia/algo/detect/assigner.py | 448 +++++++++++++++++++++++++++++++++++ wbia/core_annots.py | 63 ----- 3 files changed, 450 insertions(+), 63 deletions(-) create mode 100644 wbia/algo/detect/assigner.py diff --git a/wbia/algo/detect/__init__.py b/wbia/algo/detect/__init__.py index 5c742d556c..f400e72594 100644 --- a/wbia/algo/detect/__init__.py +++ b/wbia/algo/detect/__init__.py @@ -5,6 +5,7 @@ from wbia.algo.detect import grabmodels from wbia.algo.detect import randomforest from wbia.algo.detect import yolo +from wbia.algo.detect import assigner # from wbia.algo.detect import selectivesearch # from wbia.algo.detect import ssd @@ -93,6 +94,7 @@ def get_reload_subs(mod): ('grabmodels', None), ('randomforest', None), ('yolo', None), + ('assigner', None), # ('selectivesearch', None), # ('ssd', None), # ('fasterrcnn', None), diff --git a/wbia/algo/detect/assigner.py b/wbia/algo/detect/assigner.py new file mode 100644 index 0000000000..f3de772644 --- /dev/null +++ b/wbia/algo/detect/assigner.py @@ -0,0 +1,448 @@ +import logging +from os.path import expanduser, join +from wbia import constants as const +from wbia.control.controller_inject import register_preprocs, register_subprops, make_ibs_register_decorator +import utool as ut +import numpy as np +import random +import os +from collections import OrderedDict, defaultdict +from datetime import datetime +import time + +from sklearn import preprocessing +from tune_sklearn import TuneGridSearchCV + +# shitload of scikit classifiers +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.colors import ListedColormap +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.datasets import make_moons, make_circles, make_classification +from sklearn.neural_network import MLPClassifier +from sklearn.neighbors import KNeighborsClassifier +from sklearn.svm import SVC +from sklearn.gaussian_process import GaussianProcessClassifier +from sklearn.gaussian_process.kernels import RBF +from sklearn.tree import DecisionTreeClassifier +from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier +from sklearn.naive_bayes import GaussianNB +from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis +from sklearn.model_selection import GridSearchCV + + +# bunch of classifier models for training + + +(print, rrr, profile) = ut.inject2(__name__, '[orientation]') +logger = logging.getLogger('wbia') + +CLASS_INJECT_KEY, register_ibs_method = make_ibs_register_decorator(__name__) + + +PARALLEL = not const.CONTAINERIZED +INPUT_SIZE = 224 + +ARCHIVE_URL_DICT = {} + + +CLASSIFIER_OPTIONS = [ + # { + # "name": "Nearest Neighbors", + # "clf": KNeighborsClassifier(3), + # "param_options": { + # 'n_neighbors': [3,5,11,19], + # 'weights': ['uniform', 'distance'], + # 'metric': ['euclidean', 'manhattan'], + # } + # }, + # { + # "name": "Linear SVM", + # "clf": SVC(kernel="linear", C=0.025), + # "param_options": { + # 'C': [1, 10, 100, 1000], + # 'kernel': ['linear'], + # } + # }, + # { + # "name": "RBF SVM", + # "clf": SVC(gamma=2, C=1), + # "param_options": { + # 'C': [1, 10, 100, 1000], + # 'gamma': [0.001, 0.0001], + # 'kernel': ['rbf'] + # }, + # }, + { + "name": "Decision Tree", + "clf": DecisionTreeClassifier(), #max_depth=5 + "param_options": { + 'max_depth': np.arange(1,12), + 'max_leaf_nodes': [2, 5, 10, 20, 50, 100] + } + }, + { + "name": "Random Forest", + "clf": RandomForestClassifier(), #max_depth=5, n_estimators=10, max_features=1 + "param_options": { + 'bootstrap': [True, False], + 'max_depth': [10, 50, 100, None], + 'max_features': ['auto', 'sqrt'], + 'min_samples_leaf': [1, 2, 4], + 'min_samples_split': [2, 5, 10], + 'n_estimators': [200, 1000, 1500, 2000] + } + }, + { + "name": "Neural Net", + "clf": MLPClassifier(), #alpha=1, max_iter=1000 + "param_options": { + 'hidden_layer_sizes': [(10,30,10),(20,)], + 'activation': ['tanh', 'relu'], + 'solver': ['sgd', 'adam'], + 'alpha': [0.0001, 0.05], + 'learning_rate': ['constant','adaptive'], + } + }, + { + "name": "AdaBoost", + "clf": AdaBoostClassifier(), + "param_options": { + 'n_estimators': np.arange(10, 310, 50), + 'learning_rate': [0.01, 0.05, 0.1, 1], + } + }, + # { + # "name": "Naive Bayes", + # "clf": GaussianNB(), + # "param_options": {} # no hyperparams to optimize + # }, + # { + # "name": "QDA", + # "clf": QuadraticDiscriminantAnalysis(), + # "param_options": { + # 'reg_param': [0.1, 0.2, 0.3, 0.4, 0.5] + # } + # } +] + + + +# for model exploration +classifier_names = ["Nearest Neighbors", "Linear SVM", "RBF SVM", + "Decision Tree", "Random Forest", "Neural Net", "AdaBoost", + "Naive Bayes", "QDA"] + +classifiers = [ + KNeighborsClassifier(3), + SVC(kernel="linear", C=0.025), + SVC(gamma=2, C=1), + DecisionTreeClassifier(max_depth=5), + RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1), + MLPClassifier(alpha=1, max_iter=1000), + AdaBoostClassifier(), + GaussianNB(), + QuadraticDiscriminantAnalysis()] + + +slow_classifier_names = "Gaussian Process" +slow_classifiers = GaussianProcessClassifier(1.0 * RBF(1.0)), + + +def classifier_report(clf, name, assigner_data): + print('%s CLASSIFIER REPORT ' % name) + print(' %s: calling clf.fit' % str(datetime.now())) + clf.fit(assigner_data['data'], assigner_data['target']) + print(' %s: done training, making prediction ' % str(datetime.now())) + preds = clf.predict(assigner_data['test']) + print(' %s: done with predictions, computing accuracy' % str(datetime.now())) + agree = [pred == truth for pred, truth in zip(preds, assigner_data['test_truth'])] + accuracy = agree.count(True) / len(agree) + print(' %s accuracy' % accuracy) + print() + return accuracy + + +@register_ibs_method +def compare_ass_classifiers(ibs, depc_table_name='theta_assignment_features', print_accs=False): + + assigner_data = ibs.wd_training_data(depc_table_name) + + accuracies = OrderedDict() + for classifier in CLASSIFIER_OPTIONS: + accuracy = classifier_report(classifier['clf'], classifier['name'], assigner_data) + accuracies[classifier['name']] = accuracy + + # handy for e.g. pasting into excel + if print_accs: + just_accuracy = [accuracies[name] for name in accuracies.keys()] + print(just_accuracy) + + return accuracies + + +@register_ibs_method +def tune_ass_classifiers(ibs, depc_table_name='theta_assignment_features'): + + assigner_data = ibs.wd_training_data(depc_table_name) + + accuracies = OrderedDict() + best_acc = 0 + best_clf_name = '' + best_clf_params = {} + for classifier in CLASSIFIER_OPTIONS: + print("Tuning %s" % classifier['name']) + accuracy, best_params = ibs._tune_grid_search(classifier['clf'], classifier['param_options'], assigner_data) + print() + accuracies[classifier['name']] = { + 'accuracy': accuracy, + 'best_params': best_params + } + if accuracy > best_acc: + best_acc = accuracy + best_clf_name = classifier['name'] + best_clf_params = best_params + + + print('best performance: %s using %s with params %s' % + best_acc, best_clf_name, best_clf_params) + + return accuracies + + +@register_ibs_method +def _tune_grid_search(ibs, clf, parameters, assigner_data=None): + if assigner_data is None: + assigner_data = ibs.wd_training_data() + + X_train = assigner_data['data'] + y_train = assigner_data['target'] + X_test = assigner_data['test'] + y_test = assigner_data['test_truth'] + + tune_search = GridSearchCV( + clf, + parameters, + ) + + start = time.time() + tune_search.fit(X_train, y_train) + end = time.time() + print("Tune Fit Time: %s" % (end - start)) + pred = tune_search.predict(X_test) + accuracy = np.count_nonzero(np.array(pred) == np.array(y_test)) / len(pred) + print("Tune Accuracy: %s" % accuracy) + print("best parms : %s" % tune_search.best_params_) + + return accuracy, tune_search.best_params_ + + +@register_ibs_method +def _tune_random_search(ibs, clf, parameters, assigner_data=None): + if assigner_data is None: + assigner_data = ibs.wd_training_data() + + X_train = assigner_data['data'] + y_train = assigner_data['target'] + X_test = assigner_data['test'] + y_test = assigner_data['test_truth'] + + tune_search = GridSearchCV( + clf, + parameters, + ) + + start = time.time() + tune_search.fit(X_train, y_train) + end = time.time() + print("Tune Fit Time: %s" % (end - start)) + pred = tune_search.predict(X_test) + accuracy = np.count_nonzero(np.array(pred) == np.array(y_test)) / len(pred) + print("Tune Accuracy: %s" % accuracy) + print("best parms : %s" % tune_search.best_params_) + + return accuracy, tune_search.best_params_ + + +# for wild dog dev +@register_ibs_method +def wd_assigner_data(ibs): + return wd_training_data('part_assignment_features') + + +@register_ibs_method +def wd_normed_assigner_data(ibs): + return wd_training_data('normalized_assignment_features') + + +@register_ibs_method +def wd_training_data(ibs, depc_table_name='theta_assignment_features'): + all_aids = ibs.get_valid_aids() + ia_classes = ibs.get_annot_species(all_aids) + part_aids = [aid for aid, ia_class in zip(all_aids, ia_classes) if '+' in ia_class] + part_gids = list(set(ibs.get_annot_gids(part_aids))) + all_pairs = all_part_pairs(ibs, part_gids) + all_feats = ibs.depc_annot.get(depc_table_name, all_pairs) + names = [ibs.get_annot_names(all_pairs[0]), ibs.get_annot_names(all_pairs[1])] + ground_truth = [n1 == n2 for (n1, n2) in zip(names[0],names[1])] + + # train_feats, test_feats = train_test_split(all_feats) + # train_truth, test_truth = train_test_split(ground_truth) + pairs_in_train = ibs.gid_train_test_split(all_pairs[0]) # we could pass just the pair aids or just the body aids bc gids are the same + train_feats, test_feats = _split_list(all_feats, pairs_in_train) + train_truth, test_truth = _split_list(ground_truth, pairs_in_train) + + assigner_data = {'data': train_feats, 'target': train_truth, + 'test': test_feats, 'test_truth': test_truth} + + return assigner_data + + + +@register_ibs_method +def _are_part_annots(ibs, aid_list): + species = ibs.get_annot_species(aid_list) + are_parts = ['+' in specie for specie in species] + return are_parts + + +def all_part_pairs(ibs, gid_list): + all_aids = ibs.get_image_aids(gid_list) + all_aids_are_parts = [ibs._are_part_annots(aids) for aids in all_aids] + all_part_aids = [[aid for (aid, part) in zip(aids, are_parts) if part] for (aids, are_parts) in zip(all_aids, all_aids_are_parts)] + all_body_aids = [[aid for (aid, part) in zip(aids, are_parts) if not part] for (aids, are_parts) in zip(all_aids, all_aids_are_parts)] + part_body_parallel_lists = [_all_pairs_parallel(parts, bodies) for parts, bodies in zip(all_part_aids, all_body_aids)] + all_parts = [aid for part_body_parallel_list in part_body_parallel_lists + for aid in part_body_parallel_list[0]] + all_bodies = [aid for part_body_parallel_list in part_body_parallel_lists + for aid in part_body_parallel_list[1]] + return all_parts, all_bodies + + +def _all_pairs_parallel(list_a, list_b): + pairs = [(a, b) for a in list_a for b in list_b] + pairs_a = [pair[0] for pair in pairs] + pairs_b = [pair[1] for pair in pairs] + return pairs_a, pairs_b + + +def train_test_split(item_list, random_seed=777, test_size=0.1): + import random + import math + random.seed(random_seed) + sample_size = math.floor(len(item_list) * test_size) + all_indices = list(range(len(item_list))) + test_indices = random.sample(all_indices, sample_size) + test_items = [item_list[i] for i in test_indices] + train_indices = sorted(list( + set(all_indices) - set(test_indices) + )) + train_items = [item_list[i] for i in train_indices] + return train_items, test_items + + +@register_ibs_method +def gid_train_test_split(ibs, aid_list, random_seed=777, test_size=0.1): + print('calling gid_train_test_split') + gid_list = ibs.get_annot_gids(aid_list) + gid_set = list(set(gid_list)) + import random + import math + random.seed(random_seed) + n_test_gids = math.floor(len(gid_set) * test_size) + test_gids = set(random.sample(gid_set, n_test_gids)) + aid_in_train = [gid not in test_gids for gid in gid_list] + return aid_in_train + + +def _split_list(item_list, is_in_first_group_list): + first_group = ut.compress(item_list, is_in_first_group_list) + is_in_second_group = [not b for b in is_in_first_group_list] + second_group = ut.compress(item_list, is_in_second_group) + return first_group, second_group + + +@register_ibs_method +def _assign_parts(ibs, all_aids): + gids = ibs.get_annot_gids(all_aids) + gid_to_aids = DefaultDict(list) + for gid, aid in zip(gids, all_aids): + gid_to_aids[gid] += aid + + all_pairs = [] + all_unassigned_aids = [] + + for gid in gid_to_aids.keys(): + this_pairs, this_unassigned = _assign_parts_one_image(ibs, gid_to_aids[gid]) + all_pairs.append(this_pairs) + all_unassigned_aids.append(this_unassigned) + + return all_pairs, all_unassigned_aids + + + +@register_ibs_method +def _assign_parts_one_image(ibs, aid_list): + + are_part_aids = _are_part_annots(ibs, all_aids) + part_aids = ut.compress(all_aids, are_part_aids) + body_aids = ut.compress(all_aids, [not p for p in are_part_aids]) + + gids = ibs.get_annot_gids(list(set(part_aids)) + list(set(body_aids))) + num_images = len(set(gids)) + assert num_images is 1 + + # parallel lists representing all possible part/body pairs + all_pairs_parallel = _all_pairs_parallel(part_aids, body_aids) + pair_parts, pair_bodies = all_pairs_parallel + + + assigner_features = ibs.depc_annot.get('theta_assignment_features', all_pairs_parallel) + assigner_classifier = _load_assigner_classifier(part_aids) + + assigner_scores = assigner_classifier.predict(assigner_features) + good_pairs, unassigned_aids = _make_assignments(pair_parts, pair_bodies, assigner_scores) + + +def _make_assignments(pair_parts, pair_bodies, assigner_scores, cutoff_score=0.5): + + sorted_scored_pairs = [(part, body, score) for part, body, score in + sorted(zip(pair_parts, pair_bodies, assigner_scores), + key=lambda pbscore: pbscore[2], reverse=True)] + + assigned_pairs = [] + assigned_parts = set() + assigned_bodies = set() + n_bodies = len(set(pair_bodies)) + n_parts = len(set(pair_parts)) + n_true_pairs = min(n_bodies, n_parts) + for part_aid, body_aid, score in sorted_scored_pairs: + assign_this_pair = part_aid not in assigned_parts and \ + body_aid not in assigned_bodies and \ + score >= cutoff_score + + if assign_this_pair: + assigned_pairs.append((part_aid, body_aid)) + assigned_parts.add(part_aid) + assigned_bodies.add(body_aid) + + if len(assigned_parts) is n_true_pairs \ + or len(assigned_bodies) is n_true_pairs \ + or score > cutoff_score: + break + + unassigned_parts = set(pair_parts) - set(assigned_parts) + unassigned_bodies = set(pair_bodies) - set(assigned_bodies) + unassigned_aids = sorted(list(unassigned_parts) + list(unassigned_bodies)) + + return assigned_pairs, unassigned_aids + + + + + + + + + diff --git a/wbia/core_annots.py b/wbia/core_annots.py index 75bb3ae907..edbf238d21 100644 --- a/wbia/core_annots.py +++ b/wbia/core_annots.py @@ -3096,66 +3096,3 @@ def _bbox_to_corner_format(bboxes): return corner_bboxes -def all_part_pairs(ibs, gid_list): - all_aids = ibs.get_image_aids(gid_list) - all_aids_are_parts = [_are_part_annots(ibs, aids) for aids in all_aids] - all_part_aids = [[aid for (aid, part) in zip(aids, are_parts) if part] for (aids, are_parts) in zip(all_aids, all_aids_are_parts)] - all_body_aids = [[aid for (aid, part) in zip(aids, are_parts) if not part] for (aids, are_parts) in zip(all_aids, all_aids_are_parts)] - part_body_parallel_lists = [_all_pairs_parallel(parts, bodies) for parts, bodies in zip(all_part_aids, all_body_aids)] - all_parts = [aid for part_body_parallel_list in part_body_parallel_lists - for aid in part_body_parallel_list[0]] - all_bodies = [aid for part_body_parallel_list in part_body_parallel_lists - for aid in part_body_parallel_list[1]] - return all_parts, all_bodies - - -def _all_pairs_parallel(list_a, list_b): - pairs = [(a, b) for a in list_a for b in list_b] - pairs_a = [pair[0] for pair in pairs] - pairs_b = [pair[1] for pair in pairs] - return pairs_a, pairs_b - - -# for wild dog dev -@register_ibs_method -def wd_assigner_data(ibs): - return wd_training_data('part_assignment_features') - - -@register_ibs_method -def wd_normed_assigner_data(ibs): - return wd_training_data('normalized_assignment_features') - - -@register_ibs_method -def wd_training_data(ibs, depc_table_name='part_assignment_features'): - all_aids = ibs.get_valid_aids() - ia_classes = ibs.get_annot_species(all_aids) - part_aids = [aid for aid, ia_class in zip(all_aids, ia_classes) if '+' in ia_class] - part_gids = list(set(ibs.get_annot_gids(part_aids))) - all_pairs = all_part_pairs(ibs, part_gids) - all_feats = ibs.depc_annot.get(depc_table_name, all_pairs) - names = [ibs.get_annot_names(all_pairs[0]), ibs.get_annot_names(all_pairs[1])] - ground_truth = [n1 == n2 for (n1, n2) in zip(names[0],names[1])] - - train_feats, test_feats = train_test_split(all_feats) - train_truth, test_truth = train_test_split(ground_truth) - assigner_data = {'data': train_feats, 'target': train_truth, - 'test': test_feats, 'test_truth': test_truth} - return assigner_data - - - -def train_test_split(item_list, random_seed=777, test_size=0.1): - import random - import math - random.seed(random_seed) - sample_size = math.floor(len(item_list) * test_size) - all_indices = list(range(len(item_list))) - test_indices = random.sample(all_indices, sample_size) - test_items = [item_list[i] for i in test_indices] - train_indices = sorted(list( - set(all_indices) - set(test_indices) - )) - train_items = [item_list[i] for i in train_indices] - return train_items, test_items From 79cf07255f7de38e7bcc1a1c94ee655209e4fb97 Mon Sep 17 00:00:00 2001 From: Drew Blount Date: Tue, 29 Dec 2020 13:33:57 -0800 Subject: [PATCH 141/294] WIP with progress on assigner integration --- wbia/algo/detect/assigner.py | 291 +++++++++++++++++++++++++++-------- wbia/core_annots.py | 44 +++--- 2 files changed, 247 insertions(+), 88 deletions(-) diff --git a/wbia/algo/detect/assigner.py b/wbia/algo/detect/assigner.py index f3de772644..3d650ec0bf 100644 --- a/wbia/algo/detect/assigner.py +++ b/wbia/algo/detect/assigner.py @@ -34,18 +34,26 @@ # bunch of classifier models for training - -(print, rrr, profile) = ut.inject2(__name__, '[orientation]') +(print, rrr, profile) = ut.inject2(__name__, '[assigner]') logger = logging.getLogger('wbia') CLASS_INJECT_KEY, register_ibs_method = make_ibs_register_decorator(__name__) - PARALLEL = not const.CONTAINERIZED INPUT_SIZE = 224 ARCHIVE_URL_DICT = {} +INMEM_ASSIGNER_MODELS = {} + +SPECIES_TO_ASSIGNER_MODELFILE = { + 'wild_dog': '/tmp/assigner_model.joblib', + 'wild_dog_dark': '/tmp/assigner_model.joblib', + 'wild_dog_light': '/tmp/assigner_model.joblib', + 'wild_dog_puppy': '/tmp/assigner_model.joblib', + 'wild_dog_standard': '/tmp/assigner_model.joblib', + 'wild_dog_tan': '/tmp/assigner_model.joblib', +} CLASSIFIER_OPTIONS = [ # { @@ -76,43 +84,43 @@ # }, { "name": "Decision Tree", - "clf": DecisionTreeClassifier(), #max_depth=5 + "clf": DecisionTreeClassifier(), # max_depth=5 "param_options": { - 'max_depth': np.arange(1,12), + 'max_depth': np.arange(1, 12), 'max_leaf_nodes': [2, 5, 10, 20, 50, 100] } }, - { - "name": "Random Forest", - "clf": RandomForestClassifier(), #max_depth=5, n_estimators=10, max_features=1 - "param_options": { - 'bootstrap': [True, False], - 'max_depth': [10, 50, 100, None], - 'max_features': ['auto', 'sqrt'], - 'min_samples_leaf': [1, 2, 4], - 'min_samples_split': [2, 5, 10], - 'n_estimators': [200, 1000, 1500, 2000] - } - }, - { - "name": "Neural Net", - "clf": MLPClassifier(), #alpha=1, max_iter=1000 - "param_options": { - 'hidden_layer_sizes': [(10,30,10),(20,)], - 'activation': ['tanh', 'relu'], - 'solver': ['sgd', 'adam'], - 'alpha': [0.0001, 0.05], - 'learning_rate': ['constant','adaptive'], - } - }, - { - "name": "AdaBoost", - "clf": AdaBoostClassifier(), - "param_options": { - 'n_estimators': np.arange(10, 310, 50), - 'learning_rate': [0.01, 0.05, 0.1, 1], - } - }, + # { + # "name": "Random Forest", + # "clf": RandomForestClassifier(), #max_depth=5, n_estimators=10, max_features=1 + # "param_options": { + # 'bootstrap': [True, False], + # 'max_depth': [10, 50, 100, None], + # 'max_features': ['auto', 'sqrt'], + # 'min_samples_leaf': [1, 2, 4], + # 'min_samples_split': [2, 5, 10], + # 'n_estimators': [200, 1000, 1500, 2000] + # } + # }, + # { + # "name": "Neural Net", + # "clf": MLPClassifier(), #alpha=1, max_iter=1000 + # "param_options": { + # 'hidden_layer_sizes': [(10,30,10),(20,)], + # 'activation': ['tanh', 'relu'], + # 'solver': ['sgd', 'adam'], + # 'alpha': [0.0001, 0.05], + # 'learning_rate': ['constant','adaptive'], + # } + # }, + # { + # "name": "AdaBoost", + # "clf": AdaBoostClassifier(), + # "param_options": { + # 'n_estimators': np.arange(10, 310, 50), + # 'learning_rate': [0.01, 0.05, 0.1, 1], + # } + # }, # { # "name": "Naive Bayes", # "clf": GaussianNB(), @@ -127,12 +135,10 @@ # } ] - - # for model exploration classifier_names = ["Nearest Neighbors", "Linear SVM", "RBF SVM", - "Decision Tree", "Random Forest", "Neural Net", "AdaBoost", - "Naive Bayes", "QDA"] + "Decision Tree", "Random Forest", "Neural Net", "AdaBoost", + "Naive Bayes", "QDA"] classifiers = [ KNeighborsClassifier(3), @@ -204,7 +210,6 @@ def tune_ass_classifiers(ibs, depc_table_name='theta_assignment_features'): best_clf_name = classifier['name'] best_clf_params = best_params - print('best performance: %s using %s with params %s' % best_acc, best_clf_name, best_clf_params) @@ -221,7 +226,7 @@ def _tune_grid_search(ibs, clf, parameters, assigner_data=None): X_test = assigner_data['test'] y_test = assigner_data['test_truth'] - tune_search = GridSearchCV( + tune_search = TuneGridSearchCV( clf, parameters, ) @@ -285,21 +290,24 @@ def wd_training_data(ibs, depc_table_name='theta_assignment_features'): all_pairs = all_part_pairs(ibs, part_gids) all_feats = ibs.depc_annot.get(depc_table_name, all_pairs) names = [ibs.get_annot_names(all_pairs[0]), ibs.get_annot_names(all_pairs[1])] - ground_truth = [n1 == n2 for (n1, n2) in zip(names[0],names[1])] + ground_truth = [n1 == n2 for (n1, n2) in zip(names[0], names[1])] # train_feats, test_feats = train_test_split(all_feats) # train_truth, test_truth = train_test_split(ground_truth) - pairs_in_train = ibs.gid_train_test_split(all_pairs[0]) # we could pass just the pair aids or just the body aids bc gids are the same - train_feats, test_feats = _split_list(all_feats, pairs_in_train) - train_truth, test_truth = _split_list(ground_truth, pairs_in_train) + pairs_in_train = ibs.gid_train_test_split(all_pairs[0]) # we could pass just the pair aids or just the body aids bc gids are the same + train_feats, test_feats = split_list(all_feats, pairs_in_train) + train_truth, test_truth = split_list(ground_truth, pairs_in_train) + + all_pairs_tuple = [(part, body) for part, body in zip(all_pairs[0], all_pairs[1])] + train_pairs, test_pairs = split_list(all_pairs_tuple, pairs_in_train) assigner_data = {'data': train_feats, 'target': train_truth, - 'test': test_feats, 'test_truth': test_truth} + 'test': test_feats, 'test_truth': test_truth, + 'train_pairs': train_pairs, 'test_pairs': test_pairs} return assigner_data - @register_ibs_method def _are_part_annots(ibs, aid_list): species = ibs.get_annot_species(aid_list) @@ -344,6 +352,39 @@ def train_test_split(item_list, random_seed=777, test_size=0.1): @register_ibs_method def gid_train_test_split(ibs, aid_list, random_seed=777, test_size=0.1): + r""" + Makes a gid-wise train-test split. This avoids potential overfitting when a network + is trained on some annots from one image and tested on others from the same image. + + Args: + ibs (IBEISController): IBEIS / WBIA controller object + aid_list (int): annot ids to split + random_seed: to make this split reproducible + test_size: portion of gids reserved for test data + + Yields: + a boolean flag_list of which aids are in the training set. Returning the flag_list + allows the user to filter multiple lists with one gid_train_test_split call + + + CommandLine: + python -m algo.detect.assigner gid_train_test_split + + Example: + >>> # ENABLE_DOCTEST + >>> import utool as ut + >>> from wbia.algo.detect.assigner import * + >>> ibs = assigner_testdb_ibs() + >>> #TODO: get input aids somehow + >>> train_flag_list = ibs.gid_train_test_split(aids) + >>> train_aids, test_aids = split_list(aids, train_flag_list) + >>> train_gids = ibs.get_annot_gids(train_aids) + >>> test_gids = ibs.get_annot_gids(test_aids) + >>> train_gid_set = set(train_gids) + >>> test_gid_set = set(test_gids) + >>> assert len(train_gid_set & test_gid_set) is 0 + >>> assert len(train_gids) + len(test_gids) == len(aids) + """ print('calling gid_train_test_split') gid_list = ibs.get_annot_gids(aid_list) gid_set = list(set(gid_list)) @@ -356,7 +397,7 @@ def gid_train_test_split(ibs, aid_list, random_seed=777, test_size=0.1): return aid_in_train -def _split_list(item_list, is_in_first_group_list): +def split_list(item_list, is_in_first_group_list): first_group = ut.compress(item_list, is_in_first_group_list) is_in_second_group = [not b for b in is_in_first_group_list] second_group = ut.compress(item_list, is_in_second_group) @@ -364,30 +405,30 @@ def _split_list(item_list, is_in_first_group_list): @register_ibs_method -def _assign_parts(ibs, all_aids): +def assign_parts(ibs, all_aids, cutoff_score=0.5): gids = ibs.get_annot_gids(all_aids) - gid_to_aids = DefaultDict(list) + gid_to_aids = defaultdict(list) for gid, aid in zip(gids, all_aids): - gid_to_aids[gid] += aid + gid_to_aids[gid] += [aid] all_pairs = [] all_unassigned_aids = [] for gid in gid_to_aids.keys(): - this_pairs, this_unassigned = _assign_parts_one_image(ibs, gid_to_aids[gid]) - all_pairs.append(this_pairs) - all_unassigned_aids.append(this_unassigned) + this_pairs, this_unassigned = _assign_parts_one_image(ibs, gid_to_aids[gid], cutoff_score) + all_pairs += (this_pairs) + all_unassigned_aids += this_unassigned return all_pairs, all_unassigned_aids @register_ibs_method -def _assign_parts_one_image(ibs, aid_list): +def _assign_parts_one_image(ibs, aid_list, cutoff_score=0.5): - are_part_aids = _are_part_annots(ibs, all_aids) - part_aids = ut.compress(all_aids, are_part_aids) - body_aids = ut.compress(all_aids, [not p for p in are_part_aids]) + are_part_aids = _are_part_annots(ibs, aid_list) + part_aids = ut.compress(aid_list, are_part_aids) + body_aids = ut.compress(aid_list, [not p for p in are_part_aids]) gids = ibs.get_annot_gids(list(set(part_aids)) + list(set(body_aids))) num_images = len(set(gids)) @@ -399,10 +440,13 @@ def _assign_parts_one_image(ibs, aid_list): assigner_features = ibs.depc_annot.get('theta_assignment_features', all_pairs_parallel) - assigner_classifier = _load_assigner_classifier(part_aids) + assigner_classifier = load_assigner_classifier(ibs, part_aids) - assigner_scores = assigner_classifier.predict(assigner_features) - good_pairs, unassigned_aids = _make_assignments(pair_parts, pair_bodies, assigner_scores) + assigner_scores = assigner_classifier.predict_proba(assigner_features) + # assigner_scores is a list of [P_false, P_true] probabilities which sum to 1, so here we just pare down to the true probabilities + assigner_scores = [score[1] for score in assigner_scores] + good_pairs, unassigned_aids = _make_assignments(pair_parts, pair_bodies, assigner_scores, cutoff_score) + return good_pairs, unassigned_aids def _make_assignments(pair_parts, pair_bodies, assigner_scores, cutoff_score=0.5): @@ -439,6 +483,129 @@ def _make_assignments(pair_parts, pair_bodies, assigner_scores, cutoff_score=0.5 return assigned_pairs, unassigned_aids +def load_assigner_classifier(ibs, aid_list, fallback_species='wild_dog'): + species_with_part = ibs.get_annot_species(aid_list[0]) + species = species_with_part.split('+')[0] + if species in INMEM_ASSIGNER_MODELS.keys(): + clf = INMEM_ASSIGNER_MODELS[species] + else: + if species not in SPECIES_TO_ASSIGNER_MODELFILE.keys(): + print("WARNING: Assigner called for species %s which does not have an assigner modelfile specified. Falling back to the model for %s" % species, fallback_species) + species = fallback_species + + model_fpath = SPECIES_TO_ASSIGNER_MODELFILE[species] + from joblib import load + clf = load(model_fpath) + + return clf + + +def check_accuracy(ibs, assigner_data, cutoff_score=0.5): + + all_aids = [] + for pair in assigner_data['test_pairs']: + all_aids.extend(list(pair)) + all_aids = sorted(list(set(all_aids))) + + all_pairs, all_unassigned_aids = ibs.assign_parts(all_aids, cutoff_score) + + gid_to_assigner_results = gid_keyed_assigner_results(ibs, all_pairs, all_unassigned_aids) + gid_to_ground_truth = gid_keyed_ground_truth(ibs, assigner_data) + + correct_gids = [] + incorrect_gids = [] + gids_with_false_positives = 0 + n_false_positives = 0 + gids_with_false_negatives = 0 + n_false_negatives = 0 + gids_with_both_errors = 0 + for gid in gid_to_assigner_results.keys(): + assigned_pairs = set(gid_to_assigner_results[gid]['pairs']) + ground_t_pairs = set(gid_to_ground_truth[gid]['pairs']) + false_negatives = len(ground_t_pairs - assigned_pairs) + false_positives = len(assigned_pairs - ground_t_pairs) + n_false_negatives += false_negatives + + if false_negatives > 0: + gids_with_false_negatives += 1 + n_false_positives += false_positives + if false_positives > 0: + gids_with_false_positives += 1 + if false_negatives > 0 and false_positives > 0: + gids_with_both_errors += 1 + + pairs_equal = sorted(gid_to_assigner_results[gid]['pairs']) == sorted(gid_to_ground_truth[gid]['pairs']) + if pairs_equal: + correct_gids += [gid] + else: + incorrect_gids +=[gid] + + accuracy = len(correct_gids) / len(gid_to_assigner_results.keys()) + print('accuracy with cutoff of %s: %s' % (cutoff_score, accuracy)) + print(' %s false positives on %s error images' % (n_false_positives, gids_with_false_positives)) + print(' %s false negatives on %s error images' % (n_false_negatives, gids_with_false_negatives)) + print(' %s images with both errors' % (gids_with_both_errors)) + return accuracy + + +def gid_keyed_assigner_results(ibs, all_pairs, all_unassigned_aids): + one_from_each_pair = [p[0] for p in all_pairs] + pair_gids = ibs.get_annot_gids(one_from_each_pair) + unassigned_gids = ibs.get_annot_gids(all_unassigned_aids) + + gid_to_pairs = defaultdict(list) + for pair, gid in zip(all_pairs, pair_gids): + gid_to_pairs[gid] += [pair] + + gid_to_unassigned = defaultdict(list) + for aid, gid in zip(all_unassigned_aids, unassigned_gids): + gid_to_unassigned[gid] += [aid] + + gid_to_assigner_results = {} + for gid in (set(gid_to_pairs.keys()) | set(gid_to_unassigned.keys())): + gid_to_assigner_results[gid] = { + 'pairs': gid_to_pairs[gid], + 'unassigned': gid_to_unassigned[gid] + } + + return gid_to_assigner_results + + +def gid_keyed_ground_truth(ibs, assigner_data): + test_pairs = assigner_data['test_pairs'] + test_truth = assigner_data['test_truth'] + assert len(test_pairs) == len(test_truth) + + aid_from_each_pair = [p[0] for p in test_pairs] + gids_for_pairs = ibs.get_annot_gids(aid_from_each_pair) + + gid_to_pairs = defaultdict(list) + gid_to_paired_aids = defaultdict(set) # to know which have not been in any pair + gid_to_all_aids = defaultdict(set) + for pair, is_true_pair, gid in zip(test_pairs, test_truth, gids_for_pairs): + gid_to_all_aids[gid] = gid_to_all_aids[gid] | set(pair) + if is_true_pair: + gid_to_pairs[gid] += [pair] + gid_to_paired_aids[gid] = gid_to_paired_aids[gid] | set(pair) + + gid_to_unassigned_aids = defaultdict(list) + for gid in gid_to_all_aids.keys(): + gid_to_unassigned_aids[gid] = list(gid_to_all_aids[gid] - gid_to_paired_aids[gid]) + + + gid_to_assigner_results = {} + for gid in (set(gid_to_pairs.keys()) | set(gid_to_unassigned_aids.keys())): + gid_to_assigner_results[gid] = { + 'pairs': gid_to_pairs[gid], + 'unassigned': gid_to_unassigned_aids[gid] + } + + return gid_to_assigner_results + + + + + diff --git a/wbia/core_annots.py b/wbia/core_annots.py index edbf238d21..da61ed92ee 100644 --- a/wbia/core_annots.py +++ b/wbia/core_annots.py @@ -2510,9 +2510,9 @@ def compute_assignment_features(depc, part_aid_list, body_aid_list, config=None) part_gids = ibs.get_annot_gids(part_aid_list) body_gids = ibs.get_annot_gids(body_aid_list) assert part_gids == body_gids, 'can only compute assignment features on aids in the same image' - parts_are_parts = _are_part_annots(ibs, part_aid_list) + parts_are_parts = ibs._are_part_annots(part_aid_list) assert all(parts_are_parts), 'all part_aids must be part annots.' - bodies_are_parts = _are_part_annots(ibs, body_aid_list) + bodies_are_parts = ibs._are_part_annots(body_aid_list) assert not any(bodies_are_parts), 'body_aids cannot be part annots' part_bboxes = ibs.get_annot_bboxes(part_aid_list) @@ -2577,9 +2577,9 @@ def normalized_assignment_features(depc, part_aid_list, body_aid_list, config=No part_gids = ibs.get_annot_gids(part_aid_list) body_gids = ibs.get_annot_gids(body_aid_list) assert part_gids == body_gids, 'can only compute assignment features on aids in the same image' - parts_are_parts = _are_part_annots(ibs, part_aid_list) + parts_are_parts = ibs._are_part_annots(part_aid_list) assert all(parts_are_parts), 'all part_aids must be part annots.' - bodies_are_parts = _are_part_annots(ibs, body_aid_list) + bodies_are_parts = ibs._are_part_annots(body_aid_list) assert not any(bodies_are_parts), 'body_aids cannot be part annots' part_bboxes = ibs.get_annot_bboxes(part_aid_list) @@ -2648,9 +2648,9 @@ def standardized_assignment_features(depc, part_aid_list, body_aid_list, config= part_gids = ibs.get_annot_gids(part_aid_list) body_gids = ibs.get_annot_gids(body_aid_list) assert part_gids == body_gids, 'can only compute assignment features on aids in the same image' - parts_are_parts = _are_part_annots(ibs, part_aid_list) + parts_are_parts = ibs._are_part_annots(part_aid_list) assert all(parts_are_parts), 'all part_aids must be part annots.' - bodies_are_parts = _are_part_annots(ibs, body_aid_list) + bodies_are_parts = ibs._are_part_annots(body_aid_list) assert not any(bodies_are_parts), 'body_aids cannot be part annots' part_bboxes = ibs.get_annot_bboxes(part_aid_list) @@ -2724,9 +2724,9 @@ def mega_standardized_assignment_features(depc, part_aid_list, body_aid_list, co part_gids = ibs.get_annot_gids(part_aid_list) body_gids = ibs.get_annot_gids(body_aid_list) assert part_gids == body_gids, 'can only compute assignment features on aids in the same image' - parts_are_parts = _are_part_annots(ibs, part_aid_list) + parts_are_parts = ibs._are_part_annots(part_aid_list) assert all(parts_are_parts), 'all part_aids must be part annots.' - bodies_are_parts = _are_part_annots(ibs, body_aid_list) + bodies_are_parts = ibs._are_part_annots(body_aid_list) assert not any(bodies_are_parts), 'body_aids cannot be part annots' part_bboxes = ibs.get_annot_bboxes(part_aid_list) @@ -2739,8 +2739,6 @@ def mega_standardized_assignment_features(depc, part_aid_list, body_aid_list, co part_bboxes = _standardized_bboxes(part_bboxes) body_bboxes = _standardized_bboxes(body_bboxes) - - part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] part_area_relative_body = [part_area / body_area @@ -2809,9 +2807,9 @@ def theta_assignment_features(depc, part_aid_list, body_aid_list, config=None): part_gids = ibs.get_annot_gids(part_aid_list) body_gids = ibs.get_annot_gids(body_aid_list) assert part_gids == body_gids, 'can only compute assignment features on aids in the same image' - parts_are_parts = _are_part_annots(ibs, part_aid_list) + parts_are_parts = ibs._are_part_annots(part_aid_list) assert all(parts_are_parts), 'all part_aids must be part annots.' - bodies_are_parts = _are_part_annots(ibs, body_aid_list) + bodies_are_parts = ibs._are_part_annots(body_aid_list) assert not any(bodies_are_parts), 'body_aids cannot be part annots' im_widths = ibs.get_image_widths(part_gids) @@ -2829,7 +2827,6 @@ def theta_assignment_features(depc, part_aid_list, body_aid_list, config=None): # just to make int_areas more comparable via ML methods, and since all distances < 1 int_area_scalars = [math.sqrt(area) for area in intersect_areas] - part_bboxes = ibs.get_annot_bboxes(part_aid_list) body_bboxes = ibs.get_annot_bboxes(body_aid_list) part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) @@ -2839,7 +2836,7 @@ def theta_assignment_features(depc, part_aid_list, body_aid_list, config=None): union_areas = [part + body - intersect for (part, body, intersect) in zip(part_areas, body_areas, intersect_areas)] int_over_unions = [intersect / union for (intersect, union) - in zip(intersect_areas, union_areas)] + in zip(intersect_areas, union_areas)] part_body_distances = [part.distance(body) for part, body in zip(part_polys, body_polys)] @@ -2848,16 +2845,16 @@ def theta_assignment_features(depc, part_aid_list, body_aid_list, config=None): body_centroids = [poly.centroid for poly in body_polys] part_body_centroid_dists = [part.distance(body) for part, body - in zip(part_centroids, body_centroids)] + in zip(part_centroids, body_centroids)] int_over_parts = [int_area / part_area for part_area, int_area - in zip(part_areas, intersect_areas)] + in zip(part_areas, intersect_areas)] int_over_bodys = [int_area / body_area for body_area, int_area - in zip(body_areas, intersect_areas)] + in zip(body_areas, intersect_areas)] part_over_bodys = [part_area / body_area for part_area, body_area - in zip(part_areas, body_areas)] + in zip(part_areas, body_areas)] # note that here only parts have thetas, hence only returning body bboxes result_list = list(zip( @@ -2917,9 +2914,9 @@ def theta_standardized_assignment_features(depc, part_aid_list, body_aid_list, c part_gids = ibs.get_annot_gids(part_aid_list) body_gids = ibs.get_annot_gids(body_aid_list) assert part_gids == body_gids, 'can only compute assignment features on aids in the same image' - parts_are_parts = _are_part_annots(ibs, part_aid_list) + parts_are_parts = ibs._are_part_annots(part_aid_list) assert all(parts_are_parts), 'all part_aids must be part annots.' - bodies_are_parts = _are_part_annots(ibs, body_aid_list) + bodies_are_parts = ibs._are_part_annots(body_aid_list) assert not any(bodies_are_parts), 'body_aids cannot be part annots' im_widths = ibs.get_image_widths(part_gids) @@ -3027,12 +3024,6 @@ def _standardized_bboxes(bbox_list): return standardized_bboxes -def _are_part_annots(ibs, aid_list): - species = ibs.get_annot_species(aid_list) - are_parts = ['+' in specie for specie in species] - return are_parts - - def _bbox_intersections(bboxes_a, bboxes_b): corner_bboxes_a = _bbox_to_corner_format(bboxes_a) corner_bboxes_b = _bbox_to_corner_format(bboxes_b) @@ -3096,3 +3087,4 @@ def _bbox_to_corner_format(bboxes): return corner_bboxes +core_annots.py From b85c9104a0498bcdfa25c036013dbf9769e17122 Mon Sep 17 00:00:00 2001 From: Drew Blount Date: Mon, 4 Jan 2021 17:15:14 -0800 Subject: [PATCH 142/294] adds wild dog assigner testdb zip url --- wbia/constants.py | 3 +++ wbia/init/sysres.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/wbia/constants.py b/wbia/constants.py index 61ec041698..fef2d8f3a5 100644 --- a/wbia/constants.py +++ b/wbia/constants.py @@ -354,6 +354,9 @@ class ZIPPED_URLS(object): # NOQA ORIENTATION = ( 'https://wildbookiarepository.azureedge.net/databases/testdb_orientation.zip' ) + ASSIGNER = ( + 'https://wildbookiarepository.azureedge.net/databases/testdb_assigner.zip' + ) K7_EXAMPLE = 'https://wildbookiarepository.azureedge.net/databases/testdb_kaggle7.zip' diff --git a/wbia/init/sysres.py b/wbia/init/sysres.py index f6743a02a3..69fd0f2069 100644 --- a/wbia/init/sysres.py +++ b/wbia/init/sysres.py @@ -865,6 +865,10 @@ def ensure_testdb_identification_example(): return ensure_db_from_url(const.ZIPPED_URLS.ID_EXAMPLE) +def ensure_testdb_assigner(): + return ensure_db_from_url(const.ZIPPED_URLS.ASSIGNER) + + def ensure_testdb_kaggle7(): return ensure_db_from_url(const.ZIPPED_URLS.K7_EXAMPLE) From c480f2f66e57271e83d6cef07f0fb308968f5a15 Mon Sep 17 00:00:00 2001 From: Drew Blount Date: Mon, 4 Jan 2021 17:18:04 -0800 Subject: [PATCH 143/294] adds a bunch more assigner features, most of which were not super useful in assigner training. still deciding if we want to keep all of these --- wbia/core_annots.py | 263 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 261 insertions(+), 2 deletions(-) diff --git a/wbia/core_annots.py b/wbia/core_annots.py index da61ed92ee..65dcdf21bd 100644 --- a/wbia/core_annots.py +++ b/wbia/core_annots.py @@ -59,6 +59,8 @@ from wbia.algo.hots.chip_match import ChipMatch from wbia.algo.hots import neighbor_index +from math import sqrt + from sklearn import preprocessing (print, rrr, profile) = ut.inject2(__name__) @@ -2880,6 +2882,265 @@ def theta_assignment_features(depc, part_aid_list, body_aid_list, config=None): ) +# just like theta_assignement_features above but with a one-hot encoding of viewpoints +# viewpoints are a boolean value for each viewpoint. will possibly need to modify this for other species +@derived_attribute( + tablename='assigner_viewpoint_features', + parents=['annotations', 'annotations'], + colnames=[ + 'p_v1_x', 'p_v1_y', 'p_v2_x', 'p_v2_y', 'p_v3_x', 'p_v3_y', 'p_v4_x', 'p_v4_y', + 'p_center_x', 'p_center_y', + 'b_xtl', 'b_ytl', 'b_xbr', 'b_ybr', 'b_center_x', 'b_center_y', + 'int_area_scalar', 'part_body_distance', + 'part_body_centroid_dist', + 'int_over_union', + 'int_over_part', + 'int_over_body', + 'part_over_body', + 'part_is_left', 'part_is_right', 'part_is_up', 'part_is_down', 'part_is_front', 'part_is_back', + 'body_is_left', 'body_is_right', 'body_is_up', 'body_is_down', 'body_is_front', 'body_is_back', + ], + coltypes=[ + float, float, float, float, float, float, float, float, float, float, + float, float, float, float, float, float, + float, float, float, float, float, float, float, + bool, bool, bool, bool, bool, bool, + bool, bool, bool, bool, bool, bool, + ], + configclass=PartAssignmentFeatureConfig, + fname='assigner_viewpoint_features', + rm_extern_on_delete=True, + chunksize=256, # chunk size is huge bc we need accurate means and stdevs of various traits +) +def assigner_viewpoint_features(depc, part_aid_list, body_aid_list, config=None): + + from shapely import geometry + import math + + ibs = depc.controller + + part_gids = ibs.get_annot_gids(part_aid_list) + body_gids = ibs.get_annot_gids(body_aid_list) + assert part_gids == body_gids, 'can only compute assignment features on aids in the same image' + parts_are_parts = ibs._are_part_annots(part_aid_list) + assert all(parts_are_parts), 'all part_aids must be part annots.' + bodies_are_parts = ibs._are_part_annots(body_aid_list) + assert not any(bodies_are_parts), 'body_aids cannot be part annots' + + im_widths = ibs.get_image_widths(part_gids) + im_heights = ibs.get_image_heights(part_gids) + + part_verts = ibs.get_annot_rotated_verts(part_aid_list) + body_verts = ibs.get_annot_rotated_verts(body_aid_list) + part_verts = _norm_vertices(part_verts, im_widths, im_heights) + body_verts = _norm_vertices(body_verts, im_widths, im_heights) + part_polys = [geometry.Polygon(vert) for vert in part_verts] + body_polys = [geometry.Polygon(vert) for vert in body_verts] + intersect_polys = [part.intersection(body) + for part, body in zip(part_polys, body_polys)] + intersect_areas = [poly.area for poly in intersect_polys] + # just to make int_areas more comparable via ML methods, and since all distances < 1 + int_area_scalars = [math.sqrt(area) for area in intersect_areas] + + part_bboxes = ibs.get_annot_bboxes(part_aid_list) + body_bboxes = ibs.get_annot_bboxes(body_aid_list) + part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) + body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) + part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] + body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] + union_areas = [part + body - intersect for (part, body, intersect) + in zip(part_areas, body_areas, intersect_areas)] + int_over_unions = [intersect / union for (intersect, union) + in zip(intersect_areas, union_areas)] + + part_body_distances = [part.distance(body) + for part, body in zip(part_polys, body_polys)] + + part_centroids = [poly.centroid for poly in part_polys] + body_centroids = [poly.centroid for poly in body_polys] + + part_body_centroid_dists = [part.distance(body) for part, body + in zip(part_centroids, body_centroids)] + + int_over_parts = [int_area / part_area for part_area, int_area + in zip(part_areas, intersect_areas)] + + int_over_bodys = [int_area / body_area for body_area, int_area + in zip(body_areas, intersect_areas)] + + part_over_bodys = [part_area / body_area for part_area, body_area + in zip(part_areas, body_areas)] + + part_lrudfb_bools = get_annot_lrudfb_bools(ibs, part_aid_list) + body_lrudfb_bools = get_annot_lrudfb_bools(ibs, part_aid_list) + + # note that here only parts have thetas, hence only returning body bboxes + result_list = list(zip( + part_verts, part_centroids, body_bboxes, body_centroids, + int_area_scalars, part_body_distances, part_body_centroid_dists, + int_over_unions, int_over_parts, int_over_bodys, part_over_bodys, + part_lrudfb_bools, body_lrudfb_bools + )) + + for (part_vert, part_center, body_bbox, body_center, + int_area_scalar, part_body_distance, part_body_centroid_dist, + int_over_union, int_over_part, int_over_body, part_over_body, + part_lrudfb_bool, body_lrudfb_bool) in result_list: + ans = (part_vert[0][0], part_vert[0][1], part_vert[1][0], part_vert[1][1], + part_vert[2][0], part_vert[2][1], part_vert[3][0], part_vert[3][1], + part_center.x, part_center.y, + body_bbox[0], body_bbox[1], body_bbox[2], body_bbox[3], + body_center.x, body_center.y, + int_area_scalar, part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body) + ans += tuple(part_lrudfb_bool) + ans += tuple(body_lrudfb_bool) + yield ans + + +@derived_attribute( + tablename='assigner_viewpoint_unit_features', + parents=['annotations', 'annotations'], + colnames=[ + 'p_v1_x', 'p_v1_y', 'p_v2_x', 'p_v2_y', 'p_v3_x', 'p_v3_y', 'p_v4_x', 'p_v4_y', + 'p_center_x', 'p_center_y', + 'b_xtl', 'b_ytl', 'b_xbr', 'b_ybr', 'b_center_x', 'b_center_y', + 'int_area_scalar', 'part_body_distance', + 'part_body_centroid_dist', + 'int_over_union', + 'int_over_part', + 'int_over_body', + 'part_over_body', + 'part_is_left', 'part_is_right', 'part_is_up', 'part_is_down', 'part_is_front', 'part_is_back', + 'body_is_left', 'body_is_right', 'body_is_up', 'body_is_down', 'body_is_front', 'body_is_back', + ], + coltypes=[ + float, float, float, float, float, float, float, float, float, float, + float, float, float, float, float, float, + float, float, float, float, float, float, float, + float, float, float, float, float, float, + float, float, float, float, float, float, + ], + configclass=PartAssignmentFeatureConfig, + fname='assigner_viewpoint_unit_features', + rm_extern_on_delete=True, + chunksize=256, # chunk size is huge bc we need accurate means and stdevs of various traits +) +def assigner_viewpoint_unit_features(depc, part_aid_list, body_aid_list, config=None): + + from shapely import geometry + import math + + ibs = depc.controller + + part_gids = ibs.get_annot_gids(part_aid_list) + body_gids = ibs.get_annot_gids(body_aid_list) + assert part_gids == body_gids, 'can only compute assignment features on aids in the same image' + parts_are_parts = ibs._are_part_annots(part_aid_list) + assert all(parts_are_parts), 'all part_aids must be part annots.' + bodies_are_parts = ibs._are_part_annots(body_aid_list) + assert not any(bodies_are_parts), 'body_aids cannot be part annots' + + im_widths = ibs.get_image_widths(part_gids) + im_heights = ibs.get_image_heights(part_gids) + + part_verts = ibs.get_annot_rotated_verts(part_aid_list) + body_verts = ibs.get_annot_rotated_verts(body_aid_list) + part_verts = _norm_vertices(part_verts, im_widths, im_heights) + body_verts = _norm_vertices(body_verts, im_widths, im_heights) + part_polys = [geometry.Polygon(vert) for vert in part_verts] + body_polys = [geometry.Polygon(vert) for vert in body_verts] + intersect_polys = [part.intersection(body) + for part, body in zip(part_polys, body_polys)] + intersect_areas = [poly.area for poly in intersect_polys] + # just to make int_areas more comparable via ML methods, and since all distances < 1 + int_area_scalars = [math.sqrt(area) for area in intersect_areas] + + part_bboxes = ibs.get_annot_bboxes(part_aid_list) + body_bboxes = ibs.get_annot_bboxes(body_aid_list) + part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) + body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) + part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] + body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] + union_areas = [part + body - intersect for (part, body, intersect) + in zip(part_areas, body_areas, intersect_areas)] + int_over_unions = [intersect / union for (intersect, union) + in zip(intersect_areas, union_areas)] + + part_body_distances = [part.distance(body) + for part, body in zip(part_polys, body_polys)] + + part_centroids = [poly.centroid for poly in part_polys] + body_centroids = [poly.centroid for poly in body_polys] + + part_body_centroid_dists = [part.distance(body) for part, body + in zip(part_centroids, body_centroids)] + + int_over_parts = [int_area / part_area for part_area, int_area + in zip(part_areas, intersect_areas)] + + int_over_bodys = [int_area / body_area for body_area, int_area + in zip(body_areas, intersect_areas)] + + part_over_bodys = [part_area / body_area for part_area, body_area + in zip(part_areas, body_areas)] + + part_lrudfb_vects = get_annot_lrudfb_unit_vector(ibs, part_aid_list) + body_lrudfb_vects = get_annot_lrudfb_unit_vector(ibs, part_aid_list) + + # note that here only parts have thetas, hence only returning body bboxes + result_list = list(zip( + part_verts, part_centroids, body_bboxes, body_centroids, + int_area_scalars, part_body_distances, part_body_centroid_dists, + int_over_unions, int_over_parts, int_over_bodys, part_over_bodys, + part_lrudfb_vects, body_lrudfb_vects + )) + + for (part_vert, part_center, body_bbox, body_center, + int_area_scalar, part_body_distance, part_body_centroid_dist, + int_over_union, int_over_part, int_over_body, part_over_body, + part_lrudfb_vect, body_lrudfb_vect) in result_list: + ans = (part_vert[0][0], part_vert[0][1], part_vert[1][0], part_vert[1][1], + part_vert[2][0], part_vert[2][1], part_vert[3][0], part_vert[3][1], + part_center.x, part_center.y, + body_bbox[0], body_bbox[1], body_bbox[2], body_bbox[3], + body_center.x, body_center.y, + int_area_scalar, part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body) + ans += tuple(part_lrudfb_vect) + ans += tuple(body_lrudfb_vect) + yield ans + + +# left, right, up, down, front, back booleans, useful for assigner classification and other cases where we might want viewpoint as an input for an ML model +def get_annot_lrudfb_bools(ibs, aid_list): + views = ibs.get_annot_viewpoints(aid_list) + bool_arrays = [['left' in view, 'right' in view, + 'up' in view, 'down' in view, + 'front' in view, 'back' in view] for view in views] + return bool_arrays + + +def get_annot_lrudfb_unit_vector(ibs, aid_list): + bool_arrays = get_annot_lrudfb_bools(ibs, aid_list) + float_arrays = [[float(b) for b in lrudfb] for lrudfb in bool_arrays] + lrudfb_lengths = [sqrt(lrudfb.count(True)) for lrudfb in bool_arrays] + # lying just to avoid division by zero errors + lrudfb_lengths = [l if l != 0 else -1 for l in lrudfb_lengths] + unit_float_array = [[f / length for f in lrudfb] for lrudfb, length + in zip(float_arrays, lrudfb_lengths)] + + return unit_float_array + + @derived_attribute( tablename='theta_standardized_assignment_features', parents=['annotations', 'annotations'], @@ -3086,5 +3347,3 @@ def _bbox_to_corner_format(bboxes): for bbox in bboxes] return corner_bboxes - -core_annots.py From f80a0369bdbfcef2a80261f0621884cf1d1c16e3 Mon Sep 17 00:00:00 2001 From: Drew Blount Date: Mon, 4 Jan 2021 17:48:04 -0800 Subject: [PATCH 144/294] splits assigner into training and deploy parts, adds tests. --- wbia/algo/detect/__init__.py | 1 + wbia/algo/detect/assigner.py | 572 ++++++++++------------------- wbia/algo/detect/train_assigner.py | 477 ++++++++++++++++++++++++ 3 files changed, 667 insertions(+), 383 deletions(-) create mode 100644 wbia/algo/detect/train_assigner.py diff --git a/wbia/algo/detect/__init__.py b/wbia/algo/detect/__init__.py index f400e72594..416d644ada 100644 --- a/wbia/algo/detect/__init__.py +++ b/wbia/algo/detect/__init__.py @@ -7,6 +7,7 @@ from wbia.algo.detect import yolo from wbia.algo.detect import assigner +# from wbia.algo.detect import train_assigner # from wbia.algo.detect import selectivesearch # from wbia.algo.detect import ssd # from wbia.algo.detect import fasterrcnn diff --git a/wbia/algo/detect/assigner.py b/wbia/algo/detect/assigner.py index 3d650ec0bf..75b1dc8c97 100644 --- a/wbia/algo/detect/assigner.py +++ b/wbia/algo/detect/assigner.py @@ -10,27 +10,17 @@ from datetime import datetime import time +# illustration imports +from shutil import copy +from PIL import Image, ImageDraw +import wbia.plottool as pt + + from sklearn import preprocessing from tune_sklearn import TuneGridSearchCV # shitload of scikit classifiers import numpy as np -import matplotlib.pyplot as plt -from matplotlib.colors import ListedColormap -from sklearn.model_selection import train_test_split -from sklearn.preprocessing import StandardScaler -from sklearn.datasets import make_moons, make_circles, make_classification -from sklearn.neural_network import MLPClassifier -from sklearn.neighbors import KNeighborsClassifier -from sklearn.svm import SVC -from sklearn.gaussian_process import GaussianProcessClassifier -from sklearn.gaussian_process.kernels import RBF -from sklearn.tree import DecisionTreeClassifier -from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier -from sklearn.naive_bayes import GaussianNB -from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis -from sklearn.model_selection import GridSearchCV - # bunch of classifier models for training @@ -47,275 +37,72 @@ INMEM_ASSIGNER_MODELS = {} SPECIES_TO_ASSIGNER_MODELFILE = { - 'wild_dog': '/tmp/assigner_model.joblib', - 'wild_dog_dark': '/tmp/assigner_model.joblib', - 'wild_dog_light': '/tmp/assigner_model.joblib', - 'wild_dog_puppy': '/tmp/assigner_model.joblib', - 'wild_dog_standard': '/tmp/assigner_model.joblib', - 'wild_dog_tan': '/tmp/assigner_model.joblib', + 'wild_dog': '/tmp/balanced_wd.joblib', + 'wild_dog_dark': '/tmp/balanced_wd.joblib', + 'wild_dog_light': '/tmp/balanced_wd.joblib', + 'wild_dog_puppy': '/tmp/balanced_wd.joblib', + 'wild_dog_standard': '/tmp/balanced_wd.joblib', + 'wild_dog_tan': '/tmp/balanced_wd.joblib', } -CLASSIFIER_OPTIONS = [ - # { - # "name": "Nearest Neighbors", - # "clf": KNeighborsClassifier(3), - # "param_options": { - # 'n_neighbors': [3,5,11,19], - # 'weights': ['uniform', 'distance'], - # 'metric': ['euclidean', 'manhattan'], - # } - # }, - # { - # "name": "Linear SVM", - # "clf": SVC(kernel="linear", C=0.025), - # "param_options": { - # 'C': [1, 10, 100, 1000], - # 'kernel': ['linear'], - # } - # }, - # { - # "name": "RBF SVM", - # "clf": SVC(gamma=2, C=1), - # "param_options": { - # 'C': [1, 10, 100, 1000], - # 'gamma': [0.001, 0.0001], - # 'kernel': ['rbf'] - # }, - # }, - { - "name": "Decision Tree", - "clf": DecisionTreeClassifier(), # max_depth=5 - "param_options": { - 'max_depth': np.arange(1, 12), - 'max_leaf_nodes': [2, 5, 10, 20, 50, 100] - } - }, - # { - # "name": "Random Forest", - # "clf": RandomForestClassifier(), #max_depth=5, n_estimators=10, max_features=1 - # "param_options": { - # 'bootstrap': [True, False], - # 'max_depth': [10, 50, 100, None], - # 'max_features': ['auto', 'sqrt'], - # 'min_samples_leaf': [1, 2, 4], - # 'min_samples_split': [2, 5, 10], - # 'n_estimators': [200, 1000, 1500, 2000] - # } - # }, - # { - # "name": "Neural Net", - # "clf": MLPClassifier(), #alpha=1, max_iter=1000 - # "param_options": { - # 'hidden_layer_sizes': [(10,30,10),(20,)], - # 'activation': ['tanh', 'relu'], - # 'solver': ['sgd', 'adam'], - # 'alpha': [0.0001, 0.05], - # 'learning_rate': ['constant','adaptive'], - # } - # }, - # { - # "name": "AdaBoost", - # "clf": AdaBoostClassifier(), - # "param_options": { - # 'n_estimators': np.arange(10, 310, 50), - # 'learning_rate': [0.01, 0.05, 0.1, 1], - # } - # }, - # { - # "name": "Naive Bayes", - # "clf": GaussianNB(), - # "param_options": {} # no hyperparams to optimize - # }, - # { - # "name": "QDA", - # "clf": QuadraticDiscriminantAnalysis(), - # "param_options": { - # 'reg_param': [0.1, 0.2, 0.3, 0.4, 0.5] - # } - # } -] - -# for model exploration -classifier_names = ["Nearest Neighbors", "Linear SVM", "RBF SVM", - "Decision Tree", "Random Forest", "Neural Net", "AdaBoost", - "Naive Bayes", "QDA"] - -classifiers = [ - KNeighborsClassifier(3), - SVC(kernel="linear", C=0.025), - SVC(gamma=2, C=1), - DecisionTreeClassifier(max_depth=5), - RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1), - MLPClassifier(alpha=1, max_iter=1000), - AdaBoostClassifier(), - GaussianNB(), - QuadraticDiscriminantAnalysis()] - - -slow_classifier_names = "Gaussian Process" -slow_classifiers = GaussianProcessClassifier(1.0 * RBF(1.0)), - - -def classifier_report(clf, name, assigner_data): - print('%s CLASSIFIER REPORT ' % name) - print(' %s: calling clf.fit' % str(datetime.now())) - clf.fit(assigner_data['data'], assigner_data['target']) - print(' %s: done training, making prediction ' % str(datetime.now())) - preds = clf.predict(assigner_data['test']) - print(' %s: done with predictions, computing accuracy' % str(datetime.now())) - agree = [pred == truth for pred, truth in zip(preds, assigner_data['test_truth'])] - accuracy = agree.count(True) / len(agree) - print(' %s accuracy' % accuracy) - print() - return accuracy - - -@register_ibs_method -def compare_ass_classifiers(ibs, depc_table_name='theta_assignment_features', print_accs=False): - - assigner_data = ibs.wd_training_data(depc_table_name) - - accuracies = OrderedDict() - for classifier in CLASSIFIER_OPTIONS: - accuracy = classifier_report(classifier['clf'], classifier['name'], assigner_data) - accuracies[classifier['name']] = accuracy - - # handy for e.g. pasting into excel - if print_accs: - just_accuracy = [accuracies[name] for name in accuracies.keys()] - print(just_accuracy) - - return accuracies - - -@register_ibs_method -def tune_ass_classifiers(ibs, depc_table_name='theta_assignment_features'): - - assigner_data = ibs.wd_training_data(depc_table_name) - - accuracies = OrderedDict() - best_acc = 0 - best_clf_name = '' - best_clf_params = {} - for classifier in CLASSIFIER_OPTIONS: - print("Tuning %s" % classifier['name']) - accuracy, best_params = ibs._tune_grid_search(classifier['clf'], classifier['param_options'], assigner_data) - print() - accuracies[classifier['name']] = { - 'accuracy': accuracy, - 'best_params': best_params - } - if accuracy > best_acc: - best_acc = accuracy - best_clf_name = classifier['name'] - best_clf_params = best_params - - print('best performance: %s using %s with params %s' % - best_acc, best_clf_name, best_clf_params) - - return accuracies - - -@register_ibs_method -def _tune_grid_search(ibs, clf, parameters, assigner_data=None): - if assigner_data is None: - assigner_data = ibs.wd_training_data() - - X_train = assigner_data['data'] - y_train = assigner_data['target'] - X_test = assigner_data['test'] - y_test = assigner_data['test_truth'] - - tune_search = TuneGridSearchCV( - clf, - parameters, - ) - - start = time.time() - tune_search.fit(X_train, y_train) - end = time.time() - print("Tune Fit Time: %s" % (end - start)) - pred = tune_search.predict(X_test) - accuracy = np.count_nonzero(np.array(pred) == np.array(y_test)) / len(pred) - print("Tune Accuracy: %s" % accuracy) - print("best parms : %s" % tune_search.best_params_) - - return accuracy, tune_search.best_params_ - - -@register_ibs_method -def _tune_random_search(ibs, clf, parameters, assigner_data=None): - if assigner_data is None: - assigner_data = ibs.wd_training_data() - - X_train = assigner_data['data'] - y_train = assigner_data['target'] - X_test = assigner_data['test'] - y_test = assigner_data['test_truth'] - - tune_search = GridSearchCV( - clf, - parameters, - ) - - start = time.time() - tune_search.fit(X_train, y_train) - end = time.time() - print("Tune Fit Time: %s" % (end - start)) - pred = tune_search.predict(X_test) - accuracy = np.count_nonzero(np.array(pred) == np.array(y_test)) / len(pred) - print("Tune Accuracy: %s" % accuracy) - print("best parms : %s" % tune_search.best_params_) - - return accuracy, tune_search.best_params_ - - -# for wild dog dev -@register_ibs_method -def wd_assigner_data(ibs): - return wd_training_data('part_assignment_features') - @register_ibs_method -def wd_normed_assigner_data(ibs): - return wd_training_data('normalized_assignment_features') - - -@register_ibs_method -def wd_training_data(ibs, depc_table_name='theta_assignment_features'): - all_aids = ibs.get_valid_aids() - ia_classes = ibs.get_annot_species(all_aids) - part_aids = [aid for aid, ia_class in zip(all_aids, ia_classes) if '+' in ia_class] - part_gids = list(set(ibs.get_annot_gids(part_aids))) - all_pairs = all_part_pairs(ibs, part_gids) - all_feats = ibs.depc_annot.get(depc_table_name, all_pairs) - names = [ibs.get_annot_names(all_pairs[0]), ibs.get_annot_names(all_pairs[1])] - ground_truth = [n1 == n2 for (n1, n2) in zip(names[0], names[1])] - - # train_feats, test_feats = train_test_split(all_feats) - # train_truth, test_truth = train_test_split(ground_truth) - pairs_in_train = ibs.gid_train_test_split(all_pairs[0]) # we could pass just the pair aids or just the body aids bc gids are the same - train_feats, test_feats = split_list(all_feats, pairs_in_train) - train_truth, test_truth = split_list(ground_truth, pairs_in_train) - - all_pairs_tuple = [(part, body) for part, body in zip(all_pairs[0], all_pairs[1])] - train_pairs, test_pairs = split_list(all_pairs_tuple, pairs_in_train) - - assigner_data = {'data': train_feats, 'target': train_truth, - 'test': test_feats, 'test_truth': test_truth, - 'train_pairs': train_pairs, 'test_pairs': test_pairs} +def _are_part_annots(ibs, aid_list): + r""" + returns a boolean list representing if each aid in aid_list is a part annot. + This determination is made by the presence of a '+' in the species. - return assigner_data + Args: + ibs (IBEISController): IBEIS / WBIA controller object + aid_list (int): annot ids to split + CommandLine: + python -m wbia.algo.detect.assigner _are_part_annots -@register_ibs_method -def _are_part_annots(ibs, aid_list): + Example: + >>> # ENABLE_DOCTEST + >>> import utool as ut + >>> from wbia.algo.detect.assigner import * + >>> from wbia.algo.detect.train_assigner import * + >>> ibs = assigner_testdb_ibs() + >>> aids = ibs.get_valid_aids() + >>> result = ibs._are_part_annots(aids) + [False, False, True, True, False, True, False, True] + """ species = ibs.get_annot_species(aid_list) are_parts = ['+' in specie for specie in species] return are_parts def all_part_pairs(ibs, gid_list): + r""" + Returns all possible part,body pairs from aids in gid_list, in the format of + two parralel lists: the first being all parts, the second all bodies + + Args: + ibs (IBEISController): IBEIS / WBIA controller object + gid_list (int): gids in question + + CommandLine: + python -m wbia.algo.detect.assigner _are_part_annots + + Example: + >>> # ENABLE_DOCTEST + >>> import utool as ut + >>> from wbia.algo.detect.assigner import * + >>> from wbia.algo.detect.train_assigner import * + >>> ibs = assigner_testdb_ibs() + >>> gids = ibs.get_valid_gids() + >>> all_part_pairs = all_part_pairs(ibs, gids) + >>> parts = all_part_pairs[0] + >>> bodies = all_part_pairs[1] + >>> all_aids = ibs.get_image_aids(gids) + >>> all_aids = [aid for aids in all_aids for aid in aids] # flatten + >>> assert (set(parts) & set(bodies)) == set({}) + >>> assert (set(parts) | set(bodies)) == set(all_aids) + >>> result = all_part_pairs + ([3, 3, 4, 4, 6, 8], [1, 2, 1, 2, 5, 7]) + """ all_aids = ibs.get_image_aids(gid_list) all_aids_are_parts = [ibs._are_part_annots(aids) for aids in all_aids] all_part_aids = [[aid for (aid, part) in zip(aids, are_parts) if part] for (aids, are_parts) in zip(all_aids, all_aids_are_parts)] @@ -329,102 +116,99 @@ def all_part_pairs(ibs, gid_list): def _all_pairs_parallel(list_a, list_b): + # is tested by all_part_pairs above pairs = [(a, b) for a in list_a for b in list_b] pairs_a = [pair[0] for pair in pairs] pairs_b = [pair[1] for pair in pairs] return pairs_a, pairs_b -def train_test_split(item_list, random_seed=777, test_size=0.1): - import random - import math - random.seed(random_seed) - sample_size = math.floor(len(item_list) * test_size) - all_indices = list(range(len(item_list))) - test_indices = random.sample(all_indices, sample_size) - test_items = [item_list[i] for i in test_indices] - train_indices = sorted(list( - set(all_indices) - set(test_indices) - )) - train_items = [item_list[i] for i in train_indices] - return train_items, test_items - - @register_ibs_method -def gid_train_test_split(ibs, aid_list, random_seed=777, test_size=0.1): +def assign_parts(ibs, all_aids, cutoff_score=0.5): r""" - Makes a gid-wise train-test split. This avoids potential overfitting when a network - is trained on some annots from one image and tested on others from the same image. + Main assigner method; makes assignments on all_aids based on assigner scores. Args: ibs (IBEISController): IBEIS / WBIA controller object - aid_list (int): annot ids to split - random_seed: to make this split reproducible - test_size: portion of gids reserved for test data - - Yields: - a boolean flag_list of which aids are in the training set. Returning the flag_list - allows the user to filter multiple lists with one gid_train_test_split call + aid_list (int): aids in question + cutoff_score: the threshold for the aids' assigner scores, under which no assignments are made + Returns: + tuple of two lists: all_assignments (a list of tuples, each tuple grouping + aids assigned to a single animal), and all_unassigned_aids, which are the aids that did not meet the cutoff_score or whose body/part CommandLine: - python -m algo.detect.assigner gid_train_test_split + python -m wbia.algo.detect.assigner _are_part_annots Example: >>> # ENABLE_DOCTEST >>> import utool as ut >>> from wbia.algo.detect.assigner import * + >>> from wbia.algo.detect.train_assigner import * >>> ibs = assigner_testdb_ibs() - >>> #TODO: get input aids somehow - >>> train_flag_list = ibs.gid_train_test_split(aids) - >>> train_aids, test_aids = split_list(aids, train_flag_list) - >>> train_gids = ibs.get_annot_gids(train_aids) - >>> test_gids = ibs.get_annot_gids(test_aids) - >>> train_gid_set = set(train_gids) - >>> test_gid_set = set(test_gids) - >>> assert len(train_gid_set & test_gid_set) is 0 - >>> assert len(train_gids) + len(test_gids) == len(aids) + >>> aids = ibs.get_valid_aids() + >>> result = ibs.assign_parts(aids) + >>> assigned_pairs = result[0] + >>> unassigned_aids = result[1] + >>> assigned_aids = [item for pair in assigned_pairs for item in pair] + >>> # no overlap between assigned and unassigned aids + >>> assert (set(assigned_aids) & set(unassigned_aids) == set({})) + >>> # all aids are either assigned or unassigned + >>> assert (set(assigned_aids) | set(unassigned_aids) == set(aids)) + >>> ([(3, 1), (6, 5), (8, 7)], [2, 4]) """ - print('calling gid_train_test_split') - gid_list = ibs.get_annot_gids(aid_list) - gid_set = list(set(gid_list)) - import random - import math - random.seed(random_seed) - n_test_gids = math.floor(len(gid_set) * test_size) - test_gids = set(random.sample(gid_set, n_test_gids)) - aid_in_train = [gid not in test_gids for gid in gid_list] - return aid_in_train - - -def split_list(item_list, is_in_first_group_list): - first_group = ut.compress(item_list, is_in_first_group_list) - is_in_second_group = [not b for b in is_in_first_group_list] - second_group = ut.compress(item_list, is_in_second_group) - return first_group, second_group - - -@register_ibs_method -def assign_parts(ibs, all_aids, cutoff_score=0.5): gids = ibs.get_annot_gids(all_aids) gid_to_aids = defaultdict(list) for gid, aid in zip(gids, all_aids): gid_to_aids[gid] += [aid] - all_pairs = [] + all_assignments = [] all_unassigned_aids = [] for gid in gid_to_aids.keys(): this_pairs, this_unassigned = _assign_parts_one_image(ibs, gid_to_aids[gid], cutoff_score) - all_pairs += (this_pairs) + all_assignments += (this_pairs) all_unassigned_aids += this_unassigned - return all_pairs, all_unassigned_aids + return all_assignments, all_unassigned_aids @register_ibs_method def _assign_parts_one_image(ibs, aid_list, cutoff_score=0.5): + r""" + Main assigner method; makes assignments on all_aids based on assigner scores. + + Args: + ibs (IBEISController): IBEIS / WBIA controller object + aid_list (int): aids in question + cutoff_score: the threshold for the aids' assigner scores, under which no assignments are made + + Returns: + tuple of two lists: all_assignments (a list of tuples, each tuple grouping + aids assigned to a single animal), and all_unassigned_aids, which are the aids that did not meet the cutoff_score or whose body/part + + CommandLine: + python -m wbia.algo.detect.assigner _are_part_annots + + Example: + >>> # ENABLE_DOCTEST + >>> import utool as ut + >>> from wbia.algo.detect.assigner import * + >>> from wbia.algo.detect.train_assigner import * + >>> ibs = assigner_testdb_ibs() + >>> gid = 1 + >>> aids = ibs.get_image_aids(gid) + >>> result = ibs._assign_parts_one_image(aids) + >>> assigned_pairs = result[0] + >>> unassigned_aids = result[1] + >>> assigned_aids = [item for pair in assigned_pairs for item in pair] + >>> # no overlap between assigned and unassigned aids + >>> assert (set(assigned_aids) & set(unassigned_aids) == set({})) + >>> # all aids are either assigned or unassigned + >>> assert (set(assigned_aids) | set(unassigned_aids) == set(aids)) + >>> ([(3, 1)], [2, 4]) + """ are_part_aids = _are_part_annots(ibs, aid_list) part_aids = ut.compress(aid_list, are_part_aids) @@ -432,14 +216,13 @@ def _assign_parts_one_image(ibs, aid_list, cutoff_score=0.5): gids = ibs.get_annot_gids(list(set(part_aids)) + list(set(body_aids))) num_images = len(set(gids)) - assert num_images is 1 + assert num_images is 1, "_assign_parts_one_image called on multiple images' aids" # parallel lists representing all possible part/body pairs all_pairs_parallel = _all_pairs_parallel(part_aids, body_aids) pair_parts, pair_bodies = all_pairs_parallel - - assigner_features = ibs.depc_annot.get('theta_assignment_features', all_pairs_parallel) + assigner_features = ibs.depc_annot.get('assigner_viewpoint_features', all_pairs_parallel) assigner_classifier = load_assigner_classifier(ibs, part_aids) assigner_scores = assigner_classifier.predict_proba(assigner_features) @@ -500,52 +283,64 @@ def load_assigner_classifier(ibs, aid_list, fallback_species='wild_dog'): return clf -def check_accuracy(ibs, assigner_data, cutoff_score=0.5): - - all_aids = [] - for pair in assigner_data['test_pairs']: - all_aids.extend(list(pair)) - all_aids = sorted(list(set(all_aids))) - - all_pairs, all_unassigned_aids = ibs.assign_parts(all_aids, cutoff_score) - - gid_to_assigner_results = gid_keyed_assigner_results(ibs, all_pairs, all_unassigned_aids) - gid_to_ground_truth = gid_keyed_ground_truth(ibs, assigner_data) - - correct_gids = [] - incorrect_gids = [] - gids_with_false_positives = 0 - n_false_positives = 0 - gids_with_false_negatives = 0 - n_false_negatives = 0 - gids_with_both_errors = 0 - for gid in gid_to_assigner_results.keys(): - assigned_pairs = set(gid_to_assigner_results[gid]['pairs']) - ground_t_pairs = set(gid_to_ground_truth[gid]['pairs']) - false_negatives = len(ground_t_pairs - assigned_pairs) - false_positives = len(assigned_pairs - ground_t_pairs) - n_false_negatives += false_negatives - - if false_negatives > 0: - gids_with_false_negatives += 1 - n_false_positives += false_positives - if false_positives > 0: - gids_with_false_positives += 1 - if false_negatives > 0 and false_positives > 0: - gids_with_both_errors += 1 - - pairs_equal = sorted(gid_to_assigner_results[gid]['pairs']) == sorted(gid_to_ground_truth[gid]['pairs']) - if pairs_equal: - correct_gids += [gid] +def illustrate_all_assignments(ibs, gid_to_assigner_results, gid_to_ground_truth, + target_dir='/tmp/assigner-illustrations-2/', limit=20): + + correct_dir = os.path.join(target_dir, 'correct/') + incorrect_dir = os.path.join(target_dir, 'incorrect/') + + for gid, assigned_aid_dict in gid_to_assigner_results.items()[:limit]: + ground_t_dict = gid_to_ground_truth[gid] + assigned_correctly = sorted(assigned_aid_dict['pairs']) == sorted(ground_t_dict['pairs']) + if assigned_correctly: + illustrate_assignments(ibs, gid, assigned_aid_dict, None, correct_dir) # don't need to illustrate gtruth if it's identical to assignment else: - incorrect_gids +=[gid] + illustrate_assignments(ibs, gid, assigned_aid_dict, ground_t_dict, incorrect_dir) + + print('illustrated assignments and saved them in %s' % target_dir) + + +# works on a single gid's worth of gid_keyed_assigner_results output +def illustrate_assignments(ibs, gid, assigned_aid_dict, gtruth_aid_dict, target_dir='/tmp/assigner-illustrations/'): + impath = ibs.get_image_paths(gid) + imext = os.path.splitext(impath)[1] + new_fname = os.path.join(target_dir, '%s%s' % (gid, imext)) + os.makedirs(target_dir, exist_ok=True) + copy(impath, new_fname) + + with Image.open(new_fname) as image: + _draw_all_annots(ibs, image, assigned_aid_dict, gtruth_aid_dict) + image.save(new_fname) + + +def _draw_all_annots(ibs, image, assigned_aid_dict, gtruth_aid_dict): + n_pairs = len(assigned_aid_dict['pairs']) + n_missing_pairs = 0 + # TODO: missing pair shit + n_unass = len(assigned_aid_dict['unassigned']) + n_groups = n_pairs + n_unass + colors = _pil_distinct_colors(n_groups) - accuracy = len(correct_gids) / len(gid_to_assigner_results.keys()) - print('accuracy with cutoff of %s: %s' % (cutoff_score, accuracy)) - print(' %s false positives on %s error images' % (n_false_positives, gids_with_false_positives)) - print(' %s false negatives on %s error images' % (n_false_negatives, gids_with_false_negatives)) - print(' %s images with both errors' % (gids_with_both_errors)) - return accuracy + draw = ImageDraw.Draw(image) + for i, pair in enumerate(assigned_aid_dict['pairs']): + _draw_bbox(ibs, draw, pair[0], colors[i]) + _draw_bbox(ibs, draw, pair[1], colors[i]) + + for i, aid in enumerate(assigned_aid_dict['unassigned'], start=n_pairs): + _draw_bbox(ibs, draw, aid, colors[i]) + + +def _pil_distinct_colors(n_colors): + float_colors = pt.distinct_colors(n_colors) + int_colors = [tuple([int(256 * f) for f in color]) for color in float_colors] + return int_colors + + +def _draw_bbox(ibs, pil_draw, aid, color): + verts = ibs.get_annot_rotated_verts(aid) + pil_verts = [tuple(vertex) for vertex in verts] + pil_verts += pil_verts[:1] # for the line between the last and first vertex + pil_draw.line(pil_verts, color, width=4) def gid_keyed_assigner_results(ibs, all_pairs, all_unassigned_aids): @@ -603,13 +398,24 @@ def gid_keyed_ground_truth(ibs, assigner_data): return gid_to_assigner_results +@register_ibs_method +def assigner_testdb_ibs(): + # dbdir = sysres.ensure_testdb_assigner() + import wbia + dbdir = '/data/testdb_assigner' + ibs = wbia.opendb(dbdir=dbdir) + return ibs +if __name__ == '__main__': + r""" + CommandLine: + python -m wbia.algo.detect.assigner --allexamples + """ + import multiprocessing + multiprocessing.freeze_support() # for win32 + import utool as ut # NOQA - - - - - + ut.doctest_funcs() diff --git a/wbia/algo/detect/train_assigner.py b/wbia/algo/detect/train_assigner.py new file mode 100644 index 0000000000..a2aa62d0f5 --- /dev/null +++ b/wbia/algo/detect/train_assigner.py @@ -0,0 +1,477 @@ +import logging +from os.path import expanduser, join +from wbia import constants as const +from wbia.control.controller_inject import register_preprocs, register_subprops, make_ibs_register_decorator + +from wbia.algo.detect.assigner import gid_keyed_assigner_results, gid_keyed_ground_truth, illustrate_all_assignments + +import utool as ut +import numpy as np +import random +import os +from collections import OrderedDict, defaultdict +from datetime import datetime +import time + +# illustration imports +from shutil import copy +from PIL import Image, ImageDraw +import wbia.plottool as pt + + +import matplotlib.pyplot as plt +from matplotlib.colors import ListedColormap +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.datasets import make_moons, make_circles, make_classification +from sklearn.neural_network import MLPClassifier +from sklearn.neighbors import KNeighborsClassifier +from sklearn.svm import SVC +from sklearn.gaussian_process import GaussianProcessClassifier +from sklearn.gaussian_process.kernels import RBF +from sklearn.tree import DecisionTreeClassifier +from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier +from sklearn.naive_bayes import GaussianNB +from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis +from sklearn.model_selection import GridSearchCV + + +CLASS_INJECT_KEY, register_ibs_method = make_ibs_register_decorator(__name__) + + +CLASSIFIER_OPTIONS = [ + { + "name": "Nearest Neighbors", + "clf": KNeighborsClassifier(3), + "param_options": { + 'n_neighbors': [3, 5, 11, 19], + 'weights': ['uniform', 'distance'], + 'metric': ['euclidean', 'manhattan'], + } + }, + { + "name": "Linear SVM", + "clf": SVC(kernel="linear", C=0.025), + "param_options": { + 'C': [1, 10, 100, 1000], + 'kernel': ['linear'], + } + }, + { + "name": "RBF SVM", + "clf": SVC(gamma=2, C=1), + "param_options": { + 'C': [1, 10, 100, 1000], + 'gamma': [0.001, 0.0001], + 'kernel': ['rbf'] + }, + }, + { + "name": "Decision Tree", + "clf": DecisionTreeClassifier(), # max_depth=5 + "param_options": { + 'max_depth': np.arange(1, 12), + 'max_leaf_nodes': [2, 5, 10, 20, 50, 100] + } + }, + # { + # "name": "Random Forest", + # "clf": RandomForestClassifier(), #max_depth=5, n_estimators=10, max_features=1 + # "param_options": { + # 'bootstrap': [True, False], + # 'max_depth': [10, 50, 100, None], + # 'max_features': ['auto', 'sqrt'], + # 'min_samples_leaf': [1, 2, 4], + # 'min_samples_split': [2, 5, 10], + # 'n_estimators': [200, 1000, 2000] + # } + # }, + # { + # "name": "Neural Net", + # "clf": MLPClassifier(), #alpha=1, max_iter=1000 + # "param_options": { + # 'hidden_layer_sizes': [(10,30,10),(20,)], + # 'activation': ['tanh', 'relu'], + # 'solver': ['sgd', 'adam'], + # 'alpha': [0.0001, 0.05], + # 'learning_rate': ['constant','adaptive'], + # } + # }, + { + "name": "AdaBoost", + "clf": AdaBoostClassifier(), + "param_options": { + 'n_estimators': np.arange(10, 310, 50), + 'learning_rate': [0.01, 0.05, 0.1, 1], + } + }, + { + "name": "Naive Bayes", + "clf": GaussianNB(), + "param_options": {} # no hyperparams to optimize + }, + { + "name": "QDA", + "clf": QuadraticDiscriminantAnalysis(), + "param_options": { + 'reg_param': [0.1, 0.2, 0.3, 0.4, 0.5] + } + } +] + + +classifier_names = ["Nearest Neighbors", "Linear SVM", "RBF SVM", + "Decision Tree", "Random Forest", "Neural Net", "AdaBoost", + "Naive Bayes", "QDA"] + +classifiers = [ + KNeighborsClassifier(3), + SVC(kernel="linear", C=0.025), + SVC(gamma=2, C=1), + DecisionTreeClassifier(max_depth=5), + RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1), + MLPClassifier(alpha=1, max_iter=1000), + AdaBoostClassifier(), + GaussianNB(), + QuadraticDiscriminantAnalysis()] + + +slow_classifier_names = "Gaussian Process" +slow_classifiers = GaussianProcessClassifier(1.0 * RBF(1.0)), + + +def classifier_report(clf, name, assigner_data): + print('%s CLASSIFIER REPORT ' % name) + print(' %s: calling clf.fit' % str(datetime.now())) + clf.fit(assigner_data['data'], assigner_data['target']) + print(' %s: done training, making prediction ' % str(datetime.now())) + preds = clf.predict(assigner_data['test']) + print(' %s: done with predictions, computing accuracy' % str(datetime.now())) + agree = [pred == truth for pred, truth in zip(preds, assigner_data['test_truth'])] + accuracy = agree.count(True) / len(agree) + print(' %s accuracy' % accuracy) + print() + return accuracy + + +@register_ibs_method +def compare_ass_classifiers(ibs, depc_table_name='assigner_viewpoint_features', print_accs=False): + + assigner_data = ibs.wd_training_data(depc_table_name) + + accuracies = OrderedDict() + for classifier in CLASSIFIER_OPTIONS: + accuracy = classifier_report(classifier['clf'], classifier['name'], assigner_data) + accuracies[classifier['name']] = accuracy + + # handy for e.g. pasting into excel + if print_accs: + just_accuracy = [accuracies[name] for name in accuracies.keys()] + print(just_accuracy) + + return accuracies + + +@register_ibs_method +def tune_ass_classifiers(ibs, depc_table_name='assigner_viewpoint_unit_features'): + + assigner_data = ibs.wd_training_data(depc_table_name) + + accuracies = OrderedDict() + best_acc = 0 + best_clf_name = '' + best_clf_params = {} + for classifier in CLASSIFIER_OPTIONS: + print("Tuning %s" % classifier['name']) + accuracy, best_params = ibs._tune_grid_search(classifier['clf'], classifier['param_options'], assigner_data) + print() + accuracies[classifier['name']] = { + 'accuracy': accuracy, + 'best_params': best_params + } + if accuracy > best_acc: + best_acc = accuracy + best_clf_name = classifier['name'] + best_clf_params = best_params + + print('best performance: %s using %s with params %s' % + (best_acc, best_clf_name, best_clf_params)) + + return accuracies + + +@register_ibs_method +def _tune_grid_search(ibs, clf, parameters, assigner_data=None): + if assigner_data is None: + assigner_data = ibs.wd_training_data() + + X_train = assigner_data['data'] + y_train = assigner_data['target'] + X_test = assigner_data['test'] + y_test = assigner_data['test_truth'] + + tune_search = GridSearchCV( # TuneGridSearchCV( + clf, + parameters, + ) + + start = time.time() + tune_search.fit(X_train, y_train) + end = time.time() + print("Tune Fit Time: %s" % (end - start)) + pred = tune_search.predict(X_test) + accuracy = np.count_nonzero(np.array(pred) == np.array(y_test)) / len(pred) + print("Tune Accuracy: %s" % accuracy) + print("best parms : %s" % tune_search.best_params_) + + return accuracy, tune_search.best_params_ + + +@register_ibs_method +def _tune_random_search(ibs, clf, parameters, assigner_data=None): + if assigner_data is None: + assigner_data = ibs.wd_training_data() + + X_train = assigner_data['data'] + y_train = assigner_data['target'] + X_test = assigner_data['test'] + y_test = assigner_data['test_truth'] + + tune_search = GridSearchCV( + clf, + parameters, + ) + + start = time.time() + tune_search.fit(X_train, y_train) + end = time.time() + print("Tune Fit Time: %s" % (end - start)) + pred = tune_search.predict(X_test) + accuracy = np.count_nonzero(np.array(pred) == np.array(y_test)) / len(pred) + print("Tune Accuracy: %s" % accuracy) + print("best parms : %s" % tune_search.best_params_) + + return accuracy, tune_search.best_params_ + + +# for wild dog dev +@register_ibs_method +def wd_assigner_data(ibs): + return wd_training_data('part_assignment_features') + + +@register_ibs_method +def wd_normed_assigner_data(ibs): + return wd_training_data('normalized_assignment_features') + + +@register_ibs_method +def wd_training_data(ibs, depc_table_name='assigner_viewpoint_features', balance_t_f=True): + all_aids = ibs.get_valid_aids() + ia_classes = ibs.get_annot_species(all_aids) + part_aids = [aid for aid, ia_class in zip(all_aids, ia_classes) if '+' in ia_class] + part_gids = list(set(ibs.get_annot_gids(part_aids))) + all_pairs = all_part_pairs(ibs, part_gids) + all_feats = ibs.depc_annot.get(depc_table_name, all_pairs) + names = [ibs.get_annot_names(all_pairs[0]), ibs.get_annot_names(all_pairs[1])] + ground_truth = [n1 == n2 for (n1, n2) in zip(names[0], names[1])] + + # train_feats, test_feats = train_test_split(all_feats) + # train_truth, test_truth = train_test_split(ground_truth) + pairs_in_train = ibs.gid_train_test_split(all_pairs[0]) # we could pass just the pair aids or just the body aids bc gids are the same + train_feats, test_feats = split_list(all_feats, pairs_in_train) + train_truth, test_truth = split_list(ground_truth, pairs_in_train) + + all_pairs_tuple = [(part, body) for part, body in zip(all_pairs[0], all_pairs[1])] + train_pairs, test_pairs = split_list(all_pairs_tuple, pairs_in_train) + + if balance_t_f: + train_balance_flags = balance_true_false_training_pairs(train_truth) + train_truth = ut.compress(train_truth, train_balance_flags) + train_feats = ut.compress(train_feats, train_balance_flags) + train_pairs = ut.compress(train_pairs, train_balance_flags) + + test_balance_flags = balance_true_false_training_pairs(test_truth) + test_truth = ut.compress(test_truth, test_balance_flags) + test_feats = ut.compress(test_feats, test_balance_flags) + test_pairs = ut.compress(test_pairs, test_balance_flags) + + + assigner_data = {'data': train_feats, 'target': train_truth, + 'test': test_feats, 'test_truth': test_truth, + 'train_pairs': train_pairs, 'test_pairs': test_pairs} + + return assigner_data + + +# returns flags so we can compress other lists +def balance_true_false_training_pairs(ground_truth, seed=777): + n_true = ground_truth.count(True) + # there's always more false samples than true when we're looking at all pairs + false_indices = [i for i, ground_t in enumerate(ground_truth) if not ground_t] + import random + random.seed(seed) + subsampled_false_indices = random.sample(false_indices, n_true) + # for quick membership check + subsampled_false_indices = set(subsampled_false_indices) + # keep all true flags, and the subsampled false ones + keep_flags = [gt or (i in subsampled_false_indices) for i, gt in enumerate(ground_truth)] + return keep_flags + + +def train_test_split(item_list, random_seed=777, test_size=0.1): + import random + import math + random.seed(random_seed) + sample_size = math.floor(len(item_list) * test_size) + all_indices = list(range(len(item_list))) + test_indices = random.sample(all_indices, sample_size) + test_items = [item_list[i] for i in test_indices] + train_indices = sorted(list( + set(all_indices) - set(test_indices) + )) + train_items = [item_list[i] for i in train_indices] + return train_items, test_items + + +@register_ibs_method +def gid_train_test_split(ibs, aid_list, random_seed=777, test_size=0.1): + r""" + Makes a gid-wise train-test split. This avoids potential overfitting when a network + is trained on some annots from one image and tested on others from the same image. + + Args: + ibs (IBEISController): IBEIS / WBIA controller object + aid_list (int): annot ids to split + random_seed: to make this split reproducible + test_size: portion of gids reserved for test data + + Yields: + a boolean flag_list of which aids are in the training set. Returning the flag_list + allows the user to filter multiple lists with one gid_train_test_split call + + + CommandLine: + python -m wbia.algo.detect.train_assigner gid_train_test_split + + Example: + >>> # ENABLE_DOCTEST + >>> import utool as ut + >>> from wbia.algo.detect.assigner import * + >>> from wbia.algo.detect.train_assigner import * + >>> ibs = assigner_testdb_ibs() + >>> aids = ibs.get_valid_aids() + >>> all_gids = set(ibs.get_annot_gids(aids)) + >>> test_size = 0.34 # we want floor(test_size*3) to equal 1 + >>> aid_in_train = gid_train_test_split(ibs, aids, test_size=test_size) + >>> train_aids = ut.compress(aids, aid_in_train) + >>> aid_in_test = [not train for train in aid_in_train] + >>> test_aids = ut.compress(aids, aid_in_test) + >>> train_gids = set(ibs.get_annot_gids(train_aids)) + >>> test_gids = set(ibs.get_annot_gids(test_aids)) + >>> assert len(train_gids & test_gids) is 0 + >>> assert len(train_gids) + len(test_gids) == len(all_gids) + >>> assert len(train_gids) is 2 + >>> assert len(test_gids) is 1 + >>> result = aid_in_train # note one gid has 4 aids, the other two 2 + [False, False, False, False, True, True, True, True] + """ + print('calling gid_train_test_split') + gid_list = ibs.get_annot_gids(aid_list) + gid_set = list(set(gid_list)) + import math + random.seed(random_seed) + n_test_gids = math.floor(len(gid_set) * test_size) + test_gids = set(random.sample(gid_set, n_test_gids)) + aid_in_train = [gid not in test_gids for gid in gid_list] + return aid_in_train + + +def split_list(item_list, is_in_first_group_list): + first_group = ut.compress(item_list, is_in_first_group_list) + is_in_second_group = [not b for b in is_in_first_group_list] + second_group = ut.compress(item_list, is_in_second_group) + return first_group, second_group + + +def check_accuracy(ibs, assigner_data=None, cutoff_score=0.5, illustrate=False): + + if assigner_data is None: + assigner_data = ibs.wd_training_data() + + all_aids = [] + for pair in assigner_data['test_pairs']: + all_aids.extend(list(pair)) + all_aids = sorted(list(set(all_aids))) + + all_pairs, all_unassigned_aids = ibs.assign_parts(all_aids, cutoff_score) + + gid_to_assigner_results = gid_keyed_assigner_results(ibs, all_pairs, all_unassigned_aids) + gid_to_ground_truth = gid_keyed_ground_truth(ibs, assigner_data) + + if illustrate: + illustrate_all_assignments(ibs, gid_to_assigner_results, gid_to_ground_truth) + + correct_gids = [] + incorrect_gids = [] + gids_with_false_positives = 0 + n_false_positives = 0 + gids_with_false_negatives = 0 + gids_with_false_neg_allowing_errors = [0, 0, 0] + max_allowed_errors = len(gids_with_false_neg_allowing_errors) + n_false_negatives = 0 + gids_with_both_errors = 0 + for gid in gid_to_assigner_results.keys(): + assigned_pairs = set(gid_to_assigner_results[gid]['pairs']) + ground_t_pairs = set(gid_to_ground_truth[gid]['pairs']) + false_negatives = len(ground_t_pairs - assigned_pairs) + false_positives = len(assigned_pairs - ground_t_pairs) + n_false_negatives += false_negatives + + if false_negatives > 0: + gids_with_false_negatives += 1 + if false_negatives >= 2: + false_neg_log_index = min(false_negatives - 2, max_allowed_errors - 1) # ie, if we have 2 errors, we have a false neg even allowing 1 error, in index 0 of that list + try: + gids_with_false_neg_allowing_errors[false_neg_log_index] += 1 + except: + ut.embed() + + n_false_positives += false_positives + if false_positives > 0: + gids_with_false_positives += 1 + if false_negatives > 0 and false_positives > 0: + gids_with_both_errors += 1 + + pairs_equal = sorted(gid_to_assigner_results[gid]['pairs']) == sorted(gid_to_ground_truth[gid]['pairs']) + if pairs_equal: + correct_gids += [gid] + else: + incorrect_gids += [gid] + + n_gids = len(gid_to_assigner_results.keys()) + accuracy = len(correct_gids) / n_gids + incorrect_gids = n_gids - len(correct_gids) + acc_allowing_errors = [1 - (nerrors / n_gids) + for nerrors in gids_with_false_neg_allowing_errors] + print('accuracy with cutoff of %s: %s' % (cutoff_score, accuracy)) + for i, acc_allowing_error in enumerate(acc_allowing_errors): + print(' allowing %s errors, acc = %s' % (i + 1, acc_allowing_error)) + print(' %s false positives on %s error images' % (n_false_positives, gids_with_false_positives)) + print(' %s false negatives on %s error images' % (n_false_negatives, gids_with_false_negatives)) + print(' %s images with both errors' % (gids_with_both_errors)) + return accuracy + + +if __name__ == '__main__': + r""" + CommandLine: + python -m wbia.algo.detect.train_assigner --allexamples + """ + import multiprocessing + + multiprocessing.freeze_support() # for win32 + import utool as ut # NOQA + + ut.doctest_funcs() + From d6e51b939ff8012cc8f3e82307ebdf70dd4c382b Mon Sep 17 00:00:00 2001 From: Drew Blount Date: Mon, 4 Jan 2021 17:56:38 -0800 Subject: [PATCH 145/294] adds model URLs and reconfigures config constants --- wbia/algo/detect/assigner.py | 49 ++++++++++++++++++++++++++---------- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/wbia/algo/detect/assigner.py b/wbia/algo/detect/assigner.py index 75b1dc8c97..e0bf496e92 100644 --- a/wbia/algo/detect/assigner.py +++ b/wbia/algo/detect/assigner.py @@ -32,17 +32,39 @@ PARALLEL = not const.CONTAINERIZED INPUT_SIZE = 224 -ARCHIVE_URL_DICT = {} - INMEM_ASSIGNER_MODELS = {} -SPECIES_TO_ASSIGNER_MODELFILE = { - 'wild_dog': '/tmp/balanced_wd.joblib', - 'wild_dog_dark': '/tmp/balanced_wd.joblib', - 'wild_dog_light': '/tmp/balanced_wd.joblib', - 'wild_dog_puppy': '/tmp/balanced_wd.joblib', - 'wild_dog_standard': '/tmp/balanced_wd.joblib', - 'wild_dog_tan': '/tmp/balanced_wd.joblib', +SPECIES_CONFIG_MAP = { + 'wild_dog': { + 'model_file': '/tmp/balanced_wd.joblib', + 'model_url': 'https://wildbookiarepository.azureedge.net/models/assigner.wd_v0.joblib', + 'annot_feature_col': 'assigner_viewpoint_features' + } + 'wild_dog_dark': { + 'model_file': '/tmp/balanced_wd.joblib', + 'model_url': 'https://wildbookiarepository.azureedge.net/models/assigner.wd_v0.joblib', + 'annot_feature_col': 'assigner_viewpoint_features' + } + 'wild_dog_light': { + 'model_file': '/tmp/balanced_wd.joblib', + 'model_url': 'https://wildbookiarepository.azureedge.net/models/assigner.wd_v0.joblib', + 'annot_feature_col': 'assigner_viewpoint_features' + } + 'wild_dog_puppy': { + 'model_file': '/tmp/balanced_wd.joblib', + 'model_url': 'https://wildbookiarepository.azureedge.net/models/assigner.wd_v0.joblib', + 'annot_feature_col': 'assigner_viewpoint_features' + } + 'wild_dog_standard': { + 'model_file': '/tmp/balanced_wd.joblib', + 'model_url': 'https://wildbookiarepository.azureedge.net/models/assigner.wd_v0.joblib', + 'annot_feature_col': 'assigner_viewpoint_features' + } + 'wild_dog_tan': { + 'model_file': '/tmp/balanced_wd.joblib', + 'model_url': 'https://wildbookiarepository.azureedge.net/models/assigner.wd_v0.joblib', + 'annot_feature_col': 'assigner_viewpoint_features' + } } @@ -272,11 +294,12 @@ def load_assigner_classifier(ibs, aid_list, fallback_species='wild_dog'): if species in INMEM_ASSIGNER_MODELS.keys(): clf = INMEM_ASSIGNER_MODELS[species] else: - if species not in SPECIES_TO_ASSIGNER_MODELFILE.keys(): + if species not in SPECIES_CONFIG_MAP.keys(): print("WARNING: Assigner called for species %s which does not have an assigner modelfile specified. Falling back to the model for %s" % species, fallback_species) species = fallback_species - model_fpath = SPECIES_TO_ASSIGNER_MODELFILE[species] + model_url = SPECIES_CONFIG_MAP[species]['model_url'] + model_fpath = from joblib import load clf = load(model_fpath) @@ -400,9 +423,9 @@ def gid_keyed_ground_truth(ibs, assigner_data): @register_ibs_method def assigner_testdb_ibs(): - # dbdir = sysres.ensure_testdb_assigner() import wbia - dbdir = '/data/testdb_assigner' + dbdir = sysres.ensure_testdb_assigner() + # dbdir = '/data/testdb_assigner' ibs = wbia.opendb(dbdir=dbdir) return ibs From 3000e03be4bc0cadc9162906eae2e92b31617f8f Mon Sep 17 00:00:00 2001 From: Drew Blount Date: Tue, 5 Jan 2021 12:37:47 -0800 Subject: [PATCH 146/294] uses utool do download model file --- wbia/algo/detect/assigner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wbia/algo/detect/assigner.py b/wbia/algo/detect/assigner.py index e0bf496e92..26db9e5c96 100644 --- a/wbia/algo/detect/assigner.py +++ b/wbia/algo/detect/assigner.py @@ -299,7 +299,7 @@ def load_assigner_classifier(ibs, aid_list, fallback_species='wild_dog'): species = fallback_species model_url = SPECIES_CONFIG_MAP[species]['model_url'] - model_fpath = + model_fpath = ut.grab_file_url(model_url) from joblib import load clf = load(model_fpath) From d3410d5ef7841920390b42be833231cafc5c54d8 Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Tue, 5 Jan 2021 12:43:32 -0800 Subject: [PATCH 147/294] Small fixes --- wbia/control/manual_part_funcs.py | 76 ++++++++++++++++++------------- wbia/init/sysres.py | 8 ++-- 2 files changed, 48 insertions(+), 36 deletions(-) diff --git a/wbia/control/manual_part_funcs.py b/wbia/control/manual_part_funcs.py index 03866f74f9..2f9635f710 100644 --- a/wbia/control/manual_part_funcs.py +++ b/wbia/control/manual_part_funcs.py @@ -29,7 +29,7 @@ PART_NOTE = 'part_note' PART_NUM_VERTS = 'part_num_verts' PART_ROWID = 'part_rowid' -PART_TAG_TEXT = 'part_tag_text' +PART_TAG_TEXT = 'part_tag_text' PART_THETA = 'part_theta' PART_VERTS = 'part_verts' PART_UUID = 'part_uuid' @@ -1091,18 +1091,23 @@ def set_part_viewpoints(ibs, part_rowid_list, viewpoint_list): @register_ibs_method @accessor_decors.setter def set_part_tag_text(ibs, part_rowid_list, part_tags_list, duplicate_behavior='error'): - r""" part_tags_list -> part.part_tags[part_rowid_list] + r"""part_tags_list -> part.part_tags[part_rowid_list] Args: part_rowid_list part_tags_list """ - #logger.info('[ibs] set_part_tag_text of part_rowid_list=%r to tags=%r' % (part_rowid_list, part_tags_list)) + # logger.info('[ibs] set_part_tag_text of part_rowid_list=%r to tags=%r' % (part_rowid_list, part_tags_list)) id_iter = part_rowid_list colnames = (PART_TAG_TEXT,) - ibs.db.set(const.PART_TABLE, colnames, part_tags_list, - id_iter, duplicate_behavior=duplicate_behavior) + ibs.db.set( + const.PART_TABLE, + colnames, + part_tags_list, + id_iter, + duplicate_behavior=duplicate_behavior, + ) @register_ibs_method @@ -1491,40 +1496,47 @@ def set_part_contour(ibs, part_rowid_list, contour_dict_list): # from_aids = ibs.get_valid_aids() -def get_corresponding_aids_slow(ibs, part_rowid_list, from_aids): - part_bboxes = ibs.get_part_bboxes(part_rowid_list) - annot_bboxes = ibs.get_annot_bboxes(from_aids) - annot_gids = ibs.get_annot_gids(from_aids) - from collections import defaultdict - bbox_gid_to_aids = defaultdict(int) - for aid, gid, bbox in zip(from_aids, annot_gids, annot_bboxes): - bbox_gid_to_aids[(bbox[0], bbox[1], bbox[2], bbox[3], gid)] = aid - part_gids = ibs.get_part_image_rowids(parts) - part_rowid_to_aid = {part_id: bbox_gid_to_aids[(bbox[0], bbox[1], bbox[2], bbox[3], gid)] for part_id, gid, bbox in zip(part_rowid_list, part_gids, part_bboxes)} +# def get_corresponding_aids_slow(ibs, part_rowid_list, from_aids): +# part_bboxes = ibs.get_part_bboxes(part_rowid_list) +# annot_bboxes = ibs.get_annot_bboxes(from_aids) +# annot_gids = ibs.get_annot_gids(from_aids) +# from collections import defaultdict + +# bbox_gid_to_aids = defaultdict(int) +# for aid, gid, bbox in zip(from_aids, annot_gids, annot_bboxes): +# bbox_gid_to_aids[(bbox[0], bbox[1], bbox[2], bbox[3], gid)] = aid +# part_gids = ibs.get_part_image_rowids(parts) +# part_rowid_to_aid = { +# part_id: bbox_gid_to_aids[(bbox[0], bbox[1], bbox[2], bbox[3], gid)] +# for part_id, gid, bbox in zip(part_rowid_list, part_gids, part_bboxes) +# } - part_aids = [part_rowid_to_aid[partid] for partid in parts] - part_parent_aids = ibs.get_part_aids(part_rowid_list) +# part_aids = [part_rowid_to_aid[partid] for partid in parts] +# part_parent_aids = ibs.get_part_aids(part_rowid_list) - # parents might be non-unique so we gotta make a unique name for each parent - parent_aid_to_part_rowids = defaultdict(list) - for part_rowid, parent_aid in zip(part_rowid_list, part_parent_aids): - parent_aid_to_part_rowids[parent_aid] += [part_rowid] +# # parents might be non-unique so we gotta make a unique name for each parent +# parent_aid_to_part_rowids = defaultdict(list) +# for part_rowid, parent_aid in zip(part_rowid_list, part_parent_aids): +# parent_aid_to_part_rowids[parent_aid] += [part_rowid] - part_annot_names = [','.join(str(p) for p in parent_aid_to_part_rowids[parent_aid]) for parent_aid in part_parent_aids] +# part_annot_names = [ +# ','.join(str(p) for p in parent_aid_to_part_rowids[parent_aid]) +# for parent_aid in part_parent_aids +# ] +# # now assign names so we can associate the part annots with the non-part annots +# new_part_names = ['part-%s' % part_rowid for part_rowid in part_rowid_list] - # now assign names so we can associate the part annots with the non-part annots - new_part_names = ['part-%s' % part_rowid for part_rowid in part_rowid_list] +# def sort_parts_by_tags(ibs, part_rowid_list): +# tags = ibs.get_part_tag_text(part_rowid_list) +# from collections import defaultdict -def sort_parts_by_tags(ibs, part_rowid_list): - tags = ibs.get_part_tag_text(part_rowid_list) - from collections import defaultdict - tag_to_rowids = new defaultdict(list) - for tag, part_rowid in zip(tags, part_rowid_list): - tag_to_rowids[tag] += [part_rowid] - parts_by_tags = [tag_to_rowdids[tag] for tag in tag_to_rowdids.keys()] - return parts_by_tags +# tag_to_rowids = defaultdict(list) +# for tag, part_rowid in zip(tags, part_rowid_list): +# tag_to_rowids[tag] += [part_rowid] +# parts_by_tags = [tag_to_rowdids[tag] for tag in tag_to_rowdids.keys()] +# return parts_by_tags # ========== diff --git a/wbia/init/sysres.py b/wbia/init/sysres.py index 69fd0f2069..98afdfc811 100644 --- a/wbia/init/sysres.py +++ b/wbia/init/sysres.py @@ -861,14 +861,14 @@ def ensure_testdb_orientation(): return ensure_db_from_url(const.ZIPPED_URLS.ORIENTATION) -def ensure_testdb_identification_example(): - return ensure_db_from_url(const.ZIPPED_URLS.ID_EXAMPLE) - - def ensure_testdb_assigner(): return ensure_db_from_url(const.ZIPPED_URLS.ASSIGNER) +def ensure_testdb_identification_example(): + return ensure_db_from_url(const.ZIPPED_URLS.ID_EXAMPLE) + + def ensure_testdb_kaggle7(): return ensure_db_from_url(const.ZIPPED_URLS.K7_EXAMPLE) From 9a089225cee3d0e5a3fc10569a4a7ce772bd644e Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Tue, 5 Jan 2021 12:53:24 -0800 Subject: [PATCH 148/294] Linted and fixed errors --- wbia/algo/detect/assigner.py | 180 ++-- wbia/algo/detect/train_assigner.py | 236 +++-- wbia/constants.py | 4 +- wbia/core_annots.py | 1285 ++++++++++++++++++++-------- wbia/core_parts.py | 8 +- 5 files changed, 1202 insertions(+), 511 deletions(-) diff --git a/wbia/algo/detect/assigner.py b/wbia/algo/detect/assigner.py index 26db9e5c96..6b4aac7d9d 100644 --- a/wbia/algo/detect/assigner.py +++ b/wbia/algo/detect/assigner.py @@ -1,14 +1,24 @@ +# -*- coding: utf-8 -*- import logging -from os.path import expanduser, join + +# from os.path import expanduser, join from wbia import constants as const -from wbia.control.controller_inject import register_preprocs, register_subprops, make_ibs_register_decorator +from wbia.control.controller_inject import ( + # register_preprocs, + # register_subprops, + make_ibs_register_decorator, +) import utool as ut -import numpy as np -import random + +# import numpy as np +# import random import os -from collections import OrderedDict, defaultdict -from datetime import datetime -import time + +# from collections import OrderedDict +from collections import defaultdict + +# from datetime import datetime +# import time # illustration imports from shutil import copy @@ -16,11 +26,11 @@ import wbia.plottool as pt -from sklearn import preprocessing -from tune_sklearn import TuneGridSearchCV +# from sklearn import preprocessing +# from tune_sklearn import TuneGridSearchCV # shitload of scikit classifiers -import numpy as np +# import numpy as np # bunch of classifier models for training @@ -38,33 +48,33 @@ 'wild_dog': { 'model_file': '/tmp/balanced_wd.joblib', 'model_url': 'https://wildbookiarepository.azureedge.net/models/assigner.wd_v0.joblib', - 'annot_feature_col': 'assigner_viewpoint_features' - } + 'annot_feature_col': 'assigner_viewpoint_features', + }, 'wild_dog_dark': { 'model_file': '/tmp/balanced_wd.joblib', 'model_url': 'https://wildbookiarepository.azureedge.net/models/assigner.wd_v0.joblib', - 'annot_feature_col': 'assigner_viewpoint_features' - } + 'annot_feature_col': 'assigner_viewpoint_features', + }, 'wild_dog_light': { 'model_file': '/tmp/balanced_wd.joblib', 'model_url': 'https://wildbookiarepository.azureedge.net/models/assigner.wd_v0.joblib', - 'annot_feature_col': 'assigner_viewpoint_features' - } + 'annot_feature_col': 'assigner_viewpoint_features', + }, 'wild_dog_puppy': { 'model_file': '/tmp/balanced_wd.joblib', 'model_url': 'https://wildbookiarepository.azureedge.net/models/assigner.wd_v0.joblib', - 'annot_feature_col': 'assigner_viewpoint_features' - } + 'annot_feature_col': 'assigner_viewpoint_features', + }, 'wild_dog_standard': { 'model_file': '/tmp/balanced_wd.joblib', 'model_url': 'https://wildbookiarepository.azureedge.net/models/assigner.wd_v0.joblib', - 'annot_feature_col': 'assigner_viewpoint_features' - } + 'annot_feature_col': 'assigner_viewpoint_features', + }, 'wild_dog_tan': { 'model_file': '/tmp/balanced_wd.joblib', 'model_url': 'https://wildbookiarepository.azureedge.net/models/assigner.wd_v0.joblib', - 'annot_feature_col': 'assigner_viewpoint_features' - } + 'annot_feature_col': 'assigner_viewpoint_features', + }, } @@ -127,13 +137,28 @@ def all_part_pairs(ibs, gid_list): """ all_aids = ibs.get_image_aids(gid_list) all_aids_are_parts = [ibs._are_part_annots(aids) for aids in all_aids] - all_part_aids = [[aid for (aid, part) in zip(aids, are_parts) if part] for (aids, are_parts) in zip(all_aids, all_aids_are_parts)] - all_body_aids = [[aid for (aid, part) in zip(aids, are_parts) if not part] for (aids, are_parts) in zip(all_aids, all_aids_are_parts)] - part_body_parallel_lists = [_all_pairs_parallel(parts, bodies) for parts, bodies in zip(all_part_aids, all_body_aids)] - all_parts = [aid for part_body_parallel_list in part_body_parallel_lists - for aid in part_body_parallel_list[0]] - all_bodies = [aid for part_body_parallel_list in part_body_parallel_lists - for aid in part_body_parallel_list[1]] + all_part_aids = [ + [aid for (aid, part) in zip(aids, are_parts) if part] + for (aids, are_parts) in zip(all_aids, all_aids_are_parts) + ] + all_body_aids = [ + [aid for (aid, part) in zip(aids, are_parts) if not part] + for (aids, are_parts) in zip(all_aids, all_aids_are_parts) + ] + part_body_parallel_lists = [ + _all_pairs_parallel(parts, bodies) + for parts, bodies in zip(all_part_aids, all_body_aids) + ] + all_parts = [ + aid + for part_body_parallel_list in part_body_parallel_lists + for aid in part_body_parallel_list[0] + ] + all_bodies = [ + aid + for part_body_parallel_list in part_body_parallel_lists + for aid in part_body_parallel_list[1] + ] return all_parts, all_bodies @@ -188,14 +213,15 @@ def assign_parts(ibs, all_aids, cutoff_score=0.5): all_unassigned_aids = [] for gid in gid_to_aids.keys(): - this_pairs, this_unassigned = _assign_parts_one_image(ibs, gid_to_aids[gid], cutoff_score) - all_assignments += (this_pairs) + this_pairs, this_unassigned = _assign_parts_one_image( + ibs, gid_to_aids[gid], cutoff_score + ) + all_assignments += this_pairs all_unassigned_aids += this_unassigned return all_assignments, all_unassigned_aids - @register_ibs_method def _assign_parts_one_image(ibs, aid_list, cutoff_score=0.5): r""" @@ -231,54 +257,66 @@ def _assign_parts_one_image(ibs, aid_list, cutoff_score=0.5): >>> assert (set(assigned_aids) | set(unassigned_aids) == set(aids)) >>> ([(3, 1)], [2, 4]) """ - are_part_aids = _are_part_annots(ibs, aid_list) part_aids = ut.compress(aid_list, are_part_aids) body_aids = ut.compress(aid_list, [not p for p in are_part_aids]) gids = ibs.get_annot_gids(list(set(part_aids)) + list(set(body_aids))) num_images = len(set(gids)) - assert num_images is 1, "_assign_parts_one_image called on multiple images' aids" + assert num_images == 1, "_assign_parts_one_image called on multiple images' aids" # parallel lists representing all possible part/body pairs all_pairs_parallel = _all_pairs_parallel(part_aids, body_aids) pair_parts, pair_bodies = all_pairs_parallel - assigner_features = ibs.depc_annot.get('assigner_viewpoint_features', all_pairs_parallel) + assigner_features = ibs.depc_annot.get( + 'assigner_viewpoint_features', all_pairs_parallel + ) assigner_classifier = load_assigner_classifier(ibs, part_aids) assigner_scores = assigner_classifier.predict_proba(assigner_features) # assigner_scores is a list of [P_false, P_true] probabilities which sum to 1, so here we just pare down to the true probabilities assigner_scores = [score[1] for score in assigner_scores] - good_pairs, unassigned_aids = _make_assignments(pair_parts, pair_bodies, assigner_scores, cutoff_score) + good_pairs, unassigned_aids = _make_assignments( + pair_parts, pair_bodies, assigner_scores, cutoff_score + ) return good_pairs, unassigned_aids def _make_assignments(pair_parts, pair_bodies, assigner_scores, cutoff_score=0.5): - sorted_scored_pairs = [(part, body, score) for part, body, score in - sorted(zip(pair_parts, pair_bodies, assigner_scores), - key=lambda pbscore: pbscore[2], reverse=True)] + sorted_scored_pairs = [ + (part, body, score) + for part, body, score in sorted( + zip(pair_parts, pair_bodies, assigner_scores), + key=lambda pbscore: pbscore[2], + reverse=True, + ) + ] assigned_pairs = [] assigned_parts = set() assigned_bodies = set() n_bodies = len(set(pair_bodies)) - n_parts = len(set(pair_parts)) + n_parts = len(set(pair_parts)) n_true_pairs = min(n_bodies, n_parts) for part_aid, body_aid, score in sorted_scored_pairs: - assign_this_pair = part_aid not in assigned_parts and \ - body_aid not in assigned_bodies and \ - score >= cutoff_score + assign_this_pair = ( + part_aid not in assigned_parts + and body_aid not in assigned_bodies + and score >= cutoff_score + ) if assign_this_pair: assigned_pairs.append((part_aid, body_aid)) assigned_parts.add(part_aid) assigned_bodies.add(body_aid) - if len(assigned_parts) is n_true_pairs \ - or len(assigned_bodies) is n_true_pairs \ - or score > cutoff_score: + if ( + len(assigned_parts) is n_true_pairs + or len(assigned_bodies) is n_true_pairs + or score > cutoff_score + ): break unassigned_parts = set(pair_parts) - set(assigned_parts) @@ -295,36 +333,58 @@ def load_assigner_classifier(ibs, aid_list, fallback_species='wild_dog'): clf = INMEM_ASSIGNER_MODELS[species] else: if species not in SPECIES_CONFIG_MAP.keys(): - print("WARNING: Assigner called for species %s which does not have an assigner modelfile specified. Falling back to the model for %s" % species, fallback_species) + print( + 'WARNING: Assigner called for species %s which does not have an assigner modelfile specified. Falling back to the model for %s' + % species, + fallback_species, + ) species = fallback_species model_url = SPECIES_CONFIG_MAP[species]['model_url'] model_fpath = ut.grab_file_url(model_url) from joblib import load + clf = load(model_fpath) return clf -def illustrate_all_assignments(ibs, gid_to_assigner_results, gid_to_ground_truth, - target_dir='/tmp/assigner-illustrations-2/', limit=20): +def illustrate_all_assignments( + ibs, + gid_to_assigner_results, + gid_to_ground_truth, + target_dir='/tmp/assigner-illustrations-2/', + limit=20, +): correct_dir = os.path.join(target_dir, 'correct/') incorrect_dir = os.path.join(target_dir, 'incorrect/') for gid, assigned_aid_dict in gid_to_assigner_results.items()[:limit]: ground_t_dict = gid_to_ground_truth[gid] - assigned_correctly = sorted(assigned_aid_dict['pairs']) == sorted(ground_t_dict['pairs']) + assigned_correctly = sorted(assigned_aid_dict['pairs']) == sorted( + ground_t_dict['pairs'] + ) if assigned_correctly: - illustrate_assignments(ibs, gid, assigned_aid_dict, None, correct_dir) # don't need to illustrate gtruth if it's identical to assignment + illustrate_assignments( + ibs, gid, assigned_aid_dict, None, correct_dir + ) # don't need to illustrate gtruth if it's identical to assignment else: - illustrate_assignments(ibs, gid, assigned_aid_dict, ground_t_dict, incorrect_dir) + illustrate_assignments( + ibs, gid, assigned_aid_dict, ground_t_dict, incorrect_dir + ) print('illustrated assignments and saved them in %s' % target_dir) # works on a single gid's worth of gid_keyed_assigner_results output -def illustrate_assignments(ibs, gid, assigned_aid_dict, gtruth_aid_dict, target_dir='/tmp/assigner-illustrations/'): +def illustrate_assignments( + ibs, + gid, + assigned_aid_dict, + gtruth_aid_dict, + target_dir='/tmp/assigner-illustrations/', +): impath = ibs.get_image_paths(gid) imext = os.path.splitext(impath)[1] new_fname = os.path.join(target_dir, '%s%s' % (gid, imext)) @@ -338,7 +398,7 @@ def illustrate_assignments(ibs, gid, assigned_aid_dict, gtruth_aid_dict, target_ def _draw_all_annots(ibs, image, assigned_aid_dict, gtruth_aid_dict): n_pairs = len(assigned_aid_dict['pairs']) - n_missing_pairs = 0 + # n_missing_pairs = 0 # TODO: missing pair shit n_unass = len(assigned_aid_dict['unassigned']) n_groups = n_pairs + n_unass @@ -380,10 +440,10 @@ def gid_keyed_assigner_results(ibs, all_pairs, all_unassigned_aids): gid_to_unassigned[gid] += [aid] gid_to_assigner_results = {} - for gid in (set(gid_to_pairs.keys()) | set(gid_to_unassigned.keys())): + for gid in set(gid_to_pairs.keys()) | set(gid_to_unassigned.keys()): gid_to_assigner_results[gid] = { 'pairs': gid_to_pairs[gid], - 'unassigned': gid_to_unassigned[gid] + 'unassigned': gid_to_unassigned[gid], } return gid_to_assigner_results @@ -410,12 +470,11 @@ def gid_keyed_ground_truth(ibs, assigner_data): for gid in gid_to_all_aids.keys(): gid_to_unassigned_aids[gid] = list(gid_to_all_aids[gid] - gid_to_paired_aids[gid]) - gid_to_assigner_results = {} - for gid in (set(gid_to_pairs.keys()) | set(gid_to_unassigned_aids.keys())): + for gid in set(gid_to_pairs.keys()) | set(gid_to_unassigned_aids.keys()): gid_to_assigner_results[gid] = { 'pairs': gid_to_pairs[gid], - 'unassigned': gid_to_unassigned_aids[gid] + 'unassigned': gid_to_unassigned_aids[gid], } return gid_to_assigner_results @@ -424,6 +483,8 @@ def gid_keyed_ground_truth(ibs, assigner_data): @register_ibs_method def assigner_testdb_ibs(): import wbia + import sysres + dbdir = sysres.ensure_testdb_assigner() # dbdir = '/data/testdb_assigner' ibs = wbia.opendb(dbdir=dbdir) @@ -441,4 +502,3 @@ def assigner_testdb_ibs(): import utool as ut # NOQA ut.doctest_funcs() - diff --git a/wbia/algo/detect/train_assigner.py b/wbia/algo/detect/train_assigner.py index a2aa62d0f5..a907208a0b 100644 --- a/wbia/algo/detect/train_assigner.py +++ b/wbia/algo/detect/train_assigner.py @@ -1,29 +1,41 @@ -import logging -from os.path import expanduser, join -from wbia import constants as const -from wbia.control.controller_inject import register_preprocs, register_subprops, make_ibs_register_decorator - -from wbia.algo.detect.assigner import gid_keyed_assigner_results, gid_keyed_ground_truth, illustrate_all_assignments +# -*- coding: utf-8 -*- +# import logging +# from os.path import expanduser, join +# from wbia import constants as const +from wbia.control.controller_inject import ( + # register_preprocs, + # register_subprops, + make_ibs_register_decorator, +) + +from wbia.algo.detect.assigner import ( + gid_keyed_assigner_results, + gid_keyed_ground_truth, + illustrate_all_assignments, + all_part_pairs, +) import utool as ut import numpy as np import random -import os -from collections import OrderedDict, defaultdict + +# import os +from collections import OrderedDict + +# from collections import defaultdict from datetime import datetime import time # illustration imports -from shutil import copy -from PIL import Image, ImageDraw -import wbia.plottool as pt +# from shutil import copy +# from PIL import Image, ImageDraw +# import wbia.plottool as pt -import matplotlib.pyplot as plt -from matplotlib.colors import ListedColormap -from sklearn.model_selection import train_test_split -from sklearn.preprocessing import StandardScaler -from sklearn.datasets import make_moons, make_circles, make_classification +# import matplotlib.pyplot as plt +# from matplotlib.colors import ListedColormap +# from sklearn.model_selection import train_test_split +# from sklearn.preprocessing import StandardScaler from sklearn.neural_network import MLPClassifier from sklearn.neighbors import KNeighborsClassifier from sklearn.svm import SVC @@ -41,38 +53,38 @@ CLASSIFIER_OPTIONS = [ { - "name": "Nearest Neighbors", - "clf": KNeighborsClassifier(3), - "param_options": { + 'name': 'Nearest Neighbors', + 'clf': KNeighborsClassifier(3), + 'param_options': { 'n_neighbors': [3, 5, 11, 19], 'weights': ['uniform', 'distance'], 'metric': ['euclidean', 'manhattan'], - } + }, }, { - "name": "Linear SVM", - "clf": SVC(kernel="linear", C=0.025), - "param_options": { + 'name': 'Linear SVM', + 'clf': SVC(kernel='linear', C=0.025), + 'param_options': { 'C': [1, 10, 100, 1000], 'kernel': ['linear'], - } + }, }, { - "name": "RBF SVM", - "clf": SVC(gamma=2, C=1), - "param_options": { + 'name': 'RBF SVM', + 'clf': SVC(gamma=2, C=1), + 'param_options': { 'C': [1, 10, 100, 1000], 'gamma': [0.001, 0.0001], - 'kernel': ['rbf'] + 'kernel': ['rbf'], }, }, { - "name": "Decision Tree", - "clf": DecisionTreeClassifier(), # max_depth=5 - "param_options": { + 'name': 'Decision Tree', + 'clf': DecisionTreeClassifier(), # max_depth=5 + 'param_options': { 'max_depth': np.arange(1, 12), - 'max_leaf_nodes': [2, 5, 10, 20, 50, 100] - } + 'max_leaf_nodes': [2, 5, 10, 20, 50, 100], + }, }, # { # "name": "Random Forest", @@ -98,46 +110,53 @@ # } # }, { - "name": "AdaBoost", - "clf": AdaBoostClassifier(), - "param_options": { + 'name': 'AdaBoost', + 'clf': AdaBoostClassifier(), + 'param_options': { 'n_estimators': np.arange(10, 310, 50), 'learning_rate': [0.01, 0.05, 0.1, 1], - } + }, }, { - "name": "Naive Bayes", - "clf": GaussianNB(), - "param_options": {} # no hyperparams to optimize + 'name': 'Naive Bayes', + 'clf': GaussianNB(), + 'param_options': {}, # no hyperparams to optimize }, { - "name": "QDA", - "clf": QuadraticDiscriminantAnalysis(), - "param_options": { - 'reg_param': [0.1, 0.2, 0.3, 0.4, 0.5] - } - } + 'name': 'QDA', + 'clf': QuadraticDiscriminantAnalysis(), + 'param_options': {'reg_param': [0.1, 0.2, 0.3, 0.4, 0.5]}, + }, ] -classifier_names = ["Nearest Neighbors", "Linear SVM", "RBF SVM", - "Decision Tree", "Random Forest", "Neural Net", "AdaBoost", - "Naive Bayes", "QDA"] +classifier_names = [ + 'Nearest Neighbors', + 'Linear SVM', + 'RBF SVM', + 'Decision Tree', + 'Random Forest', + 'Neural Net', + 'AdaBoost', + 'Naive Bayes', + 'QDA', +] classifiers = [ KNeighborsClassifier(3), - SVC(kernel="linear", C=0.025), + SVC(kernel='linear', C=0.025), SVC(gamma=2, C=1), DecisionTreeClassifier(max_depth=5), RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1), MLPClassifier(alpha=1, max_iter=1000), AdaBoostClassifier(), GaussianNB(), - QuadraticDiscriminantAnalysis()] + QuadraticDiscriminantAnalysis(), +] -slow_classifier_names = "Gaussian Process" -slow_classifiers = GaussianProcessClassifier(1.0 * RBF(1.0)), +slow_classifier_names = 'Gaussian Process' +slow_classifiers = (GaussianProcessClassifier(1.0 * RBF(1.0)),) def classifier_report(clf, name, assigner_data): @@ -155,7 +174,9 @@ def classifier_report(clf, name, assigner_data): @register_ibs_method -def compare_ass_classifiers(ibs, depc_table_name='assigner_viewpoint_features', print_accs=False): +def compare_ass_classifiers( + ibs, depc_table_name='assigner_viewpoint_features', print_accs=False +): assigner_data = ibs.wd_training_data(depc_table_name) @@ -182,20 +203,24 @@ def tune_ass_classifiers(ibs, depc_table_name='assigner_viewpoint_unit_features' best_clf_name = '' best_clf_params = {} for classifier in CLASSIFIER_OPTIONS: - print("Tuning %s" % classifier['name']) - accuracy, best_params = ibs._tune_grid_search(classifier['clf'], classifier['param_options'], assigner_data) + print('Tuning %s' % classifier['name']) + accuracy, best_params = ibs._tune_grid_search( + classifier['clf'], classifier['param_options'], assigner_data + ) print() accuracies[classifier['name']] = { 'accuracy': accuracy, - 'best_params': best_params + 'best_params': best_params, } if accuracy > best_acc: best_acc = accuracy best_clf_name = classifier['name'] best_clf_params = best_params - print('best performance: %s using %s with params %s' % - (best_acc, best_clf_name, best_clf_params)) + print( + 'best performance: %s using %s with params %s' + % (best_acc, best_clf_name, best_clf_params) + ) return accuracies @@ -218,11 +243,11 @@ def _tune_grid_search(ibs, clf, parameters, assigner_data=None): start = time.time() tune_search.fit(X_train, y_train) end = time.time() - print("Tune Fit Time: %s" % (end - start)) + print('Tune Fit Time: %s' % (end - start)) pred = tune_search.predict(X_test) accuracy = np.count_nonzero(np.array(pred) == np.array(y_test)) / len(pred) - print("Tune Accuracy: %s" % accuracy) - print("best parms : %s" % tune_search.best_params_) + print('Tune Accuracy: %s' % accuracy) + print('best parms : %s' % tune_search.best_params_) return accuracy, tune_search.best_params_ @@ -245,11 +270,11 @@ def _tune_random_search(ibs, clf, parameters, assigner_data=None): start = time.time() tune_search.fit(X_train, y_train) end = time.time() - print("Tune Fit Time: %s" % (end - start)) + print('Tune Fit Time: %s' % (end - start)) pred = tune_search.predict(X_test) accuracy = np.count_nonzero(np.array(pred) == np.array(y_test)) / len(pred) - print("Tune Accuracy: %s" % accuracy) - print("best parms : %s" % tune_search.best_params_) + print('Tune Accuracy: %s' % accuracy) + print('best parms : %s' % tune_search.best_params_) return accuracy, tune_search.best_params_ @@ -266,7 +291,9 @@ def wd_normed_assigner_data(ibs): @register_ibs_method -def wd_training_data(ibs, depc_table_name='assigner_viewpoint_features', balance_t_f=True): +def wd_training_data( + ibs, depc_table_name='assigner_viewpoint_features', balance_t_f=True +): all_aids = ibs.get_valid_aids() ia_classes = ibs.get_annot_species(all_aids) part_aids = [aid for aid, ia_class in zip(all_aids, ia_classes) if '+' in ia_class] @@ -278,7 +305,9 @@ def wd_training_data(ibs, depc_table_name='assigner_viewpoint_features', balance # train_feats, test_feats = train_test_split(all_feats) # train_truth, test_truth = train_test_split(ground_truth) - pairs_in_train = ibs.gid_train_test_split(all_pairs[0]) # we could pass just the pair aids or just the body aids bc gids are the same + pairs_in_train = ibs.gid_train_test_split( + all_pairs[0] + ) # we could pass just the pair aids or just the body aids bc gids are the same train_feats, test_feats = split_list(all_feats, pairs_in_train) train_truth, test_truth = split_list(ground_truth, pairs_in_train) @@ -296,10 +325,14 @@ def wd_training_data(ibs, depc_table_name='assigner_viewpoint_features', balance test_feats = ut.compress(test_feats, test_balance_flags) test_pairs = ut.compress(test_pairs, test_balance_flags) - - assigner_data = {'data': train_feats, 'target': train_truth, - 'test': test_feats, 'test_truth': test_truth, - 'train_pairs': train_pairs, 'test_pairs': test_pairs} + assigner_data = { + 'data': train_feats, + 'target': train_truth, + 'test': test_feats, + 'test_truth': test_truth, + 'train_pairs': train_pairs, + 'test_pairs': test_pairs, + } return assigner_data @@ -310,28 +343,30 @@ def balance_true_false_training_pairs(ground_truth, seed=777): # there's always more false samples than true when we're looking at all pairs false_indices = [i for i, ground_t in enumerate(ground_truth) if not ground_t] import random + random.seed(seed) subsampled_false_indices = random.sample(false_indices, n_true) # for quick membership check subsampled_false_indices = set(subsampled_false_indices) # keep all true flags, and the subsampled false ones - keep_flags = [gt or (i in subsampled_false_indices) for i, gt in enumerate(ground_truth)] + keep_flags = [ + gt or (i in subsampled_false_indices) for i, gt in enumerate(ground_truth) + ] return keep_flags -def train_test_split(item_list, random_seed=777, test_size=0.1): - import random - import math - random.seed(random_seed) - sample_size = math.floor(len(item_list) * test_size) - all_indices = list(range(len(item_list))) - test_indices = random.sample(all_indices, sample_size) - test_items = [item_list[i] for i in test_indices] - train_indices = sorted(list( - set(all_indices) - set(test_indices) - )) - train_items = [item_list[i] for i in train_indices] - return train_items, test_items +# def train_test_split(item_list, random_seed=777, test_size=0.1): +# import random +# import math + +# random.seed(random_seed) +# sample_size = math.floor(len(item_list) * test_size) +# all_indices = list(range(len(item_list))) +# test_indices = random.sample(all_indices, sample_size) +# test_items = [item_list[i] for i in test_indices] +# train_indices = sorted(list(set(all_indices) - set(test_indices))) +# train_items = [item_list[i] for i in train_indices] +# return train_items, test_items @register_ibs_method @@ -380,6 +415,7 @@ def gid_train_test_split(ibs, aid_list, random_seed=777, test_size=0.1): gid_list = ibs.get_annot_gids(aid_list) gid_set = list(set(gid_list)) import math + random.seed(random_seed) n_test_gids = math.floor(len(gid_set) * test_size) test_gids = set(random.sample(gid_set, n_test_gids)) @@ -406,7 +442,9 @@ def check_accuracy(ibs, assigner_data=None, cutoff_score=0.5, illustrate=False): all_pairs, all_unassigned_aids = ibs.assign_parts(all_aids, cutoff_score) - gid_to_assigner_results = gid_keyed_assigner_results(ibs, all_pairs, all_unassigned_aids) + gid_to_assigner_results = gid_keyed_assigner_results( + ibs, all_pairs, all_unassigned_aids + ) gid_to_ground_truth = gid_keyed_ground_truth(ibs, assigner_data) if illustrate: @@ -431,10 +469,12 @@ def check_accuracy(ibs, assigner_data=None, cutoff_score=0.5, illustrate=False): if false_negatives > 0: gids_with_false_negatives += 1 if false_negatives >= 2: - false_neg_log_index = min(false_negatives - 2, max_allowed_errors - 1) # ie, if we have 2 errors, we have a false neg even allowing 1 error, in index 0 of that list + false_neg_log_index = min( + false_negatives - 2, max_allowed_errors - 1 + ) # ie, if we have 2 errors, we have a false neg even allowing 1 error, in index 0 of that list try: gids_with_false_neg_allowing_errors[false_neg_log_index] += 1 - except: + except Exception: ut.embed() n_false_positives += false_positives @@ -443,7 +483,9 @@ def check_accuracy(ibs, assigner_data=None, cutoff_score=0.5, illustrate=False): if false_negatives > 0 and false_positives > 0: gids_with_both_errors += 1 - pairs_equal = sorted(gid_to_assigner_results[gid]['pairs']) == sorted(gid_to_ground_truth[gid]['pairs']) + pairs_equal = sorted(gid_to_assigner_results[gid]['pairs']) == sorted( + gid_to_ground_truth[gid]['pairs'] + ) if pairs_equal: correct_gids += [gid] else: @@ -452,13 +494,20 @@ def check_accuracy(ibs, assigner_data=None, cutoff_score=0.5, illustrate=False): n_gids = len(gid_to_assigner_results.keys()) accuracy = len(correct_gids) / n_gids incorrect_gids = n_gids - len(correct_gids) - acc_allowing_errors = [1 - (nerrors / n_gids) - for nerrors in gids_with_false_neg_allowing_errors] + acc_allowing_errors = [ + 1 - (nerrors / n_gids) for nerrors in gids_with_false_neg_allowing_errors + ] print('accuracy with cutoff of %s: %s' % (cutoff_score, accuracy)) for i, acc_allowing_error in enumerate(acc_allowing_errors): print(' allowing %s errors, acc = %s' % (i + 1, acc_allowing_error)) - print(' %s false positives on %s error images' % (n_false_positives, gids_with_false_positives)) - print(' %s false negatives on %s error images' % (n_false_negatives, gids_with_false_negatives)) + print( + ' %s false positives on %s error images' + % (n_false_positives, gids_with_false_positives) + ) + print( + ' %s false negatives on %s error images' + % (n_false_negatives, gids_with_false_negatives) + ) print(' %s images with both errors' % (gids_with_both_errors)) return accuracy @@ -474,4 +523,3 @@ def check_accuracy(ibs, assigner_data=None, cutoff_score=0.5, illustrate=False): import utool as ut # NOQA ut.doctest_funcs() - diff --git a/wbia/constants.py b/wbia/constants.py index fef2d8f3a5..4fd9621504 100644 --- a/wbia/constants.py +++ b/wbia/constants.py @@ -354,9 +354,7 @@ class ZIPPED_URLS(object): # NOQA ORIENTATION = ( 'https://wildbookiarepository.azureedge.net/databases/testdb_orientation.zip' ) - ASSIGNER = ( - 'https://wildbookiarepository.azureedge.net/databases/testdb_assigner.zip' - ) + ASSIGNER = 'https://wildbookiarepository.azureedge.net/databases/testdb_assigner.zip' K7_EXAMPLE = 'https://wildbookiarepository.azureedge.net/databases/testdb_kaggle7.zip' diff --git a/wbia/core_annots.py b/wbia/core_annots.py index 65dcdf21bd..6e4ece6276 100644 --- a/wbia/core_annots.py +++ b/wbia/core_annots.py @@ -55,7 +55,11 @@ import numpy as np import cv2 import wbia.constants as const -from wbia.control.controller_inject import register_preprocs, register_subprops, make_ibs_register_decorator +from wbia.control.controller_inject import ( + register_preprocs, + register_subprops, + make_ibs_register_decorator, +) from wbia.algo.hots.chip_match import ChipMatch from wbia.algo.hots import neighbor_index @@ -2487,18 +2491,38 @@ class PartAssignmentFeatureConfig(dtool.Config): tablename='part_assignment_features', parents=['annotations', 'annotations'], colnames=[ - 'p_xtl', 'p_ytl', 'p_w', 'p_h', - 'b_xtl', 'b_ytl', 'b_w', 'b_h', - 'int_xtl', 'int_ytl', 'int_w', 'int_h', + 'p_xtl', + 'p_ytl', + 'p_w', + 'p_h', + 'b_xtl', + 'b_ytl', + 'b_w', + 'b_h', + 'int_xtl', + 'int_ytl', + 'int_w', + 'int_h', 'intersect_area_relative_part', 'intersect_area_relative_body', - 'part_area_relative_body' + 'part_area_relative_body', ], coltypes=[ - int, int, int, int, - int, int, int, int, - int, int, int, int, - float, float, float + int, + int, + int, + int, + int, + int, + int, + int, + int, + int, + int, + int, + float, + float, + float, ], configclass=PartAssignmentFeatureConfig, fname='part_assignment_features', @@ -2511,7 +2535,9 @@ def compute_assignment_features(depc, part_aid_list, body_aid_list, config=None) part_gids = ibs.get_annot_gids(part_aid_list) body_gids = ibs.get_annot_gids(body_aid_list) - assert part_gids == body_gids, 'can only compute assignment features on aids in the same image' + assert ( + part_gids == body_gids + ), 'can only compute assignment features on aids in the same image' parts_are_parts = ibs._are_part_annots(part_aid_list) assert all(parts_are_parts), 'all part_aids must be part annots.' bodies_are_parts = ibs._are_part_annots(body_aid_list) @@ -2522,50 +2548,97 @@ def compute_assignment_features(depc, part_aid_list, body_aid_list, config=None) part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] - part_area_relative_body = [part_area / body_area - for (part_area, body_area) in zip(part_areas, body_areas)] + part_area_relative_body = [ + part_area / body_area for (part_area, body_area) in zip(part_areas, body_areas) + ] intersect_bboxes = _bbox_intersections(part_bboxes, body_bboxes) # note that intesect w and h could be negative if there is no intersection, in which case it is the x/y distance between the annots. - intersect_areas = [w * h if w > 0 and h > 0 else 0 - for (_, _, w, h) in intersect_bboxes] + intersect_areas = [ + w * h if w > 0 and h > 0 else 0 for (_, _, w, h) in intersect_bboxes + ] - int_area_relative_part = [int_area / part_area for int_area, part_area - in zip(intersect_areas, part_areas)] - int_area_relative_body = [int_area / body_area for int_area, body_area - in zip(intersect_areas, body_areas)] + int_area_relative_part = [ + int_area / part_area for int_area, part_area in zip(intersect_areas, part_areas) + ] + int_area_relative_body = [ + int_area / body_area for int_area, body_area in zip(intersect_areas, body_areas) + ] - result_list = list(zip( - part_bboxes, body_bboxes, intersect_bboxes, - int_area_relative_part, int_area_relative_body, part_area_relative_body - )) + result_list = list( + zip( + part_bboxes, + body_bboxes, + intersect_bboxes, + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) + ) - for (part_bbox, body_bbox, intersect_bbox, int_area_relative_part, - int_area_relative_body, part_area_relative_body) in result_list: - yield (part_bbox[0], part_bbox[1], part_bbox[2], part_bbox[3], - body_bbox[0], body_bbox[1], body_bbox[2], body_bbox[3], - intersect_bbox[0], intersect_bbox[1], intersect_bbox[2], intersect_bbox[3], - int_area_relative_part, - int_area_relative_body, - part_area_relative_body) + for ( + part_bbox, + body_bbox, + intersect_bbox, + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) in result_list: + yield ( + part_bbox[0], + part_bbox[1], + part_bbox[2], + part_bbox[3], + body_bbox[0], + body_bbox[1], + body_bbox[2], + body_bbox[3], + intersect_bbox[0], + intersect_bbox[1], + intersect_bbox[2], + intersect_bbox[3], + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) @derived_attribute( tablename='normalized_assignment_features', parents=['annotations', 'annotations'], colnames=[ - 'p_xtl', 'p_ytl', 'p_w', 'p_h', - 'b_xtl', 'b_ytl', 'b_w', 'b_h', - 'int_xtl', 'int_ytl', 'int_w', 'int_h', + 'p_xtl', + 'p_ytl', + 'p_w', + 'p_h', + 'b_xtl', + 'b_ytl', + 'b_w', + 'b_h', + 'int_xtl', + 'int_ytl', + 'int_w', + 'int_h', 'intersect_area_relative_part', 'intersect_area_relative_body', - 'part_area_relative_body' + 'part_area_relative_body', ], coltypes=[ - float, float, float, float, - float, float, float, float, - float, float, float, float, - float, float, float + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, ], configclass=PartAssignmentFeatureConfig, fname='normalized_assignment_features', @@ -2578,7 +2651,9 @@ def normalized_assignment_features(depc, part_aid_list, body_aid_list, config=No part_gids = ibs.get_annot_gids(part_aid_list) body_gids = ibs.get_annot_gids(body_aid_list) - assert part_gids == body_gids, 'can only compute assignment features on aids in the same image' + assert ( + part_gids == body_gids + ), 'can only compute assignment features on aids in the same image' parts_are_parts = ibs._are_part_annots(part_aid_list) assert all(parts_are_parts), 'all part_aids must be part annots.' bodies_are_parts = ibs._are_part_annots(body_aid_list) @@ -2593,50 +2668,97 @@ def normalized_assignment_features(depc, part_aid_list, body_aid_list, config=No part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] - part_area_relative_body = [part_area / body_area - for (part_area, body_area) in zip(part_areas, body_areas)] + part_area_relative_body = [ + part_area / body_area for (part_area, body_area) in zip(part_areas, body_areas) + ] intersect_bboxes = _bbox_intersections(part_bboxes, body_bboxes) # note that intesect w and h could be negative if there is no intersection, in which case it is the x/y distance between the annots. - intersect_areas = [w * h if w > 0 and h > 0 else 0 - for (_, _, w, h) in intersect_bboxes] + intersect_areas = [ + w * h if w > 0 and h > 0 else 0 for (_, _, w, h) in intersect_bboxes + ] - int_area_relative_part = [int_area / part_area for int_area, part_area - in zip(intersect_areas, part_areas)] - int_area_relative_body = [int_area / body_area for int_area, body_area - in zip(intersect_areas, body_areas)] + int_area_relative_part = [ + int_area / part_area for int_area, part_area in zip(intersect_areas, part_areas) + ] + int_area_relative_body = [ + int_area / body_area for int_area, body_area in zip(intersect_areas, body_areas) + ] - result_list = list(zip( - part_bboxes, body_bboxes, intersect_bboxes, - int_area_relative_part, int_area_relative_body, part_area_relative_body - )) + result_list = list( + zip( + part_bboxes, + body_bboxes, + intersect_bboxes, + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) + ) - for (part_bbox, body_bbox, intersect_bbox, int_area_relative_part, - int_area_relative_body, part_area_relative_body) in result_list: - yield (part_bbox[0], part_bbox[1], part_bbox[2], part_bbox[3], - body_bbox[0], body_bbox[1], body_bbox[2], body_bbox[3], - intersect_bbox[0], intersect_bbox[1], intersect_bbox[2], intersect_bbox[3], - int_area_relative_part, - int_area_relative_body, - part_area_relative_body) + for ( + part_bbox, + body_bbox, + intersect_bbox, + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) in result_list: + yield ( + part_bbox[0], + part_bbox[1], + part_bbox[2], + part_bbox[3], + body_bbox[0], + body_bbox[1], + body_bbox[2], + body_bbox[3], + intersect_bbox[0], + intersect_bbox[1], + intersect_bbox[2], + intersect_bbox[3], + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) @derived_attribute( tablename='standardized_assignment_features', parents=['annotations', 'annotations'], colnames=[ - 'p_xtl', 'p_ytl', 'p_w', 'p_h', - 'b_xtl', 'b_ytl', 'b_w', 'b_h', - 'int_xtl', 'int_ytl', 'int_w', 'int_h', + 'p_xtl', + 'p_ytl', + 'p_w', + 'p_h', + 'b_xtl', + 'b_ytl', + 'b_w', + 'b_h', + 'int_xtl', + 'int_ytl', + 'int_w', + 'int_h', 'intersect_area_relative_part', 'intersect_area_relative_body', - 'part_area_relative_body' + 'part_area_relative_body', ], coltypes=[ - float, float, float, float, - float, float, float, float, - float, float, float, float, - float, float, float + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, ], configclass=PartAssignmentFeatureConfig, fname='standardized_assignment_features', @@ -2649,7 +2771,9 @@ def standardized_assignment_features(depc, part_aid_list, body_aid_list, config= part_gids = ibs.get_annot_gids(part_aid_list) body_gids = ibs.get_annot_gids(body_aid_list) - assert part_gids == body_gids, 'can only compute assignment features on aids in the same image' + assert ( + part_gids == body_gids + ), 'can only compute assignment features on aids in the same image' parts_are_parts = ibs._are_part_annots(part_aid_list) assert all(parts_are_parts), 'all part_aids must be part annots.' bodies_are_parts = ibs._are_part_annots(body_aid_list) @@ -2664,36 +2788,63 @@ def standardized_assignment_features(depc, part_aid_list, body_aid_list, config= part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] - part_area_relative_body = [part_area / body_area - for (part_area, body_area) in zip(part_areas, body_areas)] + part_area_relative_body = [ + part_area / body_area for (part_area, body_area) in zip(part_areas, body_areas) + ] intersect_bboxes = _bbox_intersections(part_bboxes, body_bboxes) # note that intesect w and h could be negative if there is no intersection, in which case it is the x/y distance between the annots. - intersect_areas = [w * h if w > 0 and h > 0 else 0 - for (_, _, w, h) in intersect_bboxes] + intersect_areas = [ + w * h if w > 0 and h > 0 else 0 for (_, _, w, h) in intersect_bboxes + ] - int_area_relative_part = [int_area / part_area for int_area, part_area - in zip(intersect_areas, part_areas)] - int_area_relative_body = [int_area / body_area for int_area, body_area - in zip(intersect_areas, body_areas)] + int_area_relative_part = [ + int_area / part_area for int_area, part_area in zip(intersect_areas, part_areas) + ] + int_area_relative_body = [ + int_area / body_area for int_area, body_area in zip(intersect_areas, body_areas) + ] int_area_relative_part = preprocessing.scale(int_area_relative_part) int_area_relative_body = preprocessing.scale(int_area_relative_body) part_area_relative_body = preprocessing.scale(part_area_relative_body) - result_list = list(zip( - part_bboxes, body_bboxes, intersect_bboxes, - int_area_relative_part, int_area_relative_body, part_area_relative_body - )) + result_list = list( + zip( + part_bboxes, + body_bboxes, + intersect_bboxes, + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) + ) - for (part_bbox, body_bbox, intersect_bbox, int_area_relative_part, - int_area_relative_body, part_area_relative_body) in result_list: - yield (part_bbox[0], part_bbox[1], part_bbox[2], part_bbox[3], - body_bbox[0], body_bbox[1], body_bbox[2], body_bbox[3], - intersect_bbox[0], intersect_bbox[1], intersect_bbox[2], intersect_bbox[3], - int_area_relative_part, - int_area_relative_body, - part_area_relative_body) + for ( + part_bbox, + body_bbox, + intersect_bbox, + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) in result_list: + yield ( + part_bbox[0], + part_bbox[1], + part_bbox[2], + part_bbox[3], + body_bbox[0], + body_bbox[1], + body_bbox[2], + body_bbox[3], + intersect_bbox[0], + intersect_bbox[1], + intersect_bbox[2], + intersect_bbox[3], + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) # like the above but bboxes are also standardized @@ -2701,31 +2852,55 @@ def standardized_assignment_features(depc, part_aid_list, body_aid_list, config= tablename='mega_standardized_assignment_features', parents=['annotations', 'annotations'], colnames=[ - 'p_xtl', 'p_ytl', 'p_w', 'p_h', - 'b_xtl', 'b_ytl', 'b_w', 'b_h', - 'int_xtl', 'int_ytl', 'int_w', 'int_h', + 'p_xtl', + 'p_ytl', + 'p_w', + 'p_h', + 'b_xtl', + 'b_ytl', + 'b_w', + 'b_h', + 'int_xtl', + 'int_ytl', + 'int_w', + 'int_h', 'intersect_area_relative_part', 'intersect_area_relative_body', - 'part_area_relative_body' + 'part_area_relative_body', ], coltypes=[ - float, float, float, float, - float, float, float, float, - float, float, float, float, - float, float, float + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, ], configclass=PartAssignmentFeatureConfig, fname='mega_standardized_assignment_features', rm_extern_on_delete=True, chunksize=256000000, # chunk size is huge bc we need accurate means and stdevs of various traits ) -def mega_standardized_assignment_features(depc, part_aid_list, body_aid_list, config=None): +def mega_standardized_assignment_features( + depc, part_aid_list, body_aid_list, config=None +): ibs = depc.controller part_gids = ibs.get_annot_gids(part_aid_list) body_gids = ibs.get_annot_gids(body_aid_list) - assert part_gids == body_gids, 'can only compute assignment features on aids in the same image' + assert ( + part_gids == body_gids + ), 'can only compute assignment features on aids in the same image' parts_are_parts = ibs._are_part_annots(part_aid_list) assert all(parts_are_parts), 'all part_aids must be part annots.' bodies_are_parts = ibs._are_part_annots(body_aid_list) @@ -2743,56 +2918,117 @@ def mega_standardized_assignment_features(depc, part_aid_list, body_aid_list, co part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] - part_area_relative_body = [part_area / body_area - for (part_area, body_area) in zip(part_areas, body_areas)] + part_area_relative_body = [ + part_area / body_area for (part_area, body_area) in zip(part_areas, body_areas) + ] intersect_bboxes = _bbox_intersections(part_bboxes, body_bboxes) # note that intesect w and h could be negative if there is no intersection, in which case it is the x/y distance between the annots. - intersect_areas = [w * h if w > 0 and h > 0 else 0 - for (_, _, w, h) in intersect_bboxes] + intersect_areas = [ + w * h if w > 0 and h > 0 else 0 for (_, _, w, h) in intersect_bboxes + ] - int_area_relative_part = [int_area / part_area for int_area, part_area - in zip(intersect_areas, part_areas)] - int_area_relative_body = [int_area / body_area for int_area, body_area - in zip(intersect_areas, body_areas)] + int_area_relative_part = [ + int_area / part_area for int_area, part_area in zip(intersect_areas, part_areas) + ] + int_area_relative_body = [ + int_area / body_area for int_area, body_area in zip(intersect_areas, body_areas) + ] int_area_relative_part = preprocessing.scale(int_area_relative_part) int_area_relative_body = preprocessing.scale(int_area_relative_body) part_area_relative_body = preprocessing.scale(part_area_relative_body) - result_list = list(zip( - part_bboxes, body_bboxes, intersect_bboxes, - int_area_relative_part, int_area_relative_body, part_area_relative_body - )) + result_list = list( + zip( + part_bboxes, + body_bboxes, + intersect_bboxes, + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) + ) - for (part_bbox, body_bbox, intersect_bbox, int_area_relative_part, - int_area_relative_body, part_area_relative_body) in result_list: - yield (part_bbox[0], part_bbox[1], part_bbox[2], part_bbox[3], - body_bbox[0], body_bbox[1], body_bbox[2], body_bbox[3], - intersect_bbox[0], intersect_bbox[1], intersect_bbox[2], intersect_bbox[3], - int_area_relative_part, - int_area_relative_body, - part_area_relative_body) + for ( + part_bbox, + body_bbox, + intersect_bbox, + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) in result_list: + yield ( + part_bbox[0], + part_bbox[1], + part_bbox[2], + part_bbox[3], + body_bbox[0], + body_bbox[1], + body_bbox[2], + body_bbox[3], + intersect_bbox[0], + intersect_bbox[1], + intersect_bbox[2], + intersect_bbox[3], + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) @derived_attribute( tablename='theta_assignment_features', parents=['annotations', 'annotations'], colnames=[ - 'p_v1_x', 'p_v1_y', 'p_v2_x', 'p_v2_y', 'p_v3_x', 'p_v3_y', 'p_v4_x', 'p_v4_y', - 'p_center_x', 'p_center_y', - 'b_xtl', 'b_ytl', 'b_xbr', 'b_ybr', 'b_center_x', 'b_center_y', - 'int_area_scalar', 'part_body_distance', + 'p_v1_x', + 'p_v1_y', + 'p_v2_x', + 'p_v2_y', + 'p_v3_x', + 'p_v3_y', + 'p_v4_x', + 'p_v4_y', + 'p_center_x', + 'p_center_y', + 'b_xtl', + 'b_ytl', + 'b_xbr', + 'b_ybr', + 'b_center_x', + 'b_center_y', + 'int_area_scalar', + 'part_body_distance', 'part_body_centroid_dist', 'int_over_union', 'int_over_part', 'int_over_body', - 'part_over_body' + 'part_over_body', ], coltypes=[ - float, float, float, float, float, float, float, float, float, float, - float, float, float, float, float, float, - float, float, float, float, float, float, float + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, ], configclass=PartAssignmentFeatureConfig, fname='theta_assignment_features', @@ -2808,7 +3044,9 @@ def theta_assignment_features(depc, part_aid_list, body_aid_list, config=None): part_gids = ibs.get_annot_gids(part_aid_list) body_gids = ibs.get_annot_gids(body_aid_list) - assert part_gids == body_gids, 'can only compute assignment features on aids in the same image' + assert ( + part_gids == body_gids + ), 'can only compute assignment features on aids in the same image' parts_are_parts = ibs._are_part_annots(part_aid_list) assert all(parts_are_parts), 'all part_aids must be part annots.' bodies_are_parts = ibs._are_part_annots(body_aid_list) @@ -2823,8 +3061,9 @@ def theta_assignment_features(depc, part_aid_list, body_aid_list, config=None): body_verts = _norm_vertices(body_verts, im_widths, im_heights) part_polys = [geometry.Polygon(vert) for vert in part_verts] body_polys = [geometry.Polygon(vert) for vert in body_verts] - intersect_polys = [part.intersection(body) - for part, body in zip(part_polys, body_polys)] + intersect_polys = [ + part.intersection(body) for part, body in zip(part_polys, body_polys) + ] intersect_areas = [poly.area for poly in intersect_polys] # just to make int_areas more comparable via ML methods, and since all distances < 1 int_area_scalars = [math.sqrt(area) for area in intersect_areas] @@ -2835,50 +3074,91 @@ def theta_assignment_features(depc, part_aid_list, body_aid_list, config=None): body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] - union_areas = [part + body - intersect for (part, body, intersect) - in zip(part_areas, body_areas, intersect_areas)] - int_over_unions = [intersect / union for (intersect, union) - in zip(intersect_areas, union_areas)] + union_areas = [ + part + body - intersect + for (part, body, intersect) in zip(part_areas, body_areas, intersect_areas) + ] + int_over_unions = [ + intersect / union for (intersect, union) in zip(intersect_areas, union_areas) + ] - part_body_distances = [part.distance(body) - for part, body in zip(part_polys, body_polys)] + part_body_distances = [ + part.distance(body) for part, body in zip(part_polys, body_polys) + ] part_centroids = [poly.centroid for poly in part_polys] body_centroids = [poly.centroid for poly in body_polys] - part_body_centroid_dists = [part.distance(body) for part, body - in zip(part_centroids, body_centroids)] + part_body_centroid_dists = [ + part.distance(body) for part, body in zip(part_centroids, body_centroids) + ] - int_over_parts = [int_area / part_area for part_area, int_area - in zip(part_areas, intersect_areas)] + int_over_parts = [ + int_area / part_area for part_area, int_area in zip(part_areas, intersect_areas) + ] - int_over_bodys = [int_area / body_area for body_area, int_area - in zip(body_areas, intersect_areas)] + int_over_bodys = [ + int_area / body_area for body_area, int_area in zip(body_areas, intersect_areas) + ] - part_over_bodys = [part_area / body_area for part_area, body_area - in zip(part_areas, body_areas)] + part_over_bodys = [ + part_area / body_area for part_area, body_area in zip(part_areas, body_areas) + ] # note that here only parts have thetas, hence only returning body bboxes - result_list = list(zip( - part_verts, part_centroids, body_bboxes, body_centroids, - int_area_scalars, part_body_distances, part_body_centroid_dists, - int_over_unions, int_over_parts, int_over_bodys, part_over_bodys - )) - - for (part_vert, part_center, body_bbox, body_center, - int_area_scalar, part_body_distance, part_body_centroid_dist, - int_over_union, int_over_part, int_over_body, part_over_body) in result_list: - yield (part_vert[0][0], part_vert[0][1], part_vert[1][0], part_vert[1][1], - part_vert[2][0], part_vert[2][1], part_vert[3][0], part_vert[3][1], - part_center.x, part_center.y, - body_bbox[0], body_bbox[1], body_bbox[2], body_bbox[3], - body_center.x, body_center.y, - int_area_scalar, part_body_distance, - part_body_centroid_dist, - int_over_union, - int_over_part, - int_over_body, - part_over_body, + result_list = list( + zip( + part_verts, + part_centroids, + body_bboxes, + body_centroids, + int_area_scalars, + part_body_distances, + part_body_centroid_dists, + int_over_unions, + int_over_parts, + int_over_bodys, + part_over_bodys, + ) + ) + + for ( + part_vert, + part_center, + body_bbox, + body_center, + int_area_scalar, + part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body, + ) in result_list: + yield ( + part_vert[0][0], + part_vert[0][1], + part_vert[1][0], + part_vert[1][1], + part_vert[2][0], + part_vert[2][1], + part_vert[3][0], + part_vert[3][1], + part_center.x, + part_center.y, + body_bbox[0], + body_bbox[1], + body_bbox[2], + body_bbox[3], + body_center.x, + body_center.y, + int_area_scalar, + part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body, ) @@ -2888,24 +3168,78 @@ def theta_assignment_features(depc, part_aid_list, body_aid_list, config=None): tablename='assigner_viewpoint_features', parents=['annotations', 'annotations'], colnames=[ - 'p_v1_x', 'p_v1_y', 'p_v2_x', 'p_v2_y', 'p_v3_x', 'p_v3_y', 'p_v4_x', 'p_v4_y', - 'p_center_x', 'p_center_y', - 'b_xtl', 'b_ytl', 'b_xbr', 'b_ybr', 'b_center_x', 'b_center_y', - 'int_area_scalar', 'part_body_distance', + 'p_v1_x', + 'p_v1_y', + 'p_v2_x', + 'p_v2_y', + 'p_v3_x', + 'p_v3_y', + 'p_v4_x', + 'p_v4_y', + 'p_center_x', + 'p_center_y', + 'b_xtl', + 'b_ytl', + 'b_xbr', + 'b_ybr', + 'b_center_x', + 'b_center_y', + 'int_area_scalar', + 'part_body_distance', 'part_body_centroid_dist', 'int_over_union', 'int_over_part', 'int_over_body', 'part_over_body', - 'part_is_left', 'part_is_right', 'part_is_up', 'part_is_down', 'part_is_front', 'part_is_back', - 'body_is_left', 'body_is_right', 'body_is_up', 'body_is_down', 'body_is_front', 'body_is_back', + 'part_is_left', + 'part_is_right', + 'part_is_up', + 'part_is_down', + 'part_is_front', + 'part_is_back', + 'body_is_left', + 'body_is_right', + 'body_is_up', + 'body_is_down', + 'body_is_front', + 'body_is_back', ], coltypes=[ - float, float, float, float, float, float, float, float, float, float, - float, float, float, float, float, float, - float, float, float, float, float, float, float, - bool, bool, bool, bool, bool, bool, - bool, bool, bool, bool, bool, bool, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + bool, + bool, + bool, + bool, + bool, + bool, + bool, + bool, + bool, + bool, + bool, + bool, ], configclass=PartAssignmentFeatureConfig, fname='assigner_viewpoint_features', @@ -2921,7 +3255,9 @@ def assigner_viewpoint_features(depc, part_aid_list, body_aid_list, config=None) part_gids = ibs.get_annot_gids(part_aid_list) body_gids = ibs.get_annot_gids(body_aid_list) - assert part_gids == body_gids, 'can only compute assignment features on aids in the same image' + assert ( + part_gids == body_gids + ), 'can only compute assignment features on aids in the same image' parts_are_parts = ibs._are_part_annots(part_aid_list) assert all(parts_are_parts), 'all part_aids must be part annots.' bodies_are_parts = ibs._are_part_annots(body_aid_list) @@ -2936,8 +3272,9 @@ def assigner_viewpoint_features(depc, part_aid_list, body_aid_list, config=None) body_verts = _norm_vertices(body_verts, im_widths, im_heights) part_polys = [geometry.Polygon(vert) for vert in part_verts] body_polys = [geometry.Polygon(vert) for vert in body_verts] - intersect_polys = [part.intersection(body) - for part, body in zip(part_polys, body_polys)] + intersect_polys = [ + part.intersection(body) for part, body in zip(part_polys, body_polys) + ] intersect_areas = [poly.area for poly in intersect_polys] # just to make int_areas more comparable via ML methods, and since all distances < 1 int_area_scalars = [math.sqrt(area) for area in intersect_areas] @@ -2948,55 +3285,99 @@ def assigner_viewpoint_features(depc, part_aid_list, body_aid_list, config=None) body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] - union_areas = [part + body - intersect for (part, body, intersect) - in zip(part_areas, body_areas, intersect_areas)] - int_over_unions = [intersect / union for (intersect, union) - in zip(intersect_areas, union_areas)] + union_areas = [ + part + body - intersect + for (part, body, intersect) in zip(part_areas, body_areas, intersect_areas) + ] + int_over_unions = [ + intersect / union for (intersect, union) in zip(intersect_areas, union_areas) + ] - part_body_distances = [part.distance(body) - for part, body in zip(part_polys, body_polys)] + part_body_distances = [ + part.distance(body) for part, body in zip(part_polys, body_polys) + ] part_centroids = [poly.centroid for poly in part_polys] body_centroids = [poly.centroid for poly in body_polys] - part_body_centroid_dists = [part.distance(body) for part, body - in zip(part_centroids, body_centroids)] + part_body_centroid_dists = [ + part.distance(body) for part, body in zip(part_centroids, body_centroids) + ] - int_over_parts = [int_area / part_area for part_area, int_area - in zip(part_areas, intersect_areas)] + int_over_parts = [ + int_area / part_area for part_area, int_area in zip(part_areas, intersect_areas) + ] - int_over_bodys = [int_area / body_area for body_area, int_area - in zip(body_areas, intersect_areas)] + int_over_bodys = [ + int_area / body_area for body_area, int_area in zip(body_areas, intersect_areas) + ] - part_over_bodys = [part_area / body_area for part_area, body_area - in zip(part_areas, body_areas)] + part_over_bodys = [ + part_area / body_area for part_area, body_area in zip(part_areas, body_areas) + ] part_lrudfb_bools = get_annot_lrudfb_bools(ibs, part_aid_list) body_lrudfb_bools = get_annot_lrudfb_bools(ibs, part_aid_list) # note that here only parts have thetas, hence only returning body bboxes - result_list = list(zip( - part_verts, part_centroids, body_bboxes, body_centroids, - int_area_scalars, part_body_distances, part_body_centroid_dists, - int_over_unions, int_over_parts, int_over_bodys, part_over_bodys, - part_lrudfb_bools, body_lrudfb_bools - )) - - for (part_vert, part_center, body_bbox, body_center, - int_area_scalar, part_body_distance, part_body_centroid_dist, - int_over_union, int_over_part, int_over_body, part_over_body, - part_lrudfb_bool, body_lrudfb_bool) in result_list: - ans = (part_vert[0][0], part_vert[0][1], part_vert[1][0], part_vert[1][1], - part_vert[2][0], part_vert[2][1], part_vert[3][0], part_vert[3][1], - part_center.x, part_center.y, - body_bbox[0], body_bbox[1], body_bbox[2], body_bbox[3], - body_center.x, body_center.y, - int_area_scalar, part_body_distance, - part_body_centroid_dist, - int_over_union, - int_over_part, - int_over_body, - part_over_body) + result_list = list( + zip( + part_verts, + part_centroids, + body_bboxes, + body_centroids, + int_area_scalars, + part_body_distances, + part_body_centroid_dists, + int_over_unions, + int_over_parts, + int_over_bodys, + part_over_bodys, + part_lrudfb_bools, + body_lrudfb_bools, + ) + ) + + for ( + part_vert, + part_center, + body_bbox, + body_center, + int_area_scalar, + part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body, + part_lrudfb_bool, + body_lrudfb_bool, + ) in result_list: + ans = ( + part_vert[0][0], + part_vert[0][1], + part_vert[1][0], + part_vert[1][1], + part_vert[2][0], + part_vert[2][1], + part_vert[3][0], + part_vert[3][1], + part_center.x, + part_center.y, + body_bbox[0], + body_bbox[1], + body_bbox[2], + body_bbox[3], + body_center.x, + body_center.y, + int_area_scalar, + part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body, + ) ans += tuple(part_lrudfb_bool) ans += tuple(body_lrudfb_bool) yield ans @@ -3006,24 +3387,78 @@ def assigner_viewpoint_features(depc, part_aid_list, body_aid_list, config=None) tablename='assigner_viewpoint_unit_features', parents=['annotations', 'annotations'], colnames=[ - 'p_v1_x', 'p_v1_y', 'p_v2_x', 'p_v2_y', 'p_v3_x', 'p_v3_y', 'p_v4_x', 'p_v4_y', - 'p_center_x', 'p_center_y', - 'b_xtl', 'b_ytl', 'b_xbr', 'b_ybr', 'b_center_x', 'b_center_y', - 'int_area_scalar', 'part_body_distance', + 'p_v1_x', + 'p_v1_y', + 'p_v2_x', + 'p_v2_y', + 'p_v3_x', + 'p_v3_y', + 'p_v4_x', + 'p_v4_y', + 'p_center_x', + 'p_center_y', + 'b_xtl', + 'b_ytl', + 'b_xbr', + 'b_ybr', + 'b_center_x', + 'b_center_y', + 'int_area_scalar', + 'part_body_distance', 'part_body_centroid_dist', 'int_over_union', 'int_over_part', 'int_over_body', 'part_over_body', - 'part_is_left', 'part_is_right', 'part_is_up', 'part_is_down', 'part_is_front', 'part_is_back', - 'body_is_left', 'body_is_right', 'body_is_up', 'body_is_down', 'body_is_front', 'body_is_back', + 'part_is_left', + 'part_is_right', + 'part_is_up', + 'part_is_down', + 'part_is_front', + 'part_is_back', + 'body_is_left', + 'body_is_right', + 'body_is_up', + 'body_is_down', + 'body_is_front', + 'body_is_back', ], coltypes=[ - float, float, float, float, float, float, float, float, float, float, - float, float, float, float, float, float, - float, float, float, float, float, float, float, - float, float, float, float, float, float, - float, float, float, float, float, float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, ], configclass=PartAssignmentFeatureConfig, fname='assigner_viewpoint_unit_features', @@ -3039,7 +3474,9 @@ def assigner_viewpoint_unit_features(depc, part_aid_list, body_aid_list, config= part_gids = ibs.get_annot_gids(part_aid_list) body_gids = ibs.get_annot_gids(body_aid_list) - assert part_gids == body_gids, 'can only compute assignment features on aids in the same image' + assert ( + part_gids == body_gids + ), 'can only compute assignment features on aids in the same image' parts_are_parts = ibs._are_part_annots(part_aid_list) assert all(parts_are_parts), 'all part_aids must be part annots.' bodies_are_parts = ibs._are_part_annots(body_aid_list) @@ -3054,8 +3491,9 @@ def assigner_viewpoint_unit_features(depc, part_aid_list, body_aid_list, config= body_verts = _norm_vertices(body_verts, im_widths, im_heights) part_polys = [geometry.Polygon(vert) for vert in part_verts] body_polys = [geometry.Polygon(vert) for vert in body_verts] - intersect_polys = [part.intersection(body) - for part, body in zip(part_polys, body_polys)] + intersect_polys = [ + part.intersection(body) for part, body in zip(part_polys, body_polys) + ] intersect_areas = [poly.area for poly in intersect_polys] # just to make int_areas more comparable via ML methods, and since all distances < 1 int_area_scalars = [math.sqrt(area) for area in intersect_areas] @@ -3066,55 +3504,99 @@ def assigner_viewpoint_unit_features(depc, part_aid_list, body_aid_list, config= body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] - union_areas = [part + body - intersect for (part, body, intersect) - in zip(part_areas, body_areas, intersect_areas)] - int_over_unions = [intersect / union for (intersect, union) - in zip(intersect_areas, union_areas)] + union_areas = [ + part + body - intersect + for (part, body, intersect) in zip(part_areas, body_areas, intersect_areas) + ] + int_over_unions = [ + intersect / union for (intersect, union) in zip(intersect_areas, union_areas) + ] - part_body_distances = [part.distance(body) - for part, body in zip(part_polys, body_polys)] + part_body_distances = [ + part.distance(body) for part, body in zip(part_polys, body_polys) + ] part_centroids = [poly.centroid for poly in part_polys] body_centroids = [poly.centroid for poly in body_polys] - part_body_centroid_dists = [part.distance(body) for part, body - in zip(part_centroids, body_centroids)] + part_body_centroid_dists = [ + part.distance(body) for part, body in zip(part_centroids, body_centroids) + ] - int_over_parts = [int_area / part_area for part_area, int_area - in zip(part_areas, intersect_areas)] + int_over_parts = [ + int_area / part_area for part_area, int_area in zip(part_areas, intersect_areas) + ] - int_over_bodys = [int_area / body_area for body_area, int_area - in zip(body_areas, intersect_areas)] + int_over_bodys = [ + int_area / body_area for body_area, int_area in zip(body_areas, intersect_areas) + ] - part_over_bodys = [part_area / body_area for part_area, body_area - in zip(part_areas, body_areas)] + part_over_bodys = [ + part_area / body_area for part_area, body_area in zip(part_areas, body_areas) + ] part_lrudfb_vects = get_annot_lrudfb_unit_vector(ibs, part_aid_list) body_lrudfb_vects = get_annot_lrudfb_unit_vector(ibs, part_aid_list) # note that here only parts have thetas, hence only returning body bboxes - result_list = list(zip( - part_verts, part_centroids, body_bboxes, body_centroids, - int_area_scalars, part_body_distances, part_body_centroid_dists, - int_over_unions, int_over_parts, int_over_bodys, part_over_bodys, - part_lrudfb_vects, body_lrudfb_vects - )) - - for (part_vert, part_center, body_bbox, body_center, - int_area_scalar, part_body_distance, part_body_centroid_dist, - int_over_union, int_over_part, int_over_body, part_over_body, - part_lrudfb_vect, body_lrudfb_vect) in result_list: - ans = (part_vert[0][0], part_vert[0][1], part_vert[1][0], part_vert[1][1], - part_vert[2][0], part_vert[2][1], part_vert[3][0], part_vert[3][1], - part_center.x, part_center.y, - body_bbox[0], body_bbox[1], body_bbox[2], body_bbox[3], - body_center.x, body_center.y, - int_area_scalar, part_body_distance, - part_body_centroid_dist, - int_over_union, - int_over_part, - int_over_body, - part_over_body) + result_list = list( + zip( + part_verts, + part_centroids, + body_bboxes, + body_centroids, + int_area_scalars, + part_body_distances, + part_body_centroid_dists, + int_over_unions, + int_over_parts, + int_over_bodys, + part_over_bodys, + part_lrudfb_vects, + body_lrudfb_vects, + ) + ) + + for ( + part_vert, + part_center, + body_bbox, + body_center, + int_area_scalar, + part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body, + part_lrudfb_vect, + body_lrudfb_vect, + ) in result_list: + ans = ( + part_vert[0][0], + part_vert[0][1], + part_vert[1][0], + part_vert[1][1], + part_vert[2][0], + part_vert[2][1], + part_vert[3][0], + part_vert[3][1], + part_center.x, + part_center.y, + body_bbox[0], + body_bbox[1], + body_bbox[2], + body_bbox[3], + body_center.x, + body_center.y, + int_area_scalar, + part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body, + ) ans += tuple(part_lrudfb_vect) ans += tuple(body_lrudfb_vect) yield ans @@ -3123,9 +3605,17 @@ def assigner_viewpoint_unit_features(depc, part_aid_list, body_aid_list, config= # left, right, up, down, front, back booleans, useful for assigner classification and other cases where we might want viewpoint as an input for an ML model def get_annot_lrudfb_bools(ibs, aid_list): views = ibs.get_annot_viewpoints(aid_list) - bool_arrays = [['left' in view, 'right' in view, - 'up' in view, 'down' in view, - 'front' in view, 'back' in view] for view in views] + bool_arrays = [ + [ + 'left' in view, + 'right' in view, + 'up' in view, + 'down' in view, + 'front' in view, + 'back' in view, + ] + for view in views + ] return bool_arrays @@ -3134,9 +3624,11 @@ def get_annot_lrudfb_unit_vector(ibs, aid_list): float_arrays = [[float(b) for b in lrudfb] for lrudfb in bool_arrays] lrudfb_lengths = [sqrt(lrudfb.count(True)) for lrudfb in bool_arrays] # lying just to avoid division by zero errors - lrudfb_lengths = [l if l != 0 else -1 for l in lrudfb_lengths] - unit_float_array = [[f / length for f in lrudfb] for lrudfb, length - in zip(float_arrays, lrudfb_lengths)] + lrudfb_lengths = [length if length != 0 else -1 for length in lrudfb_lengths] + unit_float_array = [ + [f / length for f in lrudfb] + for lrudfb, length in zip(float_arrays, lrudfb_lengths) + ] return unit_float_array @@ -3145,27 +3637,63 @@ def get_annot_lrudfb_unit_vector(ibs, aid_list): tablename='theta_standardized_assignment_features', parents=['annotations', 'annotations'], colnames=[ - 'p_v1_x', 'p_v1_y', 'p_v2_x', 'p_v2_y', 'p_v3_x', 'p_v3_y', 'p_v4_x', 'p_v4_y', - 'p_center_x', 'p_center_y', - 'b_xtl', 'b_ytl', 'b_xbr', 'b_ybr', 'b_center_x', 'b_center_y', - 'int_area_scalar', 'part_body_distance', + 'p_v1_x', + 'p_v1_y', + 'p_v2_x', + 'p_v2_y', + 'p_v3_x', + 'p_v3_y', + 'p_v4_x', + 'p_v4_y', + 'p_center_x', + 'p_center_y', + 'b_xtl', + 'b_ytl', + 'b_xbr', + 'b_ybr', + 'b_center_x', + 'b_center_y', + 'int_area_scalar', + 'part_body_distance', 'part_body_centroid_dist', 'int_over_union', 'int_over_part', 'int_over_body', - 'part_over_body' + 'part_over_body', ], coltypes=[ - float, float, float, float, float, float, float, float, float, float, - float, float, float, float, float, float, - float, float, float, float, float, float, float + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, ], configclass=PartAssignmentFeatureConfig, fname='theta_standardized_assignment_features', rm_extern_on_delete=True, chunksize=2560000, # chunk size is huge bc we need accurate means and stdevs of various traits ) -def theta_standardized_assignment_features(depc, part_aid_list, body_aid_list, config=None): +def theta_standardized_assignment_features( + depc, part_aid_list, body_aid_list, config=None +): from shapely import geometry import math @@ -3174,7 +3702,9 @@ def theta_standardized_assignment_features(depc, part_aid_list, body_aid_list, c part_gids = ibs.get_annot_gids(part_aid_list) body_gids = ibs.get_annot_gids(body_aid_list) - assert part_gids == body_gids, 'can only compute assignment features on aids in the same image' + assert ( + part_gids == body_gids + ), 'can only compute assignment features on aids in the same image' parts_are_parts = ibs._are_part_annots(part_aid_list) assert all(parts_are_parts), 'all part_aids must be part annots.' bodies_are_parts = ibs._are_part_annots(body_aid_list) @@ -3189,89 +3719,129 @@ def theta_standardized_assignment_features(depc, part_aid_list, body_aid_list, c body_verts = _norm_vertices(body_verts, im_widths, im_heights) part_polys = [geometry.Polygon(vert) for vert in part_verts] body_polys = [geometry.Polygon(vert) for vert in body_verts] - intersect_polys = [part.intersection(body) - for part, body in zip(part_polys, body_polys)] + intersect_polys = [ + part.intersection(body) for part, body in zip(part_polys, body_polys) + ] intersect_areas = [poly.area for poly in intersect_polys] # just to make int_areas more comparable via ML methods, and since all distances < 1 int_area_scalars = [math.sqrt(area) for area in intersect_areas] int_area_scalars = preprocessing.scale(int_area_scalars) - part_bboxes = ibs.get_annot_bboxes(part_aid_list) body_bboxes = ibs.get_annot_bboxes(body_aid_list) part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] - union_areas = [part + body - intersect for (part, body, intersect) - in zip(part_areas, body_areas, intersect_areas)] - int_over_unions = [intersect / union for (intersect, union) - in zip(intersect_areas, union_areas)] + union_areas = [ + part + body - intersect + for (part, body, intersect) in zip(part_areas, body_areas, intersect_areas) + ] + int_over_unions = [ + intersect / union for (intersect, union) in zip(intersect_areas, union_areas) + ] int_over_unions = preprocessing.scale(int_over_unions) - part_body_distances = [part.distance(body) - for part, body in zip(part_polys, body_polys)] + part_body_distances = [ + part.distance(body) for part, body in zip(part_polys, body_polys) + ] part_body_distances = preprocessing.scale(part_body_distances) part_centroids = [poly.centroid for poly in part_polys] body_centroids = [poly.centroid for poly in body_polys] - part_body_centroid_dists = [part.distance(body) for part, body - in zip(part_centroids, body_centroids)] + part_body_centroid_dists = [ + part.distance(body) for part, body in zip(part_centroids, body_centroids) + ] part_body_centroid_dists = preprocessing.scale(part_body_centroid_dists) - int_over_parts = [int_area / part_area for part_area, int_area - in zip(part_areas, intersect_areas)] + int_over_parts = [ + int_area / part_area for part_area, int_area in zip(part_areas, intersect_areas) + ] int_over_parts = preprocessing.scale(int_over_parts) - int_over_bodys = [int_area / body_area for body_area, int_area - in zip(body_areas, intersect_areas)] + int_over_bodys = [ + int_area / body_area for body_area, int_area in zip(body_areas, intersect_areas) + ] int_over_bodys = preprocessing.scale(int_over_bodys) - part_over_bodys = [part_area / body_area for part_area, body_area - in zip(part_areas, body_areas)] + part_over_bodys = [ + part_area / body_area for part_area, body_area in zip(part_areas, body_areas) + ] part_over_bodys = preprocessing.scale(part_over_bodys) - - - #standardization + # standardization # note that here only parts have thetas, hence only returning body bboxes - result_list = list(zip( - part_verts, part_centroids, body_bboxes, body_centroids, - int_area_scalars, part_body_distances, part_body_centroid_dists, - int_over_unions, int_over_parts, int_over_bodys, part_over_bodys - )) - - for (part_vert, part_center, body_bbox, body_center, - int_area_scalar, part_body_distance, part_body_centroid_dist, - int_over_union, int_over_part, int_over_body, part_over_body) in result_list: - yield (part_vert[0][0], part_vert[0][1], part_vert[1][0], part_vert[1][1], - part_vert[2][0], part_vert[2][1], part_vert[3][0], part_vert[3][1], - part_center.x, part_center.y, - body_bbox[0], body_bbox[1], body_bbox[2], body_bbox[3], - body_center.x, body_center.y, - int_area_scalar, part_body_distance, - part_body_centroid_dist, - int_over_union, - int_over_part, - int_over_body, - part_over_body, + result_list = list( + zip( + part_verts, + part_centroids, + body_bboxes, + body_centroids, + int_area_scalars, + part_body_distances, + part_body_centroid_dists, + int_over_unions, + int_over_parts, + int_over_bodys, + part_over_bodys, + ) + ) + + for ( + part_vert, + part_center, + body_bbox, + body_center, + int_area_scalar, + part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body, + ) in result_list: + yield ( + part_vert[0][0], + part_vert[0][1], + part_vert[1][0], + part_vert[1][1], + part_vert[2][0], + part_vert[2][1], + part_vert[3][0], + part_vert[3][1], + part_center.x, + part_center.y, + body_bbox[0], + body_bbox[1], + body_bbox[2], + body_bbox[3], + body_center.x, + body_center.y, + int_area_scalar, + part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body, ) def _norm_bboxes(bbox_list, width_list, height_list): - normed_boxes = [(bbox[0]/w, bbox[1]/h, bbox[2]/w, bbox[3]/h) - for (bbox, w, h) - in zip(bbox_list, width_list, height_list)] + normed_boxes = [ + (bbox[0] / w, bbox[1] / h, bbox[2] / w, bbox[3] / h) + for (bbox, w, h) in zip(bbox_list, width_list, height_list) + ] return normed_boxes def _norm_vertices(verts_list, width_list, height_list): - normed_verts = [[[x / w , y / h] for x, y in vert] - for vert, w, h - in zip(verts_list, width_list, height_list) - ] + normed_verts = [ + [[x / w, y / h] for x, y in vert] + for vert, w, h in zip(verts_list, width_list, height_list) + ] return normed_verts @@ -3289,47 +3859,60 @@ def _bbox_intersections(bboxes_a, bboxes_b): corner_bboxes_a = _bbox_to_corner_format(bboxes_a) corner_bboxes_b = _bbox_to_corner_format(bboxes_b) - intersect_xtls = [max(xtl_a, xtl_b) - for ((xtl_a, _, _, _), (xtl_b, _, _, _)) - in zip(corner_bboxes_a, corner_bboxes_b)] + intersect_xtls = [ + max(xtl_a, xtl_b) + for ((xtl_a, _, _, _), (xtl_b, _, _, _)) in zip(corner_bboxes_a, corner_bboxes_b) + ] - intersect_ytls = [max(ytl_a, ytl_b) - for ((_, ytl_a, _, _), (_, ytl_b, _, _)) - in zip(corner_bboxes_a, corner_bboxes_b)] + intersect_ytls = [ + max(ytl_a, ytl_b) + for ((_, ytl_a, _, _), (_, ytl_b, _, _)) in zip(corner_bboxes_a, corner_bboxes_b) + ] - intersect_xbrs = [min(xbr_a, xbr_b) - for ((_, _, xbr_a, _), (_, _, xbr_b, _)) - in zip(corner_bboxes_a, corner_bboxes_b)] + intersect_xbrs = [ + min(xbr_a, xbr_b) + for ((_, _, xbr_a, _), (_, _, xbr_b, _)) in zip(corner_bboxes_a, corner_bboxes_b) + ] - intersect_ybrs = [min(ybr_a, ybr_b) - for ((_, _, _, ybr_a), (_, _, _, ybr_b)) - in zip(corner_bboxes_a, corner_bboxes_b)] + intersect_ybrs = [ + min(ybr_a, ybr_b) + for ((_, _, _, ybr_a), (_, _, _, ybr_b)) in zip(corner_bboxes_a, corner_bboxes_b) + ] - intersect_widths = [int_xbr - int_xtl for int_xbr, int_xtl - in zip(intersect_xbrs, intersect_xtls)] + intersect_widths = [ + int_xbr - int_xtl for int_xbr, int_xtl in zip(intersect_xbrs, intersect_xtls) + ] - intersect_heights = [int_ybr - int_ytl for int_ybr, int_ytl - in zip(intersect_ybrs, intersect_ytls)] + intersect_heights = [ + int_ybr - int_ytl for int_ybr, int_ytl in zip(intersect_ybrs, intersect_ytls) + ] - intersect_bboxes = list(zip( - intersect_xtls, intersect_ytls, intersect_widths, intersect_heights)) + intersect_bboxes = list( + zip(intersect_xtls, intersect_ytls, intersect_widths, intersect_heights) + ) return intersect_bboxes + def _theta_aware_intersect_areas(verts_list_a, verts_list_b): import shapely + polys_a = [shapely.geometry.Polygon(vert) for vert in verts_list_a] polys_b = [shapely.geometry.Polygon(vert) for vert in verts_list_b] - intersect_areas = [poly1.intersection(poly2).area - for poly1, poly2 in zip(polys_a, polys_b)] + intersect_areas = [ + poly1.intersection(poly2).area for poly1, poly2 in zip(polys_a, polys_b) + ] return intersect_areas def _all_centroids(verts_list_a, verts_list_b): import shapely + polys_a = [shapely.geometry.Polygon(vert) for vert in verts_list_a] polys_b = [shapely.geometry.Polygon(vert) for vert in verts_list_b] - intersect_polys = [poly1.intersection(poly2) for poly1, poly2 in zip(polys_a, polys_b)] + intersect_polys = [ + poly1.intersection(poly2) for poly1, poly2 in zip(polys_a, polys_b) + ] centroids_a = [poly.centroid for poly in polys_a] centroids_b = [poly.centroid for poly in polys_b] @@ -3337,13 +3920,15 @@ def _all_centroids(verts_list_a, verts_list_b): return centroids_a, centroids_b, centroids_int + def _polygons_to_centroid_coords(polygon_list): centroids = [poly.centroid for poly in polygon_list] + return centroids # converts bboxes from (xtl, ytl, w, h) to (xtl, ytl, xbr, ybr) def _bbox_to_corner_format(bboxes): - corner_bboxes = [(bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]) - for bbox in bboxes] + corner_bboxes = [ + (bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]) for bbox in bboxes + ] return corner_bboxes - diff --git a/wbia/core_parts.py b/wbia/core_parts.py index 744592f927..55e6c48515 100644 --- a/wbia/core_parts.py +++ b/wbia/core_parts.py @@ -6,10 +6,12 @@ import logging import utool as ut import numpy as np -from wbia import dtool + +# from wbia import dtool from wbia.control.controller_inject import register_preprocs, register_subprops from wbia import core_annots -from wbia.constants import ANNOTATION_TABLE, PART_TABLE + +# from wbia.constants import ANNOTATION_TABLE, PART_TABLE (print, rrr, profile) = ut.inject2(__name__) logger = logging.getLogger('wbia') @@ -85,5 +87,3 @@ def compute_part_chip(depc, part_rowid_list, config=None): for result in result_list: yield result logger.info('Done Preprocessing Part Chips') - - From e632e54a7cd798f3dcdd7d697497ad39772799b8 Mon Sep 17 00:00:00 2001 From: Drew Blount Date: Tue, 5 Jan 2021 14:30:37 -0800 Subject: [PATCH 149/294] fixes sysres import --- wbia/algo/detect/assigner.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/wbia/algo/detect/assigner.py b/wbia/algo/detect/assigner.py index 6b4aac7d9d..e9c65cb5f3 100644 --- a/wbia/algo/detect/assigner.py +++ b/wbia/algo/detect/assigner.py @@ -204,16 +204,16 @@ def assign_parts(ibs, all_aids, cutoff_score=0.5): >>> assert (set(assigned_aids) | set(unassigned_aids) == set(aids)) >>> ([(3, 1), (6, 5), (8, 7)], [2, 4]) """ + gids = ibs.get_annot_gids(all_aids) gid_to_aids = defaultdict(list) for gid, aid in zip(gids, all_aids): gid_to_aids[gid] += [aid] all_assignments = [] - all_unassigned_aids = [] for gid in gid_to_aids.keys(): - this_pairs, this_unassigned = _assign_parts_one_image( + this_pairs, this_unassigned = assign_parts_one_image( ibs, gid_to_aids[gid], cutoff_score ) all_assignments += this_pairs @@ -223,7 +223,7 @@ def assign_parts(ibs, all_aids, cutoff_score=0.5): @register_ibs_method -def _assign_parts_one_image(ibs, aid_list, cutoff_score=0.5): +def assign_parts_one_image(ibs, aid_list, cutoff_score=0.5): r""" Main assigner method; makes assignments on all_aids based on assigner scores. @@ -247,7 +247,7 @@ def _assign_parts_one_image(ibs, aid_list, cutoff_score=0.5): >>> ibs = assigner_testdb_ibs() >>> gid = 1 >>> aids = ibs.get_image_aids(gid) - >>> result = ibs._assign_parts_one_image(aids) + >>> result = ibs.assign_parts_one_image(aids) >>> assigned_pairs = result[0] >>> unassigned_aids = result[1] >>> assigned_aids = [item for pair in assigned_pairs for item in pair] @@ -257,13 +257,21 @@ def _assign_parts_one_image(ibs, aid_list, cutoff_score=0.5): >>> assert (set(assigned_aids) | set(unassigned_aids) == set(aids)) >>> ([(3, 1)], [2, 4]) """ + all_species = ibs.get_annot_species(aid_list) + # put unsupported species into the all_unassigned_aids list + assign_flag_list = [species in SPECIES_CONFIG_MAP.keys() + for species in all_species] + + unassigned_aids_noconfig = ut.filterfalse_items(aid_list, assign_flag_list) + aid_list = ut.compress(aid_list, assign_flag_list) + are_part_aids = _are_part_annots(ibs, aid_list) part_aids = ut.compress(aid_list, are_part_aids) body_aids = ut.compress(aid_list, [not p for p in are_part_aids]) gids = ibs.get_annot_gids(list(set(part_aids)) + list(set(body_aids))) num_images = len(set(gids)) - assert num_images == 1, "_assign_parts_one_image called on multiple images' aids" + assert num_images == 1, "assign_parts_one_image called on multiple images' aids" # parallel lists representing all possible part/body pairs all_pairs_parallel = _all_pairs_parallel(part_aids, body_aids) @@ -280,6 +288,7 @@ def _assign_parts_one_image(ibs, aid_list, cutoff_score=0.5): good_pairs, unassigned_aids = _make_assignments( pair_parts, pair_bodies, assigner_scores, cutoff_score ) + unassigned_aids = unassigned_aids_noconfig + unassigned_aids return good_pairs, unassigned_aids @@ -483,7 +492,7 @@ def gid_keyed_ground_truth(ibs, assigner_data): @register_ibs_method def assigner_testdb_ibs(): import wbia - import sysres + from wbia import sysres dbdir = sysres.ensure_testdb_assigner() # dbdir = '/data/testdb_assigner' From 6bf08f4c11f85778af4d599d4f9b799b46fd4cd7 Mon Sep 17 00:00:00 2001 From: Drew Blount Date: Tue, 5 Jan 2021 14:31:06 -0800 Subject: [PATCH 150/294] integrates assigner with API. Note that when assigner_algo is not None, the results list is now in a new format --- wbia/web/apis_detect.py | 80 +++++++++++++++++++++++++++-------------- 1 file changed, 53 insertions(+), 27 deletions(-) diff --git a/wbia/web/apis_detect.py b/wbia/web/apis_detect.py index 79acf69282..fcd84f54b3 100644 --- a/wbia/web/apis_detect.py +++ b/wbia/web/apis_detect.py @@ -11,7 +11,6 @@ from wbia.web import appfuncs as appf import numpy as np -(print, rrr, profile) = ut.inject2(__name__) logger = logging.getLogger('wbia') CLASS_INJECT_KEY, register_ibs_method = controller_inject.make_ibs_register_decorator( @@ -440,6 +439,7 @@ def process_detection_html(ibs, **kwargs): return result_dict +# this is where the aids_list response from commit_localization_results is packaged & returned in json @register_ibs_method @accessor_decors.getter_1to1 def detect_cnn_json(ibs, gid_list, detect_func, config={}, **kwargs): @@ -465,40 +465,54 @@ def detect_cnn_json(ibs, gid_list, detect_func, config={}, **kwargs): >>> print(results_dict) """ # TODO: Return confidence here as well + def _json_result(ibs, aid): + result = { + 'id': aid, + 'uuid': ibs.get_annot_uuids(aid), + 'xtl': ibs.get_annot_bboxes(aid)[0], + 'ytl': ibs.get_annot_bboxes(aid)[1], + 'left': ibs.get_annot_bboxes(aid)[0], + 'top': ibs.get_annot_bboxes(aid)[1], + 'width': ibs.get_annot_bboxes(aid)[2], + 'height': ibs.get_annot_bboxes(aid)[3], + 'theta': round(ibs.get_annot_thetas(aid), 4), + 'confidence': round(ibs.get_annot_detect_confidence(aid), 4), + 'class': ibs.get_annot_species_texts(aid), + 'species': ibs.get_annot_species_texts(aid), + 'viewpoint': ibs.get_annot_viewpoints(aid), + 'quality': ibs.get_annot_qualities(aid), + 'multiple': ibs.get_annot_multiple(aid), + 'interest': ibs.get_annot_interest(aid), + } + return result + image_uuid_list = ibs.get_image_uuids(gid_list) ibs.assert_valid_gids(gid_list) - # Get detections from depc + # Get detections from depc --- this output will be affected by assigner aids_list = detect_func(gid_list, **config) - results_list = [ - [ - { - 'id': aid, - 'uuid': ibs.get_annot_uuids(aid), - 'xtl': ibs.get_annot_bboxes(aid)[0], - 'ytl': ibs.get_annot_bboxes(aid)[1], - 'left': ibs.get_annot_bboxes(aid)[0], - 'top': ibs.get_annot_bboxes(aid)[1], - 'width': ibs.get_annot_bboxes(aid)[2], - 'height': ibs.get_annot_bboxes(aid)[3], - 'theta': round(ibs.get_annot_thetas(aid), 4), - 'confidence': round(ibs.get_annot_detect_confidence(aid), 4), - 'class': ibs.get_annot_species_texts(aid), - 'species': ibs.get_annot_species_texts(aid), - 'viewpoint': ibs.get_annot_viewpoints(aid), - 'quality': ibs.get_annot_qualities(aid), - 'multiple': ibs.get_annot_multiple(aid), - 'interest': ibs.get_annot_interest(aid), - } - for aid in aid_list - ] - for aid_list in aids_list - ] + results_list = [] + has_assignments = False + for aid_list in aids_list: + result_list = [] + for aid in aid_list: + if not isinstance(aid, tuple): # we have an assignment + result = _json_result(ibs, aid) + else: + assert len(aid) > 0 + has_assignments = True + result = [] + for val in aid: + result.append(_json_result(ibs, val)) + result_list.append(result) + results_list.append(result_list) + score_list = [0.0] * len(gid_list) # Wrap up results with other information results_dict = { 'image_uuid_list': image_uuid_list, 'results_list': results_list, 'score_list': score_list, + 'has_assignments': has_assignments, } return results_dict @@ -626,7 +640,6 @@ def models_cnn_lightnet(ibs, **kwargs): Method: PUT, GET URL: /api/labels/cnn/lightnet/ """ - def identity(x): return x @@ -895,6 +908,8 @@ def commit_localization_results( use_labeler_species=False, orienter_algo=None, orienter_model_tag=None, + assigner_algo=None, + assigner_model_tag=None, update_json_log=True, apply_nms_post_use_labeler_species=True, **kwargs, @@ -981,10 +996,21 @@ def commit_localization_results( if len(bbox_list) > 0: ibs.set_annot_bboxes(aid_list, bbox_list, theta_list=theta_list) + if assigner_algo is not None: + # aids_list is a list of lists of aids, now we want a list of lists of + all_assignments = [] + for aids in aids_list: + assigned, unassigned = ibs.assign_parts_one_image(aids) + # unassigned aids should also be tuples to indicate they went through the assigner + unassigned = [(aid,) for aid in unassigned] + all_assignments.append(assigned + unassigned) + aids_list = all_assignments + ibs._clean_species() if update_json_log: ibs.log_detections(aid_list) + # list of list of ints return aids_list From cc8ab58da7e0339533c28910babf3981719e8762 Mon Sep 17 00:00:00 2001 From: Drew Blount Date: Tue, 5 Jan 2021 14:48:50 -0800 Subject: [PATCH 151/294] slight change to how load_assigner_classifier is used, making it more robust --- wbia/algo/detect/assigner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/wbia/algo/detect/assigner.py b/wbia/algo/detect/assigner.py index e9c65cb5f3..0509bda388 100644 --- a/wbia/algo/detect/assigner.py +++ b/wbia/algo/detect/assigner.py @@ -280,7 +280,8 @@ def assign_parts_one_image(ibs, aid_list, cutoff_score=0.5): assigner_features = ibs.depc_annot.get( 'assigner_viewpoint_features', all_pairs_parallel ) - assigner_classifier = load_assigner_classifier(ibs, part_aids) + # send all aids to this call just so it can find the right classifier model + assigner_classifier = load_assigner_classifier(ibs, body_aids+part_aids) assigner_scores = assigner_classifier.predict_proba(assigner_features) # assigner_scores is a list of [P_false, P_true] probabilities which sum to 1, so here we just pare down to the true probabilities From f8fd5269f12e053e89998dce66aec8053ae9d9c7 Mon Sep 17 00:00:00 2001 From: Drew Blount Date: Tue, 5 Jan 2021 15:01:59 -0800 Subject: [PATCH 152/294] fixes logic bug on assign_flag_list, cleans up imports --- wbia/algo/detect/assigner.py | 30 +++++------------------------- 1 file changed, 5 insertions(+), 25 deletions(-) diff --git a/wbia/algo/detect/assigner.py b/wbia/algo/detect/assigner.py index 0509bda388..5025f138c4 100644 --- a/wbia/algo/detect/assigner.py +++ b/wbia/algo/detect/assigner.py @@ -3,38 +3,17 @@ # from os.path import expanduser, join from wbia import constants as const -from wbia.control.controller_inject import ( - # register_preprocs, - # register_subprops, - make_ibs_register_decorator, -) +from wbia.control.controller_inject import make_ibs_register_decorator import utool as ut - -# import numpy as np -# import random import os - -# from collections import OrderedDict from collections import defaultdict -# from datetime import datetime -# import time - # illustration imports from shutil import copy from PIL import Image, ImageDraw import wbia.plottool as pt -# from sklearn import preprocessing -# from tune_sklearn import TuneGridSearchCV - -# shitload of scikit classifiers -# import numpy as np - -# bunch of classifier models for training - -(print, rrr, profile) = ut.inject2(__name__, '[assigner]') logger = logging.getLogger('wbia') CLASS_INJECT_KEY, register_ibs_method = make_ibs_register_decorator(__name__) @@ -204,13 +183,13 @@ def assign_parts(ibs, all_aids, cutoff_score=0.5): >>> assert (set(assigned_aids) | set(unassigned_aids) == set(aids)) >>> ([(3, 1), (6, 5), (8, 7)], [2, 4]) """ - gids = ibs.get_annot_gids(all_aids) gid_to_aids = defaultdict(list) for gid, aid in zip(gids, all_aids): gid_to_aids[gid] += [aid] all_assignments = [] + all_unassigned_aids = [] for gid in gid_to_aids.keys(): this_pairs, this_unassigned = assign_parts_one_image( @@ -259,8 +238,9 @@ def assign_parts_one_image(ibs, aid_list, cutoff_score=0.5): """ all_species = ibs.get_annot_species(aid_list) # put unsupported species into the all_unassigned_aids list + all_species_no_parts = [species.split('+')[0] for species in all_species] assign_flag_list = [species in SPECIES_CONFIG_MAP.keys() - for species in all_species] + for species in all_species_no_parts] unassigned_aids_noconfig = ut.filterfalse_items(aid_list, assign_flag_list) aid_list = ut.compress(aid_list, assign_flag_list) @@ -281,7 +261,7 @@ def assign_parts_one_image(ibs, aid_list, cutoff_score=0.5): 'assigner_viewpoint_features', all_pairs_parallel ) # send all aids to this call just so it can find the right classifier model - assigner_classifier = load_assigner_classifier(ibs, body_aids+part_aids) + assigner_classifier = load_assigner_classifier(ibs, body_aids + part_aids) assigner_scores = assigner_classifier.predict_proba(assigner_features) # assigner_scores is a list of [P_false, P_true] probabilities which sum to 1, so here we just pare down to the true probabilities From 68652cbfae7e50276c6f1f0920e7b14677ee8d21 Mon Sep 17 00:00:00 2001 From: Drew Blount Date: Tue, 5 Jan 2021 16:08:59 -0800 Subject: [PATCH 153/294] moves unused assigner features into train_assigner --- wbia/algo/detect/train_assigner.py | 1211 +++++++++++++++++++++++++++ wbia/core_annots.py | 1237 +--------------------------- 2 files changed, 1233 insertions(+), 1215 deletions(-) diff --git a/wbia/algo/detect/train_assigner.py b/wbia/algo/detect/train_assigner.py index a907208a0b..c5f511d9fd 100644 --- a/wbia/algo/detect/train_assigner.py +++ b/wbia/algo/detect/train_assigner.py @@ -523,3 +523,1214 @@ def check_accuracy(ibs, assigner_data=None, cutoff_score=0.5, illustrate=False): import utool as ut # NOQA ut.doctest_funcs() + + +# additional assigner features to explore +class PartAssignmentFeatureConfig(dtool.Config): + _param_info_list = [] + +@derived_attribute( + tablename='part_assignment_features', + parents=['annotations', 'annotations'], + colnames=[ + 'p_xtl', + 'p_ytl', + 'p_w', + 'p_h', + 'b_xtl', + 'b_ytl', + 'b_w', + 'b_h', + 'int_xtl', + 'int_ytl', + 'int_w', + 'int_h', + 'intersect_area_relative_part', + 'intersect_area_relative_body', + 'part_area_relative_body', + ], + coltypes=[ + int, + int, + int, + int, + int, + int, + int, + int, + int, + int, + int, + int, + float, + float, + float, + ], + configclass=PartAssignmentFeatureConfig, + fname='part_assignment_features', + rm_extern_on_delete=True, + chunksize=256, +) +def compute_assignment_features(depc, part_aid_list, body_aid_list, config=None): + + ibs = depc.controller + + part_gids = ibs.get_annot_gids(part_aid_list) + body_gids = ibs.get_annot_gids(body_aid_list) + assert ( + part_gids == body_gids + ), 'can only compute assignment features on aids in the same image' + parts_are_parts = ibs._are_part_annots(part_aid_list) + assert all(parts_are_parts), 'all part_aids must be part annots.' + bodies_are_parts = ibs._are_part_annots(body_aid_list) + assert not any(bodies_are_parts), 'body_aids cannot be part annots' + + part_bboxes = ibs.get_annot_bboxes(part_aid_list) + body_bboxes = ibs.get_annot_bboxes(body_aid_list) + + part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] + body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] + part_area_relative_body = [ + part_area / body_area for (part_area, body_area) in zip(part_areas, body_areas) + ] + + intersect_bboxes = _bbox_intersections(part_bboxes, body_bboxes) + # note that intesect w and h could be negative if there is no intersection, in which case it is the x/y distance between the annots. + intersect_areas = [ + w * h if w > 0 and h > 0 else 0 for (_, _, w, h) in intersect_bboxes + ] + + int_area_relative_part = [ + int_area / part_area for int_area, part_area in zip(intersect_areas, part_areas) + ] + int_area_relative_body = [ + int_area / body_area for int_area, body_area in zip(intersect_areas, body_areas) + ] + + result_list = list( + zip( + part_bboxes, + body_bboxes, + intersect_bboxes, + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) + ) + + for ( + part_bbox, + body_bbox, + intersect_bbox, + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) in result_list: + yield ( + part_bbox[0], + part_bbox[1], + part_bbox[2], + part_bbox[3], + body_bbox[0], + body_bbox[1], + body_bbox[2], + body_bbox[3], + intersect_bbox[0], + intersect_bbox[1], + intersect_bbox[2], + intersect_bbox[3], + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) + + +@derived_attribute( + tablename='normalized_assignment_features', + parents=['annotations', 'annotations'], + colnames=[ + 'p_xtl', + 'p_ytl', + 'p_w', + 'p_h', + 'b_xtl', + 'b_ytl', + 'b_w', + 'b_h', + 'int_xtl', + 'int_ytl', + 'int_w', + 'int_h', + 'intersect_area_relative_part', + 'intersect_area_relative_body', + 'part_area_relative_body', + ], + coltypes=[ + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + ], + configclass=PartAssignmentFeatureConfig, + fname='normalized_assignment_features', + rm_extern_on_delete=True, + chunksize=256, +) +def normalized_assignment_features(depc, part_aid_list, body_aid_list, config=None): + + ibs = depc.controller + + part_gids = ibs.get_annot_gids(part_aid_list) + body_gids = ibs.get_annot_gids(body_aid_list) + assert ( + part_gids == body_gids + ), 'can only compute assignment features on aids in the same image' + parts_are_parts = ibs._are_part_annots(part_aid_list) + assert all(parts_are_parts), 'all part_aids must be part annots.' + bodies_are_parts = ibs._are_part_annots(body_aid_list) + assert not any(bodies_are_parts), 'body_aids cannot be part annots' + + part_bboxes = ibs.get_annot_bboxes(part_aid_list) + body_bboxes = ibs.get_annot_bboxes(body_aid_list) + im_widths = ibs.get_image_widths(part_gids) + im_heights = ibs.get_image_heights(part_gids) + part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) + body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) + + part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] + body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] + part_area_relative_body = [ + part_area / body_area for (part_area, body_area) in zip(part_areas, body_areas) + ] + + intersect_bboxes = _bbox_intersections(part_bboxes, body_bboxes) + # note that intesect w and h could be negative if there is no intersection, in which case it is the x/y distance between the annots. + intersect_areas = [ + w * h if w > 0 and h > 0 else 0 for (_, _, w, h) in intersect_bboxes + ] + + int_area_relative_part = [ + int_area / part_area for int_area, part_area in zip(intersect_areas, part_areas) + ] + int_area_relative_body = [ + int_area / body_area for int_area, body_area in zip(intersect_areas, body_areas) + ] + + result_list = list( + zip( + part_bboxes, + body_bboxes, + intersect_bboxes, + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) + ) + + for ( + part_bbox, + body_bbox, + intersect_bbox, + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) in result_list: + yield ( + part_bbox[0], + part_bbox[1], + part_bbox[2], + part_bbox[3], + body_bbox[0], + body_bbox[1], + body_bbox[2], + body_bbox[3], + intersect_bbox[0], + intersect_bbox[1], + intersect_bbox[2], + intersect_bbox[3], + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) + + +@derived_attribute( + tablename='standardized_assignment_features', + parents=['annotations', 'annotations'], + colnames=[ + 'p_xtl', + 'p_ytl', + 'p_w', + 'p_h', + 'b_xtl', + 'b_ytl', + 'b_w', + 'b_h', + 'int_xtl', + 'int_ytl', + 'int_w', + 'int_h', + 'intersect_area_relative_part', + 'intersect_area_relative_body', + 'part_area_relative_body', + ], + coltypes=[ + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + ], + configclass=PartAssignmentFeatureConfig, + fname='standardized_assignment_features', + rm_extern_on_delete=True, + chunksize=256000000, # chunk size is huge bc we need accurate means and stdevs of various traits +) +def standardized_assignment_features(depc, part_aid_list, body_aid_list, config=None): + + ibs = depc.controller + + part_gids = ibs.get_annot_gids(part_aid_list) + body_gids = ibs.get_annot_gids(body_aid_list) + assert ( + part_gids == body_gids + ), 'can only compute assignment features on aids in the same image' + parts_are_parts = ibs._are_part_annots(part_aid_list) + assert all(parts_are_parts), 'all part_aids must be part annots.' + bodies_are_parts = ibs._are_part_annots(body_aid_list) + assert not any(bodies_are_parts), 'body_aids cannot be part annots' + + part_bboxes = ibs.get_annot_bboxes(part_aid_list) + body_bboxes = ibs.get_annot_bboxes(body_aid_list) + im_widths = ibs.get_image_widths(part_gids) + im_heights = ibs.get_image_heights(part_gids) + part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) + body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) + + part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] + body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] + part_area_relative_body = [ + part_area / body_area for (part_area, body_area) in zip(part_areas, body_areas) + ] + + intersect_bboxes = _bbox_intersections(part_bboxes, body_bboxes) + # note that intesect w and h could be negative if there is no intersection, in which case it is the x/y distance between the annots. + intersect_areas = [ + w * h if w > 0 and h > 0 else 0 for (_, _, w, h) in intersect_bboxes + ] + + int_area_relative_part = [ + int_area / part_area for int_area, part_area in zip(intersect_areas, part_areas) + ] + int_area_relative_body = [ + int_area / body_area for int_area, body_area in zip(intersect_areas, body_areas) + ] + + int_area_relative_part = preprocessing.scale(int_area_relative_part) + int_area_relative_body = preprocessing.scale(int_area_relative_body) + part_area_relative_body = preprocessing.scale(part_area_relative_body) + + result_list = list( + zip( + part_bboxes, + body_bboxes, + intersect_bboxes, + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) + ) + + for ( + part_bbox, + body_bbox, + intersect_bbox, + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) in result_list: + yield ( + part_bbox[0], + part_bbox[1], + part_bbox[2], + part_bbox[3], + body_bbox[0], + body_bbox[1], + body_bbox[2], + body_bbox[3], + intersect_bbox[0], + intersect_bbox[1], + intersect_bbox[2], + intersect_bbox[3], + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) + + +# like the above but bboxes are also standardized +@derived_attribute( + tablename='mega_standardized_assignment_features', + parents=['annotations', 'annotations'], + colnames=[ + 'p_xtl', + 'p_ytl', + 'p_w', + 'p_h', + 'b_xtl', + 'b_ytl', + 'b_w', + 'b_h', + 'int_xtl', + 'int_ytl', + 'int_w', + 'int_h', + 'intersect_area_relative_part', + 'intersect_area_relative_body', + 'part_area_relative_body', + ], + coltypes=[ + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + ], + configclass=PartAssignmentFeatureConfig, + fname='mega_standardized_assignment_features', + rm_extern_on_delete=True, + chunksize=256000000, # chunk size is huge bc we need accurate means and stdevs of various traits +) +def mega_standardized_assignment_features( + depc, part_aid_list, body_aid_list, config=None +): + + ibs = depc.controller + + part_gids = ibs.get_annot_gids(part_aid_list) + body_gids = ibs.get_annot_gids(body_aid_list) + assert ( + part_gids == body_gids + ), 'can only compute assignment features on aids in the same image' + parts_are_parts = ibs._are_part_annots(part_aid_list) + assert all(parts_are_parts), 'all part_aids must be part annots.' + bodies_are_parts = ibs._are_part_annots(body_aid_list) + assert not any(bodies_are_parts), 'body_aids cannot be part annots' + + part_bboxes = ibs.get_annot_bboxes(part_aid_list) + body_bboxes = ibs.get_annot_bboxes(body_aid_list) + im_widths = ibs.get_image_widths(part_gids) + im_heights = ibs.get_image_heights(part_gids) + part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) + body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) + + part_bboxes = _standardized_bboxes(part_bboxes) + body_bboxes = _standardized_bboxes(body_bboxes) + + part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] + body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] + part_area_relative_body = [ + part_area / body_area for (part_area, body_area) in zip(part_areas, body_areas) + ] + + intersect_bboxes = _bbox_intersections(part_bboxes, body_bboxes) + # note that intesect w and h could be negative if there is no intersection, in which case it is the x/y distance between the annots. + intersect_areas = [ + w * h if w > 0 and h > 0 else 0 for (_, _, w, h) in intersect_bboxes + ] + + int_area_relative_part = [ + int_area / part_area for int_area, part_area in zip(intersect_areas, part_areas) + ] + int_area_relative_body = [ + int_area / body_area for int_area, body_area in zip(intersect_areas, body_areas) + ] + + int_area_relative_part = preprocessing.scale(int_area_relative_part) + int_area_relative_body = preprocessing.scale(int_area_relative_body) + part_area_relative_body = preprocessing.scale(part_area_relative_body) + + result_list = list( + zip( + part_bboxes, + body_bboxes, + intersect_bboxes, + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) + ) + + for ( + part_bbox, + body_bbox, + intersect_bbox, + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) in result_list: + yield ( + part_bbox[0], + part_bbox[1], + part_bbox[2], + part_bbox[3], + body_bbox[0], + body_bbox[1], + body_bbox[2], + body_bbox[3], + intersect_bbox[0], + intersect_bbox[1], + intersect_bbox[2], + intersect_bbox[3], + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) + +@derived_attribute( + tablename='theta_assignment_features', + parents=['annotations', 'annotations'], + colnames=[ + 'p_v1_x', + 'p_v1_y', + 'p_v2_x', + 'p_v2_y', + 'p_v3_x', + 'p_v3_y', + 'p_v4_x', + 'p_v4_y', + 'p_center_x', + 'p_center_y', + 'b_xtl', + 'b_ytl', + 'b_xbr', + 'b_ybr', + 'b_center_x', + 'b_center_y', + 'int_area_scalar', + 'part_body_distance', + 'part_body_centroid_dist', + 'int_over_union', + 'int_over_part', + 'int_over_body', + 'part_over_body', + ], + coltypes=[ + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + ], + configclass=PartAssignmentFeatureConfig, + fname='theta_assignment_features', + rm_extern_on_delete=True, + chunksize=256, # chunk size is huge bc we need accurate means and stdevs of various traits +) +def theta_assignment_features(depc, part_aid_list, body_aid_list, config=None): + + from shapely import geometry + import math + + ibs = depc.controller + + part_gids = ibs.get_annot_gids(part_aid_list) + body_gids = ibs.get_annot_gids(body_aid_list) + assert ( + part_gids == body_gids + ), 'can only compute assignment features on aids in the same image' + parts_are_parts = ibs._are_part_annots(part_aid_list) + assert all(parts_are_parts), 'all part_aids must be part annots.' + bodies_are_parts = ibs._are_part_annots(body_aid_list) + assert not any(bodies_are_parts), 'body_aids cannot be part annots' + + im_widths = ibs.get_image_widths(part_gids) + im_heights = ibs.get_image_heights(part_gids) + + part_verts = ibs.get_annot_rotated_verts(part_aid_list) + body_verts = ibs.get_annot_rotated_verts(body_aid_list) + part_verts = _norm_vertices(part_verts, im_widths, im_heights) + body_verts = _norm_vertices(body_verts, im_widths, im_heights) + part_polys = [geometry.Polygon(vert) for vert in part_verts] + body_polys = [geometry.Polygon(vert) for vert in body_verts] + intersect_polys = [ + part.intersection(body) for part, body in zip(part_polys, body_polys) + ] + intersect_areas = [poly.area for poly in intersect_polys] + # just to make int_areas more comparable via ML methods, and since all distances < 1 + int_area_scalars = [math.sqrt(area) for area in intersect_areas] + + part_bboxes = ibs.get_annot_bboxes(part_aid_list) + body_bboxes = ibs.get_annot_bboxes(body_aid_list) + part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) + body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) + part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] + body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] + union_areas = [ + part + body - intersect + for (part, body, intersect) in zip(part_areas, body_areas, intersect_areas) + ] + int_over_unions = [ + intersect / union for (intersect, union) in zip(intersect_areas, union_areas) + ] + + part_body_distances = [ + part.distance(body) for part, body in zip(part_polys, body_polys) + ] + + part_centroids = [poly.centroid for poly in part_polys] + body_centroids = [poly.centroid for poly in body_polys] + + part_body_centroid_dists = [ + part.distance(body) for part, body in zip(part_centroids, body_centroids) + ] + + int_over_parts = [ + int_area / part_area for part_area, int_area in zip(part_areas, intersect_areas) + ] + + int_over_bodys = [ + int_area / body_area for body_area, int_area in zip(body_areas, intersect_areas) + ] + + part_over_bodys = [ + part_area / body_area for part_area, body_area in zip(part_areas, body_areas) + ] + + # note that here only parts have thetas, hence only returning body bboxes + result_list = list( + zip( + part_verts, + part_centroids, + body_bboxes, + body_centroids, + int_area_scalars, + part_body_distances, + part_body_centroid_dists, + int_over_unions, + int_over_parts, + int_over_bodys, + part_over_bodys, + ) + ) + + for ( + part_vert, + part_center, + body_bbox, + body_center, + int_area_scalar, + part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body, + ) in result_list: + yield ( + part_vert[0][0], + part_vert[0][1], + part_vert[1][0], + part_vert[1][1], + part_vert[2][0], + part_vert[2][1], + part_vert[3][0], + part_vert[3][1], + part_center.x, + part_center.y, + body_bbox[0], + body_bbox[1], + body_bbox[2], + body_bbox[3], + body_center.x, + body_center.y, + int_area_scalar, + part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body, + ) + + +@derived_attribute( + tablename='assigner_viewpoint_unit_features', + parents=['annotations', 'annotations'], + colnames=[ + 'p_v1_x', + 'p_v1_y', + 'p_v2_x', + 'p_v2_y', + 'p_v3_x', + 'p_v3_y', + 'p_v4_x', + 'p_v4_y', + 'p_center_x', + 'p_center_y', + 'b_xtl', + 'b_ytl', + 'b_xbr', + 'b_ybr', + 'b_center_x', + 'b_center_y', + 'int_area_scalar', + 'part_body_distance', + 'part_body_centroid_dist', + 'int_over_union', + 'int_over_part', + 'int_over_body', + 'part_over_body', + 'part_is_left', + 'part_is_right', + 'part_is_up', + 'part_is_down', + 'part_is_front', + 'part_is_back', + 'body_is_left', + 'body_is_right', + 'body_is_up', + 'body_is_down', + 'body_is_front', + 'body_is_back', + ], + coltypes=[ + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + ], + configclass=PartAssignmentFeatureConfig, + fname='assigner_viewpoint_unit_features', + rm_extern_on_delete=True, + chunksize=256, # chunk size is huge bc we need accurate means and stdevs of various traits +) +def assigner_viewpoint_unit_features(depc, part_aid_list, body_aid_list, config=None): + + from shapely import geometry + import math + + ibs = depc.controller + + part_gids = ibs.get_annot_gids(part_aid_list) + body_gids = ibs.get_annot_gids(body_aid_list) + assert ( + part_gids == body_gids + ), 'can only compute assignment features on aids in the same image' + parts_are_parts = ibs._are_part_annots(part_aid_list) + assert all(parts_are_parts), 'all part_aids must be part annots.' + bodies_are_parts = ibs._are_part_annots(body_aid_list) + assert not any(bodies_are_parts), 'body_aids cannot be part annots' + + im_widths = ibs.get_image_widths(part_gids) + im_heights = ibs.get_image_heights(part_gids) + + part_verts = ibs.get_annot_rotated_verts(part_aid_list) + body_verts = ibs.get_annot_rotated_verts(body_aid_list) + part_verts = _norm_vertices(part_verts, im_widths, im_heights) + body_verts = _norm_vertices(body_verts, im_widths, im_heights) + part_polys = [geometry.Polygon(vert) for vert in part_verts] + body_polys = [geometry.Polygon(vert) for vert in body_verts] + intersect_polys = [ + part.intersection(body) for part, body in zip(part_polys, body_polys) + ] + intersect_areas = [poly.area for poly in intersect_polys] + # just to make int_areas more comparable via ML methods, and since all distances < 1 + int_area_scalars = [math.sqrt(area) for area in intersect_areas] + + part_bboxes = ibs.get_annot_bboxes(part_aid_list) + body_bboxes = ibs.get_annot_bboxes(body_aid_list) + part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) + body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) + part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] + body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] + union_areas = [ + part + body - intersect + for (part, body, intersect) in zip(part_areas, body_areas, intersect_areas) + ] + int_over_unions = [ + intersect / union for (intersect, union) in zip(intersect_areas, union_areas) + ] + + part_body_distances = [ + part.distance(body) for part, body in zip(part_polys, body_polys) + ] + + part_centroids = [poly.centroid for poly in part_polys] + body_centroids = [poly.centroid for poly in body_polys] + + part_body_centroid_dists = [ + part.distance(body) for part, body in zip(part_centroids, body_centroids) + ] + + int_over_parts = [ + int_area / part_area for part_area, int_area in zip(part_areas, intersect_areas) + ] + + int_over_bodys = [ + int_area / body_area for body_area, int_area in zip(body_areas, intersect_areas) + ] + + part_over_bodys = [ + part_area / body_area for part_area, body_area in zip(part_areas, body_areas) + ] + + part_lrudfb_vects = get_annot_lrudfb_unit_vector(ibs, part_aid_list) + body_lrudfb_vects = get_annot_lrudfb_unit_vector(ibs, part_aid_list) + + # note that here only parts have thetas, hence only returning body bboxes + result_list = list( + zip( + part_verts, + part_centroids, + body_bboxes, + body_centroids, + int_area_scalars, + part_body_distances, + part_body_centroid_dists, + int_over_unions, + int_over_parts, + int_over_bodys, + part_over_bodys, + part_lrudfb_vects, + body_lrudfb_vects, + ) + ) + + for ( + part_vert, + part_center, + body_bbox, + body_center, + int_area_scalar, + part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body, + part_lrudfb_vect, + body_lrudfb_vect, + ) in result_list: + ans = ( + part_vert[0][0], + part_vert[0][1], + part_vert[1][0], + part_vert[1][1], + part_vert[2][0], + part_vert[2][1], + part_vert[3][0], + part_vert[3][1], + part_center.x, + part_center.y, + body_bbox[0], + body_bbox[1], + body_bbox[2], + body_bbox[3], + body_center.x, + body_center.y, + int_area_scalar, + part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body, + ) + ans += tuple(part_lrudfb_vect) + ans += tuple(body_lrudfb_vect) + yield ans + +@derived_attribute( + tablename='theta_standardized_assignment_features', + parents=['annotations', 'annotations'], + colnames=[ + 'p_v1_x', + 'p_v1_y', + 'p_v2_x', + 'p_v2_y', + 'p_v3_x', + 'p_v3_y', + 'p_v4_x', + 'p_v4_y', + 'p_center_x', + 'p_center_y', + 'b_xtl', + 'b_ytl', + 'b_xbr', + 'b_ybr', + 'b_center_x', + 'b_center_y', + 'int_area_scalar', + 'part_body_distance', + 'part_body_centroid_dist', + 'int_over_union', + 'int_over_part', + 'int_over_body', + 'part_over_body', + ], + coltypes=[ + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + ], + configclass=PartAssignmentFeatureConfig, + fname='theta_standardized_assignment_features', + rm_extern_on_delete=True, + chunksize=2560000, # chunk size is huge bc we need accurate means and stdevs of various traits +) +def theta_standardized_assignment_features( + depc, part_aid_list, body_aid_list, config=None +): + + from shapely import geometry + import math + + ibs = depc.controller + + part_gids = ibs.get_annot_gids(part_aid_list) + body_gids = ibs.get_annot_gids(body_aid_list) + assert ( + part_gids == body_gids + ), 'can only compute assignment features on aids in the same image' + parts_are_parts = ibs._are_part_annots(part_aid_list) + assert all(parts_are_parts), 'all part_aids must be part annots.' + bodies_are_parts = ibs._are_part_annots(body_aid_list) + assert not any(bodies_are_parts), 'body_aids cannot be part annots' + + im_widths = ibs.get_image_widths(part_gids) + im_heights = ibs.get_image_heights(part_gids) + + part_verts = ibs.get_annot_rotated_verts(part_aid_list) + body_verts = ibs.get_annot_rotated_verts(body_aid_list) + part_verts = _norm_vertices(part_verts, im_widths, im_heights) + body_verts = _norm_vertices(body_verts, im_widths, im_heights) + part_polys = [geometry.Polygon(vert) for vert in part_verts] + body_polys = [geometry.Polygon(vert) for vert in body_verts] + intersect_polys = [ + part.intersection(body) for part, body in zip(part_polys, body_polys) + ] + intersect_areas = [poly.area for poly in intersect_polys] + # just to make int_areas more comparable via ML methods, and since all distances < 1 + int_area_scalars = [math.sqrt(area) for area in intersect_areas] + int_area_scalars = preprocessing.scale(int_area_scalars) + + part_bboxes = ibs.get_annot_bboxes(part_aid_list) + body_bboxes = ibs.get_annot_bboxes(body_aid_list) + part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) + body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) + part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] + body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] + union_areas = [ + part + body - intersect + for (part, body, intersect) in zip(part_areas, body_areas, intersect_areas) + ] + int_over_unions = [ + intersect / union for (intersect, union) in zip(intersect_areas, union_areas) + ] + int_over_unions = preprocessing.scale(int_over_unions) + + part_body_distances = [ + part.distance(body) for part, body in zip(part_polys, body_polys) + ] + part_body_distances = preprocessing.scale(part_body_distances) + + part_centroids = [poly.centroid for poly in part_polys] + body_centroids = [poly.centroid for poly in body_polys] + + part_body_centroid_dists = [ + part.distance(body) for part, body in zip(part_centroids, body_centroids) + ] + part_body_centroid_dists = preprocessing.scale(part_body_centroid_dists) + + int_over_parts = [ + int_area / part_area for part_area, int_area in zip(part_areas, intersect_areas) + ] + int_over_parts = preprocessing.scale(int_over_parts) + + int_over_bodys = [ + int_area / body_area for body_area, int_area in zip(body_areas, intersect_areas) + ] + int_over_bodys = preprocessing.scale(int_over_bodys) + + part_over_bodys = [ + part_area / body_area for part_area, body_area in zip(part_areas, body_areas) + ] + part_over_bodys = preprocessing.scale(part_over_bodys) + + # standardization + + # note that here only parts have thetas, hence only returning body bboxes + result_list = list( + zip( + part_verts, + part_centroids, + body_bboxes, + body_centroids, + int_area_scalars, + part_body_distances, + part_body_centroid_dists, + int_over_unions, + int_over_parts, + int_over_bodys, + part_over_bodys, + ) + ) + + for ( + part_vert, + part_center, + body_bbox, + body_center, + int_area_scalar, + part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body, + ) in result_list: + yield ( + part_vert[0][0], + part_vert[0][1], + part_vert[1][0], + part_vert[1][1], + part_vert[2][0], + part_vert[2][1], + part_vert[3][0], + part_vert[3][1], + part_center.x, + part_center.y, + body_bbox[0], + body_bbox[1], + body_bbox[2], + body_bbox[3], + body_center.x, + body_center.y, + int_area_scalar, + part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body, + ) + +def get_annot_lrudfb_unit_vector(ibs, aid_list): + bool_arrays = get_annot_lrudfb_bools(ibs, aid_list) + float_arrays = [[float(b) for b in lrudfb] for lrudfb in bool_arrays] + lrudfb_lengths = [sqrt(lrudfb.count(True)) for lrudfb in bool_arrays] + # lying just to avoid division by zero errors + lrudfb_lengths = [length if length != 0 else -1 for length in lrudfb_lengths] + unit_float_array = [ + [f / length for f in lrudfb] + for lrudfb, length in zip(float_arrays, lrudfb_lengths) + ] + + return unit_float_array + +def _norm_bboxes(bbox_list, width_list, height_list): + normed_boxes = [ + (bbox[0] / w, bbox[1] / h, bbox[2] / w, bbox[3] / h) + for (bbox, w, h) in zip(bbox_list, width_list, height_list) + ] + return normed_boxes + + +def _norm_vertices(verts_list, width_list, height_list): + normed_verts = [ + [[x / w, y / h] for x, y in vert] + for vert, w, h in zip(verts_list, width_list, height_list) + ] + return normed_verts + + +# does this even make any sense? let's find out experimentally +def _standardized_bboxes(bbox_list): + xtls = preprocessing.scale([bbox[0] for bbox in bbox_list]) + ytls = preprocessing.scale([bbox[1] for bbox in bbox_list]) + wids = preprocessing.scale([bbox[2] for bbox in bbox_list]) + heis = preprocessing.scale([bbox[3] for bbox in bbox_list]) + standardized_bboxes = list(zip(xtls, ytls, wids, heis)) + return standardized_bboxes + + +def _bbox_intersections(bboxes_a, bboxes_b): + corner_bboxes_a = _bbox_to_corner_format(bboxes_a) + corner_bboxes_b = _bbox_to_corner_format(bboxes_b) + + intersect_xtls = [ + max(xtl_a, xtl_b) + for ((xtl_a, _, _, _), (xtl_b, _, _, _)) in zip(corner_bboxes_a, corner_bboxes_b) + ] + + intersect_ytls = [ + max(ytl_a, ytl_b) + for ((_, ytl_a, _, _), (_, ytl_b, _, _)) in zip(corner_bboxes_a, corner_bboxes_b) + ] + + intersect_xbrs = [ + min(xbr_a, xbr_b) + for ((_, _, xbr_a, _), (_, _, xbr_b, _)) in zip(corner_bboxes_a, corner_bboxes_b) + ] + + intersect_ybrs = [ + min(ybr_a, ybr_b) + for ((_, _, _, ybr_a), (_, _, _, ybr_b)) in zip(corner_bboxes_a, corner_bboxes_b) + ] + + intersect_widths = [ + int_xbr - int_xtl for int_xbr, int_xtl in zip(intersect_xbrs, intersect_xtls) + ] + + intersect_heights = [ + int_ybr - int_ytl for int_ybr, int_ytl in zip(intersect_ybrs, intersect_ytls) + ] + + intersect_bboxes = list( + zip(intersect_xtls, intersect_ytls, intersect_widths, intersect_heights) + ) + + return intersect_bboxes + + +def _all_centroids(verts_list_a, verts_list_b): + import shapely + + polys_a = [shapely.geometry.Polygon(vert) for vert in verts_list_a] + polys_b = [shapely.geometry.Polygon(vert) for vert in verts_list_b] + intersect_polys = [ + poly1.intersection(poly2) for poly1, poly2 in zip(polys_a, polys_b) + ] + + centroids_a = [poly.centroid for poly in polys_a] + centroids_b = [poly.centroid for poly in polys_b] + centroids_int = [poly.centroid for poly in intersect_polys] + + return centroids_a, centroids_b, centroids_int + + +def _theta_aware_intersect_areas(verts_list_a, verts_list_b): + import shapely + + polys_a = [shapely.geometry.Polygon(vert) for vert in verts_list_a] + polys_b = [shapely.geometry.Polygon(vert) for vert in verts_list_b] + intersect_areas = [ + poly1.intersection(poly2).area for poly1, poly2 in zip(polys_a, polys_b) + ] + return intersect_areas + + +# converts bboxes from (xtl, ytl, w, h) to (xtl, ytl, xbr, ybr) +def _bbox_to_corner_format(bboxes): + corner_bboxes = [ + (bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]) for bbox in bboxes + ] + return corner_bboxes + + +def _polygons_to_centroid_coords(polygon_list): + centroids = [poly.centroid for poly in polygon_list] + return centroids + + diff --git a/wbia/core_annots.py b/wbia/core_annots.py index 6e4ece6276..9f966f2431 100644 --- a/wbia/core_annots.py +++ b/wbia/core_annots.py @@ -2487,681 +2487,6 @@ class PartAssignmentFeatureConfig(dtool.Config): _param_info_list = [] -@derived_attribute( - tablename='part_assignment_features', - parents=['annotations', 'annotations'], - colnames=[ - 'p_xtl', - 'p_ytl', - 'p_w', - 'p_h', - 'b_xtl', - 'b_ytl', - 'b_w', - 'b_h', - 'int_xtl', - 'int_ytl', - 'int_w', - 'int_h', - 'intersect_area_relative_part', - 'intersect_area_relative_body', - 'part_area_relative_body', - ], - coltypes=[ - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - float, - float, - float, - ], - configclass=PartAssignmentFeatureConfig, - fname='part_assignment_features', - rm_extern_on_delete=True, - chunksize=256, -) -def compute_assignment_features(depc, part_aid_list, body_aid_list, config=None): - - ibs = depc.controller - - part_gids = ibs.get_annot_gids(part_aid_list) - body_gids = ibs.get_annot_gids(body_aid_list) - assert ( - part_gids == body_gids - ), 'can only compute assignment features on aids in the same image' - parts_are_parts = ibs._are_part_annots(part_aid_list) - assert all(parts_are_parts), 'all part_aids must be part annots.' - bodies_are_parts = ibs._are_part_annots(body_aid_list) - assert not any(bodies_are_parts), 'body_aids cannot be part annots' - - part_bboxes = ibs.get_annot_bboxes(part_aid_list) - body_bboxes = ibs.get_annot_bboxes(body_aid_list) - - part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] - body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] - part_area_relative_body = [ - part_area / body_area for (part_area, body_area) in zip(part_areas, body_areas) - ] - - intersect_bboxes = _bbox_intersections(part_bboxes, body_bboxes) - # note that intesect w and h could be negative if there is no intersection, in which case it is the x/y distance between the annots. - intersect_areas = [ - w * h if w > 0 and h > 0 else 0 for (_, _, w, h) in intersect_bboxes - ] - - int_area_relative_part = [ - int_area / part_area for int_area, part_area in zip(intersect_areas, part_areas) - ] - int_area_relative_body = [ - int_area / body_area for int_area, body_area in zip(intersect_areas, body_areas) - ] - - result_list = list( - zip( - part_bboxes, - body_bboxes, - intersect_bboxes, - int_area_relative_part, - int_area_relative_body, - part_area_relative_body, - ) - ) - - for ( - part_bbox, - body_bbox, - intersect_bbox, - int_area_relative_part, - int_area_relative_body, - part_area_relative_body, - ) in result_list: - yield ( - part_bbox[0], - part_bbox[1], - part_bbox[2], - part_bbox[3], - body_bbox[0], - body_bbox[1], - body_bbox[2], - body_bbox[3], - intersect_bbox[0], - intersect_bbox[1], - intersect_bbox[2], - intersect_bbox[3], - int_area_relative_part, - int_area_relative_body, - part_area_relative_body, - ) - - -@derived_attribute( - tablename='normalized_assignment_features', - parents=['annotations', 'annotations'], - colnames=[ - 'p_xtl', - 'p_ytl', - 'p_w', - 'p_h', - 'b_xtl', - 'b_ytl', - 'b_w', - 'b_h', - 'int_xtl', - 'int_ytl', - 'int_w', - 'int_h', - 'intersect_area_relative_part', - 'intersect_area_relative_body', - 'part_area_relative_body', - ], - coltypes=[ - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - ], - configclass=PartAssignmentFeatureConfig, - fname='normalized_assignment_features', - rm_extern_on_delete=True, - chunksize=256, -) -def normalized_assignment_features(depc, part_aid_list, body_aid_list, config=None): - - ibs = depc.controller - - part_gids = ibs.get_annot_gids(part_aid_list) - body_gids = ibs.get_annot_gids(body_aid_list) - assert ( - part_gids == body_gids - ), 'can only compute assignment features on aids in the same image' - parts_are_parts = ibs._are_part_annots(part_aid_list) - assert all(parts_are_parts), 'all part_aids must be part annots.' - bodies_are_parts = ibs._are_part_annots(body_aid_list) - assert not any(bodies_are_parts), 'body_aids cannot be part annots' - - part_bboxes = ibs.get_annot_bboxes(part_aid_list) - body_bboxes = ibs.get_annot_bboxes(body_aid_list) - im_widths = ibs.get_image_widths(part_gids) - im_heights = ibs.get_image_heights(part_gids) - part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) - body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) - - part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] - body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] - part_area_relative_body = [ - part_area / body_area for (part_area, body_area) in zip(part_areas, body_areas) - ] - - intersect_bboxes = _bbox_intersections(part_bboxes, body_bboxes) - # note that intesect w and h could be negative if there is no intersection, in which case it is the x/y distance between the annots. - intersect_areas = [ - w * h if w > 0 and h > 0 else 0 for (_, _, w, h) in intersect_bboxes - ] - - int_area_relative_part = [ - int_area / part_area for int_area, part_area in zip(intersect_areas, part_areas) - ] - int_area_relative_body = [ - int_area / body_area for int_area, body_area in zip(intersect_areas, body_areas) - ] - - result_list = list( - zip( - part_bboxes, - body_bboxes, - intersect_bboxes, - int_area_relative_part, - int_area_relative_body, - part_area_relative_body, - ) - ) - - for ( - part_bbox, - body_bbox, - intersect_bbox, - int_area_relative_part, - int_area_relative_body, - part_area_relative_body, - ) in result_list: - yield ( - part_bbox[0], - part_bbox[1], - part_bbox[2], - part_bbox[3], - body_bbox[0], - body_bbox[1], - body_bbox[2], - body_bbox[3], - intersect_bbox[0], - intersect_bbox[1], - intersect_bbox[2], - intersect_bbox[3], - int_area_relative_part, - int_area_relative_body, - part_area_relative_body, - ) - - -@derived_attribute( - tablename='standardized_assignment_features', - parents=['annotations', 'annotations'], - colnames=[ - 'p_xtl', - 'p_ytl', - 'p_w', - 'p_h', - 'b_xtl', - 'b_ytl', - 'b_w', - 'b_h', - 'int_xtl', - 'int_ytl', - 'int_w', - 'int_h', - 'intersect_area_relative_part', - 'intersect_area_relative_body', - 'part_area_relative_body', - ], - coltypes=[ - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - ], - configclass=PartAssignmentFeatureConfig, - fname='standardized_assignment_features', - rm_extern_on_delete=True, - chunksize=256000000, # chunk size is huge bc we need accurate means and stdevs of various traits -) -def standardized_assignment_features(depc, part_aid_list, body_aid_list, config=None): - - ibs = depc.controller - - part_gids = ibs.get_annot_gids(part_aid_list) - body_gids = ibs.get_annot_gids(body_aid_list) - assert ( - part_gids == body_gids - ), 'can only compute assignment features on aids in the same image' - parts_are_parts = ibs._are_part_annots(part_aid_list) - assert all(parts_are_parts), 'all part_aids must be part annots.' - bodies_are_parts = ibs._are_part_annots(body_aid_list) - assert not any(bodies_are_parts), 'body_aids cannot be part annots' - - part_bboxes = ibs.get_annot_bboxes(part_aid_list) - body_bboxes = ibs.get_annot_bboxes(body_aid_list) - im_widths = ibs.get_image_widths(part_gids) - im_heights = ibs.get_image_heights(part_gids) - part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) - body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) - - part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] - body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] - part_area_relative_body = [ - part_area / body_area for (part_area, body_area) in zip(part_areas, body_areas) - ] - - intersect_bboxes = _bbox_intersections(part_bboxes, body_bboxes) - # note that intesect w and h could be negative if there is no intersection, in which case it is the x/y distance between the annots. - intersect_areas = [ - w * h if w > 0 and h > 0 else 0 for (_, _, w, h) in intersect_bboxes - ] - - int_area_relative_part = [ - int_area / part_area for int_area, part_area in zip(intersect_areas, part_areas) - ] - int_area_relative_body = [ - int_area / body_area for int_area, body_area in zip(intersect_areas, body_areas) - ] - - int_area_relative_part = preprocessing.scale(int_area_relative_part) - int_area_relative_body = preprocessing.scale(int_area_relative_body) - part_area_relative_body = preprocessing.scale(part_area_relative_body) - - result_list = list( - zip( - part_bboxes, - body_bboxes, - intersect_bboxes, - int_area_relative_part, - int_area_relative_body, - part_area_relative_body, - ) - ) - - for ( - part_bbox, - body_bbox, - intersect_bbox, - int_area_relative_part, - int_area_relative_body, - part_area_relative_body, - ) in result_list: - yield ( - part_bbox[0], - part_bbox[1], - part_bbox[2], - part_bbox[3], - body_bbox[0], - body_bbox[1], - body_bbox[2], - body_bbox[3], - intersect_bbox[0], - intersect_bbox[1], - intersect_bbox[2], - intersect_bbox[3], - int_area_relative_part, - int_area_relative_body, - part_area_relative_body, - ) - - -# like the above but bboxes are also standardized -@derived_attribute( - tablename='mega_standardized_assignment_features', - parents=['annotations', 'annotations'], - colnames=[ - 'p_xtl', - 'p_ytl', - 'p_w', - 'p_h', - 'b_xtl', - 'b_ytl', - 'b_w', - 'b_h', - 'int_xtl', - 'int_ytl', - 'int_w', - 'int_h', - 'intersect_area_relative_part', - 'intersect_area_relative_body', - 'part_area_relative_body', - ], - coltypes=[ - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - ], - configclass=PartAssignmentFeatureConfig, - fname='mega_standardized_assignment_features', - rm_extern_on_delete=True, - chunksize=256000000, # chunk size is huge bc we need accurate means and stdevs of various traits -) -def mega_standardized_assignment_features( - depc, part_aid_list, body_aid_list, config=None -): - - ibs = depc.controller - - part_gids = ibs.get_annot_gids(part_aid_list) - body_gids = ibs.get_annot_gids(body_aid_list) - assert ( - part_gids == body_gids - ), 'can only compute assignment features on aids in the same image' - parts_are_parts = ibs._are_part_annots(part_aid_list) - assert all(parts_are_parts), 'all part_aids must be part annots.' - bodies_are_parts = ibs._are_part_annots(body_aid_list) - assert not any(bodies_are_parts), 'body_aids cannot be part annots' - - part_bboxes = ibs.get_annot_bboxes(part_aid_list) - body_bboxes = ibs.get_annot_bboxes(body_aid_list) - im_widths = ibs.get_image_widths(part_gids) - im_heights = ibs.get_image_heights(part_gids) - part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) - body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) - - part_bboxes = _standardized_bboxes(part_bboxes) - body_bboxes = _standardized_bboxes(body_bboxes) - - part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] - body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] - part_area_relative_body = [ - part_area / body_area for (part_area, body_area) in zip(part_areas, body_areas) - ] - - intersect_bboxes = _bbox_intersections(part_bboxes, body_bboxes) - # note that intesect w and h could be negative if there is no intersection, in which case it is the x/y distance between the annots. - intersect_areas = [ - w * h if w > 0 and h > 0 else 0 for (_, _, w, h) in intersect_bboxes - ] - - int_area_relative_part = [ - int_area / part_area for int_area, part_area in zip(intersect_areas, part_areas) - ] - int_area_relative_body = [ - int_area / body_area for int_area, body_area in zip(intersect_areas, body_areas) - ] - - int_area_relative_part = preprocessing.scale(int_area_relative_part) - int_area_relative_body = preprocessing.scale(int_area_relative_body) - part_area_relative_body = preprocessing.scale(part_area_relative_body) - - result_list = list( - zip( - part_bboxes, - body_bboxes, - intersect_bboxes, - int_area_relative_part, - int_area_relative_body, - part_area_relative_body, - ) - ) - - for ( - part_bbox, - body_bbox, - intersect_bbox, - int_area_relative_part, - int_area_relative_body, - part_area_relative_body, - ) in result_list: - yield ( - part_bbox[0], - part_bbox[1], - part_bbox[2], - part_bbox[3], - body_bbox[0], - body_bbox[1], - body_bbox[2], - body_bbox[3], - intersect_bbox[0], - intersect_bbox[1], - intersect_bbox[2], - intersect_bbox[3], - int_area_relative_part, - int_area_relative_body, - part_area_relative_body, - ) - - -@derived_attribute( - tablename='theta_assignment_features', - parents=['annotations', 'annotations'], - colnames=[ - 'p_v1_x', - 'p_v1_y', - 'p_v2_x', - 'p_v2_y', - 'p_v3_x', - 'p_v3_y', - 'p_v4_x', - 'p_v4_y', - 'p_center_x', - 'p_center_y', - 'b_xtl', - 'b_ytl', - 'b_xbr', - 'b_ybr', - 'b_center_x', - 'b_center_y', - 'int_area_scalar', - 'part_body_distance', - 'part_body_centroid_dist', - 'int_over_union', - 'int_over_part', - 'int_over_body', - 'part_over_body', - ], - coltypes=[ - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - ], - configclass=PartAssignmentFeatureConfig, - fname='theta_assignment_features', - rm_extern_on_delete=True, - chunksize=256, # chunk size is huge bc we need accurate means and stdevs of various traits -) -def theta_assignment_features(depc, part_aid_list, body_aid_list, config=None): - - from shapely import geometry - import math - - ibs = depc.controller - - part_gids = ibs.get_annot_gids(part_aid_list) - body_gids = ibs.get_annot_gids(body_aid_list) - assert ( - part_gids == body_gids - ), 'can only compute assignment features on aids in the same image' - parts_are_parts = ibs._are_part_annots(part_aid_list) - assert all(parts_are_parts), 'all part_aids must be part annots.' - bodies_are_parts = ibs._are_part_annots(body_aid_list) - assert not any(bodies_are_parts), 'body_aids cannot be part annots' - - im_widths = ibs.get_image_widths(part_gids) - im_heights = ibs.get_image_heights(part_gids) - - part_verts = ibs.get_annot_rotated_verts(part_aid_list) - body_verts = ibs.get_annot_rotated_verts(body_aid_list) - part_verts = _norm_vertices(part_verts, im_widths, im_heights) - body_verts = _norm_vertices(body_verts, im_widths, im_heights) - part_polys = [geometry.Polygon(vert) for vert in part_verts] - body_polys = [geometry.Polygon(vert) for vert in body_verts] - intersect_polys = [ - part.intersection(body) for part, body in zip(part_polys, body_polys) - ] - intersect_areas = [poly.area for poly in intersect_polys] - # just to make int_areas more comparable via ML methods, and since all distances < 1 - int_area_scalars = [math.sqrt(area) for area in intersect_areas] - - part_bboxes = ibs.get_annot_bboxes(part_aid_list) - body_bboxes = ibs.get_annot_bboxes(body_aid_list) - part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) - body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) - part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] - body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] - union_areas = [ - part + body - intersect - for (part, body, intersect) in zip(part_areas, body_areas, intersect_areas) - ] - int_over_unions = [ - intersect / union for (intersect, union) in zip(intersect_areas, union_areas) - ] - - part_body_distances = [ - part.distance(body) for part, body in zip(part_polys, body_polys) - ] - - part_centroids = [poly.centroid for poly in part_polys] - body_centroids = [poly.centroid for poly in body_polys] - - part_body_centroid_dists = [ - part.distance(body) for part, body in zip(part_centroids, body_centroids) - ] - - int_over_parts = [ - int_area / part_area for part_area, int_area in zip(part_areas, intersect_areas) - ] - - int_over_bodys = [ - int_area / body_area for body_area, int_area in zip(body_areas, intersect_areas) - ] - - part_over_bodys = [ - part_area / body_area for part_area, body_area in zip(part_areas, body_areas) - ] - - # note that here only parts have thetas, hence only returning body bboxes - result_list = list( - zip( - part_verts, - part_centroids, - body_bboxes, - body_centroids, - int_area_scalars, - part_body_distances, - part_body_centroid_dists, - int_over_unions, - int_over_parts, - int_over_bodys, - part_over_bodys, - ) - ) - - for ( - part_vert, - part_center, - body_bbox, - body_center, - int_area_scalar, - part_body_distance, - part_body_centroid_dist, - int_over_union, - int_over_part, - int_over_body, - part_over_body, - ) in result_list: - yield ( - part_vert[0][0], - part_vert[0][1], - part_vert[1][0], - part_vert[1][1], - part_vert[2][0], - part_vert[2][1], - part_vert[3][0], - part_vert[3][1], - part_center.x, - part_center.y, - body_bbox[0], - body_bbox[1], - body_bbox[2], - body_bbox[3], - body_center.x, - body_center.y, - int_area_scalar, - part_body_distance, - part_body_centroid_dist, - int_over_union, - int_over_part, - int_over_body, - part_over_body, - ) - - # just like theta_assignement_features above but with a one-hot encoding of viewpoints # viewpoints are a boolean value for each viewpoint. will possibly need to modify this for other species @derived_attribute( @@ -3228,244 +2553,25 @@ def theta_assignment_features(depc, part_aid_list, body_aid_list, config=None): float, float, float, - bool, - bool, - bool, - bool, - bool, - bool, - bool, - bool, - bool, - bool, - bool, - bool, - ], - configclass=PartAssignmentFeatureConfig, - fname='assigner_viewpoint_features', - rm_extern_on_delete=True, - chunksize=256, # chunk size is huge bc we need accurate means and stdevs of various traits -) -def assigner_viewpoint_features(depc, part_aid_list, body_aid_list, config=None): - - from shapely import geometry - import math - - ibs = depc.controller - - part_gids = ibs.get_annot_gids(part_aid_list) - body_gids = ibs.get_annot_gids(body_aid_list) - assert ( - part_gids == body_gids - ), 'can only compute assignment features on aids in the same image' - parts_are_parts = ibs._are_part_annots(part_aid_list) - assert all(parts_are_parts), 'all part_aids must be part annots.' - bodies_are_parts = ibs._are_part_annots(body_aid_list) - assert not any(bodies_are_parts), 'body_aids cannot be part annots' - - im_widths = ibs.get_image_widths(part_gids) - im_heights = ibs.get_image_heights(part_gids) - - part_verts = ibs.get_annot_rotated_verts(part_aid_list) - body_verts = ibs.get_annot_rotated_verts(body_aid_list) - part_verts = _norm_vertices(part_verts, im_widths, im_heights) - body_verts = _norm_vertices(body_verts, im_widths, im_heights) - part_polys = [geometry.Polygon(vert) for vert in part_verts] - body_polys = [geometry.Polygon(vert) for vert in body_verts] - intersect_polys = [ - part.intersection(body) for part, body in zip(part_polys, body_polys) - ] - intersect_areas = [poly.area for poly in intersect_polys] - # just to make int_areas more comparable via ML methods, and since all distances < 1 - int_area_scalars = [math.sqrt(area) for area in intersect_areas] - - part_bboxes = ibs.get_annot_bboxes(part_aid_list) - body_bboxes = ibs.get_annot_bboxes(body_aid_list) - part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) - body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) - part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] - body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] - union_areas = [ - part + body - intersect - for (part, body, intersect) in zip(part_areas, body_areas, intersect_areas) - ] - int_over_unions = [ - intersect / union for (intersect, union) in zip(intersect_areas, union_areas) - ] - - part_body_distances = [ - part.distance(body) for part, body in zip(part_polys, body_polys) - ] - - part_centroids = [poly.centroid for poly in part_polys] - body_centroids = [poly.centroid for poly in body_polys] - - part_body_centroid_dists = [ - part.distance(body) for part, body in zip(part_centroids, body_centroids) - ] - - int_over_parts = [ - int_area / part_area for part_area, int_area in zip(part_areas, intersect_areas) - ] - - int_over_bodys = [ - int_area / body_area for body_area, int_area in zip(body_areas, intersect_areas) - ] - - part_over_bodys = [ - part_area / body_area for part_area, body_area in zip(part_areas, body_areas) - ] - - part_lrudfb_bools = get_annot_lrudfb_bools(ibs, part_aid_list) - body_lrudfb_bools = get_annot_lrudfb_bools(ibs, part_aid_list) - - # note that here only parts have thetas, hence only returning body bboxes - result_list = list( - zip( - part_verts, - part_centroids, - body_bboxes, - body_centroids, - int_area_scalars, - part_body_distances, - part_body_centroid_dists, - int_over_unions, - int_over_parts, - int_over_bodys, - part_over_bodys, - part_lrudfb_bools, - body_lrudfb_bools, - ) - ) - - for ( - part_vert, - part_center, - body_bbox, - body_center, - int_area_scalar, - part_body_distance, - part_body_centroid_dist, - int_over_union, - int_over_part, - int_over_body, - part_over_body, - part_lrudfb_bool, - body_lrudfb_bool, - ) in result_list: - ans = ( - part_vert[0][0], - part_vert[0][1], - part_vert[1][0], - part_vert[1][1], - part_vert[2][0], - part_vert[2][1], - part_vert[3][0], - part_vert[3][1], - part_center.x, - part_center.y, - body_bbox[0], - body_bbox[1], - body_bbox[2], - body_bbox[3], - body_center.x, - body_center.y, - int_area_scalar, - part_body_distance, - part_body_centroid_dist, - int_over_union, - int_over_part, - int_over_body, - part_over_body, - ) - ans += tuple(part_lrudfb_bool) - ans += tuple(body_lrudfb_bool) - yield ans - - -@derived_attribute( - tablename='assigner_viewpoint_unit_features', - parents=['annotations', 'annotations'], - colnames=[ - 'p_v1_x', - 'p_v1_y', - 'p_v2_x', - 'p_v2_y', - 'p_v3_x', - 'p_v3_y', - 'p_v4_x', - 'p_v4_y', - 'p_center_x', - 'p_center_y', - 'b_xtl', - 'b_ytl', - 'b_xbr', - 'b_ybr', - 'b_center_x', - 'b_center_y', - 'int_area_scalar', - 'part_body_distance', - 'part_body_centroid_dist', - 'int_over_union', - 'int_over_part', - 'int_over_body', - 'part_over_body', - 'part_is_left', - 'part_is_right', - 'part_is_up', - 'part_is_down', - 'part_is_front', - 'part_is_back', - 'body_is_left', - 'body_is_right', - 'body_is_up', - 'body_is_down', - 'body_is_front', - 'body_is_back', - ], - coltypes=[ - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, + bool, + bool, + bool, + bool, + bool, + bool, + bool, + bool, + bool, + bool, + bool, + bool, ], configclass=PartAssignmentFeatureConfig, - fname='assigner_viewpoint_unit_features', + fname='assigner_viewpoint_features', rm_extern_on_delete=True, chunksize=256, # chunk size is huge bc we need accurate means and stdevs of various traits ) -def assigner_viewpoint_unit_features(depc, part_aid_list, body_aid_list, config=None): +def assigner_viewpoint_features(depc, part_aid_list, body_aid_list, config=None): from shapely import geometry import math @@ -3535,8 +2641,8 @@ def assigner_viewpoint_unit_features(depc, part_aid_list, body_aid_list, config= part_area / body_area for part_area, body_area in zip(part_areas, body_areas) ] - part_lrudfb_vects = get_annot_lrudfb_unit_vector(ibs, part_aid_list) - body_lrudfb_vects = get_annot_lrudfb_unit_vector(ibs, part_aid_list) + part_lrudfb_bools = get_annot_lrudfb_bools(ibs, part_aid_list) + body_lrudfb_bools = get_annot_lrudfb_bools(ibs, part_aid_list) # note that here only parts have thetas, hence only returning body bboxes result_list = list( @@ -3552,8 +2658,8 @@ def assigner_viewpoint_unit_features(depc, part_aid_list, body_aid_list, config= int_over_parts, int_over_bodys, part_over_bodys, - part_lrudfb_vects, - body_lrudfb_vects, + part_lrudfb_bools, + body_lrudfb_bools, ) ) @@ -3569,8 +2675,8 @@ def assigner_viewpoint_unit_features(depc, part_aid_list, body_aid_list, config= int_over_part, int_over_body, part_over_body, - part_lrudfb_vect, - body_lrudfb_vect, + part_lrudfb_bool, + body_lrudfb_bool, ) in result_list: ans = ( part_vert[0][0], @@ -3597,8 +2703,8 @@ def assigner_viewpoint_unit_features(depc, part_aid_list, body_aid_list, config= int_over_body, part_over_body, ) - ans += tuple(part_lrudfb_vect) - ans += tuple(body_lrudfb_vect) + ans += tuple(part_lrudfb_bool) + ans += tuple(body_lrudfb_bool) yield ans @@ -3619,216 +2725,6 @@ def get_annot_lrudfb_bools(ibs, aid_list): return bool_arrays -def get_annot_lrudfb_unit_vector(ibs, aid_list): - bool_arrays = get_annot_lrudfb_bools(ibs, aid_list) - float_arrays = [[float(b) for b in lrudfb] for lrudfb in bool_arrays] - lrudfb_lengths = [sqrt(lrudfb.count(True)) for lrudfb in bool_arrays] - # lying just to avoid division by zero errors - lrudfb_lengths = [length if length != 0 else -1 for length in lrudfb_lengths] - unit_float_array = [ - [f / length for f in lrudfb] - for lrudfb, length in zip(float_arrays, lrudfb_lengths) - ] - - return unit_float_array - - -@derived_attribute( - tablename='theta_standardized_assignment_features', - parents=['annotations', 'annotations'], - colnames=[ - 'p_v1_x', - 'p_v1_y', - 'p_v2_x', - 'p_v2_y', - 'p_v3_x', - 'p_v3_y', - 'p_v4_x', - 'p_v4_y', - 'p_center_x', - 'p_center_y', - 'b_xtl', - 'b_ytl', - 'b_xbr', - 'b_ybr', - 'b_center_x', - 'b_center_y', - 'int_area_scalar', - 'part_body_distance', - 'part_body_centroid_dist', - 'int_over_union', - 'int_over_part', - 'int_over_body', - 'part_over_body', - ], - coltypes=[ - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - float, - ], - configclass=PartAssignmentFeatureConfig, - fname='theta_standardized_assignment_features', - rm_extern_on_delete=True, - chunksize=2560000, # chunk size is huge bc we need accurate means and stdevs of various traits -) -def theta_standardized_assignment_features( - depc, part_aid_list, body_aid_list, config=None -): - - from shapely import geometry - import math - - ibs = depc.controller - - part_gids = ibs.get_annot_gids(part_aid_list) - body_gids = ibs.get_annot_gids(body_aid_list) - assert ( - part_gids == body_gids - ), 'can only compute assignment features on aids in the same image' - parts_are_parts = ibs._are_part_annots(part_aid_list) - assert all(parts_are_parts), 'all part_aids must be part annots.' - bodies_are_parts = ibs._are_part_annots(body_aid_list) - assert not any(bodies_are_parts), 'body_aids cannot be part annots' - - im_widths = ibs.get_image_widths(part_gids) - im_heights = ibs.get_image_heights(part_gids) - - part_verts = ibs.get_annot_rotated_verts(part_aid_list) - body_verts = ibs.get_annot_rotated_verts(body_aid_list) - part_verts = _norm_vertices(part_verts, im_widths, im_heights) - body_verts = _norm_vertices(body_verts, im_widths, im_heights) - part_polys = [geometry.Polygon(vert) for vert in part_verts] - body_polys = [geometry.Polygon(vert) for vert in body_verts] - intersect_polys = [ - part.intersection(body) for part, body in zip(part_polys, body_polys) - ] - intersect_areas = [poly.area for poly in intersect_polys] - # just to make int_areas more comparable via ML methods, and since all distances < 1 - int_area_scalars = [math.sqrt(area) for area in intersect_areas] - int_area_scalars = preprocessing.scale(int_area_scalars) - - part_bboxes = ibs.get_annot_bboxes(part_aid_list) - body_bboxes = ibs.get_annot_bboxes(body_aid_list) - part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) - body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) - part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] - body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] - union_areas = [ - part + body - intersect - for (part, body, intersect) in zip(part_areas, body_areas, intersect_areas) - ] - int_over_unions = [ - intersect / union for (intersect, union) in zip(intersect_areas, union_areas) - ] - int_over_unions = preprocessing.scale(int_over_unions) - - part_body_distances = [ - part.distance(body) for part, body in zip(part_polys, body_polys) - ] - part_body_distances = preprocessing.scale(part_body_distances) - - part_centroids = [poly.centroid for poly in part_polys] - body_centroids = [poly.centroid for poly in body_polys] - - part_body_centroid_dists = [ - part.distance(body) for part, body in zip(part_centroids, body_centroids) - ] - part_body_centroid_dists = preprocessing.scale(part_body_centroid_dists) - - int_over_parts = [ - int_area / part_area for part_area, int_area in zip(part_areas, intersect_areas) - ] - int_over_parts = preprocessing.scale(int_over_parts) - - int_over_bodys = [ - int_area / body_area for body_area, int_area in zip(body_areas, intersect_areas) - ] - int_over_bodys = preprocessing.scale(int_over_bodys) - - part_over_bodys = [ - part_area / body_area for part_area, body_area in zip(part_areas, body_areas) - ] - part_over_bodys = preprocessing.scale(part_over_bodys) - - # standardization - - # note that here only parts have thetas, hence only returning body bboxes - result_list = list( - zip( - part_verts, - part_centroids, - body_bboxes, - body_centroids, - int_area_scalars, - part_body_distances, - part_body_centroid_dists, - int_over_unions, - int_over_parts, - int_over_bodys, - part_over_bodys, - ) - ) - - for ( - part_vert, - part_center, - body_bbox, - body_center, - int_area_scalar, - part_body_distance, - part_body_centroid_dist, - int_over_union, - int_over_part, - int_over_body, - part_over_body, - ) in result_list: - yield ( - part_vert[0][0], - part_vert[0][1], - part_vert[1][0], - part_vert[1][1], - part_vert[2][0], - part_vert[2][1], - part_vert[3][0], - part_vert[3][1], - part_center.x, - part_center.y, - body_bbox[0], - body_bbox[1], - body_bbox[2], - body_bbox[3], - body_center.x, - body_center.y, - int_area_scalar, - part_body_distance, - part_body_centroid_dist, - int_over_union, - int_over_part, - int_over_body, - part_over_body, - ) - - def _norm_bboxes(bbox_list, width_list, height_list): normed_boxes = [ (bbox[0] / w, bbox[1] / h, bbox[2] / w, bbox[3] / h) @@ -3843,92 +2739,3 @@ def _norm_vertices(verts_list, width_list, height_list): for vert, w, h in zip(verts_list, width_list, height_list) ] return normed_verts - - -# does this even make any sense? let's find out experimentally -def _standardized_bboxes(bbox_list): - xtls = preprocessing.scale([bbox[0] for bbox in bbox_list]) - ytls = preprocessing.scale([bbox[1] for bbox in bbox_list]) - wids = preprocessing.scale([bbox[2] for bbox in bbox_list]) - heis = preprocessing.scale([bbox[3] for bbox in bbox_list]) - standardized_bboxes = list(zip(xtls, ytls, wids, heis)) - return standardized_bboxes - - -def _bbox_intersections(bboxes_a, bboxes_b): - corner_bboxes_a = _bbox_to_corner_format(bboxes_a) - corner_bboxes_b = _bbox_to_corner_format(bboxes_b) - - intersect_xtls = [ - max(xtl_a, xtl_b) - for ((xtl_a, _, _, _), (xtl_b, _, _, _)) in zip(corner_bboxes_a, corner_bboxes_b) - ] - - intersect_ytls = [ - max(ytl_a, ytl_b) - for ((_, ytl_a, _, _), (_, ytl_b, _, _)) in zip(corner_bboxes_a, corner_bboxes_b) - ] - - intersect_xbrs = [ - min(xbr_a, xbr_b) - for ((_, _, xbr_a, _), (_, _, xbr_b, _)) in zip(corner_bboxes_a, corner_bboxes_b) - ] - - intersect_ybrs = [ - min(ybr_a, ybr_b) - for ((_, _, _, ybr_a), (_, _, _, ybr_b)) in zip(corner_bboxes_a, corner_bboxes_b) - ] - - intersect_widths = [ - int_xbr - int_xtl for int_xbr, int_xtl in zip(intersect_xbrs, intersect_xtls) - ] - - intersect_heights = [ - int_ybr - int_ytl for int_ybr, int_ytl in zip(intersect_ybrs, intersect_ytls) - ] - - intersect_bboxes = list( - zip(intersect_xtls, intersect_ytls, intersect_widths, intersect_heights) - ) - - return intersect_bboxes - - -def _theta_aware_intersect_areas(verts_list_a, verts_list_b): - import shapely - - polys_a = [shapely.geometry.Polygon(vert) for vert in verts_list_a] - polys_b = [shapely.geometry.Polygon(vert) for vert in verts_list_b] - intersect_areas = [ - poly1.intersection(poly2).area for poly1, poly2 in zip(polys_a, polys_b) - ] - return intersect_areas - - -def _all_centroids(verts_list_a, verts_list_b): - import shapely - - polys_a = [shapely.geometry.Polygon(vert) for vert in verts_list_a] - polys_b = [shapely.geometry.Polygon(vert) for vert in verts_list_b] - intersect_polys = [ - poly1.intersection(poly2) for poly1, poly2 in zip(polys_a, polys_b) - ] - - centroids_a = [poly.centroid for poly in polys_a] - centroids_b = [poly.centroid for poly in polys_b] - centroids_int = [poly.centroid for poly in intersect_polys] - - return centroids_a, centroids_b, centroids_int - - -def _polygons_to_centroid_coords(polygon_list): - centroids = [poly.centroid for poly in polygon_list] - return centroids - - -# converts bboxes from (xtl, ytl, w, h) to (xtl, ytl, xbr, ybr) -def _bbox_to_corner_format(bboxes): - corner_bboxes = [ - (bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]) for bbox in bboxes - ] - return corner_bboxes From 992eeb57a7db503f5b93bc03b413af393292308f Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Tue, 5 Jan 2021 15:08:33 -0800 Subject: [PATCH 154/294] Linted and all tests passing --- wbia/algo/detect/assigner.py | 7 +++++-- wbia/algo/detect/train_assigner.py | 1 + wbia/web/apis_detect.py | 1 + 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/wbia/algo/detect/assigner.py b/wbia/algo/detect/assigner.py index 5025f138c4..bfc0fb6aab 100644 --- a/wbia/algo/detect/assigner.py +++ b/wbia/algo/detect/assigner.py @@ -78,6 +78,7 @@ def _are_part_annots(ibs, aid_list): >>> ibs = assigner_testdb_ibs() >>> aids = ibs.get_valid_aids() >>> result = ibs._are_part_annots(aids) + >>> print(result) [False, False, True, True, False, True, False, True] """ species = ibs.get_annot_species(aid_list) @@ -112,6 +113,7 @@ def all_part_pairs(ibs, gid_list): >>> assert (set(parts) & set(bodies)) == set({}) >>> assert (set(parts) | set(bodies)) == set(all_aids) >>> result = all_part_pairs + >>> print(result) ([3, 3, 4, 4, 6, 8], [1, 2, 1, 2, 5, 7]) """ all_aids = ibs.get_image_aids(gid_list) @@ -239,8 +241,9 @@ def assign_parts_one_image(ibs, aid_list, cutoff_score=0.5): all_species = ibs.get_annot_species(aid_list) # put unsupported species into the all_unassigned_aids list all_species_no_parts = [species.split('+')[0] for species in all_species] - assign_flag_list = [species in SPECIES_CONFIG_MAP.keys() - for species in all_species_no_parts] + assign_flag_list = [ + species in SPECIES_CONFIG_MAP.keys() for species in all_species_no_parts + ] unassigned_aids_noconfig = ut.filterfalse_items(aid_list, assign_flag_list) aid_list = ut.compress(aid_list, assign_flag_list) diff --git a/wbia/algo/detect/train_assigner.py b/wbia/algo/detect/train_assigner.py index c5f511d9fd..f604238d45 100644 --- a/wbia/algo/detect/train_assigner.py +++ b/wbia/algo/detect/train_assigner.py @@ -409,6 +409,7 @@ def gid_train_test_split(ibs, aid_list, random_seed=777, test_size=0.1): >>> assert len(train_gids) is 2 >>> assert len(test_gids) is 1 >>> result = aid_in_train # note one gid has 4 aids, the other two 2 + >>> print(result) [False, False, False, False, True, True, True, True] """ print('calling gid_train_test_split') diff --git a/wbia/web/apis_detect.py b/wbia/web/apis_detect.py index fcd84f54b3..d31abd5297 100644 --- a/wbia/web/apis_detect.py +++ b/wbia/web/apis_detect.py @@ -640,6 +640,7 @@ def models_cnn_lightnet(ibs, **kwargs): Method: PUT, GET URL: /api/labels/cnn/lightnet/ """ + def identity(x): return x From 585ead05a0657c5193bed438056911b84fd21ff8 Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Tue, 5 Jan 2021 16:13:12 -0800 Subject: [PATCH 155/294] Formatted and linted --- wbia/algo/detect/train_assigner.py | 19 ++++++++++++++++--- wbia/core_annots.py | 4 ---- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/wbia/algo/detect/train_assigner.py b/wbia/algo/detect/train_assigner.py index f604238d45..a76f2fc1a4 100644 --- a/wbia/algo/detect/train_assigner.py +++ b/wbia/algo/detect/train_assigner.py @@ -3,7 +3,7 @@ # from os.path import expanduser, join # from wbia import constants as const from wbia.control.controller_inject import ( - # register_preprocs, + register_preprocs, # register_subprops, make_ibs_register_decorator, ) @@ -17,6 +17,7 @@ import utool as ut import numpy as np +from wbia import dtool import random # import os @@ -26,6 +27,10 @@ from datetime import datetime import time +from math import sqrt + +from sklearn import preprocessing + # illustration imports # from shutil import copy # from PIL import Image, ImageDraw @@ -48,6 +53,9 @@ from sklearn.model_selection import GridSearchCV +derived_attribute = register_preprocs['annot'] + + CLASS_INJECT_KEY, register_ibs_method = make_ibs_register_decorator(__name__) @@ -530,6 +538,7 @@ def check_accuracy(ibs, assigner_data=None, cutoff_score=0.5, illustrate=False): class PartAssignmentFeatureConfig(dtool.Config): _param_info_list = [] + @derived_attribute( tablename='part_assignment_features', parents=['annotations', 'annotations'], @@ -1019,6 +1028,7 @@ def mega_standardized_assignment_features( part_area_relative_body, ) + @derived_attribute( tablename='theta_assignment_features', parents=['annotations', 'annotations'], @@ -1422,6 +1432,7 @@ def assigner_viewpoint_unit_features(depc, part_aid_list, body_aid_list, config= ans += tuple(body_lrudfb_vect) yield ans + @derived_attribute( tablename='theta_standardized_assignment_features', parents=['annotations', 'annotations'], @@ -1617,7 +1628,10 @@ def theta_standardized_assignment_features( part_over_body, ) + def get_annot_lrudfb_unit_vector(ibs, aid_list): + from wbia.core_annots import get_annot_lrudfb_bools + bool_arrays = get_annot_lrudfb_bools(ibs, aid_list) float_arrays = [[float(b) for b in lrudfb] for lrudfb in bool_arrays] lrudfb_lengths = [sqrt(lrudfb.count(True)) for lrudfb in bool_arrays] @@ -1630,6 +1644,7 @@ def get_annot_lrudfb_unit_vector(ibs, aid_list): return unit_float_array + def _norm_bboxes(bbox_list, width_list, height_list): normed_boxes = [ (bbox[0] / w, bbox[1] / h, bbox[2] / w, bbox[3] / h) @@ -1733,5 +1748,3 @@ def _bbox_to_corner_format(bboxes): def _polygons_to_centroid_coords(polygon_list): centroids = [poly.centroid for poly in polygon_list] return centroids - - diff --git a/wbia/core_annots.py b/wbia/core_annots.py index 9f966f2431..4802fbd1d7 100644 --- a/wbia/core_annots.py +++ b/wbia/core_annots.py @@ -63,10 +63,6 @@ from wbia.algo.hots.chip_match import ChipMatch from wbia.algo.hots import neighbor_index -from math import sqrt - -from sklearn import preprocessing - (print, rrr, profile) = ut.inject2(__name__) logger = logging.getLogger('wbia') From a461561dd27aad695b2e8c15d3285846be044cc1 Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Tue, 5 Jan 2021 16:18:45 -0800 Subject: [PATCH 156/294] Small updates --- wbia/core_annots.py | 6 ++++++ wbia/web/apis_sync.py | 4 +++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/wbia/core_annots.py b/wbia/core_annots.py index 4802fbd1d7..99af209729 100644 --- a/wbia/core_annots.py +++ b/wbia/core_annots.py @@ -2735,3 +2735,9 @@ def _norm_vertices(verts_list, width_list, height_list): for vert, w, h in zip(verts_list, width_list, height_list) ] return normed_verts + + +if __name__ == '__main__': + import xdoctest as xdoc + + xdoc.doctest_module(__file__) diff --git a/wbia/web/apis_sync.py b/wbia/web/apis_sync.py index 7b97ca55f7..04f8139653 100644 --- a/wbia/web/apis_sync.py +++ b/wbia/web/apis_sync.py @@ -366,7 +366,9 @@ def sync_get_training_data(ibs, species_name, force_update=False, **kwargs): name_texts = ibs._sync_get_annot_endpoint('/api/annot/name/text/', aid_list) name_uuids = ibs._sync_get_annot_endpoint('/api/annot/name/uuid/', aid_list) images = ibs._sync_get_annot_endpoint('/api/annot/image/rowid/', aid_list) - gpaths = [ibs._construct_route_url_ibs('/api/image/src/%s.jpg' % gid) for gid in images] + gpaths = [ + ibs._construct_route_url_ibs('/api/image/src/%s.jpg' % gid) for gid in images + ] specieses = [species_name] * len(aid_list) gid_list = ibs.add_images(gpaths) From d3012f11294052beb74b5e4dc531dac263ac6c8f Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Tue, 5 Jan 2021 16:20:31 -0800 Subject: [PATCH 157/294] Remove make_ibs_register_decorator from core_annots --- wbia/core_annots.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/wbia/core_annots.py b/wbia/core_annots.py index 99af209729..43c7a48d3d 100644 --- a/wbia/core_annots.py +++ b/wbia/core_annots.py @@ -55,18 +55,13 @@ import numpy as np import cv2 import wbia.constants as const -from wbia.control.controller_inject import ( - register_preprocs, - register_subprops, - make_ibs_register_decorator, -) +from wbia.control.controller_inject import register_preprocs, register_subprops from wbia.algo.hots.chip_match import ChipMatch from wbia.algo.hots import neighbor_index (print, rrr, profile) = ut.inject2(__name__) logger = logging.getLogger('wbia') -CLASS_INJECT_KEY, register_ibs_method = make_ibs_register_decorator(__name__) derived_attribute = register_preprocs['annot'] register_subprop = register_subprops['annot'] From 60c15b154459345ebdb89224fab417ccf0736438 Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Tue, 5 Jan 2021 16:22:53 -0800 Subject: [PATCH 158/294] Remove embed --- wbia/dtool/depcache_control.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/wbia/dtool/depcache_control.py b/wbia/dtool/depcache_control.py index f5f6e43be9..3a3a8e25e5 100644 --- a/wbia/dtool/depcache_control.py +++ b/wbia/dtool/depcache_control.py @@ -470,8 +470,6 @@ def rectify_input_tuple(self, exi_inputs, input_tuple): input_tuple_ = (input_tuple_,) if len(exi_inputs) != len(input_tuple_): msg = '#expected=%d, #got=%d' % (len(exi_inputs), len(input_tuple_)) - print(msg) - ut.embed() raise ValueError(msg) # rectify input depth From 1cae0cebcd28ea2570be34f26c4bf05a4870fc03 Mon Sep 17 00:00:00 2001 From: Drew Blount Date: Thu, 7 Jan 2021 11:46:52 -0800 Subject: [PATCH 159/294] fixes case where there are no part/body pairs in an image --- wbia/algo/detect/assigner.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/wbia/algo/detect/assigner.py b/wbia/algo/detect/assigner.py index bfc0fb6aab..2b3fced4cc 100644 --- a/wbia/algo/detect/assigner.py +++ b/wbia/algo/detect/assigner.py @@ -254,24 +254,29 @@ def assign_parts_one_image(ibs, aid_list, cutoff_score=0.5): gids = ibs.get_annot_gids(list(set(part_aids)) + list(set(body_aids))) num_images = len(set(gids)) - assert num_images == 1, "assign_parts_one_image called on multiple images' aids" + assert num_images <= 1, "assign_parts_one_image called on multiple images' aids" # parallel lists representing all possible part/body pairs all_pairs_parallel = _all_pairs_parallel(part_aids, body_aids) pair_parts, pair_bodies = all_pairs_parallel - assigner_features = ibs.depc_annot.get( - 'assigner_viewpoint_features', all_pairs_parallel - ) - # send all aids to this call just so it can find the right classifier model - assigner_classifier = load_assigner_classifier(ibs, body_aids + part_aids) - - assigner_scores = assigner_classifier.predict_proba(assigner_features) - # assigner_scores is a list of [P_false, P_true] probabilities which sum to 1, so here we just pare down to the true probabilities - assigner_scores = [score[1] for score in assigner_scores] - good_pairs, unassigned_aids = _make_assignments( - pair_parts, pair_bodies, assigner_scores, cutoff_score - ) + if len(pair_parts) > 0 and len(pair_bodies) > 0: + assigner_features = ibs.depc_annot.get( + 'assigner_viewpoint_features', all_pairs_parallel + ) + # send all aids to this call just so it can find the right classifier model + assigner_classifier = load_assigner_classifier(ibs, body_aids + part_aids) + + assigner_scores = assigner_classifier.predict_proba(assigner_features) + # assigner_scores is a list of [P_false, P_true] probabilities which sum to 1, so here we just pare down to the true probabilities + assigner_scores = [score[1] for score in assigner_scores] + good_pairs, unassigned_aids = _make_assignments( + pair_parts, pair_bodies, assigner_scores, cutoff_score + ) + else: + good_pairs = [] + unassigned_aids = aid_list + unassigned_aids = unassigned_aids_noconfig + unassigned_aids return good_pairs, unassigned_aids From a820c52469ca80acfbf809f84182d879105c3dfb Mon Sep 17 00:00:00 2001 From: Drew Blount Date: Thu, 7 Jan 2021 16:25:28 -0800 Subject: [PATCH 160/294] fixes None view bug that was killing detection jobs --- wbia/core_annots.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wbia/core_annots.py b/wbia/core_annots.py index 43c7a48d3d..8c30143736 100644 --- a/wbia/core_annots.py +++ b/wbia/core_annots.py @@ -2711,7 +2711,7 @@ def get_annot_lrudfb_bools(ibs, aid_list): 'front' in view, 'back' in view, ] - for view in views + if view is not None else [False] * 6 for view in views ] return bool_arrays From ab4098fa1a9d09469727d763895243f6fb309b96 Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Thu, 7 Jan 2021 21:08:54 -0800 Subject: [PATCH 161/294] Add render status log --- wbia/web/apis_query.py | 62 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/wbia/web/apis_query.py b/wbia/web/apis_query.py index d448d3dd9d..e1bb64ea80 100644 --- a/wbia/web/apis_query.py +++ b/wbia/web/apis_query.py @@ -762,6 +762,24 @@ def query_chips_graph_complete(ibs, aid_list, query_config_dict={}, k=5, **kwarg return result_dict +@register_ibs_method +def log_render_status(ibs, *args): + import os + + json_log_path = ibs.get_logdir_local() + json_log_filename = 'render.log' + json_log_filepath = os.path.join(json_log_path, json_log_filename) + logger.info('Logging renders added to: %r' % (json_log_filepath,)) + + try: + with open(json_log_filepath, 'a') as json_log_file: + line = ','.join(['%s' % (arg,) for arg in args]) + line = '%s\n' % (line,) + json_log_file.write(line) + except Exception: + logger.info('WRITE RENDER.LOG FAILED') + + @register_ibs_method @register_api('/api/query/graph/', methods=['GET', 'POST']) def query_chips_graph( @@ -772,7 +790,7 @@ def query_chips_graph( query_config_dict={}, echo_query_params=True, cache_images=True, - n=16, + n=30, view_orientation='horizontal', return_summary=True, **kwargs, @@ -963,6 +981,20 @@ def convert_to_uuid(nid): except Exception: filepath_matches = None extern_flag = 'error' + log_render_status( + ibs, + cm.qaid, + daid, + quuid, + duuid, + cm, + qreq_, + view_orientation, + True, + False, + filepath_matches, + extern_flag, + ) try: _, filepath_heatmask = ensure_review_image( ibs, @@ -976,6 +1008,20 @@ def convert_to_uuid(nid): except Exception: filepath_heatmask = None extern_flag = 'error' + log_render_status( + ibs, + cm.qaid, + daid, + quuid, + duuid, + cm, + qreq_, + view_orientation, + False, + True, + filepath_heatmask, + extern_flag, + ) try: _, filepath_clean = ensure_review_image( ibs, @@ -989,6 +1035,20 @@ def convert_to_uuid(nid): except Exception: filepath_clean = None extern_flag = 'error' + log_render_status( + ibs, + cm.qaid, + daid, + quuid, + duuid, + cm, + qreq_, + view_orientation, + False, + False, + filepath_clean, + extern_flag, + ) if filepath_matches is not None: args = ( From aef2b1427ce581986ff7ef234db9726fd2a5dcc4 Mon Sep 17 00:00:00 2001 From: karen chan Date: Wed, 6 Jan 2021 19:56:26 +0000 Subject: [PATCH 162/294] Fix db.get and db.get_where_eq when id is not integer The sql statement itself works, but the python code to extract the results wasn't working correctly. It uses a dict to look up the result for each id, but when the id is a string but the column is actually an int, the code fails to look up the result. Use the column type processors to transform the id to the same as what's returned in the sql query to make the look up work. The `image_src_api` test used to return 200 OK but with: ``` { "status": { "success": false, "code": 400, "message": "Route error, Python Exception thrown: 'image path should not be None'", "cache": -1 }, "response": "Traceback (most recent call last):\\n File \\"/wbia/wildbook-ia/wbia/control/controller_inject.py\\", line 1217, in translated_call\\n result = func(**kwargs)\\n File \\"/wbia/wildbook-ia/wbia/web/apis.py\\", line 108, in image_src_api\\n assert gpath is not None, 'image path should not be None'\\nAssertionError: image path should not be None\\n" } ``` Change the test so it actually asserts the content. --- wbia/dtool/sql_control.py | 50 ++++++++++++++++++++++++++++++++++++++- wbia/web/apis.py | 2 +- 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 2bb0037f3e..912466f544 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -5,6 +5,7 @@ TODO; need to use some sort of sticky bit so sql files are created with reasonable permissions. """ +import functools import logging import collections import os @@ -24,7 +25,7 @@ from wbia.dtool import lite from wbia.dtool.dump import dumps -from wbia.dtool.types import Integer +from wbia.dtool.types import Integer, TYPE_TO_SQLTYPE print, rrr, profile = ut.inject2(__name__) @@ -1042,6 +1043,34 @@ def get_where_eq( existing.append(values) results = [] + processors = [] + for c in tuple(where_colnames): + + def process(column, a): + processor = column.type.bind_processor(self._engine.dialect) + if processor: + a = processor(a) + result_processor = column.type.result_processor( + self._engine.dialect, str(column.type) + ) + if result_processor: + return result_processor(a) + return a + + processors.append(functools.partial(process, table.c[c])) + + if params_iter: + first_params = params_iter[0] + if any( + not isinstance(a, bool) + and TYPE_TO_SQLTYPE.get(type(a)) != str(table.c[c].type) + for a, c in zip(first_params, where_colnames) + ): + params_iter = ( + (processor(raw_id) for raw_id, processor in zip(id_, processors)) + for id_ in params_iter + ) + for id_ in params_iter: result = sorted(list(result_map.get(tuple(id_), set()))) if unpack_scalars and isinstance(result, list): @@ -1303,6 +1332,25 @@ def get( existing.append(values) results = [] + + def process(a): + processor = id_column.type.bind_processor(self._engine.dialect) + if processor: + a = processor(a) + result_processor = id_column.type.result_processor( + self._engine.dialect, str(id_column.type) + ) + if result_processor: + return result_processor(a) + return a + + if id_iter: + first_id = id_iter[0] + if isinstance(first_id, bool) or TYPE_TO_SQLTYPE.get( + type(first_id) + ) != str(id_column.type): + id_iter = (process(id_) for id_ in id_iter) + for id_ in id_iter: result = sorted(list(result_map.get(id_, set()))) if kwargs.get('unpack_scalars', True) and isinstance(result, list): diff --git a/wbia/web/apis.py b/wbia/web/apis.py index f9c5e8408f..dc337b35ac 100644 --- a/wbia/web/apis.py +++ b/wbia/web/apis.py @@ -78,12 +78,12 @@ def image_src_api(rowid=None, thumbnail=False, fresh=False, **kwargs): Returns the image file of image Example: - >>> # xdoctest: +REQUIRES(--web-tests) >>> from wbia.web.app import * # NOQA >>> import wbia >>> with wbia.opendb_bg_web('testdb1', start_job_queue=False, managed=True) as web_ibs: ... resp = web_ibs.send_wbia_request('/api/image/src/1/', type_='get', json=False) >>> print(resp) + b'\xff\xd8\xff\xe0\x00\x10JFIF... RESTful: Method: GET From a133a95faac9a350f7f17993a26a5cf2a3dd7e05 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Wed, 30 Dec 2020 00:24:08 -0800 Subject: [PATCH 163/294] Enable the web-tests on CI test runs --- .github/workflows/testing.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index e198616c1d..f534f32bf5 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -70,7 +70,7 @@ jobs: run: | mkdir -p data/work python -m wbia --set-workdir data/work --preload-exit - pytest --slow + pytest --slow --web-tests on-failure: # This is not in the 'test' job itself because it would otherwise notify once per matrix combination. From d8d0be733f2b4f3386d6ce519e6610de8468dfc2 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Tue, 29 Dec 2020 23:16:09 -0800 Subject: [PATCH 164/294] Use the Werkzeug Client for functional testing Instead of starting the web application in the background, use the flask test client (i.e. Werkzeug Client). This means we aren't creating a process that is creating child processes. --- wbia/__init__.py | 2 +- wbia/entry_points.py | 79 +++++++++++++++++++++++++++++++------------- 2 files changed, 57 insertions(+), 24 deletions(-) diff --git a/wbia/__init__.py b/wbia/__init__.py index d2fb0f1715..e6945edd7c 100644 --- a/wbia/__init__.py +++ b/wbia/__init__.py @@ -68,7 +68,7 @@ main_loop, opendb, opendb_in_background, - opendb_bg_web, + opendb_with_web, ) from wbia.control.IBEISControl import IBEISController from wbia.algo.hots.query_request import QueryRequest diff --git a/wbia/entry_points.py b/wbia/entry_points.py index 3a2fd648a0..32e21b1fcd 100644 --- a/wbia/entry_points.py +++ b/wbia/entry_points.py @@ -331,6 +331,62 @@ def opendb_in_background(*args, **kwargs): return proc +@contextmanager +def opendb_with_web(*args, with_job_engine=False, **kwargs): + """Opens the database and starts the web server. + + Returns: + ibs, client - IBEISController and Werkzeug Client + + Example: + >>> from wbia.entry_points import opendb_with_web + >>> expected_response_data = {'status': {'success': True, 'code': 200, 'message': '', 'cache': -1}, 'response': True} + >>> with opendb_with_web('testdb1') as (ibs, client): + ... response = client.get('/api/test/heartbeat/') + ... assert expected_response_data == response.json + + """ + from wbia.control.controller_inject import get_flask_app + + # Create the controller instance + ibs = opendb(*args, **kwargs) + if with_job_engine: + # TODO start jobs engine + pass + + # Create the web application + app = get_flask_app() + # ??? Gotta attach the controller to the application? + setattr(app, 'ibs', ibs) + + # Return the controller and client instances to the caller + with app.test_client() as client: + yield ibs, client + + +def opendb_fg_web(*args, **kwargs): + """ + Ignore: + >>> from wbia.entry_points import * # NOQA + >>> kwargs = {'db': 'testdb1'} + >>> args = tuple() + + >>> import wbia + >>> ibs = wbia.opendb_fg_web() + + """ + # Gives you context inside the web app for testing + kwargs['start_web_loop'] = False + kwargs['web'] = True + kwargs['browser'] = False + ibs = opendb(*args, **kwargs) + from wbia.control import controller_inject + + app = controller_inject.get_flask_app() + ibs.app = app + return ibs + + def opendb_bg_web(*args, managed=False, **kwargs): """ Wrapper around opendb_in_background, returns a nice web_ibs @@ -494,29 +550,6 @@ def managed_server(): return web_ibs -def opendb_fg_web(*args, **kwargs): - """ - Ignore: - >>> from wbia.entry_points import * # NOQA - >>> kwargs = {'db': 'testdb1'} - >>> args = tuple() - - >>> import wbia - >>> ibs = wbia.opendb_fg_web() - - """ - # Gives you context inside the web app for testing - kwargs['start_web_loop'] = False - kwargs['web'] = True - kwargs['browser'] = False - ibs = opendb(*args, **kwargs) - from wbia.control import controller_inject - - app = controller_inject.get_flask_app() - ibs.app = app - return ibs - - def opendb( db=None, dbdir=None, From 5221708fd368e0de7f26103136b5ca159bbf284e Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Tue, 29 Dec 2020 23:39:38 -0800 Subject: [PATCH 165/294] Fix translate_wbia_webcall test --- wbia/control/controller_inject.py | 31 ++++++++----------------------- 1 file changed, 8 insertions(+), 23 deletions(-) diff --git a/wbia/control/controller_inject.py b/wbia/control/controller_inject.py index 12c7ae2ca6..58cb7539a8 100644 --- a/wbia/control/controller_inject.py +++ b/wbia/control/controller_inject.py @@ -489,35 +489,20 @@ def translate_wbia_webcall(func, *args, **kwargs): Returns: tuple: (output, True, 200, None, jQuery_callback) - CommandLine: - python -m wbia.control.controller_inject --exec-translate_wbia_webcall - python -m wbia.control.controller_inject --exec-translate_wbia_webcall --domain http://52.33.105.88 - Example: >>> # xdoctest: +REQUIRES(--web-tests) >>> from wbia.control.controller_inject import * # NOQA >>> import wbia - >>> import time - >>> import wbia.web - >>> with wbia.opendb_bg_web('testdb1', start_job_queue=False, managed=True) as web_ibs: - ... aids = web_ibs.send_wbia_request('/api/annot/', 'get') - ... uuid_list = web_ibs.send_wbia_request('/api/annot/uuids/', aid_list=aids, json=False) - ... failrsp = web_ibs.send_wbia_request('/api/annot/uuids/', json=False) - ... failrsp2 = web_ibs.send_wbia_request('/api/query/chips/simple_dict//', 'get', qaid_list=[0], daid_list=[0], json=False) - ... log_text = web_ibs.send_wbia_request('/api/query/chips/simple_dict/', 'get', qaid_list=[0], daid_list=[0], json=False) - >>> time.sleep(.1) - >>> print('\n---\nuuid_list = %r' % (uuid_list,)) - >>> print('\n---\nfailrsp =\n%s' % (failrsp,)) - >>> print('\n---\nfailrsp2 =\n%s' % (failrsp2,)) + >>> with wbia.opendb_with_web('testdb1') as (ibs, client): + ... aids = client.get('/api/annot/').json + ... failrsp = client.post('/api/annot/uuids/') + ... failrsp2 = client.get('/api/query/chips/simple_dict//', data={'qaid_list': [0], 'daid_list': [0]}) + ... log_text = client.get('/api/query/chips/simple_dict/', data={'qaid_list': [0], 'daid_list': [0]}) + >>> print('\n---\nfailrsp =\n%s' % (failrsp.data,)) + >>> print('\n---\nfailrsp2 =\n%s' % (failrsp2.data,)) >>> print('Finished test') + Finished test - Ignore: - app = get_flask_app() - with app.app_context(): - #ibs = wbia.opendb('testdb1') - func = ibs.get_annot_uuids - args = tuple() - kwargs = dict() """ assert len(args) == 0, 'There should not be any args=%r' % (args,) From aad87c29278f0c6c03a35b7c2283205dc54c40f3 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Wed, 30 Dec 2020 00:07:28 -0800 Subject: [PATCH 166/294] Fix get_current_log_text test --- wbia/control/IBEISControl.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/wbia/control/IBEISControl.py b/wbia/control/IBEISControl.py index 5443e327a4..d8ec608881 100644 --- a/wbia/control/IBEISControl.py +++ b/wbia/control/IBEISControl.py @@ -1153,20 +1153,15 @@ def get_smart_patrol_dir(self, ensure=True): @register_api('/log/current/', methods=['GET']) def get_current_log_text(self): r""" - CommandLine: - python -m wbia.control.IBEISControl --exec-get_current_log_text - python -m wbia.control.IBEISControl --exec-get_current_log_text --domain http://52.33.105.88 Example: >>> # xdoctest: +REQUIRES(--web-tests) - >>> from wbia.control.IBEISControl import * # NOQA >>> import wbia - >>> import wbia.web - >>> with wbia.opendb_bg_web('testdb1', start_job_queue=False, managed=True) as web_ibs: - ... resp = web_ibs.send_wbia_request('/log/current/', 'get') - >>> print('\n-------Logs ----: \n' ) - >>> print(resp) - >>> print('\nL____ END LOGS ___\n') + >>> with wbia.opendb_with_web('testdb1') as (ibs, client): + ... resp = client.get('/log/current/') + >>> resp.json + {'status': {'success': True, 'code': 200, 'message': '', 'cache': -1}, 'response': None} + """ text = ut.get_current_log_text() return text From 6034e53d5e14782392670aa2a72fc60e47352f13 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Wed, 30 Dec 2020 00:22:40 -0800 Subject: [PATCH 167/294] Fix test_turk_identification_no_more_to_review test --- wbia/tests/web/test_routes.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/wbia/tests/web/test_routes.py b/wbia/tests/web/test_routes.py index 4c5f04f4db..967fdc9298 100644 --- a/wbia/tests/web/test_routes.py +++ b/wbia/tests/web/test_routes.py @@ -3,8 +3,8 @@ def test_turk_identification_no_more_to_review(): - with wbia.opendb_bg_web('testdb2', managed=True) as web_ibs: - resp = web_ibs.get('/turk/identification/lnbnn/') + with wbia.opendb_with_web('testdb2') as (ibs, client): + resp = client.get('/turk/identification/lnbnn/') assert resp.status_code == 200 - assert b'Traceback' not in resp.content, resp.content - assert b'

No more to review!

' in resp.content, resp.content + assert b'Traceback' not in resp.data + assert b'

No more to review!

' in resp.data From 83072436622dd9920a2b5996c9f90f31d7543860 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Wed, 30 Dec 2020 00:33:14 -0800 Subject: [PATCH 168/294] Fix script docstring Just fixing the usage here. Untested. --- wbia/web/routes.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/wbia/web/routes.py b/wbia/web/routes.py index b91bf29611..cd58fcb500 100644 --- a/wbia/web/routes.py +++ b/wbia/web/routes.py @@ -3974,11 +3974,6 @@ def check_engine_identification_query_object( if current_app.QUERY_OBJECT_JOBID is None: current_app.QUERY_OBJECT = None current_app.QUERY_OBJECT_JOBID = ibs.start_web_query_all() - # import wbia - # web_ibs = wbia.opendb_bg_web(dbdir=ibs.dbdir, port=6000) - # query_object_jobid = web_ibs.send_wbia_request('/api/engine/query/graph/') - # logger.info('query_object_jobid = %r' % (query_object_jobid, )) - # current_app.QUERY_OBJECT_JOBID = query_object_jobid query_object_status_dict = ibs.get_job_status(current_app.QUERY_OBJECT_JOBID) args = ( @@ -4018,11 +4013,11 @@ def turk_identification( >>> # SCRIPT >>> from wbia.other.ibsfuncs import * # NOQA >>> import wbia - >>> with wbia.opendb_bg_web('testdb1', managed=True) as web_ibs: - ... resp = web_ibs.get('/turk/identification/lnbnn/') + >>> with wbia.opendb_with_web('testdb1') as (ibs, client): + ... resp = client.get('/turk/identification/lnbnn/') >>> ut.quit_if_noshow() >>> import wbia.plottool as pt - >>> ut.render_html(resp.content) + >>> ut.render_html(resp.data.decode('utf8')) >>> ut.show_if_requested() """ from wbia.web import apis_query @@ -4762,8 +4757,8 @@ def turk_identification_hardcase(*args, **kwargs): >>> # SCRIPT >>> from wbia.other.ibsfuncs import * # NOQA >>> import wbia - >>> with wbia.opendb_bg_web('PZ_Master1', managed=True) as web_ibs: - ... resp = web_ibs.get('/turk/identification/hardcase/') + >>> with wbia.opendb_with_web('PZ_Master1') as (ibs, client): + ... resp = client.get('/turk/identification/hardcase/') Ignore: import wbia @@ -4822,11 +4817,11 @@ def turk_identification_graph( >>> # SCRIPT >>> from wbia.other.ibsfuncs import * # NOQA >>> import wbia - >>> with wbia.opendb_bg_web('testdb1', managed=True) as web_ibs: - ... resp = web_ibs.get('/turk/identification/graph/') + >>> with wbia.opendb_with_web('testdb1') as (ibs, client): + ... resp = client.get('/turk/identification/graph/') >>> ut.quit_if_noshow() >>> import wbia.plottool as pt - >>> ut.render_html(resp.content) + >>> ut.render_html(resp.data.decode('utf8')) >>> ut.show_if_requested() """ ibs = current_app.ibs From 57fe09fbdea132e019ee34217ba6da53dcab0196 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Wed, 30 Dec 2020 13:36:29 -0800 Subject: [PATCH 169/294] Fix query_chips_graph_v2 doctest --- wbia/web/apis_query.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/wbia/web/apis_query.py b/wbia/web/apis_query.py index e1bb64ea80..12a6544897 100644 --- a/wbia/web/apis_query.py +++ b/wbia/web/apis_query.py @@ -1468,13 +1468,12 @@ def query_chips_graph_v2( >>> # Open local instance >>> ibs = wbia.opendb('PZ_MTEST') >>> uuid_list = ibs.annots().uuids[0:10] - >>> # Start up the web instance - >>> web_ibs = wbia.opendb_bg_web(db='PZ_MTEST', web=True, browser=False) >>> data = dict(annot_uuid_list=uuid_list) - >>> resp = web_ibs.send_wbia_request('/api/query/graph/v2/', **data) - >>> print('resp = %r' % (resp,)) - >>> #cmdict_list = json_dict['response'] - >>> #assert 'score_list' in cmdict_list[0] + >>> # Start up the web instance + >>> with wbia.opendb_with_web(db='PZ_MTEST') as (ibs, client): + ... resp = client.post('/api/query/graph/v2/', data=data) + >>> resp.json + {'status': {'success': False, 'code': 608, 'message': 'Invalid image and/or annotation UUIDs (0, 1)', 'cache': -1}, 'response': {'invalid_image_uuid_list': [], 'invalid_annot_uuid_list': [[0, 'c544d25f-fd03-5a2d-6611-cd77430ca251']]}} Example: >>> # DEBUG_SCRIPT From 479e60e1165eb4a79b1d98460bb32753662bb504 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Wed, 30 Dec 2020 13:38:08 -0800 Subject: [PATCH 170/294] Remove review_query_chips_test docstring script --- wbia/web/apis_query.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/wbia/web/apis_query.py b/wbia/web/apis_query.py index 12a6544897..fb63bd39c1 100644 --- a/wbia/web/apis_query.py +++ b/wbia/web/apis_query.py @@ -544,16 +544,6 @@ def review_graph_match_html( @register_route('/test/review/query/chip/', methods=['GET']) def review_query_chips_test(**kwargs): - """ - CommandLine: - python -m wbia.web.apis_query review_query_chips_test --show - - Example: - >>> # SCRIPT - >>> import wbia - >>> web_ibs = wbia.opendb_bg_web( - >>> browser=True, url_suffix='/test/review/query/chip/?__format__=true') - """ ibs = current_app.ibs # the old block curvature dtw From d2fd5da1020ed665922ed916f52137f704e7aca8 Mon Sep 17 00:00:00 2001 From: karen chan Date: Wed, 6 Jan 2021 14:39:56 +0000 Subject: [PATCH 171/294] Disable test in wbia/web/test_api.py `run_test_api` itself doesn't work anyway (as shown in the test) because `web_instance` is not an `IBEISController` and doesn't have attribute `get_web_port_via_scan`. On top of that, pytest doesn't finish unless ctrl+c is sent. ``` DOCTEST TRACEBACK Traceback (most recent call last): File "/virtualenv/env3/lib/python3.7/site-packages/xdoctest/doctest_example.py", line 612, in run checker.check_exception(exc_got, want, runstate) File "/virtualenv/env3/lib/python3.7/site-packages/xdoctest/doctest_example.py", line 598, in run exec(code, test_globals) File "", line rel: 3, abs: 92, in >>> response = run_test_api() File "/wbia/wildbook-ia/wbia/web/test_api.py", line 106, in run_test_api web_port = web_instance.get_web_port_via_scan() AttributeError: 'KillableProcess' object has no attribute 'get_web_port_via_scan' DOCTEST REPRODUCTION CommandLine: pytest /wbia/wildbook-ia/wbia/web/test_api.py::run_test_api:0 ``` --- wbia/web/test_api.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/wbia/web/test_api.py b/wbia/web/test_api.py index d86a1f418c..ffb9f4f4cd 100755 --- a/wbia/web/test_api.py +++ b/wbia/web/test_api.py @@ -87,7 +87,7 @@ def run_test_api(): python -m wbia.web.test_api --test-run_test_api Example: - >>> # xdoctest: +REQUIRES(--web-tests) + >>> # DISABLE_DOCTEST >>> from wbia.web.test_api import * # NOQA >>> response = run_test_api() >>> print('Server response: %r' % (response, )) @@ -103,6 +103,8 @@ def run_test_api(): # Get the application port from the background process if APPLICATION_PORT is None: + # FIXME web_instance is a KillableProcess, not IBEISController, + # it doesn't have get_web_port_via_scan web_port = web_instance.get_web_port_via_scan() if web_port is None: raise ValueError('IA web server is not running on any expected port') From f5961932df9266384553cec055a0cab549ba6b94 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Fri, 8 Jan 2021 15:47:53 -0800 Subject: [PATCH 172/294] Fix apis module routes doctests to use opendb_with_web --- wbia/web/apis.py | 55 ++++++++++++++++++------------------------------ 1 file changed, 21 insertions(+), 34 deletions(-) diff --git a/wbia/web/apis.py b/wbia/web/apis.py index dc337b35ac..798a3e12b6 100644 --- a/wbia/web/apis.py +++ b/wbia/web/apis.py @@ -80,9 +80,9 @@ def image_src_api(rowid=None, thumbnail=False, fresh=False, **kwargs): Example: >>> from wbia.web.app import * # NOQA >>> import wbia - >>> with wbia.opendb_bg_web('testdb1', start_job_queue=False, managed=True) as web_ibs: - ... resp = web_ibs.send_wbia_request('/api/image/src/1/', type_='get', json=False) - >>> print(resp) + >>> with wbia.opendb_with_web('testdb1') as (ibs, client): + ... resp = client.get('/api/image/src/1/') + >>> print(resp.data) b'\xff\xd8\xff\xe0\x00\x10JFIF... RESTful: @@ -136,12 +136,13 @@ def annot_src_api(rowid=None, fresh=False, **kwargs): Example: >>> # xdoctest: +REQUIRES(--slow) - >>> # WEB_DOCTEST + >>> # xdoctest: +REQUIRES(--web-tests) >>> from wbia.web.app import * # NOQA >>> import wbia - >>> with wbia.opendb_bg_web('testdb1', start_job_queue=False, managed=True) as web_ibs: - ... resp = web_ibs.send_wbia_request('/api/annot/src/1/', type_='get', json=False) - >>> print(resp) + >>> with wbia.opendb_with_web('testdb1') as (ibs, client): + ... resp = client.get('/api/annot/src/1/') + >>> print(resp.data) + b'\xff\xd8\xff\xe0\x00\x10JFIF... RESTful: Method: GET @@ -184,12 +185,13 @@ def background_src_api(rowid=None, fresh=False, **kwargs): Example: >>> # xdoctest: +REQUIRES(--slow) - >>> # WEB_DOCTEST + >>> # xdoctest: +REQUIRES(--web-tests) >>> from wbia.web.app import * # NOQA >>> import wbia - >>> with wbia.opendb_bg_web('testdb1', start_job_queue=False, managed=True) as web_ibs: - ... resp = web_ibs.send_wbia_request('/api/background/src/1/', type_='get', json=False) - >>> print(resp) + >>> with wbia.opendb_with_web('testdb1') as (ibs, client): + ... resp = client.get('/api/background/src/1/') + >>> print(resp.data) + b'\xff\xd8\xff\xe0\x00\x10JFIF... RESTful: Method: GET @@ -234,9 +236,10 @@ def image_src_api_json(uuid=None, **kwargs): >>> # xdoctest: +REQUIRES(--web-tests) >>> from wbia.web.app import * # NOQA >>> import wbia - >>> with wbia.opendb_bg_web('testdb1', start_job_queue=False, managed=True) as web_ibs: - ... resp = web_ibs.send_wbia_request('/api/image/src/json/0a9bc03d-a75e-8d14-0153-e2949502aba7/', type_='get', json=False) - >>> print(resp) + >>> with wbia.opendb_with_web('testdb1') as (ibs, client): + ... resp = client.get('/api/image/src/json/0a9bc03d-a75e-8d14-0153-e2949502aba7/') + >>> print(resp.data) + b'\xff\xd8\xff\xe0\x00\x10JFIF... RESTful: Method: GET @@ -436,37 +439,21 @@ def image_upload_zip(**kwargs): @register_api('/api/test/helloworld/', methods=['GET', 'POST', 'DELETE', 'PUT']) def hello_world(*args, **kwargs): """ - CommandLine: - python -m wbia.web.apis --exec-hello_world:0 - python -m wbia.web.apis --exec-hello_world:1 Example: - >>> # xdoctest: +REQUIRES(--web-tests) - >>> from wbia.web.app import * # NOQA - >>> import wbia - >>> web_ibs = wbia.opendb_bg_web(browser=True, start_job_queue=False, url_suffix='/api/test/helloworld/?test0=0') # start_job_queue=False) - >>> print('web_ibs = %r' % (web_ibs,)) - >>> print('Server will run until control c') - >>> web_ibs.terminate2() - - Example1: >>> # xdoctest: +REQUIRES(--web-tests) >>> from wbia.web.app import * # NOQA >>> import wbia >>> import requests >>> import wbia - >>> with wbia.opendb_bg_web('testdb1', start_job_queue=False, managed=True) as web_ibs: - ... web_port = ibs.get_web_port_via_scan() - ... if web_port is None: - ... raise ValueError('IA web server is not running on any expected port') - ... domain = 'http://127.0.0.1:%s' % (web_port, ) - ... url = domain + '/api/test/helloworld/?test0=0' + >>> with wbia.opendb_with_web('testdb1') as (ibs, client): + ... resp = client.get('/api/test/helloworld/?test0=0') ... payload = { ... 'test1' : 'test1', ... 'test2' : None, # NOTICE test2 DOES NOT SHOW UP ... } - ... resp = requests.post(url, data=payload) - ... print(resp) + ... resp = client.post('/api/test/helloworld/', data=payload) + """ logger.info('+------------ HELLO WORLD ------------') logger.info('Args: %r' % (args,)) From 85c0baa35f7f38a9a918128e05182531fb941a90 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Fri, 8 Jan 2021 16:04:01 -0800 Subject: [PATCH 173/294] Fix doctest for wbia.web.app module to use opendb_with_web --- wbia/web/app.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/wbia/web/app.py b/wbia/web/app.py index 9c397cd926..1368673d1f 100644 --- a/wbia/web/app.py +++ b/wbia/web/app.py @@ -37,14 +37,13 @@ def tst_html_error(): r""" This test will show what our current errors look like - CommandLine: - python -m wbia.web.app --exec-tst_html_error - Example: - >>> # DISABLE_DOCTEST - >>> from wbia.web.app import * # NOQA >>> import wbia - >>> web_ibs = wbia.opendb_bg_web(browser=True, start_job_queue=False, url_suffix='/api/image/imagesettext/?__format__=True') + >>> with wbia.opendb_with_web('testdb1') as (ibs, client): + ... resp = client.get('/api/image/imagesettext/?__format__=True') + >>> print(resp) + + """ pass From d2067725efad86f5d9e585e1cc515b0bb166734a Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Fri, 8 Jan 2021 21:44:56 -0800 Subject: [PATCH 174/294] Label tests that use the job-engine --- wbia/web/apis_engine.py | 2 ++ wbia/web/apis_query.py | 4 ++++ wbia/web/job_engine.py | 24 +++++------------------- 3 files changed, 11 insertions(+), 19 deletions(-) diff --git a/wbia/web/apis_engine.py b/wbia/web/apis_engine.py index 022293f0de..c1646b8a0d 100644 --- a/wbia/web/apis_engine.py +++ b/wbia/web/apis_engine.py @@ -262,6 +262,7 @@ def start_identify_annots( Example: >>> # xdoctest: +REQUIRES(--web-tests) + >>> # xdoctest: +REQUIRES(--job-engine-tests) >>> from wbia.web.apis_engine import * # NOQA >>> import wbia >>> with wbia.opendb_bg_web('testdb1', managed=True) as web_ibs: # , domain='http://52.33.105.88') @@ -452,6 +453,7 @@ def start_identify_annots_query( Example: >>> # DISABLE_DOCTEST + >>> # xdoctest: +REQUIRES(--job-engine-tests) >>> from wbia.web.apis_engine import * # NOQA >>> import wbia >>> #domain = 'localhost' diff --git a/wbia/web/apis_query.py b/wbia/web/apis_query.py index fb63bd39c1..56048f308d 100644 --- a/wbia/web/apis_query.py +++ b/wbia/web/apis_query.py @@ -325,6 +325,9 @@ def review_graph_match_html( Example: >>> # xdoctest: +REQUIRES(--web-tests) + >>> # xdoctest: +REQUIRES(--job-engine-tests) + >>> # DISABLE_DOCTEST + >>> # Disabled because this test uses opendb_bg_web, which hangs the test runner and leaves zombie processes >>> from wbia.web.apis_query import * # NOQA >>> import wbia >>> web_ibs = wbia.opendb_bg_web('testdb1') # , domain='http://52.33.105.88') @@ -377,6 +380,7 @@ def review_graph_match_html( Example2: >>> # DISABLE_DOCTEST + >>> # xdoctest: +REQUIRES(--job-engine-tests) >>> # This starts off using web to get information, but finishes the rest in python >>> from wbia.web.apis_query import * # NOQA >>> import wbia diff --git a/wbia/web/job_engine.py b/wbia/web/job_engine.py index f61ff4ed66..85ee10e3dd 100644 --- a/wbia/web/job_engine.py +++ b/wbia/web/job_engine.py @@ -162,6 +162,7 @@ def initialize_job_manager(ibs): Example: >>> # DISABLE_DOCTEST + >>> # xdoctest: +REQUIRES(--job-engine-tests) >>> from wbia.web.job_engine import * # NOQA >>> import wbia >>> ibs = wbia.opendb(defaultdb='testdb1') @@ -174,24 +175,6 @@ def initialize_job_manager(ibs): >>> ibs.close_job_manager() >>> print('Closing success.') - Example: - >>> # xdoctest: +REQUIRES(--web-tests) - >>> from wbia.web.job_engine import * # NOQA - >>> import wbia - >>> import requests - >>> with wbia.opendb_bg_web(db='testdb1', managed=True) as web_instance: - ... web_port = ibs.get_web_port_via_scan() - ... if web_port is None: - ... raise ValueError('IA web server is not running on any expected port') - ... baseurl = 'http://127.0.1.1:%s' % (web_port, ) - ... _payload = {'image_attrs_list': [], 'annot_attrs_list': []} - ... payload = ut.map_dict_vals(ut.to_json, _payload) - ... resp1 = requests.post(baseurl + '/api/test/helloworld/?f=b', data=payload) - ... #resp2 = requests.post(baseurl + '/api/image/json/', data=payload) - ... #print(resp2) - ... #json_dict = resp2.json() - ... #text = json_dict['response'] - ... #print(text) """ ibs.job_manager = ut.DynStruct() @@ -265,6 +248,7 @@ def get_job_id_list(ibs): Example: >>> # xdoctest: +REQUIRES(--web-tests) + >>> # xdoctest: +REQUIRES(--job-engine-tests) >>> from wbia.web.job_engine import * # NOQA >>> import wbia >>> with wbia.opendb_bg_web('testdb1', managed=True) as web_ibs: # , domain='http://52.33.105.88') @@ -317,6 +301,7 @@ def get_job_status(ibs, jobid=None): Example: >>> # xdoctest: +REQUIRES(--web-tests) + >>> # xdoctest: +REQUIRES(--job-engine-tests) >>> from wbia.web.job_engine import * # NOQA >>> import wbia >>> with wbia.opendb_bg_web('testdb1', managed=True) as web_ibs: # , domain='http://52.33.105.88') @@ -363,7 +348,8 @@ def get_job_metadata(ibs, jobid): Example: >>> # xdoctest: +REQUIRES(--slow) - >>> # WEB_DOCTEST + >>> # xdoctest: +REQUIRES(--job-engine-tests) + >>> # xdoctest: +REQUIRES(--web-tests) >>> from wbia.web.job_engine import * # NOQA >>> import wbia >>> with wbia.opendb_bg_web('testdb1', managed=True) as web_ibs: # , domain='http://52.33.105.88') From 5edd071c2ababc2f43591c7f656d98bee67df2c0 Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Mon, 11 Jan 2021 14:29:51 -0800 Subject: [PATCH 175/294] Small linting fixes --- wbia/core_annots.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/wbia/core_annots.py b/wbia/core_annots.py index 8c30143736..ff2fa52b0c 100644 --- a/wbia/core_annots.py +++ b/wbia/core_annots.py @@ -2711,7 +2711,9 @@ def get_annot_lrudfb_bools(ibs, aid_list): 'front' in view, 'back' in view, ] - if view is not None else [False] * 6 for view in views + if view is not None + else [False] * 6 + for view in views ] return bool_arrays From 014510fa550dcde71555434ac9b42ddb230d1cde Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Mon, 11 Jan 2021 15:21:24 -0800 Subject: [PATCH 176/294] Add Megan detector models V1 for Argentina and Kenya --- wbia/algo/detect/lightnet.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/wbia/algo/detect/lightnet.py b/wbia/algo/detect/lightnet.py index 790539aedd..d856cc55cb 100644 --- a/wbia/algo/detect/lightnet.py +++ b/wbia/algo/detect/lightnet.py @@ -71,6 +71,8 @@ 'candidacy': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.candidacy.py', 'ggr2': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.ggr2.py', 'snow_leopard_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.snow_leopard.v0.py', + 'megan_argentina_v1': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.megan.argentina.v1.py', + 'megan_kenya_v1': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.megan.kenya.v1.py', None: 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.candidacy.py', 'training_kit': 'https://wildbookiarepository.azureedge.net/data/lightnet-training-kit.zip', } From 33d4640899afbc2ac3d37c005b2b99736a04d3e5 Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Mon, 11 Jan 2021 15:57:55 -0800 Subject: [PATCH 177/294] Peg version of Jedi to fix this error ipython/ipython/issues/12742 --- devops/provision/Dockerfile | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/devops/provision/Dockerfile b/devops/provision/Dockerfile index 1e4afa1928..2d7d11a808 100644 --- a/devops/provision/Dockerfile +++ b/devops/provision/Dockerfile @@ -188,7 +188,9 @@ RUN set -ex \ && /virtualenv/env3/bin/pip install \ 'tensorflow-gpu==1.15.4' \ 'keras==2.2.5' \ - 'h5py<3.0.0' + 'h5py<3.0.0' \ + && /virtualenv/env3/bin/pip install \ + 'jedi==0.17.2' RUN set -ex \ && /virtualenv/env3/bin/pip freeze | grep wbia \ From 869ae21684f4a021716d5d08e55931e125bcf754 Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Wed, 13 Jan 2021 13:16:19 -0800 Subject: [PATCH 178/294] Added envirronment variable for Tensorflow --- devops/_config/setup.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/devops/_config/setup.sh b/devops/_config/setup.sh index cf22d64331..11276bcbc5 100755 --- a/devops/_config/setup.sh +++ b/devops/_config/setup.sh @@ -44,3 +44,6 @@ chown -R ${HOST_USER}:${HOST_USER} /wbia/wbia-plugin-pie/ if [ ! -d "/data/docker" ]; then ln -s -T /data/db /data/docker fi + +# Allow Tensorflow to use GPU memory more dynamically +export TF_FORCE_GPU_ALLOW_GROWTH=true From f8b4018565c78199edbd6bddf376c43daefb5e5a Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Wed, 13 Jan 2021 15:00:45 -0800 Subject: [PATCH 179/294] Add grey_whale_v0 --- wbia/algo/detect/densenet.py | 1 + wbia/algo/detect/lightnet.py | 1 + 2 files changed, 2 insertions(+) diff --git a/wbia/algo/detect/densenet.py b/wbia/algo/detect/densenet.py index e80ccd333e..d0942cc6c9 100644 --- a/wbia/algo/detect/densenet.py +++ b/wbia/algo/detect/densenet.py @@ -71,6 +71,7 @@ 'flukebook_v1': 'https://wildbookiarepository.azureedge.net/models/classifier2.flukebook.v1.zip', 'rightwhale_v5': 'https://wildbookiarepository.azureedge.net/models/labeler.rightwhale.v5.zip', 'snow_leopard_v0': 'https://wildbookiarepository.azureedge.net/models/labeler.snow_leopard.v0.zip', + 'grey_whale_v0': 'https://wildbookiarepository.azureedge.net/models/labeler.whale_grey.v0.zip', } diff --git a/wbia/algo/detect/lightnet.py b/wbia/algo/detect/lightnet.py index d856cc55cb..63ede7770a 100644 --- a/wbia/algo/detect/lightnet.py +++ b/wbia/algo/detect/lightnet.py @@ -73,6 +73,7 @@ 'snow_leopard_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.snow_leopard.v0.py', 'megan_argentina_v1': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.megan.argentina.v1.py', 'megan_kenya_v1': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.megan.kenya.v1.py', + 'grey_whale_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.whale_grey.v0.py', None: 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.candidacy.py', 'training_kit': 'https://wildbookiarepository.azureedge.net/data/lightnet-training-kit.zip', } From 8a6723a9f578f618b1440b5dd70ecef18a8a80a6 Mon Sep 17 00:00:00 2001 From: karen chan Date: Fri, 9 Oct 2020 00:17:35 +0100 Subject: [PATCH 180/294] Handle duplicate columns in SQLDatabaseController.get_where Specifically, this test was failing: ``` DOCTEST TRACEBACK Traceback (most recent call last): File "/virtualenv/env3/lib/python3.7/site-packages/xdoctest/doctest_example.py", line 598, in run exec(code, test_globals) File "", line rel: 10, abs: 2585, in >>> prop_list = table.get_row_data(tbl_rowids, colnames, **kwargs) File "/wbia/wildbook-ia/wbia/dtool/depcache_table.py", line 2741, in get_row_data tries_left, File "/wbia/wildbook-ia/wbia/dtool/depcache_table.py", line 2794, in _resolve_any_external_data for uri in prop_listT[extern_colx]: IndexError: list index out of range DOCTEST REPRODUCTION CommandLine: pytest /wbia/wildbook-ia/wbia/dtool/depcache_table.py::DependencyCacheTable.get_row_data:0 /wbia/wildbook-ia/wbia/dtool/depcache_table.py:2585: IndexError ``` The test was getting `size_1`, `size_0`, `size_1`, `chip_extern_uri`, `chip_extern_uri` from the `chip` table. SQLAlchemy generated sql that looks like: ``` SELECT chip.size_1, chip.size_0, chip.chip_extern_uri FROM chip WHERE chip_rowid = :_identifier ``` removing all the duplicate columns, causing the `IndexError` above. --- wbia/dtool/sql_control.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 912466f544..22e920a20b 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -20,6 +20,7 @@ import sqlalchemy import utool as ut from deprecated import deprecated +from sqlalchemy.engine import RowProxy from sqlalchemy.schema import Table from sqlalchemy.sql import bindparam, text, ClauseElement @@ -1174,7 +1175,24 @@ def get_where( **kwargs, ) - return val_list + # This code is specifically for handling duplication in colnames + # because sqlalchemy removes them. + # e.g. select field1, field1, field2 from table; + # becomes + # select field1, field2 from table; + # so the items in val_list only have 2 values + # but the caller isn't expecting it so it causes problems + returned_columns = tuple([c.name for c in stmt.columns]) + if colnames == returned_columns: + return val_list + + result = [] + for val in val_list: + if isinstance(val, RowProxy): + result.append(tuple(val[returned_columns.index(c)] for c in colnames)) + else: + result.append(val) + return result def exists_where_eq( self, From 4c212c54606f22fb3fe7b2110735a33d3634ad56 Mon Sep 17 00:00:00 2001 From: karen chan Date: Tue, 20 Oct 2020 16:58:06 +0100 Subject: [PATCH 181/294] Try translating sqlite uri to a postgres uri For example, translate from `sqlite:////wbia/wildbook-ia/testdb1/_ibsdb/_ibeis_cache/chipcache4.sqlite` to `postgresql://wbia@db/testdb1` The `testdb1` database needed to be created by "postgres" on the db container: ``` psql -U postgres -c 'CREATE DATABASE testdb1 WITH OWNER wbia;' ``` --- wbia/dtool/sql_control.py | 52 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 22e920a20b..a7f1fde406 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -68,6 +68,38 @@ METADATA_TABLE_COLUMN_NAMES = list(METADATA_TABLE_COLUMNS.keys()) +def sqlite_uri_to_postgres_uri_schema(uri): + from wbia.init.sysres import get_workdir + + workdir = os.path.normpath(os.path.abspath(get_workdir())) + base_db_uri = os.getenv('WBIA_BASE_DB_URI') + namespace = None + if base_db_uri and uri.startswith(f'sqlite:///{workdir}'): + # Change sqlite uri to postgres uri + + # Remove sqlite:///{workdir} from uri + # -> /NAUT_test/_ibsdb/_ibeis_cache/chipcache4.sqlite + sqlite_db_path = uri[len(f'sqlite:///{workdir}') :] + + # ['', 'naut_test', '_ibsdb', '_ibeis_cache', 'chipcache4.sqlite'] + sqlite_db_path_parts = sqlite_db_path.lower().split(os.path.sep) + + if len(sqlite_db_path_parts) > 2: + # naut_test + db_name = sqlite_db_path_parts[1] + # chipcache4 + namespace = os.path.splitext(sqlite_db_path_parts[-1])[0] + + # postgresql://wbia@db/naut_test + uri = f'{base_db_uri}/{db_name}' + else: + raise RuntimeError( + f'uri={uri} sqlite_db_path={sqlite_db_path} sqlite_db_path_parts={sqlite_db_path_parts}' + ) + + return (uri, namespace) + + def _unpacker(results): """ HELPER: Unpacks results if unpack_scalars is True. """ if not results: # Check for None or empty list @@ -430,6 +462,10 @@ def __len__(self): def __init_engine(self): """Create the SQLAlchemy Engine""" + if os.getenv('POSTGRES'): + uri, schema = sqlite_uri_to_postgres_uri_schema(self.uri) + self.uri = uri + self.schema = schema self._engine = sqlalchemy.create_engine( self.uri, # The echo flag is a shortcut to set up SQLAlchemy logging @@ -497,6 +533,9 @@ def from_uri(cls, uri, readonly=READ_ONLY, timeout=TIMEOUT): def connect(self): """Create a connection for the instance or use the existing connection""" self._connection = self._engine.connect() + if self._engine.dialect.name == 'postgresql': + self._connection.execute(f'CREATE SCHEMA IF NOT EXISTS {self.schema}') + self.connection.execute(text('SET SCHEMA :schema'), schema=self.schema) return self._connection @property @@ -523,11 +562,16 @@ def _create_connection(self): uri = self.uri if self.readonly: uri += '?mode=ro' + if os.getenv('POSTGRES'): + uri, schema = sqlite_uri_to_postgres_uri_schema(uri) engine = sqlalchemy.create_engine( uri, echo=False, ) connection = engine.connect() + if engine.dialect.name == 'postgresql': + connection.execute(f'CREATE SCHEMA IF NOT EXISTS {schema}') + connection.execute(text('SET SCHEMA :schema'), schema=schema) # Keep track of what thead this was started in threadid = threading.current_thread() @@ -660,6 +704,9 @@ def reboot(self): # ??? May be better to use the `dispose()` method? self.__init_engine() self.connection = self._engine.connect() + if self._engine.dialect.name == 'postgresql': + self.connection.execute(f'CREATE SCHEMA IF NOT EXISTS {self.schema}') + self.connection.execute(text('SET SCHEMA :schema'), schema=self.schema) def backup(self, backup_filepath): """ @@ -729,8 +776,11 @@ def _reflect_table(self, table_name): """Produces a SQLAlchemy Table object from the given ``table_name``""" # Note, this on introspects once. Repeated calls will pull the Table object # from the MetaData object. + kw = {} + if self._engine.dialect.name == 'postgresql': + kw = {'schema': self.schema} return Table( - table_name, self._sa_metadata, autoload=True, autoload_with=self._engine + table_name, self._sa_metadata, autoload=True, autoload_with=self._engine, **kw ) # ============== From 326f10beaae39202fc66f856989add5f213f6b7d Mon Sep 17 00:00:00 2001 From: karen chan Date: Tue, 20 Oct 2020 17:04:00 +0100 Subject: [PATCH 182/294] Catch sqlalchemy error when table doesn't exist It seems that sqlalchemy raises a different error when a table doesn't exist in postgres: ``` def do_execute(self, cursor, statement, parameters, context=None): > cursor.execute(statement, parameters) E sqlalchemy.exc.ProgrammingError: (psycopg2.errors.UndefinedTable) relation "metadata" does not exist E LINE 1: SELECT 1 FROM metadata LIMIT 1 E ^ E E [SQL: SELECT 1 FROM metadata LIMIT 1] E (Background on this error at: http://sqlalche.me/e/13/f405) /virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/engine/default.py:593: ProgrammingError ``` We need to catch this error as well in `SQLDatabaseController._ensure_metadata_table`. --- wbia/dtool/sql_control.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index a7f1fde406..c7eef6ad85 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -608,8 +608,15 @@ def _ensure_metadata_table(self): """ try: orig_table_kw = self.get_table_autogen_dict(METADATA_TABLE_NAME) - except (sqlalchemy.exc.OperationalError, NameError): + except ( + sqlalchemy.exc.OperationalError, # sqlite error + sqlalchemy.exc.ProgrammingError, # postgres error + NameError, + ): orig_table_kw = None + # Reset connection because schema was rolled back due to + # the error + self._connection = None meta_table_kw = ut.odict( [ From 0dcd435ee0ea0c4aef1861dbaa30ca991a3172ac Mon Sep 17 00:00:00 2001 From: karen chan Date: Mon, 26 Oct 2020 22:38:14 +0000 Subject: [PATCH 183/294] Create custom sql types (domains) in postgres We have non-standard sql types like "list", "json", "ndarray" etc. Create these types in postgresql so we can use them in create table statements. --- wbia/dtool/sql_control.py | 2 ++ wbia/dtool/types.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index c7eef6ad85..3a349d2b73 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -27,6 +27,7 @@ from wbia.dtool import lite from wbia.dtool.dump import dumps from wbia.dtool.types import Integer, TYPE_TO_SQLTYPE +from wbia.dtool.types import initialize_postgresql_types print, rrr, profile = ut.inject2(__name__) @@ -536,6 +537,7 @@ def connect(self): if self._engine.dialect.name == 'postgresql': self._connection.execute(f'CREATE SCHEMA IF NOT EXISTS {self.schema}') self.connection.execute(text('SET SCHEMA :schema'), schema=self.schema) + initialize_postgresql_types(self.connection, self.schema) return self._connection @property diff --git a/wbia/dtool/types.py b/wbia/dtool/types.py index e5f893226c..eeec5f9566 100644 --- a/wbia/dtool/types.py +++ b/wbia/dtool/types.py @@ -5,6 +5,8 @@ import numpy as np from utool.util_cache import from_json, to_json +import sqlalchemy +from sqlalchemy.sql import text from sqlalchemy.types import Integer as SAInteger from sqlalchemy.types import TypeDecorator, UserDefinedType @@ -40,6 +42,7 @@ class JSONCodeableType(UserDefinedType): # Abstract properties base_py_type = None col_spec = None + postgresql_base_type = 'json' def get_col_spec(self, **kw): return self.col_spec @@ -68,6 +71,7 @@ class NumPyPicklableType(UserDefinedType): # Abstract properties base_py_types = None col_spec = None + postgresql_base_type = 'bytea' def get_col_spec(self, **kw): return self.col_spec @@ -180,3 +184,18 @@ def process(value): # SQL type (e.g. 'DICT') to SQLAlchemy type: SQL_TYPE_TO_SA_TYPE = {cls().get_col_spec(): cls for cls in _USER_DEFINED_TYPES} SQL_TYPE_TO_SA_TYPE['INTEGER'] = Integer + + +def initialize_postgresql_types(conn, schema): + domain_names = conn.execute( + """\ + SELECT domain_name FROM information_schema.domains + WHERE domain_schema = (select current_schema)""" + ).fetchall() + for type_name, cls in SQL_TYPE_TO_SA_TYPE.items(): + if type_name not in domain_names and hasattr(cls, 'postgresql_base_type'): + base_type = cls.postgresql_base_type + try: + conn.execute(f'CREATE DOMAIN {type_name} AS {base_type}') + except sqlalchemy.exc.ProgrammingError: + conn.execute(text('SET SCHEMA :schema'), schema=schema) From e87c83c5fac698b6f99368f51579ef0495a8cc08 Mon Sep 17 00:00:00 2001 From: karen chan Date: Tue, 20 Oct 2020 17:09:10 +0100 Subject: [PATCH 184/294] Implement postgres version of getting all table names The code gets all the table names by querying a table called `sqlite_master`, which is obviously not available in postgres: ``` def do_execute(self, cursor, statement, parameters, context=None): > cursor.execute(statement, parameters) E sqlalchemy.exc.ProgrammingError: (psycopg2.errors.UndefinedTable) relation "sqlite_master" does not exist E LINE 1: SELECT name FROM sqlite_master WHERE type='table' E ^ E E [SQL: SELECT name FROM sqlite_master WHERE type='table'] E (Background on this error at: http://sqlalche.me/e/13/f405) /virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/engine/default.py:593: ProgrammingError ``` The postgres version is `information_schema.tables`. --- wbia/dtool/sql_control.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 3a349d2b73..45082da533 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -2443,9 +2443,21 @@ def invalidate_tables_cache(self): def get_table_names(self, lazy=False): """ Conveinience: """ if not lazy or self._tablenames is None: - result = self.connection.execute( - "SELECT name FROM sqlite_master WHERE type='table'" - ) + dialect = self.connection.engine.dialect.name + if dialect == 'sqlite': + stmt = "SELECT name FROM sqlite_master WHERE type='table'" + params = {} + elif dialect == 'postgresql': + stmt = text( + """\ + SELECT table_name FROM information_schema.tables + WHERE table_type='BASE TABLE' + AND table_schema = :schema""" + ) + params = {'schema': self.schema} + else: + raise RuntimeError(f'Unknown dialect {dialect}') + result = self.connection.execute(stmt, **params) tablename_list = result.fetchall() self._tablenames = {str(tablename[0]) for tablename in tablename_list} return self._tablenames From 2fb3332017616d83f14fd91939c5df420d97b77a Mon Sep 17 00:00:00 2001 From: karen chan Date: Tue, 20 Oct 2020 17:12:12 +0100 Subject: [PATCH 185/294] Implement postgres version of getting table column info `PRAGMA TABLE_INFO` isn't available in postgres so we need a different way of getting the table column info. ``` def _discovery_table_columns(inspector, table_name): """Discover the original column type information in a _dialect_ specific way""" dialect = inspector.engine.dialect.name with inspector.engine.connect() as conn: if dialect == 'sqlite': # See also, https://sqlite.org/pragma.html#pragma_table_info result = conn.execute(f"PRAGMA TABLE_INFO('{table_name}')") #: column-id, name, data-type, nullable, default-value, is-primary-key info_rows = result.fetchall() names_to_types = {info[1]: info[2] for info in info_rows} else: raise RuntimeError( > "Unknown dialect ('{dialect}'), can't introspect column information." ) E RuntimeError: Unknown dialect ('{dialect}'), can't introspect column information. wbia/dtool/events.py:27: RuntimeError ``` --- wbia/dtool/events.py | 27 ++++++++++++++++++++++++++- wbia/dtool/sql_control.py | 29 ++++++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/wbia/dtool/events.py b/wbia/dtool/events.py index 06e0447fc3..8ec92988ff 100644 --- a/wbia/dtool/events.py +++ b/wbia/dtool/events.py @@ -7,6 +7,7 @@ """ from sqlalchemy import event from sqlalchemy.schema import Table +from sqlalchemy.sql import text from .types import SQL_TYPE_TO_SA_TYPE @@ -22,9 +23,33 @@ def _discovery_table_columns(inspector, table_name): #: column-id, name, data-type, nullable, default-value, is-primary-key info_rows = result.fetchall() names_to_types = {info[1]: info[2] for info in info_rows} + elif dialect == 'postgresql': + result = conn.execute( + text( + """SELECT + row_number() over () - 1, + column_name, + coalesce(domain_name, data_type), + is_nullable, + column_default, + column_name = ( + SELECT column_name + FROM information_schema.table_constraints + NATURAL JOIN information_schema.constraint_column_usage + WHERE table_name = :table_name + AND constraint_type = 'PRIMARY KEY' + LIMIT 1 + ) AS pk + FROM information_schema.columns + WHERE table_name = :table_name""" + ), + table_name=table_name, + ) + info_rows = result.fetchall() + names_to_types = {info[1]: info[2] for info in info_rows} else: raise RuntimeError( - "Unknown dialect ('{dialect}'), can't introspect column information." + f"Unknown dialect ('{dialect}'), can't introspect column information." ) return names_to_types diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 45082da533..60790bd9f0 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -2581,7 +2581,34 @@ def get_columns(self, tablename): """ # check if the table exists first. Throws an error if it does not exist. self.connection.execute('SELECT 1 FROM ' + tablename + ' LIMIT 1') - result = self.connection.execute("PRAGMA TABLE_INFO('" + tablename + "')") + dialect = self._engine.dialect.name + if dialect == 'sqlite': + stmt = f"PRAGMA TABLE_INFO('{tablename}')" + params = {} + elif dialect == 'postgresql': + stmt = text( + """SELECT + row_number() over () - 1, + column_name, + coalesce(domain_name, data_type), + is_nullable, + column_default, + column_name = ( + SELECT column_name + FROM information_schema.table_constraints + NATURAL JOIN information_schema.constraint_column_usage + WHERE table_name = :table_name + AND constraint_type = 'PRIMARY KEY' + AND table_schema = :table_schema + LIMIT 1 + ) AS pk + FROM information_schema.columns + WHERE table_name = :table_name + AND table_schema = :table_schema""" + ) + params = {'table_name': tablename, 'table_schema': self.schema} + + result = self.connection.execute(stmt, **params) colinfo_list = result.fetchall() colrichinfo_list = [SQLColumnRichInfo(*colinfo) for colinfo in colinfo_list] return colrichinfo_list From d03cf4d9bcef8820c11a3616e578ae5a676da2f2 Mon Sep 17 00:00:00 2001 From: karen chan Date: Tue, 20 Oct 2020 17:14:24 +0100 Subject: [PATCH 186/294] Implement postgres version of upserting metadata `INSERT OR REPLACE` doesn't exist in postgres: ``` def do_execute(self, cursor, statement, parameters, context=None): > cursor.execute(statement, parameters) E sqlalchemy.exc.ProgrammingError: (psycopg2.errors.SyntaxError) syntax error at or near "OR" E LINE 1: INSERT OR REPLACE INTO metadata (metadata_key, metadata_valu... E ^ E E [SQL: INSERT OR REPLACE INTO metadata (metadata_key, metadata_value) VALUES (%(key)s, %(value)s)] E [parameters: {'key': 'metadata_docstr', 'value': '\n The table that stores permanently all of the metadata about the\n database (tables, etc)'}] E (Background on this error at: http://sqlalche.me/e/13/f405) /virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/engine/default.py:593: ProgrammingError ``` Instead we need to use `INSERT ... ON CONFLICT ...` to do upsert. --- wbia/dtool/sql_control.py | 79 ++++++++++++++++++++++++++++++++------- 1 file changed, 65 insertions(+), 14 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 60790bd9f0..4d584fd9fc 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -243,10 +243,23 @@ def version(self): def version(self, value): if not value: raise ValueError(value) - stmt = text( - f'INSERT OR REPLACE INTO {METADATA_TABLE_NAME} (metadata_key, metadata_value)' - 'VALUES (:key, :value)' - ) + dialect = self.ctrlr.connection.engine.dialect.name + if dialect == 'sqlite': + stmt = text( + f'INSERT OR REPLACE INTO {METADATA_TABLE_NAME} (metadata_key, metadata_value)' + 'VALUES (:key, :value)' + ) + elif dialect == 'postgresql': + stmt = text( + f"""\ + INSERT INTO {METADATA_TABLE_NAME} + (metadata_key, metadata_value) + VALUES (:key, :value) + ON CONFLICT (metadata_key) DO UPDATE + SET metadata_value = EXCLUDED.metadata_value""" + ) + else: + raise RuntimeError(f'Unknown dialect {dialect}') params = {'key': 'database_version', 'value': value} self.ctrlr.executeone(stmt, params) @@ -271,10 +284,23 @@ def init_uuid(self, value): raise ValueError(value) elif isinstance(value, uuid.UUID): value = str(value) - stmt = text( - f'INSERT OR REPLACE INTO {METADATA_TABLE_NAME} (metadata_key, metadata_value) ' - 'VALUES (:key, :value)' - ) + dialect = self.ctrlr.connection.engine.dialect.name + if dialect == 'sqlite': + stmt = text( + f'INSERT OR REPLACE INTO {METADATA_TABLE_NAME} (metadata_key, metadata_value) ' + 'VALUES (:key, :value)' + ) + elif dialect == 'postgresql': + stmt = text( + f"""\ + INSERT INTO {METADATA_TABLE_NAME} + (metadata_key, metadata_value) + VALUES (:key, :value) + ON CONFLICT (metadata_key) DO UPDATE + SET metadata_value = EXCLUDED.metadata_value""" + ) + else: + raise RuntimeError(f'Unknown dialect {dialect}') params = {'key': 'database_init_uuid', 'value': value} self.ctrlr.executeone(stmt, params) @@ -363,11 +389,23 @@ def __setattr__(self, name, value): key = self._get_key_name(name) # Insert or update the record - # FIXME postgresql (4-Aug-12020) 'insert or replace' is not valid for postgresql - statement = text( - f'INSERT OR REPLACE INTO {METADATA_TABLE_NAME} ' - f'(metadata_key, metadata_value) VALUES (:key, :value)' - ) + dialect = self.ctrlr.connection.engine.dialect.name + if dialect == 'sqlite': + statement = text( + f'INSERT OR REPLACE INTO {METADATA_TABLE_NAME} ' + f'(metadata_key, metadata_value) VALUES (:key, :value)' + ) + elif dialect == 'postgresql': + statement = text( + f"""\ + INSERT INTO {METADATA_TABLE_NAME} + (metadata_key, metadata_value) + VALUES (:key, :value) + ON CONFLICT (metadata_key) DO UPDATE + SET metadata_value = EXCLUDED.metadata_value""" + ) + else: + raise RuntimeError(f'Unknown dialect {dialect}') params = { 'key': key, 'value': value, @@ -1790,7 +1828,20 @@ def set_metadata_val(self, key, val): 'tablename': METADATA_TABLE_NAME, 'columns': 'metadata_key, metadata_value', } - op_fmtstr = 'INSERT OR REPLACE INTO {tablename} ({columns}) VALUES (:key, :val)' + dialect = self._engine.dialect.name + if dialect == 'sqlite': + op_fmtstr = ( + 'INSERT OR REPLACE INTO {tablename} ({columns}) VALUES (:key, :val)' + ) + elif dialect == 'postgresql': + op_fmtstr = f"""\ + INSERT INTO {METADATA_TABLE_NAME} + (metadata_key, metadata_value) + VALUES (:key, :val) + ON CONFLICT (metadata_key) DO UPDATE + SET metadata_value = EXCLUDED.metadata_value""" + else: + raise RuntimeError(f'Unknown dialect {dialect}') operation = text(op_fmtstr.format(**fmtkw)) params = {'key': key, 'val': val} self.executeone(operation, params, verbose=False) From 4e9d3b16cabd56b36ff64b41cbebab9c758ae4c7 Mon Sep 17 00:00:00 2001 From: karen chan Date: Tue, 20 Oct 2020 21:32:33 +0100 Subject: [PATCH 187/294] Change integer primary key fields to serial in postgresql In postgresql, to create auto-increment fields, we can use `SERIAL`. ``` E sqlalchemy.exc.IntegrityError: (psycopg2.errors.NotNullViolation) null value in column "metadata_rowid" violates not-null constraint E DETAIL: Failing row contains (null, metadata_docstr, E The table that stores permanently all of the ...). E E [SQL: INSERT INTO metadata E (metadata_key, metadata_value) E VALUES (%(key)s, %(value)s) E ON CONFLICT (metadata_key) DO UPDATE E SET metadata_value = EXCLUDED.metadata_value] E [parameters: {'key': 'metadata_docstr', 'value': '\n The table that stores permanently all of the metadata about the\n database (tables, etc)'}] E (Background on this error at: http://sqlalche.me/e/13/gkpj) /virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/engine/default.py:593: IntegrityError ``` --- wbia/dtool/sql_control.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 4d584fd9fc..9dc7ee12d3 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1913,6 +1913,13 @@ def __make_column_definition(self, name: str, definition: str) -> str: raise ValueError(f'name cannot be an empty string paired with {definition}') elif not definition: raise ValueError(f'definition cannot be an empty string paired with {name}') + if self._engine.dialect.name == 'postgresql': + if ( + name.endswith('rowid') + and 'INTEGER' in definition + and 'PRIMARY KEY' in definition + ): + definition = definition.replace('INTEGER', 'SERIAL') return f'{name} {definition}' def _make_add_table_sqlstr( From b01cd3179bf541bb34f31bea5fa233c4bfcfd01b Mon Sep 17 00:00:00 2001 From: karen chan Date: Tue, 20 Oct 2020 21:38:53 +0100 Subject: [PATCH 188/294] Only optimize database in SQLDatabaseController if using sqlite The database optimization code is very specific to sqlite. We may come up with something for postgresql later but just skip for now. ``` def do_execute(self, cursor, statement, parameters, context=None): > cursor.execute(statement, parameters) E sqlalchemy.exc.ProgrammingError: (psycopg2.errors.SyntaxError) syntax error at or near "PRAGMA" E LINE 1: PRAGMA cache_size = 10000; E ^ E E [SQL: PRAGMA cache_size = 10000;] E (Background on this error at: http://sqlalche.me/e/13/f405) /virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/engine/default.py:593: ProgrammingError ``` --- wbia/dtool/sql_control.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 9dc7ee12d3..5181ddb1be 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -779,6 +779,8 @@ def backup(self, backup_filepath): connection.close() def optimize(self): + if self._engine.dialect.name != 'sqlite': + return # http://web.utk.edu/~jplyon/sqlite/SQLite_optimization_FAQ.html#pragma-cache_size # http://web.utk.edu/~jplyon/sqlite/SQLite_optimization_FAQ.html if VERBOSE_SQL: @@ -797,24 +799,32 @@ def optimize(self): # self.connection.execute('PRAGMA default_cache_size = 0;') def shrink_memory(self): + if self._engine.dialect.name != 'sqlite': + return logger.info('[sql] shrink_memory') transaction = self.connection.begin() self.connection.execute('PRAGMA shrink_memory;') transaction.commit() def vacuum(self): + if self._engine.dialect.name != 'sqlite': + return logger.info('[sql] vaccum') transaction = self.connection.begin() self.connection.execute('VACUUM;') transaction.commit() def integrity(self): + if self._engine.dialect.name != 'sqlite': + return logger.info('[sql] vaccum') transaction = self.connection.begin() self.connection.execute('PRAGMA integrity_check;') transaction.commit() def squeeze(self): + if self._engine.dialect.name != 'sqlite': + return logger.info('[sql] squeeze') self.shrink_memory() self.vacuum() From 47fccb5f5b4f8f2676abd58285061cc5bd2da5e0 Mon Sep 17 00:00:00 2001 From: karen chan Date: Fri, 23 Oct 2020 16:52:07 +0100 Subject: [PATCH 189/294] Include table name in unique constraints Constraint names in postgres appears to have to be unique across all tables so if two tables have the same column names, they have the same constraint names and causes an error. ``` def do_execute(self, cursor, statement, parameters, context=None): > cursor.execute(statement, parameters) E sqlalchemy.exc.ProgrammingError: (psycopg2.errors.DuplicateTable) relation "unique_images_rowid_config_rowid" already exists E E [SQL: CREATE TABLE IF NOT EXISTS classifier_two ( classifier_two_rowid SERIAL PRIMARY KEY, images_rowid INTEGER NOT NULL, config_rowid INTEGER DEFAULT 0, scores JSON, classes TEXT[], CONSTRAINT unique_images_rowid_config_rowid UNIQUE (images_rowid, config_rowid) )] E (Background on this error at: http://sqlalche.me/e/13/f405) ``` --- wbia/dtool/sql_control.py | 6 +++--- wbia/tests/dtool/test_sql_control.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 5181ddb1be..d9fd0382d2 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1906,14 +1906,14 @@ def add_column(self, tablename, colname, coltype): operation = op_fmtstr.format(**fmtkw) self.executeone(operation, [], verbose=False) - def __make_unique_constraint(self, column_or_columns): + def __make_unique_constraint(self, table_name, column_or_columns): """Creates a SQL ``CONSTRAINT`` clause for ``UNIQUE`` column data""" if not isinstance(column_or_columns, (list, tuple)): columns = [column_or_columns] else: # Cast as list incase it's a tuple, b/c tuple + list = error columns = list(column_or_columns) - constraint_name = '_'.join(['unique'] + columns) + constraint_name = '_'.join(['unique', table_name] + columns) columns_listing = ', '.join(columns) return f'CONSTRAINT {constraint_name} UNIQUE ({columns_listing})' @@ -1967,7 +1967,7 @@ def _make_add_table_sqlstr( # Make a list of constraints to place on the table # superkeys = [(, ...), ...] constraint_list = [ - self.__make_unique_constraint(x) + self.__make_unique_constraint(tablename, x) for x in metadata_keyval.get('superkeys') or [] ] constraint_list = ut.unique_ordered(constraint_list) diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index d492b72552..748a94fb92 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -119,7 +119,7 @@ def test_make_add_table_sqlstr(self): 'indexer_id INTEGER NOT NULL, ' 'config_id INTEGER DEFAULT 0, ' 'data TEXT, ' - 'CONSTRAINT unique_meta_labeler_id_indexer_id_config_id ' + 'CONSTRAINT unique_foobars_meta_labeler_id_indexer_id_config_id ' 'UNIQUE (meta_labeler_id, indexer_id, config_id) )' ) assert sql.text == expected @@ -160,7 +160,7 @@ def test_add_table(self): ('PrimaryKeyConstraint', None, ['bars_id']), ( 'UniqueConstraint', - 'unique_meta_labeler_id_indexer_id_config_id', + 'unique_bars_meta_labeler_id_indexer_id_config_id', ['meta_labeler_id', 'indexer_id', 'config_id'], ), ] From 66fa9840ce409e4d68bad6802a3bab4a783b8d52 Mon Sep 17 00:00:00 2001 From: karen chan Date: Tue, 27 Oct 2020 15:04:34 +0000 Subject: [PATCH 190/294] Skip database backup if not sqlite The database backup code is based on copying files so it's only suitable for sqlite databases. We'll need to implement something for postgres in the future. --- wbia/control/_sql_helpers.py | 4 ++++ wbia/dtool/sql_control.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/wbia/control/_sql_helpers.py b/wbia/control/_sql_helpers.py index bc8439d655..8ccfe98fd3 100644 --- a/wbia/control/_sql_helpers.py +++ b/wbia/control/_sql_helpers.py @@ -359,6 +359,10 @@ def update_schema_version( clearbackup = False FIXME: AN SQL HELPER FUNCTION SHOULD BE AGNOSTIC TO CONTROLER OBJECTS """ + if db._engine.dialect.name != 'sqlite': + # Backup is based on copying files so if we're not using sqlite, skip + # backup + dobackup = False def _check_superkeys(): all_tablename_list = db.get_table_names() diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index d9fd0382d2..9a276838e8 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -759,6 +759,9 @@ def backup(self, backup_filepath): """ backup_filepath = dst_fpath """ + if self._engine.dialect.name == 'postgresql': + # TODO postgresql backup + return # Create a brand new conenction to lock out current thread and any others connection, uri = self._create_connection() # Start Exclusive transaction, lock out all other writers from making database changes From 5e3faf890bfcfe0cc3debfb3b04bafbc31a144a4 Mon Sep 17 00:00:00 2001 From: karen chan Date: Tue, 27 Oct 2020 15:27:35 +0000 Subject: [PATCH 191/294] Add sqlite built-in rowid column to tables if using postgres Tables are created with a built-in `rowid` column in sqlite by default. We use the `rowid` column in queries but it's not available in postgres so we need to create this column in tables. --- wbia/dtool/sql_control.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 9a276838e8..7b61ca89aa 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1953,6 +1953,11 @@ def _make_add_table_sqlstr( if not coldef_list: raise ValueError(f'empty coldef_list specified for {tablename}') + if self._engine.dialect.name == 'postgresql' and 'rowid' not in [ + name for name, _ in coldef_list + ]: + coldef_list = [('rowid', 'SERIAL UNIQUE')] + list(coldef_list) + # Check for invalid keyword arguments bad_kwargs = set(metadata_keyval.keys()) - set(METADATA_TABLE_COLUMN_NAMES) if len(bad_kwargs) > 0: From d577e0ff377849773863176b7971bfdc1f2f788f Mon Sep 17 00:00:00 2001 From: karen chan Date: Tue, 27 Oct 2020 15:36:28 +0000 Subject: [PATCH 192/294] Drop table cascade for postgresql There are sequences created for tables in order to do serial fields (e.g. rowid) and when the table is dropped, postgresql raises a warning saying that there are objects depending on the table so we need to drop everything by adding "CASCADE". --- wbia/dtool/sql_control.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 7b61ca89aa..7533bc3e22 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -2243,8 +2243,10 @@ def drop_table(self, tablename, invalidate_cache=True): # Technically insecure call, but all entries are statically inputted by # the database's owner, who could delete or alter the entire database # anyway. - operation = text(f'DROP TABLE IF EXISTS {tablename}') - self.executeone(operation, []) + operation = f'DROP TABLE IF EXISTS {tablename}' + if self.uri.startswith('postgresql'): + operation = f'{operation} CASCADE' + self.executeone(text(operation), []) # Delete table's metadata key_list = [tablename + '_' + suffix for suffix in METADATA_TABLE_COLUMN_NAMES] From 02ca5da5334b53e879aed322c21e4155b64fa541 Mon Sep 17 00:00:00 2001 From: karen chan Date: Tue, 27 Oct 2020 22:36:25 +0000 Subject: [PATCH 193/294] Drop and create postgres databases as part of set_up_db fixture "DROP DATABASE" and "CREATE DATABASE" cannot run inside a transaction block for postgresql so it's necessary to add `.execution_options(isolation_level='AUTOCOMMIT')` before `.execute()`. --- wbia/conftest.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/wbia/conftest.py b/wbia/conftest.py index 22cd04a077..31991c5db2 100644 --- a/wbia/conftest.py +++ b/wbia/conftest.py @@ -6,6 +6,7 @@ from pathlib import Path import pytest +import sqlalchemy from wbia.dbio import ingest_database from wbia.init.sysres import ( @@ -114,6 +115,15 @@ def set_up_db(request): # FIXME (16-Jul-12020) this fixture does not cleanup after itself to preserve exiting usage behavior for dbname in TEST_DBNAMES: delete_dbdir(dbname) + if os.getenv('POSTGRES'): + engine = sqlalchemy.create_engine(os.getenv('WBIA_BASE_DB_URI')) + engine.execution_options(isolation_level='AUTOCOMMIT').execute( + f'DROP DATABASE IF EXISTS {dbname}' + ) + engine.execution_options(isolation_level='AUTOCOMMIT').execute( + f'CREATE DATABASE {dbname}' + ) + engine.dispose() # Set up DBs ingest_database.ingest_standard_database('testdb1') From 8c24d87596afb1a35937cdf323d78099ead003d0 Mon Sep 17 00:00:00 2001 From: karen chan Date: Wed, 28 Oct 2020 00:00:56 +0000 Subject: [PATCH 194/294] Implement sqlite STRFTIME('%s', ...) in postgresql In sqlite, we have some fields default to: ``` CAST(STRFTIME('%s', 'NOW', 'UTC') AS INTEGER) ``` which needs to be changed in postgresql to: ``` CAST(EXTRACT(EPOCH FROM NOW() AT TIME ZONE 'UTC') AS INTEGER) ``` ``` [!update_schema_version] The database update failed, and no backup was made. : (psycopg2.errors.UndefinedFunction) function strftime(unknown, unknown, unknown) does not exist HINT: No function matches the given name and argument types. You might need to add explicit type casts. [SQL: CREATE TABLE IF NOT EXISTS reviews ( rowid SERIAL UNIQUE, review_rowid SERIAL PRIMARY KEY, annot_1_rowid INTEGER NOT NULL, annot_2_rowid INTEGER NOT NULL, review_count INTEGER NOT NULL, review_decision INTEGER NOT NULL, review_time_posix INTEGER DEFAULT (CAST(STRFTIME('%%s', 'NOW', 'UTC') AS INTEGER)), review_identity TEXT, review_tags TEXT, CONSTRAINT unique_reviews_annot_1_rowid_annot_2_rowid_review_count UNIQUE (annot_1_rowid, annot_2_rowid, review_count) )] ``` --- wbia/control/STAGING_SCHEMA.py | 62 ++++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 25 deletions(-) diff --git a/wbia/control/STAGING_SCHEMA.py b/wbia/control/STAGING_SCHEMA.py index b165166f53..7cbd69ce69 100644 --- a/wbia/control/STAGING_SCHEMA.py +++ b/wbia/control/STAGING_SCHEMA.py @@ -40,21 +40,27 @@ @profile def update_1_0_0(db, ibs=None): + columns = [ + ('review_rowid', 'INTEGER PRIMARY KEY'), + ('annot_1_rowid', 'INTEGER NOT NULL'), + ('annot_2_rowid', 'INTEGER NOT NULL'), + ('review_count', 'INTEGER NOT NULL'), + ('review_decision', 'INTEGER NOT NULL'), + ( + 'review_time_posix', + """INTEGER DEFAULT (CAST(STRFTIME('%s', 'NOW', 'UTC') AS INTEGER))""", + ), # this should probably be UCT + ('review_identity', 'TEXT'), + ('review_tags', 'TEXT'), + ] + if db._engine.dialect.name == 'postgresql': + columns[5] = ( + 'review_time_posix', + "INTEGER DEFAULT (CAST(EXTRACT(EPOCH FROM NOW() AT TIME ZONE 'UTC') AS INTEGER))", + ) db.add_table( const.REVIEW_TABLE, - ( - ('review_rowid', 'INTEGER PRIMARY KEY'), - ('annot_1_rowid', 'INTEGER NOT NULL'), - ('annot_2_rowid', 'INTEGER NOT NULL'), - ('review_count', 'INTEGER NOT NULL'), - ('review_decision', 'INTEGER NOT NULL'), - ( - 'review_time_posix', - """INTEGER DEFAULT (CAST(STRFTIME('%s', 'NOW', 'UTC') AS INTEGER))""", - ), # this should probably be UCT - ('review_identity', 'TEXT'), - ('review_tags', 'TEXT'), - ), + columns, superkeys=[('annot_1_rowid', 'annot_2_rowid', 'review_count')], docstr=""" Used to store completed user review states of two matched annotations @@ -104,20 +110,26 @@ def update_1_0_3(db, ibs=None): def update_1_1_0(db, ibs=None): + columns = [ + ('test_rowid', 'INTEGER PRIMARY KEY'), + ('test_uuid', 'UUID'), + ('test_user_identity', 'TEXT'), + ('test_challenge_json', 'TEXT'), + ('test_response_json', 'TEXT'), + ('test_result', 'INTEGER'), + ( + 'test_time_posix', + """INTEGER DEFAULT (CAST(STRFTIME('%s', 'NOW', 'UTC') AS INTEGER))""", + ), # this should probably be UCT + ] + if db._engine.dialect.name == 'postgresql': + columns[6] = ( + 'test_time_posix', + "INTEGER DEFAULT (CAST(EXTRACT(EPOCH FROM NOW() AT TIME ZONE 'UTC') AS INTEGER))", + ) db.add_table( const.TEST_TABLE, - ( - ('test_rowid', 'INTEGER PRIMARY KEY'), - ('test_uuid', 'UUID'), - ('test_user_identity', 'TEXT'), - ('test_challenge_json', 'TEXT'), - ('test_response_json', 'TEXT'), - ('test_result', 'INTEGER'), - ( - 'test_time_posix', - """INTEGER DEFAULT (CAST(STRFTIME('%s', 'NOW', 'UTC') AS INTEGER))""", - ), # this should probably be UCT - ), + columns, superkeys=[('test_uuid',)], docstr=""" Used to store tests given to the user, their responses, and their results From 93f7315398d7b676f87f77302043f7b6bf9ec9ca Mon Sep 17 00:00:00 2001 From: karen chan Date: Sun, 1 Nov 2020 23:19:24 +0000 Subject: [PATCH 195/294] Map postgresql types to sqlalchemy types Postgresql types use lowercase so they need to be registered separately from the sqlite types that use uppercase. --- wbia/dtool/types.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/wbia/dtool/types.py b/wbia/dtool/types.py index eeec5f9566..1f204b2e14 100644 --- a/wbia/dtool/types.py +++ b/wbia/dtool/types.py @@ -183,7 +183,12 @@ def process(value): _USER_DEFINED_TYPES = (Dict, List, NDArray, Number, UUID) # SQL type (e.g. 'DICT') to SQLAlchemy type: SQL_TYPE_TO_SA_TYPE = {cls().get_col_spec(): cls for cls in _USER_DEFINED_TYPES} +# Map postgresql types to SQLAlchemy types (postgresql type names are lowercase) +SQL_TYPE_TO_SA_TYPE.update( + {cls().get_col_spec().lower(): cls for cls in _USER_DEFINED_TYPES} +) SQL_TYPE_TO_SA_TYPE['INTEGER'] = Integer +SQL_TYPE_TO_SA_TYPE['integer'] = Integer def initialize_postgresql_types(conn, schema): From 681ddd6746cdc8eee630df15644c6f9c6fb819c2 Mon Sep 17 00:00:00 2001 From: karen chan Date: Sun, 1 Nov 2020 23:20:17 +0000 Subject: [PATCH 196/294] Adjust how uuid types work in postgresql Postgresql has a built in uuid type and it works differently from the sqlite one which is more like a binary field so we need to adjust the way `UUID` class works. --- wbia/dtool/types.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/wbia/dtool/types.py b/wbia/dtool/types.py index 1f204b2e14..ea187c422a 100644 --- a/wbia/dtool/types.py +++ b/wbia/dtool/types.py @@ -158,12 +158,20 @@ def bind_processor(self, dialect): def process(value): if value is None: return value + if not isinstance(value, uuid.UUID): + value = uuid.UUID(value) + + if dialect.name == 'sqlite': + return value.bytes_le + elif dialect.name == 'postgresql': + return value else: if not isinstance(value, uuid.UUID): return uuid.UUID(value).bytes_le else: # hexstring return value.bytes_le + raise RuntimeError(f'Unknown dialect {dialect.name}') return process @@ -171,11 +179,15 @@ def result_processor(self, dialect, coltype): def process(value): if value is None: return value - else: + if dialect.name == 'sqlite': if not isinstance(value, uuid.UUID): return uuid.UUID(bytes_le=value) else: return value + elif dialect.name == 'postgresql': + return value + else: + raise RuntimeError(f'Unknown dialect {dialect.name}') return process From a04ef310ba5aeb587f4a549770960b86e22cd40f Mon Sep 17 00:00:00 2001 From: karen chan Date: Mon, 2 Nov 2020 14:30:36 +0000 Subject: [PATCH 197/294] Update postgres sequence owner when modifying tables The way we modify tables is by creating a new table, dropping the original table and then renaming the new table to the original table. This turns out to be more complicated in postgresql. We have lots of "self-increment" ("serial" in postgresql) fields, e.g. rowid. In postgresql, these are done by having a "sequence", e.g. `images_rowid_seq`. When we create a new table and use the existing sequence and then try to drop the old table with "cascade", the sequence is dropped and the `images_rowid` field is not "self-incrementing" anymore. If we don't drop the old table with "cascade": ``` testdb1=# drop table images; ERROR: cannot drop table images because other objects depend on it DETAIL: default for table images_temp2603a7eb column rowid depends on sequence images_rowid_seq HINT: Use DROP ... CASCADE to drop the dependent objects too. ``` The way to fix this is by changing the sequences to be owned by the new table. That way the old table can be dropped and the "self-incrementing" property is retained in the new table. --- wbia/dtool/sql_control.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 7533bc3e22..d5a72df776 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -2084,6 +2084,14 @@ def modify_table( colname_list = ut.take_column(coldef_list, 0) coltype_list = ut.take_column(coldef_list, 1) + # Find all dependent sequences so we can change the owners of the + # sequences to the new table (for postgresql) + dependent_sequences = [ + (colname, re.search(r"nextval\('([^']*)'", coldef).group(1)) + for colname, coldef in self.get_coldef_list(tablename) + if 'nextval' in coldef + ] + colname_original_list = colname_list[:] colname_dict = {colname: colname for colname in colname_list} colmap_dict = {} @@ -2172,6 +2180,17 @@ def modify_table( self.add_table(tablename_temp, coldef_list, **metadata_keyval2) + # Change owners of sequences from old table to new table + if self._engine.dialect.name == 'postgresql': + new_colnames = [name for name, _ in coldef_list] + for colname, sequence in dependent_sequences: + if colname in new_colnames: + self.executeone( + text( + f'ALTER SEQUENCE {sequence} OWNED BY {tablename_temp}.{colname}' + ) + ) + # Copy data src_list = [] dst_list = [] From c84b62cd8b71f80a8ffa7908c526742f6ba1754f Mon Sep 17 00:00:00 2001 From: karen chan Date: Mon, 2 Nov 2020 20:07:39 +0000 Subject: [PATCH 198/294] Adjust modify_table insert index for postgresql Postgresql tables have an additional "rowid" field which is built-in in sqlite so when columns are inserted with a specific index to a table, we need to + 1 to the index so the columns are in the same order as sqlite. --- wbia/dtool/sql_control.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index d5a72df776..ba2fb22119 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -2118,6 +2118,9 @@ def modify_table( '[sql] WARNING: multiple index inserted add ' 'columns, may cause alignment issues' ) + if self._engine.dialect.name == 'postgresql': + # adjust for the additional "rowid" field + src += 1 colname_list.insert(src, dst) coltype_list.insert(src, type_) insert = True From ac80edefd9bd219a92f41788b491927663adcc7d Mon Sep 17 00:00:00 2001 From: karen chan Date: Mon, 2 Nov 2020 20:09:33 +0000 Subject: [PATCH 199/294] Always include default value in SQLDatabaseController.get_coldef_list For some reason, default values are not included if a field is a primary field. This doesn't work for postgresql because we have rowid primary keys which gets their default value from a sequence. --- wbia/dtool/sql_control.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index ba2fb22119..4d9ac44176 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -2375,7 +2375,7 @@ def get_coldef_list(self, tablename): col_type += ' PRIMARY KEY' elif column[3] == 1: col_type += ' NOT NULL' - elif column[4] is not None: + if column[4] is not None: default_value = six.text_type(column[4]) # HACK: add parens if the value contains parens in the future # all default values should contain parens From 9b80f4c54c0628fc132b62cca4842b9a06686991 Mon Sep 17 00:00:00 2001 From: karen chan Date: Tue, 3 Nov 2020 14:29:35 +0000 Subject: [PATCH 200/294] Add function to compare column definition lists The column definitions returned by sqlite and postgresql are slightly different, for example postgresql types are all lowercase so we need to normalize the column definitions a bit before comparing them. --- wbia/dtool/depcache_table.py | 10 +++++++--- wbia/dtool/sql_control.py | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/wbia/dtool/depcache_table.py b/wbia/dtool/depcache_table.py index 9a1defc607..2a43bcbe7a 100644 --- a/wbia/dtool/depcache_table.py +++ b/wbia/dtool/depcache_table.py @@ -32,7 +32,7 @@ from six.moves import zip, range from wbia.dtool import sqlite3 as lite -from wbia.dtool.sql_control import SQLDatabaseController +from wbia.dtool.sql_control import SQLDatabaseController, compare_coldef_lists from wbia.dtool.types import TYPE_TO_SQLTYPE import time @@ -159,7 +159,9 @@ def ensure_config_table(db): else: current_state = db.get_table_autogen_dict(CONFIG_TABLE) new_state = config_addtable_kw - if current_state['coldef_list'] != new_state['coldef_list']: + if not compare_coldef_lists( + current_state['coldef_list'], new_state['coldef_list'] + ): if predrop_grace_period(CONFIG_TABLE): db.drop_all_tables() db.add_table(**new_state) @@ -1996,7 +1998,9 @@ def initialize(self, _debug=None): self.clear_table() current_state = self.db.get_table_autogen_dict(self.tablename) - if current_state['coldef_list'] != new_state['coldef_list']: + if not compare_coldef_lists( + current_state['coldef_list'], new_state['coldef_list'] + ): logger.info('WARNING TABLE IS MODIFIED') if predrop_grace_period(self.tablename): self.clear_table() diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 4d9ac44176..f28486ce15 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -101,6 +101,24 @@ def sqlite_uri_to_postgres_uri_schema(uri): return (uri, namespace) +def compare_coldef_lists(coldef_list1, coldef_list2): + # Remove "rowid" which is added to postgresql tables + coldef_list1 = [(name, coldef) for name, coldef in coldef_list1 if name != 'rowid'] + coldef_list2 = [(name, coldef) for name, coldef in coldef_list2 if name != 'rowid'] + if len(coldef_list1) != len(coldef_list2): + return False + for i in range(len(coldef_list1)): + name1, coldef1 = coldef_list1[i] + name2, coldef2 = coldef_list2[i] + if name1 != name2: + return False + coldef1 = re.sub(r' DEFAULT \(nextval\(.*', '', coldef1) + coldef2 = re.sub(r' DEFAULT \(nextval\(.*', '', coldef2) + if coldef1.lower() != coldef2.lower(): + return False + return True + + def _unpacker(results): """ HELPER: Unpacks results if unpack_scalars is True. """ if not results: # Check for None or empty list From de33c495d6b23841cf4f486b5fa505ca72318094 Mon Sep 17 00:00:00 2001 From: karen chan Date: Wed, 4 Nov 2020 23:28:32 +0000 Subject: [PATCH 201/294] Call sqlalchemy create_engine once per uri Creating an engine per SQLDatabaseController appears to cause this error when using postgres: ``` > conn = _connect(dsn, connection_factory=connection_factory, **kwasync) E sqlalchemy.exc.OperationalError: (psycopg2.OperationalError) FATAL: sorry, too many clients already E E (Background on this error at: http://sqlalche.me/e/13/e3q8) /virtualenv/env3/lib/python3.7/site-packages/psycopg2/__init__.py:127: OperationalError ``` Fixed by sharing engines across instances of SQLDatabaseController. --- wbia/dtool/sql_control.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index f28486ce15..97c2797a63 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -69,6 +69,22 @@ METADATA_TABLE_COLUMN_NAMES = list(METADATA_TABLE_COLUMNS.keys()) +def create_engine(uri, POSTGRESQL_POOL_SIZE=20, ENGINES={}): + kw = { + # The echo flag is a shortcut to set up SQLAlchemy logging + 'echo': False, + } + if uri.startswith('sqlite:') and ':memory:' in uri: + # Don't share engines for in memory sqlite databases + return sqlalchemy.create_engine(uri, **kw) + if uri not in ENGINES: + if uri.startswith('postgresql:'): + # pool_size is not available for sqlite + kw['pool_size'] = POSTGRESQL_POOL_SIZE + ENGINES[uri] = sqlalchemy.create_engine(uri, **kw) + return ENGINES[uri] + + def sqlite_uri_to_postgres_uri_schema(uri): from wbia.init.sysres import get_workdir @@ -523,11 +539,7 @@ def __init_engine(self): uri, schema = sqlite_uri_to_postgres_uri_schema(self.uri) self.uri = uri self.schema = schema - self._engine = sqlalchemy.create_engine( - self.uri, - # The echo flag is a shortcut to set up SQLAlchemy logging - echo=False, - ) + self._engine = create_engine(self.uri) @classmethod def from_uri(cls, uri, readonly=READ_ONLY, timeout=TIMEOUT): @@ -622,10 +634,7 @@ def _create_connection(self): uri += '?mode=ro' if os.getenv('POSTGRES'): uri, schema = sqlite_uri_to_postgres_uri_schema(uri) - engine = sqlalchemy.create_engine( - uri, - echo=False, - ) + engine = create_engine(uri) connection = engine.connect() if engine.dialect.name == 'postgresql': connection.execute(f'CREATE SCHEMA IF NOT EXISTS {schema}') From f3227fbd2b04f02f7f09b18cdca9e71034c3372f Mon Sep 17 00:00:00 2001 From: karen chan Date: Sat, 7 Nov 2020 22:39:12 +0000 Subject: [PATCH 202/294] Load downloaded test sqlite databases in postgresql Use pgloader (installed from apt-get) to copy sqlite databases to postgresql. There is a bit of extra code to move the tables to the right schema/namespace in postgresql and to include the sqlite built-in rowid columns. --- wbia/dtool/copy_sqlite_to_postgres.py | 148 ++++++++++++++++++++++++++ wbia/init/sysres.py | 5 + 2 files changed, 153 insertions(+) create mode 100644 wbia/dtool/copy_sqlite_to_postgres.py diff --git a/wbia/dtool/copy_sqlite_to_postgres.py b/wbia/dtool/copy_sqlite_to_postgres.py new file mode 100644 index 0000000000..3d08a896d5 --- /dev/null +++ b/wbia/dtool/copy_sqlite_to_postgres.py @@ -0,0 +1,148 @@ +# -*- coding: utf-8 -*- +""" +Copy sqlite database into a postgresql database using pgloader (from +apt-get) +""" +import os +import re +import subprocess +import tempfile + +import sqlalchemy +from wbia.dtool.sql_control import ( + create_engine, + sqlite_uri_to_postgres_uri_schema, +) + + +def get_sqlite_db_paths(parent_dir): + for dirpath, dirnames, filenames in os.walk(parent_dir): + for filename in filenames: + if ( + filename.endswith('.sqlite') or filename.endswith('.sqlite3') + ) and 'backup' not in filename: + yield os.path.join(dirpath, filename) + + +def add_rowids(engine): + connection = engine.connect() + create_table_stmts = connection.execute( + """\ + SELECT name, sql FROM sqlite_master + WHERE name NOT LIKE 'sqlite_%'""" + ).fetchall() + for table, stmt in create_table_stmts: + # Create a new table with suffix "_with_rowid" + new_table = f'{table}_with_rowid' + stmt = re.sub( + r'CREATE TABLE [^ ]* \(', + f'CREATE TABLE {new_table} (rowid INTEGER NOT NULL UNIQUE, ', + stmt, + ) + connection.execute(stmt) + connection.execute(f'INSERT INTO {new_table} SELECT rowid, * FROM {table}') + + +def remove_rowids(engine): + connection = engine.connect() + create_table_stmts = connection.execute( + """\ + SELECT name, sql FROM sqlite_master + WHERE name LIKE '%_with_rowid'""" + ).fetchall() + for table, stmt in create_table_stmts: + connection.execute(f'DROP TABLE {table}') + + +def before_pgloader(engine, schema): + connection = engine.connect() + for domain, base_type in ( + ('dict', 'json'), + ('list', 'json'), + ('ndarray', 'bytea'), + ('numpy', 'bytea'), + ): + try: + connection.execute(f'CREATE DOMAIN {domain} AS {base_type}') + except sqlalchemy.exc.ProgrammingError: + # sqlalchemy.exc.ProgrammingError: + # (psycopg2.errors.DuplicateObject) type "dict" already + # exists + pass + + +def run_pgloader(sqlite_db_path, postgres_uri, tempdir): + # create the pgloader source file + fname = os.path.join(tempdir, 'sqlite.load') + with open(fname, 'w') as f: + f.write( + f"""\ +LOAD DATABASE + FROM '{sqlite_db_path}' + INTO {postgres_uri} + + WITH include drop, + create tables, + create indexes, + reset sequences + + SET work_mem to '16MB', + maintenance_work_mem to '512 MB' + + CAST type uuid to uuid using sql-server-uniqueidentifier-to-uuid + + INCLUDING ONLY TABLE NAMES LIKE '%_with_rowid'; +""" + ) + subprocess.check_output(['pgloader', fname]) + + +def after_pgloader(engine, schema): + connection = engine.connect() + connection.execute(f'CREATE SCHEMA IF NOT EXISTS {schema}') + table_pkeys = connection.execute( + """\ + SELECT table_name, column_name + FROM information_schema.table_constraints + NATURAL JOIN information_schema.constraint_column_usage + WHERE table_schema = 'public' + AND constraint_type = 'PRIMARY KEY'""" + ).fetchall() + for (table_name, pkey) in table_pkeys: + new_table_name = table_name.replace('_with_rowid', '') + # Rename tables from "images_with_rowid" to "images" + connection.execute(f'ALTER TABLE {table_name} RENAME TO {new_table_name}') + # Create sequences for rowid fields + for column_name in ('rowid', pkey): + seq_name = f'{new_table_name}_{column_name}_seq' + connection.execute(f'CREATE SEQUENCE {seq_name}') + connection.execute( + f"SELECT setval('{seq_name}', (SELECT max({column_name}) FROM {new_table_name}))" + ) + connection.execute( + f"ALTER TABLE {new_table_name} ALTER COLUMN {column_name} SET DEFAULT nextval('{seq_name}')" + ) + connection.execute( + f'ALTER SEQUENCE {seq_name} OWNED BY {new_table_name}.{column_name}' + ) + # Set schema / namespace to "_ibeis_database" for example + connection.execute(f'ALTER TABLE {new_table_name} SET SCHEMA {schema}') + + +def copy_sqlite_to_postgres(parent_dir): + with tempfile.TemporaryDirectory() as tempdir: + for sqlite_db_path in get_sqlite_db_paths(parent_dir): + # create new tables with sqlite built-in rowid column + sqlite_engine = create_engine(f'sqlite:///{sqlite_db_path}') + add_rowids(sqlite_engine) + + try: + uri, schema = sqlite_uri_to_postgres_uri_schema( + f'sqlite:///{os.path.realpath(sqlite_db_path)}' + ) + engine = create_engine(uri) + before_pgloader(engine, schema) + run_pgloader(sqlite_db_path, uri, tempdir) + after_pgloader(engine, schema) + finally: + remove_rowids(sqlite_engine) diff --git a/wbia/init/sysres.py b/wbia/init/sysres.py index 98afdfc811..cd310a9c23 100644 --- a/wbia/init/sysres.py +++ b/wbia/init/sysres.py @@ -11,6 +11,7 @@ import ubelt as ub from six.moves import input, zip, map from wbia import constants as const +from wbia.dtool.copy_sqlite_to_postgres import copy_sqlite_to_postgres (print, rrr, profile) = ut.inject2(__name__) @@ -421,6 +422,8 @@ def ensure_pz_mtest(): mtest_zipped_url = const.ZIPPED_URLS.PZ_MTEST mtest_dir = ut.grab_zipped_url(mtest_zipped_url, ensure=True, download_dir=workdir) logger.info('have mtest_dir=%r' % (mtest_dir,)) + if os.getenv('POSTGRES'): + copy_sqlite_to_postgres(mtest_dir) # update the the newest database version import wbia @@ -881,6 +884,8 @@ def ensure_db_from_url(zipped_db_url): dbdir = ut.grab_zipped_url( zipped_url=zipped_db_url, ensure=True, download_dir=workdir ) + if os.getenv('POSTGRES'): + copy_sqlite_to_postgres(dbdir) logger.info('have %s=%r' % (zipped_db_url, dbdir)) return dbdir From 2d3d32d43af6a3ac6670a1a4e63bb6ce0925a7ff Mon Sep 17 00:00:00 2001 From: karen chan Date: Sat, 7 Nov 2020 22:41:38 +0000 Subject: [PATCH 203/294] Map bigint postgresql type to Integer in dtool/types.py There are some "bigint" fields in sqlite used as boolean fields and caused this kind of error: ``` def do_execute(self, cursor, statement, parameters, context=None): > cursor.execute(statement, parameters) E sqlalchemy.exc.ProgrammingError: (psycopg2.errors.DatatypeMismatch) column "annot_exemplar_flag" is of type bigint but expression is of type boolean E LINE 1: ...beis_database.annotations SET annot_exemplar_flag=true WHERE... E ^ E HINT: You will need to rewrite or cast the expression. E E [SQL: UPDATE _ibeis_database.annotations SET annot_exemplar_flag=%(e0)s WHERE rowid = %(_identifier)s] E [parameters: {'e0': True, '_identifier': 1}] E (Background on this error at: http://sqlalche.me/e/13/f405) ``` --- wbia/dtool/types.py | 1 + 1 file changed, 1 insertion(+) diff --git a/wbia/dtool/types.py b/wbia/dtool/types.py index ea187c422a..dc76dbc778 100644 --- a/wbia/dtool/types.py +++ b/wbia/dtool/types.py @@ -201,6 +201,7 @@ def process(value): ) SQL_TYPE_TO_SA_TYPE['INTEGER'] = Integer SQL_TYPE_TO_SA_TYPE['integer'] = Integer +SQL_TYPE_TO_SA_TYPE['bigint'] = Integer def initialize_postgresql_types(conn, schema): From 4ea0573a38601b4a3050778d7e0e23c6239d125d Mon Sep 17 00:00:00 2001 From: karen chan Date: Sat, 7 Nov 2020 23:09:32 +0000 Subject: [PATCH 204/294] Change column names to lowercase when inserting in postgres Postgres table column names are lowercase causing `SQLDatabaseController._add` to fail because the parameterized values do not match the column names expected by the insert statement: ``` DOCTEST TRACEBACK Traceback (most recent call last): File "/virtualenv/env3/lib/python3.7/site-packages/xdoctest/doctest_example.py", line 598, in run exec(code, test_globals) File "", line rel: 9, abs: 525, in >>> assert np.all(v.vecs[0] == v.vecs[1]) File "/wbia/wildbook-ia/wbia/_wbia_object.py", line 567, in __getattr__ miss_data = _rowid_getter(miss_rowids) File "/wbia/wildbook-ia/wbia/_wbia_object.py", line 160, in _rowid_getter data = ibs_callable(rowids, config2_=self._config) File "/wbia/wbia-utool/utool/util_decor.py", line 556, in wrp_asivo return func(self, input_, *args, **kwargs) File "/wbia/wildbook-ia/wbia/control/accessor_decors.py", line 487, in getter_vector_wrp return func(*args, **kwargs) File "/wbia/wildbook-ia/wbia/control/manual_feat_funcs.py", line 176, in get_annot_vecs 'feat', aid_list, 'vecs', config=config2_, ensure=ensure, eager=eager File "/wbia/wbia-utool/utool/util_decor.py", line 489, in wrp_asi2 return func(self, *args, **kwargs) File "/wbia/wildbook-ia/wbia/dtool/depcache_control.py", line 885, in get tbl_rowids = self.get_rowids(tablename, input_tuple, **rowid_kw) File "/wbia/wildbook-ia/wbia/dtool/depcache_control.py", line 722, in get_rowids **_kwargs, File "/wbia/wildbook-ia/wbia/dtool/depcache_control.py", line 634, in get_parent_rowids _parent_rowids, config=config_, recompute=_recompute, **_kwargs File "/wbia/wildbook-ia/wbia/dtool/depcache_table.py", line 2377, in get_rowid config=config, File "/wbia/wildbook-ia/wbia/dtool/depcache_table.py", line 2155, in ensure_rows nInput=nChunkInput, File "/wbia/wildbook-ia/wbia/dtool/sql_control.py", line 954, in _add result = self.connection.execute(insert_stmt.values(vals)) File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/engine/base.py", line 1011, in execute return meth(self, multiparams, params) File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/sql/elements.py", line 298, in _execute_on_connection return connection._execute_clauseelement(self, multiparams, params) File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/engine/base.py", line 1121, in _execute_clauseelement else None, File "", line 1, in File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/sql/elements.py", line 481, in compile return self._compiler(dialect, bind=bind, **kw) File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/sql/elements.py", line 487, in _compiler return dialect.statement_compiler(dialect, self, **kw) File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/sql/compiler.py", line 592, in __init__ Compiled.__init__(self, dialect, statement, **kwargs) File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/sql/compiler.py", line 322, in __init__ self.string = self.process(self.statement, **compile_kwargs) File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/sql/compiler.py", line 352, in process return obj._compiler_dispatch(self, **kwargs) File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/sql/visitors.py", line 96, in _compiler_dispatch return meth(self, **kw) File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/sql/compiler.py", line 2430, in visit_insert self, insert_stmt, crud.ISINSERT, **kw File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/sql/crud.py", line 64, in _setup_crud_params return _get_crud_params(compiler, stmt, **kw) File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/sql/crud.py", line 179, in _get_crud_params % (", ".join("%s" % c for c in check)) sqlalchemy.exc.CompileError: Unconsumed column names: M DOCTEST REPRODUCTION CommandLine: pytest /wbia/wildbook-ia/wbia/_wbia_object.py::ObjectView1D:0 ``` --- wbia/dtool/sql_control.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 97c2797a63..dfaf77cdce 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -936,6 +936,12 @@ def _add(self, tblname, colnames, params_iter, unpack_scalars=True, **kwargs): parameterized_values = [ {col: val for col, val in zip(colnames, params)} for params in params_iter ] + if self._engine.dialect.name == 'postgresql': + # postgresql column names are lowercase + parameterized_values = [ + {col.lower(): val for col, val in params.items()} + for params in parameterized_values + ] table = self._reflect_table(tblname) # It would be possible to do one insert, From d84270f2d9824bf6f85e46930e762bfe9f63ab09 Mon Sep 17 00:00:00 2001 From: karen chan Date: Sat, 7 Nov 2020 23:43:20 +0000 Subject: [PATCH 205/294] Fix rowid integer casting code in SQLDatabaseController ``` id_iter = [id_ is not None and int(id_) or id_ for id_ in id_iter] ``` if `id_` is `np.int64(0)`, `int(id_)` returns 0 and `int(id_) or id_` returns `id_` so it is not casted to an integer correctly. ``` Traceback (most recent call last): File "/virtualenv/env3/lib/python3.7/site-packages/xdoctest/doctest_example.py", line 598, in run exec(code, test_globals) File "", line rel: 7, abs: 626, in >>> result = set_annot_pair_as_negative_match(ibs, aid1, aid2, dryrun) File "/wbia/wildbook-ia/wbia/annotmatch_funcs.py", line 696, in set_annot_pair_as_negative_match next_nids = ibs.make_next_nids(num=1) File "/wbia/wildbook-ia/wbia/other/ibsfuncs.py", line 3020, in make_next_nids location_text=location_text, File "/wbia/wildbook-ia/wbia/other/ibsfuncs.py", line 3104, in make_next_name species_code = ibs.get_species_codes(species_rowid) File "/wbia/wbia-utool/utool/util_decor.py", line 452, in wrp_asi ret = func(self, [input_], *args, **kwargs) File "/wbia/wildbook-ia/wbia/control/accessor_decors.py", line 470, in wrp_getter return func(*args, **kwargs) File "/wbia/wildbook-ia/wbia/control/manual_species_funcs.py", line 556, in get_species_codes const.SPECIES_TABLE, (SPECIES_CODE,), species_rowid_list File "/wbia/wildbook-ia/wbia/dtool/sql_control.py", line 1382, in get tblname, colnames, params_iter, where_clause, eager=eager, **kwargs File "/wbia/wildbook-ia/wbia/dtool/sql_control.py", line 1228, in get_where **kwargs, File "/wbia/wildbook-ia/wbia/dtool/sql_control.py", line 1671, in executemany value = self.executeone(operation, params, keepwrap=keepwrap) File "/wbia/wildbook-ia/wbia/dtool/sql_control.py", line 1621, in executeone results = self.connection.execute(operation, params) File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/engine/base.py", line 1011, in execute return meth(self, multiparams, params) File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/sql/elements.py", line 298, in _execute_on_connection return connection._execute_clauseelement(self, multiparams, params) File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/engine/base.py", line 1130, in _execute_clauseelement distilled_params, File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/engine/base.py", line 1317, in _execute_context e, statement, parameters, cursor, context File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/engine/base.py", line 1511, in _handle_dbapi_exception sqlalchemy_exception, with_traceback=exc_info[2], from_=e File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/util/compat.py", line 182, in raise_ raise exception File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/engine/base.py", line 1277, in _execute_context cursor, statement, parameters, context File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/engine/default.py", line 593, in do_execute cursor.execute(statement, parameters) sqlalchemy.exc.ProgrammingError: (psycopg2.ProgrammingError) can't adapt type 'numpy.int64' [SQL: SELECT _ibeis_database.species.species_code FROM _ibeis_database.species WHERE rowid = %(_identifier)s] [parameters: {'_identifier': 0}] (Background on this error at: http://sqlalche.me/e/13/f405) DOCTEST REPRODUCTION CommandLine: pytest /wbia/wildbook-ia/wbia/annotmatch_funcs.py::set_annot_pair_as_negative_match:0 ``` --- wbia/dtool/sql_control.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index dfaf77cdce..2b791a63f0 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -1656,7 +1656,7 @@ def set( if id_colname == 'rowid': # Cast all item values to in, in case values are numpy.integer* # Strangely allow for None values - id_list = [id_ is not None and int(id_) or id_ for id_ in id_list] + id_list = [id_ if id_ is None else int(id_) for id_ in id_list] else: # b/c rowid doesn't really exist as a column id_column = table.c[id_colname] where_clause = where_clause.bindparams( @@ -1681,7 +1681,7 @@ def delete(self, tblname, id_list, id_colname='rowid', **kwargs): if id_colname == 'rowid': # Cast all item values to in, in case values are numpy.integer* # Strangely allow for None values - id_list = [id_ is not None and int(id_) or id_ for id_ in id_list] + id_list = [id_ if id_ is None else int(id_) for id_ in id_list] else: # b/c rowid doesn't really exist as a column id_column = table.c[id_colname] where_clause = where_clause.bindparams( From e6db56320cc340790a74db82b046ba06122e5a06 Mon Sep 17 00:00:00 2001 From: karen chan Date: Tue, 10 Nov 2020 16:28:53 +0000 Subject: [PATCH 206/294] Fix sql equality "=" instead of "==" in postgres ``` Traceback (most recent call last): File "/virtualenv/env3/lib/python3.7/site-packages/xdoctest/doctest_example.py", line 598, in run exec(code, test_globals) File "", line rel: 5, abs: 214, in >>> daid_list = ibs.get_valid_aids(species=wbia.const.TEST_SPECIES.ZEB_PLAIN) File "/wbia/wildbook-ia/wbia/control/manual_annot_funcs.py", line 806, in get_valid_aids min_timedelta=min_timedelta, File "/wbia/wildbook-ia/wbia/control/manual_annot_funcs.py", line 877, in filter_annotation_set aid_list = ibs.filter_aids_to_species(aid_list, species) File "/wbia/wildbook-ia/wbia/other/ibsfuncs.py", line 5104, in filter_aids_to_species aid_list_ = ut.take_column(ibs.db.connection.execute(operation).fetchall(), 0) File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/engine/base.py", line 1003, in execute return self._execute_text(object_, multiparams, params) File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/engine/base.py", line 1178, in _execute_text parameters, File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/engine/base.py", line 1317, in _execute_context e, statement, parameters, cursor, context File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/engine/base.py", line 1511, in _handle_dbapi_exception sqlalchemy_exception, with_traceback=exc_info[2], from_=e File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/util/compat.py", line 182, in raise_ raise exception File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/engine/base.py", line 1277, in _execute_context cursor, statement, parameters, context File "/virtualenv/env3/lib/python3.7/site-packages/sqlalchemy/engine/default.py", line 593, in do_execute cursor.execute(statement, parameters) sqlalchemy.exc.ProgrammingError: (psycopg2.errors.UndefinedFunction) operator does not exist: integer == integer LINE 1: ...ELECT rowid from annotations WHERE (species_rowid == 1) AND ... ^ HINT: No operator matches the given name and argument type(s). You might need to add explicit type casts. [SQL: SELECT rowid from annotations WHERE (species_rowid == 1) AND rowid IN (1,2,3,4,5,6,7,8,9,10,11,12,13)] (Background on this error at: http://sqlalche.me/e/13/f405) DOCTEST REPRODUCTION CommandLine: pytest /wbia/wildbook-ia/wbia/algo/hots/neighbor_index_cache.py::build_nnindex_cfgstr:0 /wbia/wildbook-ia/wbia/algo/hots/neighbor_index_cache.py:214: ProgrammingError ``` --- wbia/other/ibsfuncs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wbia/other/ibsfuncs.py b/wbia/other/ibsfuncs.py index 3bbf2c8293..41a459ec2e 100644 --- a/wbia/other/ibsfuncs.py +++ b/wbia/other/ibsfuncs.py @@ -5146,7 +5146,7 @@ def filter_aids_to_species(ibs, aid_list, species, speedhack=True): species_rowid = ibs.get_species_rowids_from_text(species) if speedhack: list_repr = ','.join(map(str, aid_list)) - operation = 'SELECT rowid from annotations WHERE (species_rowid == {species_rowid}) AND rowid IN ({aids})' + operation = 'SELECT rowid from annotations WHERE (species_rowid = {species_rowid}) AND rowid IN ({aids})' operation = operation.format(aids=list_repr, species_rowid=species_rowid) aid_list_ = ut.take_column(ibs.db.connection.execute(operation).fetchall(), 0) else: From 704a2e559da5287b50bbcf05eb4a76ed382bbbea Mon Sep 17 00:00:00 2001 From: karen chan Date: Tue, 10 Nov 2020 21:59:45 +0000 Subject: [PATCH 207/294] Adjust JSONCodeableType to work with postgres It seems postgres returns the value already decoded so there's no need to decode it again: ``` DOCTEST TRACEBACK Traceback (most recent call last): File "/virtualenv/env3/lib/python3.7/site-packages/xdoctest/doctest_example.py", line 598, in run exec(code, test_globals) File "", line rel: 5, abs: 49, in >>> cm_list = qreq_.execute() File "/wbia/wildbook-ia/wbia/algo/hots/query_request.py", line 1322, in execute invalidate_supercache=invalidate_supercache, File "/wbia/wildbook-ia/wbia/algo/hots/match_chips4.py", line 129, in submit_query_request invalidate_supercache=invalidate_supercache, File "/wbia/wildbook-ia/wbia/algo/hots/match_chips4.py", line 320, in execute_query_and_save_L1 qaid2_cm = execute_query2(qreq_, verbose, save_qcache, batch_size, use_supercache) ... File "/wbia/wildbook-ia/wbia/dtool/types.py", line 58, in process return from_json(value) File "/wbia/wbia-utool/utool/util_cache.py", line 677, in from_json val = json.loads(json_str, object_hook=object_hook) File "/usr/lib/python3.7/json/__init__.py", line 341, in loads raise TypeError(f'the JSON object must be str, bytes or bytearray, ' TypeError: the JSON object must be str, bytes or bytearray, not dict DOCTEST REPRODUCTION CommandLine: pytest /wbia/wildbook-ia/wbia/expt/test_result.py::build_cmsinfo:0 /wbia/wildbook-ia/wbia/expt/test_result.py:49: TypeError ``` --- wbia/dtool/types.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/wbia/dtool/types.py b/wbia/dtool/types.py index dc76dbc778..4734a16763 100644 --- a/wbia/dtool/types.py +++ b/wbia/dtool/types.py @@ -60,6 +60,9 @@ def result_processor(self, dialect, coltype): def process(value): if value is None: return value + elif dialect.name == 'postgresql': + # postgresql doesn't need the value to be json decoded + return value else: return from_json(value) From c1b3772e3aef39343c77fda0e0dad42c5385d4b3 Mon Sep 17 00:00:00 2001 From: karen chan Date: Thu, 12 Nov 2020 16:15:31 +0000 Subject: [PATCH 208/294] Ignore rowid in SQLDatabaseController.merge_databases_new `rowid` is built-in to sqlite tables and we use it extensively when querying throughout the codebase. It has been added to the postgres tables as well but that causes code to fail because the code doesn't expect `rowid` to be returned. So we need to remove the `rowid` column and adjust indices. --- wbia/dtool/sql_control.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 2b791a63f0..ab56f214d7 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -3230,6 +3230,12 @@ def find_depth(tablename, dependency_digraph): extern_tablename_list, extern_primarycolnames_list, ) = new_transferdata + if column_names[0] == 'rowid': + # This is a postgresql database, ignore the rowid column + # which is built-in to sqlite + column_names = column_names[1:] + column_list = column_list[1:] + extern_colx_list = [i - 1 for i in extern_colx_list] # FIXME: extract the primary rowid column a little bit nicer assert column_names[0].endswith('_rowid') old_rowid_list = column_list[0] From 379f420ed25cce4cf4639d50d0965f2929c2e736 Mon Sep 17 00:00:00 2001 From: karen chan Date: Wed, 18 Nov 2020 14:49:01 +0000 Subject: [PATCH 209/294] Adjust test expectations to work with postgres and sqlite Sometimes we get different ids from postgres and sqlite and the ids are not always sorted in postgres so we need to adjust the tests for both to work. --- wbia/algo/hots/neighbor_index_cache.py | 6 +++--- wbia/control/manual_name_funcs.py | 7 +++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/wbia/algo/hots/neighbor_index_cache.py b/wbia/algo/hots/neighbor_index_cache.py index 1f8d8356c4..cb1dc76bc2 100644 --- a/wbia/algo/hots/neighbor_index_cache.py +++ b/wbia/algo/hots/neighbor_index_cache.py @@ -338,14 +338,14 @@ def request_augmented_wbia_nnindexer( >>> ZEB_PLAIN = wbia.const.TEST_SPECIES.ZEB_PLAIN >>> ibs = wbia.opendb('testdb1') >>> use_memcache, max_covers, verbose = True, None, True - >>> daid_list = ibs.get_valid_aids(species=ZEB_PLAIN)[0:6] + >>> daid_list = sorted(ibs.get_valid_aids(species=ZEB_PLAIN))[0:6] >>> qreq_ = ibs.new_query_request(daid_list, daid_list) >>> qreq_.qparams.min_reindex_thresh = 1 >>> min_reindex_thresh = qreq_.qparams.min_reindex_thresh >>> # CLEAR CACHE for clean test >>> clear_uuid_cache(qreq_) >>> # LOAD 3 AIDS INTO CACHE - >>> aid_list = ibs.get_valid_aids(species=ZEB_PLAIN)[0:3] + >>> aid_list = sorted(ibs.get_valid_aids(species=ZEB_PLAIN))[0:3] >>> # Should fallback >>> nnindexer = request_augmented_wbia_nnindexer(qreq_, aid_list) >>> # assert the fallback @@ -630,7 +630,7 @@ def group_daids_by_cached_nnindexer( >>> # STEP 0: CLEAR THE CACHE >>> clear_uuid_cache(qreq_) >>> # STEP 1: ASSERT EMPTY INDEX - >>> daid_list = ibs.get_valid_aids(species=ZEB_PLAIN)[0:3] + >>> daid_list = sorted(ibs.get_valid_aids(species=ZEB_PLAIN))[0:3] >>> uncovered_aids, covered_aids_list = group_daids_by_cached_nnindexer( ... qreq_, daid_list, min_reindex_thresh, max_covers) >>> result1 = uncovered_aids, covered_aids_list diff --git a/wbia/control/manual_name_funcs.py b/wbia/control/manual_name_funcs.py index 4afdc5a0fb..3c4a2d65cd 100644 --- a/wbia/control/manual_name_funcs.py +++ b/wbia/control/manual_name_funcs.py @@ -617,7 +617,7 @@ def get_name_exemplar_aids(ibs, nid_list): >>> aid_list = ibs.get_valid_aids() >>> nid_list = ibs.get_annot_name_rowids(aid_list) >>> exemplar_aids_list = ibs.get_name_exemplar_aids(nid_list) - >>> result = exemplar_aids_list + >>> result = [sorted(i) for i in exemplar_aids_list] >>> print(result) [[], [2, 3], [2, 3], [], [5, 6], [5, 6], [7], [8], [], [10], [], [12], [13]] """ @@ -659,7 +659,7 @@ def get_name_gids(ibs, nid_list): >>> ibs = wbia.opendb('testdb1') >>> nid_list = ibs._get_all_known_name_rowids() >>> gids_list = ibs.get_name_gids(nid_list) - >>> result = gids_list + >>> result = [sorted(gids) for gids in gids_list] >>> print(result) [[2, 3], [5, 6], [7], [8], [10], [12], [13]] """ @@ -1042,9 +1042,8 @@ def get_name_rowids_from_text(ibs, name_text_list, ensure=True): >>> result += str(ibs._get_all_known_name_rowids()) >>> print('----') >>> ibs.print_name_table() + >>> assert result == f'{name_rowid_list}\n[1, 2, 3, 4, 5, 6, 7]' >>> print(result) - [8, 9, 0, 10, 11, 0] - [1, 2, 3, 4, 5, 6, 7] """ if ensure: name_rowid_list = ibs.add_names(name_text_list) From 536c2c0b1212dc2a1ff922f0a4993c2e365d9dda Mon Sep 17 00:00:00 2001 From: karen chan Date: Mon, 30 Nov 2020 17:12:01 +0000 Subject: [PATCH 210/294] Add postgresql service to github testing action --- .github/workflows/testing.yml | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index f534f32bf5..65b85af15f 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -47,10 +47,28 @@ jobs: os: [ubuntu-latest] # Disable "macos-latest" for now # For speed, we choose one version and that should be the lowest common denominator python-version: [3.6, 3.7, 3.8] + postgres: ['', postgres] + + services: + db: + image: postgres:10 + env: + POSTGRES_PASSWORD: wbia + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 steps: # Checkout and env setup - uses: actions/checkout@v2 + - name: Install pgloader + if: matrix.postgres + run: sudo apt-get install pgloader - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: @@ -65,12 +83,15 @@ jobs: # Install and test - name: Install the project - run: pip install -e .[tests] + run: pip install -e .[tests,postgres] - name: Test with pytest run: | mkdir -p data/work python -m wbia --set-workdir data/work --preload-exit pytest --slow --web-tests + env: + WBIA_BASE_DB_URI: postgresql://postgres:wbia@localhost:5432 + POSTGRES: ${{ matrix.postgres }} on-failure: # This is not in the 'test' job itself because it would otherwise notify once per matrix combination. From 30a424f529eafd28bf1aa5955dfe9d7d2eadaa45 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Thu, 10 Dec 2020 20:37:52 -0800 Subject: [PATCH 211/294] Use a database connect and release pattern Connect to the database when needed and then release the connection when finished. The goal is to minimize the number of open connections. This effectively releases the connections to the connection pool so that they may be used by the next process that needs a connection. --- wbia/annotmatch_funcs.py | 30 +-- wbia/control/DB_SCHEMA.py | 3 +- wbia/control/manual_annot_funcs.py | 11 +- wbia/control/manual_image_funcs.py | 27 ++- wbia/control/manual_imageset_funcs.py | 13 +- wbia/control/manual_name_funcs.py | 14 +- wbia/control/manual_review_funcs.py | 37 +-- wbia/dtool/sql_control.py | 311 +++++++++++--------------- wbia/other/ibsfuncs.py | 28 ++- wbia/tests/dtool/test_sql_control.py | 165 +++++++------- 10 files changed, 297 insertions(+), 342 deletions(-) diff --git a/wbia/annotmatch_funcs.py b/wbia/annotmatch_funcs.py index 579ebaefb2..5e23533925 100644 --- a/wbia/annotmatch_funcs.py +++ b/wbia/annotmatch_funcs.py @@ -45,14 +45,15 @@ def get_annotmatch_rowids_from_aid1(ibs, aid1_list, eager=True, nInput=None): params_iter = zip(aid1_list) if True: # HACK IN INDEX - ibs.db.connection.execute( - """ - CREATE INDEX IF NOT EXISTS aid1_to_am ON {ANNOTMATCH_TABLE} ({annot_rowid1}); - """.format( - ANNOTMATCH_TABLE=ibs.const.ANNOTMATCH_TABLE, - annot_rowid1=manual_annotmatch_funcs.ANNOT_ROWID1, + with ibs.db.connect() as conn: + conn.execute( + """ + CREATE INDEX IF NOT EXISTS aid1_to_am ON {ANNOTMATCH_TABLE} ({annot_rowid1}); + """.format( + ANNOTMATCH_TABLE=ibs.const.ANNOTMATCH_TABLE, + annot_rowid1=manual_annotmatch_funcs.ANNOT_ROWID1, + ) ) - ) where_colnames = [manual_annotmatch_funcs.ANNOT_ROWID1] annotmatch_rowid_list = ibs.db.get_where_eq( ibs.const.ANNOTMATCH_TABLE, @@ -82,14 +83,15 @@ def get_annotmatch_rowids_from_aid2( nInput = len(aid2_list) if True: # HACK IN INDEX - ibs.db.connection.execute( - """ - CREATE INDEX IF NOT EXISTS aid2_to_am ON {ANNOTMATCH_TABLE} ({annot_rowid2}); - """.format( - ANNOTMATCH_TABLE=ibs.const.ANNOTMATCH_TABLE, - annot_rowid2=manual_annotmatch_funcs.ANNOT_ROWID2, + with ibs.db.connect() as conn: + conn.execute( + """ + CREATE INDEX IF NOT EXISTS aid2_to_am ON {ANNOTMATCH_TABLE} ({annot_rowid2}); + """.format( + ANNOTMATCH_TABLE=ibs.const.ANNOTMATCH_TABLE, + annot_rowid2=manual_annotmatch_funcs.ANNOT_ROWID2, + ) ) - ) colnames = (manual_annotmatch_funcs.ANNOTMATCH_ROWID,) # FIXME: col_rowid is not correct params_iter = zip(aid2_list) diff --git a/wbia/control/DB_SCHEMA.py b/wbia/control/DB_SCHEMA.py index de4c5643bd..12047052ef 100644 --- a/wbia/control/DB_SCHEMA.py +++ b/wbia/control/DB_SCHEMA.py @@ -2167,7 +2167,8 @@ def dump_schema_sql(): db = dt.SQLDatabaseController.from_uri(':memory:') DB_SCHEMA_CURRENT.update_current(db) - dump_str = dumps(db.connection) + with db.connect() as conn: + dump_str = dumps(conn) logger.info(dump_str) for tablename in db.get_table_names(): diff --git a/wbia/control/manual_annot_funcs.py b/wbia/control/manual_annot_funcs.py index 6ad08986fb..070da45f0a 100644 --- a/wbia/control/manual_annot_funcs.py +++ b/wbia/control/manual_annot_funcs.py @@ -2484,11 +2484,12 @@ def get_annot_part_rowids(ibs, aid_list, is_staged=False): """ # FIXME: This index should when the database is defined. # Ensure that an index exists on the image column of the annotation table - ibs.db.connection.execute( - """ - CREATE INDEX IF NOT EXISTS aid_to_part_rowids ON parts (annot_rowid); - """ - ) + with ibs.db.connect() as conn: + conn.execute( + """ + CREATE INDEX IF NOT EXISTS aid_to_part_rowids ON parts (annot_rowid); + """ + ) # The index maxes the following query very efficient part_rowids_list = ibs.db.get( ibs.const.PART_TABLE, diff --git a/wbia/control/manual_image_funcs.py b/wbia/control/manual_image_funcs.py index 75c6b4a019..7170088b66 100644 --- a/wbia/control/manual_image_funcs.py +++ b/wbia/control/manual_image_funcs.py @@ -2180,13 +2180,14 @@ def get_image_imgsetids(ibs, gid_list): if NEW_INDEX_HACK: # FIXME: This index should when the database is defined. # Ensure that an index exists on the image column of the annotation table - ibs.db.connection.execute( - """ - CREATE INDEX IF NOT EXISTS gs_to_gids ON {GSG_RELATION_TABLE} ({IMAGE_ROWID}); - """.format( - GSG_RELATION_TABLE=const.GSG_RELATION_TABLE, IMAGE_ROWID=IMAGE_ROWID + with ibs.db.connect() as conn: + conn.execute( + """ + CREATE INDEX IF NOT EXISTS gs_to_gids ON {GSG_RELATION_TABLE} ({IMAGE_ROWID}); + """.format( + GSG_RELATION_TABLE=const.GSG_RELATION_TABLE, IMAGE_ROWID=IMAGE_ROWID + ) ) - ) colnames = ('imageset_rowid',) imgsetids_list = ibs.db.get( const.GSG_RELATION_TABLE, @@ -2284,11 +2285,12 @@ def get_image_aids(ibs, gid_list, is_staged=False, __check_staged__=True): # FIXME: This index should when the database is defined. # Ensure that an index exists on the image column of the annotation table - ibs.db.connection.execute( - """ - CREATE INDEX IF NOT EXISTS gid_to_aids ON annotations (image_rowid); - """ - ) + with ibs.db.connect() as conn: + conn.execute( + """ + CREATE INDEX IF NOT EXISTS gid_to_aids ON annotations (image_rowid); + """ + ) # The index maxes the following query very efficient if __check_staged__: @@ -2328,7 +2330,8 @@ def get_image_aids(ibs, gid_list, is_staged=False, __check_staged__=True): """.format( input_str=input_str, ANNOTATION_TABLE=const.ANNOTATION_TABLE ) - pair_list = ibs.db.connection.execute(opstr).fetchall() + with ibs.db.connect() as conn: + pair_list = conn.execute(opstr).fetchall() aidscol = np.array(ut.get_list_column(pair_list, 0)) gidscol = np.array(ut.get_list_column(pair_list, 1)) unique_gids, groupx = vt.group_indices(gidscol) diff --git a/wbia/control/manual_imageset_funcs.py b/wbia/control/manual_imageset_funcs.py index cf15ae3c82..fb534d6ee5 100644 --- a/wbia/control/manual_imageset_funcs.py +++ b/wbia/control/manual_imageset_funcs.py @@ -543,13 +543,14 @@ def get_imageset_gids(ibs, imgsetid_list): if NEW_INDEX_HACK: # FIXME: This index should when the database is defined. # Ensure that an index exists on the image column of the annotation table - ibs.db.connection.execute( - """ - CREATE INDEX IF NOT EXISTS gids_to_gs ON {GSG_RELATION_TABLE} (imageset_rowid); - """.format( - GSG_RELATION_TABLE=const.GSG_RELATION_TABLE + with ibs.db.connect() as conn: + conn.execute( + """ + CREATE INDEX IF NOT EXISTS gids_to_gs ON {GSG_RELATION_TABLE} (imageset_rowid); + """.format( + GSG_RELATION_TABLE=const.GSG_RELATION_TABLE + ) ) - ) gids_list = ibs.db.get( const.GSG_RELATION_TABLE, ('image_rowid',), diff --git a/wbia/control/manual_name_funcs.py b/wbia/control/manual_name_funcs.py index 3c4a2d65cd..2d2e56d3d8 100644 --- a/wbia/control/manual_name_funcs.py +++ b/wbia/control/manual_name_funcs.py @@ -489,11 +489,12 @@ def get_name_aids(ibs, nid_list, enable_unknown_fix=True, is_staged=False): # FIXME: This index should when the database is defined. # Ensure that an index exists on the image column of the annotation table # logger.info(len(nid_list_)) - ibs.db.connection.execute( - """ - CREATE INDEX IF NOT EXISTS nid_to_aids ON annotations (name_rowid); - """ - ) + with ibs.db.connect() as conn: + conn.execute( + """ + CREATE INDEX IF NOT EXISTS nid_to_aids ON annotations (name_rowid); + """ + ) aids_list = ibs.db.get( const.ANNOTATION_TABLE, (ANNOT_ROWID,), @@ -516,7 +517,8 @@ def get_name_aids(ibs, nid_list, enable_unknown_fix=True, is_staged=False): """.format( input_str=input_str, ANNOTATION_TABLE=const.ANNOTATION_TABLE ) - pair_list = ibs.db.connection.execute(opstr).fetchall() + with ibs.db.connect() as conn: + pair_list = conn.execute(opstr).fetchall() aidscol = np.array(ut.get_list_column(pair_list, 0)) nidscol = np.array(ut.get_list_column(pair_list, 1)) unique_nids, groupx = vt.group_indices(nidscol) diff --git a/wbia/control/manual_review_funcs.py b/wbia/control/manual_review_funcs.py index ab37ae5552..f1ad370b1b 100644 --- a/wbia/control/manual_review_funcs.py +++ b/wbia/control/manual_review_funcs.py @@ -57,24 +57,25 @@ def hack_create_aidpair_index(ibs): CREATE INDEX IF NOT EXISTS {index_name} ON {table} ({index_cols}); """ ) - sqlcmd = sqlfmt.format( - index_name='aidpair_to_rowid', - table=ibs.const.REVIEW_TABLE, - index_cols=','.join([REVIEW_AID1, REVIEW_AID2]), - ) - ibs.staging.connection.execute(sqlcmd) - sqlcmd = sqlfmt.format( - index_name='aid1_to_rowids', - table=ibs.const.REVIEW_TABLE, - index_cols=','.join([REVIEW_AID1]), - ) - ibs.staging.connection.execute(sqlcmd) - sqlcmd = sqlfmt.format( - index_name='aid2_to_rowids', - table=ibs.const.REVIEW_TABLE, - index_cols=','.join([REVIEW_AID2]), - ) - ibs.staging.connection.execute(sqlcmd) + with ibs.staging.connect() as conn: + sqlcmd = sqlfmt.format( + index_name='aidpair_to_rowid', + table=ibs.const.REVIEW_TABLE, + index_cols=','.join([REVIEW_AID1, REVIEW_AID2]), + ) + conn.execute(sqlcmd) + sqlcmd = sqlfmt.format( + index_name='aid1_to_rowids', + table=ibs.const.REVIEW_TABLE, + index_cols=','.join([REVIEW_AID1]), + ) + conn.execute(sqlcmd) + sqlcmd = sqlfmt.format( + index_name='aid2_to_rowids', + table=ibs.const.REVIEW_TABLE, + index_cols=','.join([REVIEW_AID2]), + ) + conn.execute(sqlcmd) @register_ibs_method diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index ab56f214d7..479809e47d 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -11,9 +11,9 @@ import os import parse import re -import threading import uuid from collections.abc import Mapping, MutableMapping +from contextlib import contextmanager from os.path import join, exists import six @@ -277,7 +277,7 @@ def version(self): def version(self, value): if not value: raise ValueError(value) - dialect = self.ctrlr.connection.engine.dialect.name + dialect = self.ctrlr._engine.dialect.name if dialect == 'sqlite': stmt = text( f'INSERT OR REPLACE INTO {METADATA_TABLE_NAME} (metadata_key, metadata_value)' @@ -318,7 +318,7 @@ def init_uuid(self, value): raise ValueError(value) elif isinstance(value, uuid.UUID): value = str(value) - dialect = self.ctrlr.connection.engine.dialect.name + dialect = self.ctrlr._engine.dialect.name if dialect == 'sqlite': stmt = text( f'INSERT OR REPLACE INTO {METADATA_TABLE_NAME} (metadata_key, metadata_value) ' @@ -423,7 +423,7 @@ def __setattr__(self, name, value): key = self._get_key_name(name) # Insert or update the record - dialect = self.ctrlr.connection.engine.dialect.name + dialect = self.ctrlr._engine.dialect.name if dialect == 'sqlite': statement = text( f'INSERT OR REPLACE INTO {METADATA_TABLE_NAME} ' @@ -584,9 +584,6 @@ def from_uri(cls, uri, readonly=READ_ONLY, timeout=TIMEOUT): self._sa_metadata.reflect(bind=self._engine) self._tablenames = None - # FIXME (31-Jul-12020) rename to private attribute - self.thread_connections = {} - self._connection = None if not self.readonly: # Ensure the metadata table is initialized. @@ -599,71 +596,15 @@ def from_uri(cls, uri, readonly=READ_ONLY, timeout=TIMEOUT): return self + @contextmanager def connect(self): - """Create a connection for the instance or use the existing connection""" - self._connection = self._engine.connect() - if self._engine.dialect.name == 'postgresql': - self._connection.execute(f'CREATE SCHEMA IF NOT EXISTS {self.schema}') - self.connection.execute(text('SET SCHEMA :schema'), schema=self.schema) - initialize_postgresql_types(self.connection, self.schema) - return self._connection - - @property - def connection(self): - """Create a connection or reuse the existing connection""" - # TODO (31-Jul-12020) Grab the correct connection for the thread. - if self._connection is not None: - conn = self._connection - else: - conn = self.connect() - return conn - - def _create_connection(self): - path = self.uri.replace('sqlite://', '') - if not exists(path): - logger.info('[sql] Initializing new database: %r' % (self.uri,)) - if self.readonly: - raise AssertionError('Cannot open a new database in readonly mode') - # Open the SQL database connection with support for custom types - # lite.enable_callback_tracebacks(True) - - # References: - # http://stackoverflow.com/questions/10205744/opening-sqlite3-database-from-python-in-read-only-mode - uri = self.uri - if self.readonly: - uri += '?mode=ro' - if os.getenv('POSTGRES'): - uri, schema = sqlite_uri_to_postgres_uri_schema(uri) - engine = create_engine(uri) - connection = engine.connect() - if engine.dialect.name == 'postgresql': - connection.execute(f'CREATE SCHEMA IF NOT EXISTS {schema}') - connection.execute(text('SET SCHEMA :schema'), schema=schema) - - # Keep track of what thead this was started in - threadid = threading.current_thread() - self.thread_connections[threadid] = connection - - return connection, uri - - def close(self): - self.connection.close() - self.thread_connections = {} - - # def reconnect(db): - # # Call this if we move into a new thread - # assert db.fname != ':memory:', 'cant reconnect to mem' - # connection, uri = db._create_connection() - # db.connection = connection - # db.cur = db.connection.cursor() - - def thread_connection(self): - threadid = threading.current_thread() - if threadid in self.thread_connections: - connection = self.thread_connections[threadid] - else: - connection, uri = self._create_connection() - return connection + """Create a connection instance to wrap a SQL execution block as a context manager""" + with self._engine.connect() as conn: + if self._engine.dialect.name == 'postgresql': + conn.execute(f'CREATE SCHEMA IF NOT EXISTS {self.schema}') + conn.execute(text('SET SCHEMA :schema'), schema=self.schema) + initialize_postgresql_types(conn, self.schema) + yield conn @profile def _ensure_metadata_table(self): @@ -789,68 +730,60 @@ def backup(self, backup_filepath): if self._engine.dialect.name == 'postgresql': # TODO postgresql backup return - # Create a brand new conenction to lock out current thread and any others - connection, uri = self._create_connection() + else: + # Assert the database file exists, and copy to backup path + path = self.uri.replace('sqlite://', '') + if not exists(path): + raise IOError( + 'Could not backup the database as the URI does not exist: %r' + % (self.uri,) + ) # Start Exclusive transaction, lock out all other writers from making database changes - transaction = connection.begin() - connection.isolation_level = 'EXCLUSIVE' - connection.execute('BEGIN EXCLUSIVE') - # Assert the database file exists, and copy to backup path - path = self.uri.replace('sqlite://', '') - if exists(path): + with self.connect() as conn: + conn.execute('BEGIN EXCLUSIVE') ut.copy(path, backup_filepath) - else: - raise IOError( - 'Could not backup the database as the URI does not exist: %r' % (uri,) - ) - # Commit the transaction, releasing the lock - transaction.commit() - # Close the connection - connection.close() def optimize(self): if self._engine.dialect.name != 'sqlite': return # http://web.utk.edu/~jplyon/sqlite/SQLite_optimization_FAQ.html#pragma-cache_size # http://web.utk.edu/~jplyon/sqlite/SQLite_optimization_FAQ.html - if VERBOSE_SQL: - logger.info('[sql] running sql pragma optimizions') - # self.connection.execute('PRAGMA cache_size = 0;') - # self.connection.execute('PRAGMA cache_size = 1024;') - # self.connection.execute('PRAGMA page_size = 1024;') - # logger.info('[sql] running sql pragma optimizions') - self.connection.execute('PRAGMA cache_size = 10000;') # Default: 2000 - self.connection.execute('PRAGMA temp_store = MEMORY;') - self.connection.execute('PRAGMA synchronous = OFF;') - # self.connection.execute('PRAGMA synchronous = NORMAL;') - # self.connection.execute('PRAGMA synchronous = FULL;') # Default - # self.connection.execute('PRAGMA parser_trace = OFF;') - # self.connection.execute('PRAGMA busy_timeout = 1;') - # self.connection.execute('PRAGMA default_cache_size = 0;') + logger.info('[sql] running sql pragma optimizions') + + with self.connect() as conn: + # conn.execute('PRAGMA cache_size = 0;') + # conn.execute('PRAGMA cache_size = 1024;') + # conn.execute('PRAGMA page_size = 1024;') + # logger.info('[sql] running sql pragma optimizions') + conn.execute('PRAGMA cache_size = 10000;') # Default: 2000 + conn.execute('PRAGMA temp_store = MEMORY;') + conn.execute('PRAGMA synchronous = OFF;') + # conn.execute('PRAGMA synchronous = NORMAL;') + # conn.execute('PRAGMA synchronous = FULL;') # Default + # conn.execute('PRAGMA parser_trace = OFF;') + # conn.execute('PRAGMA busy_timeout = 1;') + # conn.execute('PRAGMA default_cache_size = 0;') def shrink_memory(self): if self._engine.dialect.name != 'sqlite': return logger.info('[sql] shrink_memory') - transaction = self.connection.begin() - self.connection.execute('PRAGMA shrink_memory;') - transaction.commit() + with self.connect() as conn: + conn.execute('PRAGMA shrink_memory;') def vacuum(self): if self._engine.dialect.name != 'sqlite': return logger.info('[sql] vaccum') - transaction = self.connection.begin() - self.connection.execute('VACUUM;') - transaction.commit() + with self.connect() as conn: + conn.execute('VACUUM;') def integrity(self): if self._engine.dialect.name != 'sqlite': return logger.info('[sql] vaccum') - transaction = self.connection.begin() - self.connection.execute('PRAGMA integrity_check;') - transaction.commit() + with self.connect() as conn: + conn.execute('PRAGMA integrity_check;') def squeeze(self): if self._engine.dialect.name != 'sqlite': @@ -950,16 +883,17 @@ def _add(self, tblname, colnames, params_iter, unpack_scalars=True, **kwargs): insert_stmt = sqlalchemy.insert(table) primary_keys = [] - with self.connection.begin(): # new nested database transaction - for vals in parameterized_values: - result = self.connection.execute(insert_stmt.values(vals)) - - pk = result.inserted_primary_key - if unpack_scalars: - # Assumption at the time of writing this is that the primary key is the SQLite rowid. - # Therefore, we can assume the primary key is a single column value. - pk = pk[0] - primary_keys.append(pk) + with self.connect() as conn: + with conn.begin(): # new nested database transaction + for vals in parameterized_values: + result = conn.execute(insert_stmt.values(vals)) + + pk = result.inserted_primary_key + if unpack_scalars: + # Assumption at the time of writing this is that the primary key is the SQLite rowid. + # Therefore, we can assume the primary key is a single column value. + pk = pk[0] + primary_keys.append(pk) return primary_keys def add_cleanly( @@ -1442,7 +1376,8 @@ def get( columns = ', '.join(colnames) ids_listing = ', '.join(map(str, id_iter)) operation = f'SELECT {columns} FROM {tblname} WHERE rowid in ({ids_listing}) ORDER BY rowid ASC' - results = self.connection.execute(operation).fetchall() + with self.connect() as conn: + results = conn.execute(operation).fetchall() import numpy as np # ??? Why order the results if they are going to be sorted here? @@ -1663,10 +1598,11 @@ def set( bindparam(id_param_name, type_=id_column.type) ) stmt = stmt.where(where_clause) - for i, id in enumerate(id_list): - params = {id_param_name: id} - params.update({f'e{e}': p for e, p in enumerate(val_list[i])}) - self.connection.execute(stmt, **params) + with self.connect() as conn: + for i, id in enumerate(id_list): + params = {id_param_name: id} + params.update({f'e{e}': p for e, p in enumerate(val_list[i])}) + conn.execute(stmt, **params) def delete(self, tblname, id_list, id_colname='rowid', **kwargs): """Deletes rows from a SQL table (``tblname``) by ID, @@ -1688,8 +1624,9 @@ def delete(self, tblname, id_list, id_colname='rowid', **kwargs): bindparam(id_param_name, type_=id_column.type) ) stmt = stmt.where(where_clause) - for id in id_list: - self.connection.execute(stmt, {id_param_name: id}) + with self.connect() as conn: + for id in id_list: + conn.execute(stmt, {id_param_name: id}) def delete_rowids(self, tblname, rowid_list, **kwargs): """ deletes the the rows in rowid_list """ @@ -1757,49 +1694,51 @@ def executeone( f"'operation' is a '{type(operation)}'" ) # FIXME (12-Sept-12020) Allows passing through '?' (question mark) parameters. - results = self.connection.execute(operation, params) - - # BBB (12-Sept-12020) Retaining insertion rowid result - # FIXME postgresql (12-Sept-12020) This won't work in postgres. - # Maybe see if ResultProxy.inserted_primary_key will work - if 'insert' in str(operation).lower(): # cast in case it's an SQLAlchemy object - # BBB (12-Sept-12020) Retaining behavior to unwrap single value rows. - return [results.lastrowid] - elif not results.returns_rows: - return None - else: - if isinstance(operation, sqlalchemy.sql.selectable.Select): - # This code is specifically for handling duplication in colnames - # because sqlalchemy removes them. - # e.g. select field1, field1, field2 from table; - # becomes - # select field1, field2 from table; - # so the items in val_list only have 2 values - # but the caller isn't expecting it so it causes problems - returned_columns = tuple([c.name for c in operation.columns]) - raw_columns = tuple([c.name for c in operation._raw_columns]) - if raw_columns != returned_columns: - results_ = [] - for r in results: - results_.append( - tuple(r[returned_columns.index(c)] for c in raw_columns) - ) - results = results_ + with self.connect() as conn: + results = conn.execute(operation, params) - values = list( - [ - # BBB (12-Sept-12020) Retaining behavior to unwrap single value rows. - row[0] if not keepwrap and len(row) == 1 else row - for row in results - ] - ) - # FIXME (28-Sept-12020) No rows results in an empty list. This behavior does not - # match the resulting expectations of `fetchone`'s DBAPI spec. - # If executeone is the shortcut of `execute` and `fetchone`, - # the expectation should be to return according to DBAPI spec. - if use_fetchone_behavior and not values: # empty list - values = None - return values + # BBB (12-Sept-12020) Retaining insertion rowid result + # FIXME postgresql (12-Sept-12020) This won't work in postgres. + # Maybe see if ResultProxy.inserted_primary_key will work + if ( + 'insert' in str(operation).lower() + ): # cast in case it's an SQLAlchemy object + # BBB (12-Sept-12020) Retaining behavior to unwrap single value rows. + return [results.lastrowid] + elif not results.returns_rows: + return None + else: + if isinstance(operation, sqlalchemy.sql.selectable.Select): + # This code is specifically for handling duplication in colnames + # because sqlalchemy removes them. + # e.g. select field1, field1, field2 from table; + # becomes + # select field1, field2 from table; + # so the items in val_list only have 2 values + # but the caller isn't expecting it so it causes problems + returned_columns = tuple([c.name for c in operation.columns]) + raw_columns = tuple([c.name for c in operation._raw_columns]) + if raw_columns != returned_columns: + results_ = [] + for r in results: + results_.append( + tuple(r[returned_columns.index(c)] for c in raw_columns) + ) + results = results_ + values = list( + [ + # BBB (12-Sept-12020) Retaining behavior to unwrap single value rows. + row[0] if not keepwrap and len(row) == 1 else row + for row in results + ] + ) + # FIXME (28-Sept-12020) No rows results in an empty list. This behavior does not + # match the resulting expectations of `fetchone`'s DBAPI spec. + # If executeone is the shortcut of `execute` and `fetchone`, + # the expectation should be to return according to DBAPI spec. + if use_fetchone_behavior and not values: # empty list + values = None + return values def executemany( self, operation, params_iter, unpack_scalars=True, keepwrap=False, **kwargs @@ -1823,15 +1762,16 @@ def executemany( ) results = [] - with self.connection.begin(): - for params in params_iter: - value = self.executeone(operation, params, keepwrap=keepwrap) - # Should only be used when the user wants back on value. - # Let the error bubble up if used wrong. - # Deprecated... Do not depend on the unpacking behavior. - if unpack_scalars: - value = _unpacker(value) - results.append(value) + with self.connect() as conn: + with conn.begin(): + for params in params_iter: + value = self.executeone(operation, params, keepwrap=keepwrap) + # Should only be used when the user wants back on value. + # Let the error bubble up if used wrong. + # Deprecated... Do not depend on the unpacking behavior. + if unpack_scalars: + value = _unpacker(value) + results.append(value) return results def print_dbg_schema(self): @@ -2576,7 +2516,7 @@ def invalidate_tables_cache(self): def get_table_names(self, lazy=False): """ Conveinience: """ if not lazy or self._tablenames is None: - dialect = self.connection.engine.dialect.name + dialect = self._engine.dialect.name if dialect == 'sqlite': stmt = "SELECT name FROM sqlite_master WHERE type='table'" params = {} @@ -2590,8 +2530,9 @@ def get_table_names(self, lazy=False): params = {'schema': self.schema} else: raise RuntimeError(f'Unknown dialect {dialect}') - result = self.connection.execute(stmt, **params) - tablename_list = result.fetchall() + with self.connect() as conn: + result = conn.execute(stmt, **params) + tablename_list = result.fetchall() self._tablenames = {str(tablename[0]) for tablename in tablename_list} return self._tablenames @@ -2713,7 +2654,8 @@ def get_columns(self, tablename): ] """ # check if the table exists first. Throws an error if it does not exist. - self.connection.execute('SELECT 1 FROM ' + tablename + ' LIMIT 1') + with self.connect() as conn: + conn.execute('SELECT 1 FROM ' + tablename + ' LIMIT 1') dialect = self._engine.dialect.name if dialect == 'sqlite': stmt = f"PRAGMA TABLE_INFO('{tablename}')" @@ -2741,8 +2683,9 @@ def get_columns(self, tablename): ) params = {'table_name': tablename, 'table_schema': self.schema} - result = self.connection.execute(stmt, **params) - colinfo_list = result.fetchall() + with self.connect() as conn: + result = conn.execute(stmt, **params) + colinfo_list = result.fetchall() colrichinfo_list = [SQLColumnRichInfo(*colinfo) for colinfo in colinfo_list] return colrichinfo_list diff --git a/wbia/other/ibsfuncs.py b/wbia/other/ibsfuncs.py index 41a459ec2e..ca06abde57 100644 --- a/wbia/other/ibsfuncs.py +++ b/wbia/other/ibsfuncs.py @@ -3697,15 +3697,16 @@ def get_database_species_count(ibs, aid_list=None, BATCH_SIZE=25000): ) for batch in range(int(len(aid_list) / BATCH_SIZE) + 1): aids = aid_list[batch * BATCH_SIZE : (batch + 1) * BATCH_SIZE] - results = ibs.db.connection.execute(stmt, {'aids': aids}) - - for row in results: - species_text = row.species_text - if species_text is None: - species_text = const.UNKNOWN - species_count[species_text] = ( - species_count.get(species_text, 0) + row.num_annots - ) + with ibs.db.connect() as conn: + results = conn.execute(stmt, {'aids': aids}) + + for row in results: + species_text = row.species_text + if species_text is None: + species_text = const.UNKNOWN + species_count[species_text] = ( + species_count.get(species_text, 0) + row.num_annots + ) return species_count @@ -4912,7 +4913,8 @@ def filter_aids_to_quality(ibs, aid_list, minqual, unknown_ok=True, speedhack=Tr else: operation = 'SELECT rowid from annotations WHERE annot_quality NOTNULL AND annot_quality>={minqual_int} AND rowid IN ({aids})' operation = operation.format(aids=list_repr, minqual_int=minqual_int) - aid_list_ = ut.take_column(ibs.db.connection.execute(operation).fetchall(), 0) + with ibs.db.connect() as conn: + aid_list_ = ut.take_column(conn.execute(operation).fetchall(), 0) else: qual_flags = list( ibs.get_quality_filterflags(aid_list, minqual, unknown_ok=unknown_ok) @@ -5006,7 +5008,8 @@ def filter_aids_without_name(ibs, aid_list, invert=False, speedhack=True): 'SELECT rowid from annotations WHERE name_rowid>0 AND rowid IN (%s)' % (list_repr,) ) - aid_list_ = ut.take_column(ibs.db.connection.execute(operation).fetchall(), 0) + with ibs.db.connect() as conn: + aid_list_ = ut.take_column(conn.execute(operation).fetchall(), 0) else: flag_list = ibs.is_aid_unknown(aid_list) if not invert: @@ -5148,7 +5151,8 @@ def filter_aids_to_species(ibs, aid_list, species, speedhack=True): list_repr = ','.join(map(str, aid_list)) operation = 'SELECT rowid from annotations WHERE (species_rowid = {species_rowid}) AND rowid IN ({aids})' operation = operation.format(aids=list_repr, species_rowid=species_rowid) - aid_list_ = ut.take_column(ibs.db.connection.execute(operation).fetchall(), 0) + with ibs.db.connect() as conn: + aid_list_ = ut.take_column(conn.execute(operation).fetchall(), 0) else: species_rowid_list = ibs.get_annot_species_rowids(aid_list) is_valid_species = [sid == species_rowid for sid in species_rowid_list] diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index 748a94fb92..8e32abb679 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -6,12 +6,10 @@ import pytest import sqlalchemy.exc from sqlalchemy import MetaData, Table -from sqlalchemy.engine import Connection from sqlalchemy.sql import select, text from wbia.dtool.sql_control import ( METADATA_TABLE_COLUMNS, - TIMEOUT, SQLDatabaseController, ) @@ -41,16 +39,6 @@ def make_table_definition(name, depends_on=[]): return definition -def test_instantiation(ctrlr): - # Check for basic connection information - assert ctrlr.uri == 'sqlite:///:memory:' - assert ctrlr.timeout == TIMEOUT - - # Check for a connection, that would have been made during instantiation - assert isinstance(ctrlr.connection, Connection) - assert not ctrlr.connection.closed - - def test_instantiation_with_table_reflection(tmp_path): db_file = (tmp_path / 'testing.db').resolve() creating_ctrlr = SQLDatabaseController.from_uri(f'sqlite:///{db_file}') @@ -171,7 +159,7 @@ def test_add_table(self): assert sorted(found_constraint_info) == expected_constraint_info # Check for metadata entries - results = self.ctrlr.connection.execute( + results = self.ctrlr._engine.execute( select([metadata.c.metadata_key, metadata.c.metadata_value]).where( metadata.c.metadata_key.like('bars_%') ) @@ -200,7 +188,7 @@ def test_rename_table(self): self.reflect_table(new_table_name, md) # Check for metadata entries have been renamed. - results = self.ctrlr.connection.execute( + results = self.ctrlr._engine.execute( select([metadata.c.metadata_key, metadata.c.metadata_value]).where( metadata.c.metadata_key.like(f'{new_table_name}_%') ) @@ -232,7 +220,7 @@ def test_drop_table(self): self.reflect_table(table_name, md) # Check for metadata entries have been renamed. - results = self.ctrlr.connection.execute( + results = self.ctrlr._engine.execute( select([metadata.c.metadata_key, metadata.c.metadata_value]).where( metadata.c.metadata_key.like(f'{table_name}_%') ) @@ -258,7 +246,7 @@ def test_drop_all_table(self): self.reflect_table(name, md) # Check for the absents of metadata for the removed tables. - results = self.ctrlr.connection.execute(select([metadata.c.metadata_key])) + results = self.ctrlr._engine.execute(select([metadata.c.metadata_key])) expected_metadata_rows = [ ('database_init_uuid',), ('database_version',), @@ -302,7 +290,7 @@ def fixture(self, ctrlr, monkeypatch): unprefixed_name = key.split('_')[-1] if METADATA_TABLE_COLUMNS[unprefixed_name]['is_coded_data']: value = repr(value) - self.ctrlr.connection.execute(insert_stmt, key=key, value=value) + self.ctrlr._engine.execute(insert_stmt, key=key, value=value) def monkey_get_table_names(self, *args, **kwargs): return ['foo', 'metadata'] @@ -360,7 +348,7 @@ def test_setting_to_none(self): assert new_value == value # Also check the table does not have the record - assert not self.ctrlr.connection.execute( + assert not self.ctrlr._engine.execute( f"SELECT * FROM metadata WHERE metadata_key = 'foo_{key}'" ).fetchone() @@ -381,7 +369,7 @@ def test_deleter(self): assert self.ctrlr.metadata.foo.docstr is None # Also check the table does not have the record - assert not self.ctrlr.connection.execute( + assert not self.ctrlr._engine.execute( f"SELECT * FROM metadata WHERE metadata_key = 'foo_{key}'" ).fetchone() @@ -484,7 +472,7 @@ def test_delitem_for_table(self): assert self.ctrlr.metadata.foo.docstr is None # Also check the table does not have the record - assert not self.ctrlr.connection.execute( + assert not self.ctrlr._engine.execute( f"SELECT * FROM metadata WHERE metadata_key = 'foo_{key}'" ).fetchone() @@ -513,7 +501,7 @@ def fixture(self, ctrlr): self.ctrlr = ctrlr def make_table(self, name): - self.ctrlr.connection.execute( + self.ctrlr._engine.execute( f'CREATE TABLE IF NOT EXISTS {name} ' '(id INTEGER PRIMARY KEY, x TEXT, y INTEGER, z REAL)' ) @@ -530,7 +518,7 @@ def populate_table(self, name): i, i * 2.01, ) - self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + self.ctrlr._engine.execute(insert_stmt, x=x, y=y, z=z) class TestExecutionAPI(BaseAPITestCase): @@ -587,7 +575,7 @@ def test_executeone_on_insert(self): assert result == [11] # the result list with one unwrapped value # Check for the actual value associated with the resulting id - inserted_value = self.ctrlr.connection.execute( + inserted_value = self.ctrlr._engine.execute( text(f'SELECT id, y FROM {table_name} WHERE rowid = :rowid'), rowid=result[0], ).fetchone() @@ -632,7 +620,7 @@ def test_executemany_transaction(self): results = self.ctrlr.executemany(insert, params) # Check for results - results = self.ctrlr.connection.execute(f'select count(*) from {table_name}') + results = self.ctrlr._engine.execute(f'select count(*) from {table_name}') assert results.fetchone()[0] == 0 def test_executeone_for_single_column(self): @@ -670,7 +658,7 @@ def test_add(self): # Verify the resulting ids assert ids == [i + 1 for i in range(0, len(parameter_values))] # Verify addition of records - results = self.ctrlr.connection.execute(f'SELECT id, x, y, z FROM {table_name}') + results = self.ctrlr._engine.execute(f'SELECT id, x, y, z FROM {table_name}') expected = [(i + 1, x, y, z) for i, (x, y, z) in enumerate(parameter_values)] assert results.fetchall() == expected @@ -762,14 +750,15 @@ def test_get_all(self): # Create some dummy records insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') - for i in range(0, 10): - x, y, z = (str(i), i, i * 2.01) - self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + with self.ctrlr.connect() as conn: + for i in range(0, 10): + x, y, z = (str(i), i, i * 2.01) + conn.execute(insert_stmt, x=x, y=y, z=z) - # Build the expect results of the testing target - results = self.ctrlr.connection.execute(f'SELECT id, x, z FROM {table_name}') - rows = results.fetchall() - row_mapping = {row[0]: row[1:] for row in rows if row[1]} + # Build the expect results of the testing target + results = conn.execute(f'SELECT id, x, z FROM {table_name}') + rows = results.fetchall() + row_mapping = {row[0]: row[1:] for row in rows if row[1]} # Call the testing target data = self.ctrlr.get(table_name, ['x', 'z']) @@ -784,9 +773,10 @@ def test_get_by_id(self): # Create some dummy records insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') - for i in range(0, 10): - x, y, z = (str(i), i, i * 2.01) - self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + with self.ctrlr.connect() as conn: + for i in range(0, 10): + x, y, z = (str(i), i, i * 2.01) + conn.execute(insert_stmt, x=x, y=y, z=z) # Call the testing target requested_ids = [2, 4, 6] @@ -794,10 +784,11 @@ def test_get_by_id(self): # Build the expect results of the testing target sql_array = ', '.join([str(id) for id in requested_ids]) - results = self.ctrlr.connection.execute( - f'SELECT x, z FROM {table_name} WHERE id in ({sql_array})' - ) - expected = results.fetchall() + with self.ctrlr.connect() as conn: + results = conn.execute( + f'SELECT x, z FROM {table_name} WHERE id in ({sql_array})' + ) + expected = results.fetchall() # Verify getting assert data == expected @@ -808,9 +799,10 @@ def test_get_by_numpy_array_of_ids(self): # Create some dummy records insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') - for i in range(0, 10): - x, y, z = (str(i), i, i * 2.01) - self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + with self.ctrlr.connect() as conn: + for i in range(0, 10): + x, y, z = (str(i), i, i * 2.01) + conn.execute(insert_stmt, x=x, y=y, z=z) # Call the testing target requested_ids = np.array([2, 4, 6]) @@ -818,10 +810,11 @@ def test_get_by_numpy_array_of_ids(self): # Build the expect results of the testing target sql_array = ', '.join([str(id) for id in requested_ids]) - results = self.ctrlr.connection.execute( - f'SELECT x, z FROM {table_name} WHERE id in ({sql_array})' - ) - expected = results.fetchall() + with self.ctrlr.connect() as conn: + results = conn.execute( + f'SELECT x, z FROM {table_name} WHERE id in ({sql_array})' + ) + expected = results.fetchall() # Verify getting assert data == expected @@ -835,9 +828,10 @@ def test_get_as_unique(self): # Create some dummy records insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') - for i in range(0, 10): - x, y, z = (str(i), i, i * 2.01) - self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + with self.ctrlr.connect() as conn: + for i in range(0, 10): + x, y, z = (str(i), i, i * 2.01) + conn.execute(insert_stmt, x=x, y=y, z=z) # Call the testing target # The table has a INTEGER PRIMARY KEY, which essentially maps to the rowid @@ -847,11 +841,12 @@ def test_get_as_unique(self): # Build the expect results of the testing target sql_array = ', '.join([str(id) for id in requested_ids]) - results = self.ctrlr.connection.execute( - f'SELECT x FROM {table_name} WHERE id in ({sql_array})' - ) - # ... recall that the controller unpacks single values - expected = [row[0] for row in results] + with self.ctrlr.connect() as conn: + results = conn.execute( + f'SELECT x FROM {table_name} WHERE id in ({sql_array})' + ) + # ... recall that the controller unpacks single values + expected = [row[0] for row in results] # Verify getting assert data == expected @@ -865,14 +860,13 @@ def test_setting(self): # Create some dummy records insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') - for i in range(0, 10): - x, y, z = (str(i), i, i * 2.01) - self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + with self.ctrlr.connect() as conn: + for i in range(0, 10): + x, y, z = (str(i), i, i * 2.01) + conn.execute(insert_stmt, x=x, y=y, z=z) - results = self.ctrlr.connection.execute( - f'SELECT id, CAST((y%2) AS BOOL) FROM {table_name}' - ) - rows = results.fetchall() + results = conn.execute(f'SELECT id, CAST((y%2) AS BOOL) FROM {table_name}') + rows = results.fetchall() ids = [row[0] for row in rows if row[1]] # Call the testing target @@ -882,11 +876,12 @@ def test_setting(self): # Verify setting sql_array = ', '.join([str(id) for id in ids]) - results = self.ctrlr.connection.execute( - f'SELECT id, x, z FROM {table_name} ' f'WHERE id in ({sql_array})' - ) - expected = sorted(map(lambda a: tuple([a] + ['even', 0.0]), ids)) - set_rows = sorted(results) + with self.ctrlr.connect() as conn: + results = conn.execute( + f'SELECT id, x, z FROM {table_name} ' f'WHERE id in ({sql_array})' + ) + expected = sorted(map(lambda a: tuple([a] + ['even', 0.0]), ids)) + set_rows = sorted(results) assert set_rows == expected @@ -898,14 +893,13 @@ def test_delete(self): # Create some dummy records insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') - for i in range(0, 10): - x, y, z = (str(i), i, i * 2.01) - self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + with self.ctrlr.connect() as conn: + for i in range(0, 10): + x, y, z = (str(i), i, i * 2.01) + conn.execute(insert_stmt, x=x, y=y, z=z) - results = self.ctrlr.connection.execute( - f'SELECT id, CAST((y % 2) AS BOOL) FROM {table_name}' - ) - rows = results.fetchall() + results = conn.execute(f'SELECT id, CAST((y % 2) AS BOOL) FROM {table_name}') + rows = results.fetchall() del_ids = [row[0] for row in rows if row[1]] remaining_ids = sorted([row[0] for row in rows if not row[1]]) @@ -913,8 +907,9 @@ def test_delete(self): self.ctrlr.delete(table_name, del_ids, 'id') # Verify the deletion - results = self.ctrlr.connection.execute(f'SELECT id FROM {table_name}') - assert sorted([r[0] for r in results]) == remaining_ids + with self.ctrlr.connect() as conn: + results = conn.execute(f'SELECT id FROM {table_name}') + assert sorted([r[0] for r in results]) == remaining_ids def test_delete_rowid(self): # Make a table for records @@ -923,14 +918,15 @@ def test_delete_rowid(self): # Create some dummy records insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') - for i in range(0, 10): - x, y, z = (str(i), i, i * 2.01) - self.ctrlr.connection.execute(insert_stmt, x=x, y=y, z=z) + with self.ctrlr.connect() as conn: + for i in range(0, 10): + x, y, z = (str(i), i, i * 2.01) + conn.execute(insert_stmt, x=x, y=y, z=z) - results = self.ctrlr.connection.execute( - f'SELECT rowid, CAST((y % 2) AS BOOL) FROM {table_name}' - ) - rows = results.fetchall() + results = conn.execute( + f'SELECT rowid, CAST((y % 2) AS BOOL) FROM {table_name}' + ) + rows = results.fetchall() del_ids = [row[0] for row in rows if row[1]] remaining_ids = sorted([row[0] for row in rows if not row[1]]) @@ -938,5 +934,6 @@ def test_delete_rowid(self): self.ctrlr.delete_rowids(table_name, del_ids) # Verify the deletion - results = self.ctrlr.connection.execute(f'SELECT rowid FROM {table_name}') - assert sorted([r[0] for r in results]) == remaining_ids + with self.ctrlr.connect() as conn: + results = conn.execute(f'SELECT rowid FROM {table_name}') + assert sorted([r[0] for r in results]) == remaining_ids From c89bc6edd8b8cc4eb31883b3a6db209219e6181a Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Wed, 16 Dec 2020 21:53:33 -0800 Subject: [PATCH 212/294] Fix ImportError in SQLAlchemy 1.4 `RowProxy` is now just `Row`, but the legacy behavior is available as `LegacyRow`. I'm using `LegacyRow` because I'm not sure of the exact difference at the moment. See https://docs.sqlalchemy.org/en/14/changelog/migration_14.html#change-4710-core --- wbia/dtool/sql_control.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 479809e47d..3db54b6a69 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -20,7 +20,7 @@ import sqlalchemy import utool as ut from deprecated import deprecated -from sqlalchemy.engine import RowProxy +from sqlalchemy.engine import LegacyRow from sqlalchemy.schema import Table from sqlalchemy.sql import bindparam, text, ClauseElement @@ -1265,7 +1265,7 @@ def get_where( result = [] for val in val_list: - if isinstance(val, RowProxy): + if isinstance(val, LegacyRow): result.append(tuple(val[returned_columns.index(c)] for c in colnames)) else: result.append(val) From f4515da79acffe242f5125573236e53872eedb97 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Mon, 28 Dec 2020 14:30:57 -0800 Subject: [PATCH 213/294] Remove crufty database version check This block of code will alway execute as True. I'd guess it dates back to a time before the codebase did automatic database upgrades on startup. --- wbia/control/manual_image_funcs.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/wbia/control/manual_image_funcs.py b/wbia/control/manual_image_funcs.py index 7170088b66..12afb194a1 100644 --- a/wbia/control/manual_image_funcs.py +++ b/wbia/control/manual_image_funcs.py @@ -405,15 +405,11 @@ def add_images( ) # - # Execute SQL Add - from distutils.version import LooseVersion - - if LooseVersion(ibs.db.get_db_version()) >= LooseVersion('1.3.4'): - colnames = IMAGE_COLNAMES + ('image_original_path', 'image_location_code') - params_list = [ - tuple(params) + (gpath, location_for_names) if params is not None else None - for params, gpath in zip(params_list, gpath_list) - ] + colnames = IMAGE_COLNAMES + ('image_original_path', 'image_location_code') + params_list = [ + tuple(params) + (gpath, location_for_names) if params is not None else None + for params, gpath in zip(params_list, gpath_list) + ] all_gid_list = ibs.db.add_cleanly( const.IMAGE_TABLE, colnames, params_list, ibs.get_image_gids_from_uuid From 401e1a684cc01bafc258e13e6f49d7ebcae5f4e3 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Tue, 29 Dec 2020 17:20:30 -0800 Subject: [PATCH 214/294] Classify the test as a web-test --- wbia/web/job_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/wbia/web/job_engine.py b/wbia/web/job_engine.py index 85ee10e3dd..d125e17f8b 100644 --- a/wbia/web/job_engine.py +++ b/wbia/web/job_engine.py @@ -347,6 +347,7 @@ def get_job_metadata(ibs, jobid): python -m wbia.web.job_engine --exec-get_job_metadata:0 --fg Example: + >>> # xdoctest: +REQUIRES(--web-tests) >>> # xdoctest: +REQUIRES(--slow) >>> # xdoctest: +REQUIRES(--job-engine-tests) >>> # xdoctest: +REQUIRES(--web-tests) From e44bad81ad7bd13d4e4bbfed6398b66ab888ab03 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Tue, 29 Dec 2020 18:51:33 -0800 Subject: [PATCH 215/294] Correct sql equality comparison Equal `=` works for both postgres and sqlite sql dialects. --- wbia/other/ibsfuncs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wbia/other/ibsfuncs.py b/wbia/other/ibsfuncs.py index ca06abde57..1a1812f23f 100644 --- a/wbia/other/ibsfuncs.py +++ b/wbia/other/ibsfuncs.py @@ -4909,7 +4909,7 @@ def filter_aids_to_quality(ibs, aid_list, minqual, unknown_ok=True, speedhack=Tr list_repr = ','.join(map(str, aid_list)) minqual_int = const.QUALITY_TEXT_TO_INT[minqual] if unknown_ok: - operation = 'SELECT rowid from annotations WHERE (annot_quality ISNULL OR annot_quality==-1 OR annot_quality>={minqual_int}) AND rowid IN ({aids})' + operation = 'SELECT rowid from annotations WHERE (annot_quality ISNULL OR annot_quality=-1 OR annot_quality>={minqual_int}) AND rowid IN ({aids})' else: operation = 'SELECT rowid from annotations WHERE annot_quality NOTNULL AND annot_quality>={minqual_int} AND rowid IN ({aids})' operation = operation.format(aids=list_repr, minqual_int=minqual_int) From 37a28f1ab498ec1008bc8b648617b590bbe0c1b7 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Mon, 28 Dec 2020 23:09:55 -0800 Subject: [PATCH 216/294] Include pgloader in the wildbook-ia image --- Dockerfile | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Dockerfile b/Dockerfile index 716adaf0e5..8ebc823e77 100644 --- a/Dockerfile +++ b/Dockerfile @@ -67,6 +67,10 @@ RUN set -x \ libxext6 \ #: opencv2 dependency libgl1 \ + #: required to stop prompting by pgloader + libssl1.0.0 \ + #: sqlite->postgres dependency + pgloader \ #: dev debug dependency #: python3-dev required to build 'annoy' python3-dev \ From aaeb9d0791943a9953ad22339165410f6e949c37 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sat, 9 Jan 2021 02:14:46 -0800 Subject: [PATCH 217/294] Append testdb_assigner to list of test databases --- wbia/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/wbia/conftest.py b/wbia/conftest.py index 31991c5db2..bc39a81347 100644 --- a/wbia/conftest.py +++ b/wbia/conftest.py @@ -27,6 +27,7 @@ 'testdb2', 'testdb_guiall', 'wd_peter2', + 'testdb_assigner', ) From a139d905d9ce4cd43bb93ac8de6a2ee6629d93b8 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sat, 9 Jan 2021 02:15:58 -0800 Subject: [PATCH 218/294] Cache ensure_testdb_assigner This is mostly fine when run in the sqlite context. Calling this multiple times in the postgres context causes issues because the database is already loaded, but this code has no way of knowing that. --- wbia/init/sysres.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/wbia/init/sysres.py b/wbia/init/sysres.py index cd310a9c23..8057d41a28 100644 --- a/wbia/init/sysres.py +++ b/wbia/init/sysres.py @@ -6,7 +6,9 @@ """ import logging import os +from functools import lru_cache from os.path import exists, join, realpath + import utool as ut import ubelt as ub from six.moves import input, zip, map @@ -864,6 +866,7 @@ def ensure_testdb_orientation(): return ensure_db_from_url(const.ZIPPED_URLS.ORIENTATION) +@lru_cache(maxsize=None) def ensure_testdb_assigner(): return ensure_db_from_url(const.ZIPPED_URLS.ASSIGNER) From a6f1e7f2613a9ad0e0be9f75172927b8e7ec3f81 Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sat, 9 Jan 2021 02:19:13 -0800 Subject: [PATCH 219/294] Only query for sqlite tables The testdb_assigner database has an index in it. In this code that means it was trying to recreate the index, but without the correct sql syntax. --- wbia/dtool/copy_sqlite_to_postgres.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/wbia/dtool/copy_sqlite_to_postgres.py b/wbia/dtool/copy_sqlite_to_postgres.py index 3d08a896d5..a9cf0ff9e5 100644 --- a/wbia/dtool/copy_sqlite_to_postgres.py +++ b/wbia/dtool/copy_sqlite_to_postgres.py @@ -29,7 +29,8 @@ def add_rowids(engine): create_table_stmts = connection.execute( """\ SELECT name, sql FROM sqlite_master - WHERE name NOT LIKE 'sqlite_%'""" + WHERE name NOT LIKE 'sqlite_%' AND type = 'table' + """ ).fetchall() for table, stmt in create_table_stmts: # Create a new table with suffix "_with_rowid" From f102df80b65bbeab8c174da92265014a6aafed6d Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Sat, 9 Jan 2021 02:22:47 -0800 Subject: [PATCH 220/294] Include add_rowids in the try..except block Because add_rowids modifies the database, when it errors the original database is left in an unexpected changed state for the next run. --- wbia/dtool/copy_sqlite_to_postgres.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wbia/dtool/copy_sqlite_to_postgres.py b/wbia/dtool/copy_sqlite_to_postgres.py index a9cf0ff9e5..39015f0434 100644 --- a/wbia/dtool/copy_sqlite_to_postgres.py +++ b/wbia/dtool/copy_sqlite_to_postgres.py @@ -135,9 +135,9 @@ def copy_sqlite_to_postgres(parent_dir): for sqlite_db_path in get_sqlite_db_paths(parent_dir): # create new tables with sqlite built-in rowid column sqlite_engine = create_engine(f'sqlite:///{sqlite_db_path}') - add_rowids(sqlite_engine) try: + add_rowids(sqlite_engine) uri, schema = sqlite_uri_to_postgres_uri_schema( f'sqlite:///{os.path.realpath(sqlite_db_path)}' ) From 8580b0e0af54b028e2166e9052ad97875938d5cd Mon Sep 17 00:00:00 2001 From: Michael Mulich Date: Mon, 11 Jan 2021 19:56:30 -0800 Subject: [PATCH 221/294] Define exception for an out-of-sync table When the table defined in code differs from that in the database, this will raise an error instead of deleting the table. Seems like a not so good idea for a server application to delete it's own data. I can understand desire from the GUI app perspective, but it doesn't make sense for an unsupervised process. --- wbia/dtool/depcache_table.py | 52 ++++++++++++++++++++++++++---------- wbia/dtool/sql_control.py | 8 +++--- 2 files changed, 42 insertions(+), 18 deletions(-) diff --git a/wbia/dtool/depcache_table.py b/wbia/dtool/depcache_table.py index 2a43bcbe7a..b0934e1b28 100644 --- a/wbia/dtool/depcache_table.py +++ b/wbia/dtool/depcache_table.py @@ -59,6 +59,26 @@ GRACE_PERIOD = ut.get_argval('--grace', type_=int, default=0) +class TableOutOfSyncError(Exception): + """Raised when the code's table definition doesn't match the defition in the database""" + + def __init__(self, db, tablename, extended_msg): + db_name = db._engine.url.database + + if getattr(db, 'schema', None): + under_schema = f"under schema '{db.schema}' " + else: + # Not a table under a schema + under_schema = '' + msg = ( + f"database '{db_name}' " + + under_schema + + f"with table '{tablename}' does not match the code definition; " + f"it's likely the database needs upgraded; {extended_msg}" + ) + super().__init__(msg) + + class ExternType(ub.NiceRepr): """ Type to denote an external resource not saved in an SQL table @@ -159,14 +179,16 @@ def ensure_config_table(db): else: current_state = db.get_table_autogen_dict(CONFIG_TABLE) new_state = config_addtable_kw - if not compare_coldef_lists( + results = compare_coldef_lists( current_state['coldef_list'], new_state['coldef_list'] - ): - if predrop_grace_period(CONFIG_TABLE): - db.drop_all_tables() - db.add_table(**new_state) - else: - raise NotImplementedError('Need to be able to modify tables') + ) + if results: + current_coldef, new_coldef = results + raise TableOutOfSyncError( + db, + CONFIG_TABLE, + f'Current schema: {current_coldef} Expected schema: {new_coldef}', + ) @ut.reloadable_class @@ -1998,14 +2020,16 @@ def initialize(self, _debug=None): self.clear_table() current_state = self.db.get_table_autogen_dict(self.tablename) - if not compare_coldef_lists( + results = compare_coldef_lists( current_state['coldef_list'], new_state['coldef_list'] - ): - logger.info('WARNING TABLE IS MODIFIED') - if predrop_grace_period(self.tablename): - self.clear_table() - else: - raise NotImplementedError('Need to be able to modify tables') + ) + if results: + current_coldef, new_coldef = results + raise TableOutOfSyncError( + self.db, + self.tablename, + f'Current schema: {current_coldef} Expected schema: {new_coldef}', + ) def _get_addtable_kw(self): """ diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 3db54b6a69..656e91aa1f 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -122,17 +122,17 @@ def compare_coldef_lists(coldef_list1, coldef_list2): coldef_list1 = [(name, coldef) for name, coldef in coldef_list1 if name != 'rowid'] coldef_list2 = [(name, coldef) for name, coldef in coldef_list2 if name != 'rowid'] if len(coldef_list1) != len(coldef_list2): - return False + return coldef_list1, coldef_list2 for i in range(len(coldef_list1)): name1, coldef1 = coldef_list1[i] name2, coldef2 = coldef_list2[i] if name1 != name2: - return False + return coldef_list1, coldef_list2 coldef1 = re.sub(r' DEFAULT \(nextval\(.*', '', coldef1) coldef2 = re.sub(r' DEFAULT \(nextval\(.*', '', coldef2) if coldef1.lower() != coldef2.lower(): - return False - return True + return coldef_list1, coldef_list2 + return def _unpacker(results): From 8ffcbfeb68c8ce9ecb80cbb7db53d72699abccc3 Mon Sep 17 00:00:00 2001 From: karen chan Date: Wed, 13 Jan 2021 20:18:38 +0000 Subject: [PATCH 222/294] Compare lowercase column names in compare_coldef_lists We have a table that has "M" (uppercase) as the column name in sqlite which became "m" (lowercase) in postgresql. They should be considered the same. --- wbia/dtool/sql_control.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 656e91aa1f..e993c9b27f 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -119,8 +119,18 @@ def sqlite_uri_to_postgres_uri_schema(uri): def compare_coldef_lists(coldef_list1, coldef_list2): # Remove "rowid" which is added to postgresql tables - coldef_list1 = [(name, coldef) for name, coldef in coldef_list1 if name != 'rowid'] - coldef_list2 = [(name, coldef) for name, coldef in coldef_list2 if name != 'rowid'] + # Remove "default nextval" for postgresql auto-increment fields as sqlite + # doesn't need it + coldef_list1 = [ + (name.lower(), re.sub(r' default \(nextval\(.*', '', coldef.lower())) + for name, coldef in coldef_list1 + if name != 'rowid' + ] + coldef_list2 = [ + (name.lower(), re.sub(r' default \(nextval\(.*', '', coldef.lower())) + for name, coldef in coldef_list2 + if name != 'rowid' + ] if len(coldef_list1) != len(coldef_list2): return coldef_list1, coldef_list2 for i in range(len(coldef_list1)): @@ -128,9 +138,7 @@ def compare_coldef_lists(coldef_list1, coldef_list2): name2, coldef2 = coldef_list2[i] if name1 != name2: return coldef_list1, coldef_list2 - coldef1 = re.sub(r' DEFAULT \(nextval\(.*', '', coldef1) - coldef2 = re.sub(r' DEFAULT \(nextval\(.*', '', coldef2) - if coldef1.lower() != coldef2.lower(): + if coldef1 != coldef2: return coldef_list1, coldef_list2 return From 7a8ecb058585bb9035015a9d58745a088485c267 Mon Sep 17 00:00:00 2001 From: karen chan Date: Wed, 13 Jan 2021 20:35:53 +0000 Subject: [PATCH 223/294] Change postgres column info is_nullable to 0 or 1 0 for nullable and 1 for "not null", that's what sqlite returns. --- wbia/dtool/events.py | 2 +- wbia/dtool/sql_control.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/wbia/dtool/events.py b/wbia/dtool/events.py index 8ec92988ff..8eba20e80f 100644 --- a/wbia/dtool/events.py +++ b/wbia/dtool/events.py @@ -30,7 +30,7 @@ def _discovery_table_columns(inspector, table_name): row_number() over () - 1, column_name, coalesce(domain_name, data_type), - is_nullable, + CASE WHEN is_nullable = 'YES' THEN 0 ELSE 1 END, column_default, column_name = ( SELECT column_name diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index e993c9b27f..0ba70c107b 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -2674,7 +2674,7 @@ def get_columns(self, tablename): row_number() over () - 1, column_name, coalesce(domain_name, data_type), - is_nullable, + CASE WHEN is_nullable = 'YES' THEN 0 ELSE 1 END, column_default, column_name = ( SELECT column_name From 32d01aa73723af5fc4fd10d37adb14cbbc8e7344 Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Wed, 13 Jan 2021 13:17:29 -0800 Subject: [PATCH 224/294] WIP --- devops/provision/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/devops/provision/Dockerfile b/devops/provision/Dockerfile index 2d7d11a808..667d0a71fa 100644 --- a/devops/provision/Dockerfile +++ b/devops/provision/Dockerfile @@ -59,7 +59,7 @@ RUN set -ex \ # Clone third-party WBIA plug-in repositories RUN set -ex \ && cd /wbia \ - && git clone --recursive --branch develop https://github.com/WildbookOrg/wbia-plugin-curvrank.git \ + && git clone --recursive --branch develop https://github.com/WildbookOrg/wbia-plugin-curvrank.git /wbia/wbia-plugin-curvrank-v1 \ && cd /wbia/wbia-plugin-curvrank/wbia_curvrank \ && git fetch origin \ && git checkout develop From 2ae5594119af5a42c2e40422cce2d25d94c4769d Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Thu, 14 Jan 2021 13:56:02 -0800 Subject: [PATCH 225/294] Add version 2 of Curvrank alongside version 1 --- devops/Dockerfile | 15 +++++++++------ devops/provision/Dockerfile | 18 +++++++++++++++--- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/devops/Dockerfile b/devops/Dockerfile index bc213ae98a..b457f66179 100644 --- a/devops/Dockerfile +++ b/devops/Dockerfile @@ -12,7 +12,9 @@ RUN set -ex \ 'cd {} && cd .. && echo $(pwd) && (git stash && git pull && git stash pop || git reset --hard origin/develop)' \ && find /wbia/wildbook* -name '.git' -type d -print0 | xargs -0 -i /bin/bash -c \ 'cd {} && cd .. && echo $(pwd) && (git stash && git pull && git stash pop || git reset --hard origin/develop)' \ - && cd /wbia/wbia-plugin-curvrank/wbia_curvrank \ + && cd /wbia/wbia-plugin-curvrank-v1/wbia_curvrank \ + && git stash && git pull && git stash pop || git reset --hard origin/develop \ + && cd /wbia/wbia-plugin-curvrank-v2/wbia_curvrank_v2 \ && git stash && git pull && git stash pop || git reset --hard origin/develop \ && cd /wbia/wbia-plugin-kaggle7/wbia_kaggle7 \ && git stash && git pull && git stash pop || git reset --hard origin/develop \ @@ -27,11 +29,12 @@ RUN set -ex \ && /virtualenv/env3/bin/python -c "import wbia_cnn; from wbia_cnn.__main__ import main; main()" \ && /virtualenv/env3/bin/python -c "import wbia_pie; from wbia_pie.__main__ import main; main()" \ && /virtualenv/env3/bin/python -c "import wbia_orientation; from wbia_orientation.__main__ import main; main()" \ - && /virtualenv/env3/bin/python -c "import wbia_flukematch; from wbia_flukematch.plugin import *" \ - && /virtualenv/env3/bin/python -c "import wbia_curvrank; from wbia_curvrank._plugin import *" \ - && /virtualenv/env3/bin/python -c "import wbia_finfindr; from wbia_finfindr._plugin import *" \ - && /virtualenv/env3/bin/python -c "import wbia_kaggle7; from wbia_kaggle7._plugin import *" \ - && /virtualenv/env3/bin/python -c "import wbia_deepsense; from wbia_deepsense._plugin import *" \ + && /virtualenv/env3/bin/python -c "import wbia_flukematch; from wbia_flukematch.plugin import *" \ + && /virtualenv/env3/bin/python -c "import wbia_curvrank; from wbia_curvrank._plugin import *" \ + && /virtualenv/env3/bin/python -c "import wbia_curvrank_v2; from wbia_curvrank_v2._plugin import *" \ + && /virtualenv/env3/bin/python -c "import wbia_finfindr; from wbia_finfindr._plugin import *" \ + && /virtualenv/env3/bin/python -c "import wbia_kaggle7; from wbia_kaggle7._plugin import *" \ + && /virtualenv/env3/bin/python -c "import wbia_deepsense; from wbia_deepsense._plugin import *" \ && find /wbia/wbia* -name '*.a' -print0 | xargs -0 -i /bin/bash -c 'echo {} && ld -d {}' \ && find /wbia/wbia* -name '*.so' -print0 | xargs -0 -i /bin/bash -c 'echo {} && ld -d {}' \ && find /wbia/wildbook* -name '*.a' -print0 | xargs -0 -i /bin/bash -c 'echo {} && ld -d {}' \ diff --git a/devops/provision/Dockerfile b/devops/provision/Dockerfile index 667d0a71fa..c64bc548d7 100644 --- a/devops/provision/Dockerfile +++ b/devops/provision/Dockerfile @@ -59,8 +59,15 @@ RUN set -ex \ # Clone third-party WBIA plug-in repositories RUN set -ex \ && cd /wbia \ - && git clone --recursive --branch develop https://github.com/WildbookOrg/wbia-plugin-curvrank.git /wbia/wbia-plugin-curvrank-v1 \ - && cd /wbia/wbia-plugin-curvrank/wbia_curvrank \ + && git clone --recursive --branch develop-curvrank-v1 https://github.com/WildbookOrg/wbia-plugin-curvrank.git /wbia/wbia-plugin-curvrank-v1 \ + && cd /wbia/wbia-plugin-curvrank-v1/wbia_curvrank \ + && git fetch origin \ + && git checkout develop + +RUN set -ex \ + && cd /wbia \ + && git clone --recursive --branch develop-curvrank-v2 https://github.com/WildbookOrg/wbia-plugin-curvrank.git /wbia/wbia-plugin-curvrank-v2 \ + && cd /wbia/wbia-plugin-curvrank-v2/wbia_curvrank_v2 \ && git fetch origin \ && git checkout develop @@ -152,7 +159,12 @@ RUN /bin/bash -xc '. /virtualenv/env3/bin/activate \ && pip install -e .' RUN /bin/bash -xc '. /virtualenv/env3/bin/activate \ - && cd /wbia/wbia-plugin-curvrank \ + && cd /wbia/wbia-plugin-curvrank-v1 \ + && ./unix_build.sh \ + && pip install -e .' + +RUN /bin/bash -xc '. /virtualenv/env3/bin/activate \ + && cd /wbia/wbia-plugin-curvrank-v2 \ && ./unix_build.sh \ && pip install -e .' From 83566ff7ed6661fc0076115b5caae79bdc6e16cf Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Wed, 20 Jan 2021 21:51:42 -0800 Subject: [PATCH 226/294] Add new configurations for CurvRankV2 --- wbia/control/IBEISControl.py | 5 +++++ wbia/control/manual_annot_funcs.py | 8 ++++++++ wbia/other/ibsfuncs.py | 1 + wbia/scripts/specialdraw.py | 6 ++++++ wbia/viz/viz_matches.py | 4 ++++ wbia/web/apis_engine.py | 2 ++ wbia/web/apis_microsoft.py | 10 ++++++++-- wbia/web/apis_query.py | 6 ++++++ 8 files changed, 40 insertions(+), 2 deletions(-) diff --git a/wbia/control/IBEISControl.py b/wbia/control/IBEISControl.py index d8ec608881..25840680aa 100644 --- a/wbia/control/IBEISControl.py +++ b/wbia/control/IBEISControl.py @@ -117,6 +117,11 @@ (('--no-curvrank', '--nocurvrank'), 'wbia_curvrank._plugin'), ] +if ut.get_argflag('--curvrank-v2'): + AUTOLOAD_PLUGIN_MODNAMES += [ + (('--no-curvrank-v2', '--nocurvrankv2'), 'wbia_curvrank_v2._plugin'), + ] + if ut.get_argflag('--deepsense'): AUTOLOAD_PLUGIN_MODNAMES += [ (('--no-deepsense', '--nodeepsense'), 'wbia_deepsense._plugin'), diff --git a/wbia/control/manual_annot_funcs.py b/wbia/control/manual_annot_funcs.py index 070da45f0a..3b7aac13ef 100644 --- a/wbia/control/manual_annot_funcs.py +++ b/wbia/control/manual_annot_funcs.py @@ -2329,6 +2329,14 @@ def set_annot_viewpoints( message = 'Could not purge CurvRankDorsal cache for viewpoint' # raise RuntimeError(message) logger.info(message) + try: + ibs.wbia_plugin_curvrank_v2_delete_cache_optimized( + update_aid_list, 'CurvRankTwoDorsal' + ) + except Exception: + message = 'Could not purge CurvRankTwoDorsal cache for viewpoint' + # raise RuntimeError(message) + logger.info(message) try: ibs.wbia_plugin_curvrank_delete_cache_optimized( update_aid_list, 'CurvRankFinfindrHybridDorsal' diff --git a/wbia/other/ibsfuncs.py b/wbia/other/ibsfuncs.py index 1a1812f23f..ab86e92ee6 100644 --- a/wbia/other/ibsfuncs.py +++ b/wbia/other/ibsfuncs.py @@ -1310,6 +1310,7 @@ def check_cache_purge(ibs, ttl_days=90, dryrun=True, squeeze=True): './_ibsdb/_ibeis_cache/match_thumbs', './_ibsdb/_ibeis_cache/qres_new', './_ibsdb/_ibeis_cache/curvrank', + './_ibsdb/_ibeis_cache/curvrank_v2', './_ibsdb/_ibeis_cache/pie_neighbors', ] diff --git a/wbia/scripts/specialdraw.py b/wbia/scripts/specialdraw.py index ee6feccb24..0262a46536 100644 --- a/wbia/scripts/specialdraw.py +++ b/wbia/scripts/specialdraw.py @@ -289,6 +289,9 @@ def double_depcache_graph(): 'CurvRankDorsal': 'CurvRank (Dorsal) Distance', 'CurvRankFinfindrHybridDorsal': 'CurvRank + FinFindR Hybrid (Dorsal) Distance', 'CurvRankFluke': 'CurvRank (Fluke) Distance', + 'CurvRankTwo': 'CurvRank V2 Distance', + 'CurvRankTwoDorsal': 'CurvRank V2 (Dorsal) Distance', + 'CurvRankTwoFluke': 'CurvRank V2 (Fluke) Distance', 'Deepsense': 'Deepsense Distance', 'Pie': 'Pie Distance', 'Finfindr': 'Finfindr Distance', @@ -316,6 +319,9 @@ def double_depcache_graph(): 'CurvRankDorsal': 'curvrank_distance_dorsal', 'CurvRankFinfindrHybridDorsal': 'curvrank_finfindr_hybrid_distance_dorsal', 'CurvRankFluke': 'curvrank_distance_fluke', + 'CurvRankTwo': 'curvrank_two_distance', + 'CurvRankTwoDorsal': 'curvrank_two_distance_dorsal', + 'CurvRankTwoFluke': 'curvrank_two_distance_fluke', 'Deepsense': 'deepsense_distance', 'Pie': 'pie_distance', 'Finfindr': 'finfindr_distance', diff --git a/wbia/viz/viz_matches.py b/wbia/viz/viz_matches.py index 01710ad9b4..78b7e52837 100644 --- a/wbia/viz/viz_matches.py +++ b/wbia/viz/viz_matches.py @@ -41,6 +41,8 @@ def get_query_annot_pair_info( 'curvrankdorsal', 'curvrankfinfindrhybriddorsal', 'curvrankfluke', + 'curvranktwodorsal', + 'curvranktwofluke', 'deepsense', 'finfindr', 'kaggle7', @@ -86,6 +88,8 @@ def get_data_annot_pair_info( 'curvrankdorsal', 'curvrankfinfindrhybriddorsal', 'curvrankfluke', + 'curvranktwodorsal', + 'curvranktwofluke', 'deepsense', 'finfindr', 'kaggle7', diff --git a/wbia/web/apis_engine.py b/wbia/web/apis_engine.py index c1646b8a0d..aa7700d929 100644 --- a/wbia/web/apis_engine.py +++ b/wbia/web/apis_engine.py @@ -535,6 +535,8 @@ def sanitize(state): 'curvrankdorsal', 'curvrankfinfindrhybriddorsal', 'curvrankfluke', + 'curvranktwodorsal', + 'curvranktwofluke', ): curvrank_daily_tag = query_config_dict.get('curvrank_daily_tag', '') if len(curvrank_daily_tag) > 144: diff --git a/wbia/web/apis_microsoft.py b/wbia/web/apis_microsoft.py index 5e11535d6a..904f348751 100644 --- a/wbia/web/apis_microsoft.py +++ b/wbia/web/apis_microsoft.py @@ -1060,7 +1060,7 @@ def microsoft_identify( $ref: "#/definitions/Annotation" - name: algorithm in: formData - description: The algorithm you with to run ID with. Must be one of "HotSpotter", "CurvRank", "Finfindr", or "Deepsense" + description: The algorithm you with to run ID with. Must be one of "HotSpotter", "CurvRank", "CurvRankTwo", "Finfindr", or "Deepsense" required: true type: string - name: callback_url @@ -1104,11 +1104,13 @@ def microsoft_identify( assert algorithm in [ 'hotspotter', 'curvrank', + 'curvrank_v2', + 'curvrankv2', 'deepsense', 'finfindr', 'kaggle7', 'kaggleseven', - ], 'Must specify the algorithm for ID as HotSpotter, CurvRank, Deepsense, Finfindr, Kaggle7' + ], 'Must specify the algorithm for ID as HotSpotter, CurvRank, CurvRankTwo, Deepsense, Finfindr, Kaggle7' parameter = 'callback_url' assert callback_url is None or isinstance( @@ -1140,6 +1142,10 @@ def microsoft_identify( query_config_dict = { 'pipeline_root': 'CurvRankFluke', } + elif algorithm in ['curvrank_v2', 'curvrankv2']: + query_config_dict = { + 'pipeline_root': 'CurvRankTwoFluke', + } elif algorithm in ['deepsense']: query_config_dict = { 'pipeline_root': 'Deepsense', diff --git a/wbia/web/apis_query.py b/wbia/web/apis_query.py index 56048f308d..d8316fe153 100644 --- a/wbia/web/apis_query.py +++ b/wbia/web/apis_query.py @@ -437,6 +437,8 @@ def review_graph_match_html( 'curvrankdorsal', 'curvrankfinfindrhybriddorsal', 'curvrankfluke', + 'curvranktwodorsal', + 'curvranktwofluke', 'deepsense', 'finfindr', 'kaggle7', @@ -562,6 +564,10 @@ def review_query_chips_test(**kwargs): query_config_dict = {'pipeline_root': 'CurvRankFinfindrHybridDorsal'} elif 'use_curvrank_fluke' in request.args: query_config_dict = {'pipeline_root': 'CurvRankFluke'} + elif 'use_curvrank_v2_dorsal' in request.args: + query_config_dict = {'pipeline_root': 'CurvRankTwoDorsal'} + elif 'use_curvrank_v2_fluke' in request.args: + query_config_dict = {'pipeline_root': 'CurvRankTwoFluke'} elif 'use_deepsense' in request.args: query_config_dict = {'pipeline_root': 'Deepsense'} elif 'use_finfindr' in request.args: From 96eed03191853767d32bad9e675a560e59a76af8 Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Fri, 22 Jan 2021 17:00:17 -0800 Subject: [PATCH 227/294] Omit cpython so files --- devops/Dockerfile | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/devops/Dockerfile b/devops/Dockerfile index b457f66179..3e3c24a420 100644 --- a/devops/Dockerfile +++ b/devops/Dockerfile @@ -35,10 +35,10 @@ RUN set -ex \ && /virtualenv/env3/bin/python -c "import wbia_finfindr; from wbia_finfindr._plugin import *" \ && /virtualenv/env3/bin/python -c "import wbia_kaggle7; from wbia_kaggle7._plugin import *" \ && /virtualenv/env3/bin/python -c "import wbia_deepsense; from wbia_deepsense._plugin import *" \ - && find /wbia/wbia* -name '*.a' -print0 | xargs -0 -i /bin/bash -c 'echo {} && ld -d {}' \ - && find /wbia/wbia* -name '*.so' -print0 | xargs -0 -i /bin/bash -c 'echo {} && ld -d {}' \ - && find /wbia/wildbook* -name '*.a' -print0 | xargs -0 -i /bin/bash -c 'echo {} && ld -d {}' \ - && find /wbia/wildbook* -name '*.so' -print0 | xargs -0 -i /bin/bash -c 'echo {} && ld -d {}' + && find /wbia/wbia* -name '*.a' -print | grep -v "cpython-37m-x86_64-linux-gnu" | xargs -i /bin/bash -c 'echo {} && ld -d {}' \ + && find /wbia/wbia* -name '*.so' -print | grep -v "cpython-37m-x86_64-linux-gnu" | xargs -i /bin/bash -c 'echo {} && ld -d {}' \ + && find /wbia/wildbook* -name '*.a' -print | grep -v "cpython-37m-x86_64-linux-gnu" | xargs -i /bin/bash -c 'echo {} && ld -d {}' \ + && find /wbia/wildbook* -name '*.so' -print | grep -v "cpython-37m-x86_64-linux-gnu" | xargs -i /bin/bash -c 'echo {} && ld -d {}' \ ########################################################################################## From bbc71267df26b3b4ae15285f672e447fcae222df Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Sun, 24 Jan 2021 23:22:33 -0800 Subject: [PATCH 228/294] Small bug fixes and upgrade to DOCKER_BUILDKIT --- .github/workflows/docker-publish.yaml | 2 +- .github/workflows/nightly.yml | 2 +- devops/Dockerfile | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/docker-publish.yaml b/.github/workflows/docker-publish.yaml index ec85279fd2..2671f3ec62 100644 --- a/.github/workflows/docker-publish.yaml +++ b/.github/workflows/docker-publish.yaml @@ -23,7 +23,7 @@ jobs: # Build images - name: Build images - run: bash devops/build.sh ${{ matrix.images }} + run: DOCKER_BUILDKIT=1 bash devops/build.sh ${{ matrix.images }} # Log into container registries - name: Log into Docker Hub diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 704544af17..e6c34c0986 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -34,7 +34,7 @@ jobs: # Build images - name: Build images - run: bash devops/build.sh ${{ matrix.images }} + run: DOCKER_BUILDKIT=1 bash devops/build.sh ${{ matrix.images }} # Log into image registries - name: Log into Docker Hub diff --git a/devops/Dockerfile b/devops/Dockerfile index 3e3c24a420..63f72db5f4 100644 --- a/devops/Dockerfile +++ b/devops/Dockerfile @@ -38,7 +38,7 @@ RUN set -ex \ && find /wbia/wbia* -name '*.a' -print | grep -v "cpython-37m-x86_64-linux-gnu" | xargs -i /bin/bash -c 'echo {} && ld -d {}' \ && find /wbia/wbia* -name '*.so' -print | grep -v "cpython-37m-x86_64-linux-gnu" | xargs -i /bin/bash -c 'echo {} && ld -d {}' \ && find /wbia/wildbook* -name '*.a' -print | grep -v "cpython-37m-x86_64-linux-gnu" | xargs -i /bin/bash -c 'echo {} && ld -d {}' \ - && find /wbia/wildbook* -name '*.so' -print | grep -v "cpython-37m-x86_64-linux-gnu" | xargs -i /bin/bash -c 'echo {} && ld -d {}' \ + && find /wbia/wildbook* -name '*.so' -print | grep -v "cpython-37m-x86_64-linux-gnu" | xargs -i /bin/bash -c 'echo {} && ld -d {}' ########################################################################################## From b2a19c3ee41752875c47990cb8de26b44e34b13f Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Mon, 25 Jan 2021 10:11:23 -0800 Subject: [PATCH 229/294] Updated proper checkout --- devops/Dockerfile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/devops/Dockerfile b/devops/Dockerfile index 63f72db5f4..8421ffc71b 100644 --- a/devops/Dockerfile +++ b/devops/Dockerfile @@ -20,6 +20,8 @@ RUN set -ex \ && git stash && git pull && git stash pop || git reset --hard origin/develop \ && cd /wbia/wbia-plugin-orientation/ \ && git stash && git pull && git stash pop || git reset --hard origin/develop \ + && cd /wbia/wildbook-ia/ \ + && git checkout curvrank-v2 \ && find /wbia -name '.git' -type d -print0 | xargs -0 rm -rf \ && find /wbia -name '_skbuild' -type d -print0 | xargs -0 rm -rf From 098f9a40d8f2b23b753869b7b9f503cb32db1ae3 Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Mon, 25 Jan 2021 10:13:48 -0800 Subject: [PATCH 230/294] Use docekr BuildKit universally --- .github/workflows/docker-publish.yaml | 2 +- .github/workflows/nightly.yml | 2 +- devops/build.sh | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docker-publish.yaml b/.github/workflows/docker-publish.yaml index 2671f3ec62..ec85279fd2 100644 --- a/.github/workflows/docker-publish.yaml +++ b/.github/workflows/docker-publish.yaml @@ -23,7 +23,7 @@ jobs: # Build images - name: Build images - run: DOCKER_BUILDKIT=1 bash devops/build.sh ${{ matrix.images }} + run: bash devops/build.sh ${{ matrix.images }} # Log into container registries - name: Log into Docker Hub diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index e6c34c0986..704544af17 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -34,7 +34,7 @@ jobs: # Build images - name: Build images - run: DOCKER_BUILDKIT=1 bash devops/build.sh ${{ matrix.images }} + run: bash devops/build.sh ${{ matrix.images }} # Log into image registries - name: Log into Docker Hub diff --git a/devops/build.sh b/devops/build.sh index 2bb012968c..43354f7e87 100755 --- a/devops/build.sh +++ b/devops/build.sh @@ -5,6 +5,8 @@ set -ex # See https://stackoverflow.com/a/246128/176882 export ROOT_LOC="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +export DOCKER_BUILDKIT=1 + # Change to the script's root directory location cd ${ROOT_LOC} From 2fbe6e7f1a1d569985b3e77001e01c5c0256d733 Mon Sep 17 00:00:00 2001 From: Jason Parham Date: Tue, 26 Jan 2021 12:54:45 -0800 Subject: [PATCH 231/294] pdated WildbookOrg to WildMeOrg, added new research area and updated dbinfo --- README.rst | 70 ++++---- _dev/super_setup_old.py | 44 +++-- devops/Dockerfile | 2 +- devops/install.ubuntu.sh | 34 ++-- devops/provision/Dockerfile | 40 ++--- devops/publish.sh | 2 +- docs/index.rst | 2 +- setup.py | 4 +- super_setup.py | 36 ++-- wbia/__init__.py | 1 + wbia/algo/graph/mixin_loops.py | 28 ++- wbia/algo/preproc/preproc_occurrence.py | 2 +- wbia/control/IBEISControl.py | 1 + wbia/control/manual_meta_funcs.py | 1 + wbia/dtool/sql_control.py | 25 ++- wbia/other/dbinfo.py | 229 +++++++++++++++++++----- wbia/other/ibsfuncs.py | 4 +- wbia/research/__init__.py | 7 + wbia/research/metrics.py | 60 +++++++ wbia/web/routes.py | 18 ++ wbia/web/templates/index.html | 2 +- 21 files changed, 433 insertions(+), 179 deletions(-) create mode 100644 wbia/research/__init__.py create mode 100644 wbia/research/metrics.py diff --git a/README.rst b/README.rst index a61e18a7ec..722436665f 100644 --- a/README.rst +++ b/README.rst @@ -12,7 +12,7 @@ use in computer vision algorithms. It aims to compute who an animal is, what species an animal is, and where an animal is with the ultimate goal being to ask important why biological questions. -This project is the Machine Learning (ML) / computer vision component of the WildBook project: See https://github.com/WildbookOrg/. This project is an actively maintained fork of the popular IBEIS (Image Based Ecological Information System) software suite for wildlife conservation. The original IBEIS project is maintained by Jon Crall (@Erotemic) at https://github.com/Erotemic/ibeis. The IBEIS toolkit originally was a wrapper around HotSpotter, which original binaries can be downloaded from: http://cs.rpi.edu/hotspotter/ +This project is the Machine Learning (ML) / computer vision component of the WildBook project: See https://github.com/WildMeOrg/. This project is an actively maintained fork of the popular IBEIS (Image Based Ecological Information System) software suite for wildlife conservation. The original IBEIS project is maintained by Jon Crall (@Erotemic) at https://github.com/Erotemic/ibeis. The IBEIS toolkit originally was a wrapper around HotSpotter, which original binaries can be downloaded from: http://cs.rpi.edu/hotspotter/ Currently the system is build around and SQLite database, a web GUI, and matplotlib visualizations. Algorithms employed are: convolutional neural network @@ -53,7 +53,7 @@ We highly recommend using a Python virtual environment: https://docs.python-guid Documentation ~~~~~~~~~~~~~ -The documentation is built and available online at `wildbookorg.github.io/wildbook-ia/ `_. However, if you need to build a local copy of the source, the following instructions can be used. +The documentation is built and available online at `WildMeOrg.github.io/wildbook-ia/ `_. However, if you need to build a local copy of the source, the following instructions can be used. .. code:: bash @@ -87,76 +87,76 @@ This project depends on an array of other repositories for functionality. First Party Toolkits (Required) -* https://github.com/WildbookOrg/wbia-utool +* https://github.com/WildMeOrg/wbia-utool -* https://github.com/WildbookOrg/wbia-vtool +* https://github.com/WildMeOrg/wbia-vtool First Party Dependencies for Third Party Libraries (Required) -* https://github.com/WildbookOrg/wbia-tpl-pyhesaff +* https://github.com/WildMeOrg/wbia-tpl-pyhesaff -* https://github.com/WildbookOrg/wbia-tpl-pyflann +* https://github.com/WildMeOrg/wbia-tpl-pyflann -* https://github.com/WildbookOrg/wbia-tpl-pydarknet +* https://github.com/WildMeOrg/wbia-tpl-pydarknet -* https://github.com/WildbookOrg/wbia-tpl-pyrf +* https://github.com/WildMeOrg/wbia-tpl-pyrf First Party Plug-ins (Optional) -* https://github.com/WildbookOrg/wbia-plugin-cnn +* https://github.com/WildMeOrg/wbia-plugin-cnn -* https://github.com/WildbookOrg/wbia-plugin-flukematch +* https://github.com/WildMeOrg/wbia-plugin-flukematch -* https://github.com/WildbookOrg/wbia-plugin-deepsense +* https://github.com/WildMeOrg/wbia-plugin-deepsense -* https://github.com/WildbookOrg/wbia-plugin-finfindr +* https://github.com/WildMeOrg/wbia-plugin-finfindr -* https://github.com/WildbookOrg/wbia-plugin-curvrank +* https://github.com/WildMeOrg/wbia-plugin-curvrank - + https://github.com/WildbookOrg/wbia-tpl-curvrank + + https://github.com/WildMeOrg/wbia-tpl-curvrank -* https://github.com/WildbookOrg/wbia-plugin-kaggle7 +* https://github.com/WildMeOrg/wbia-plugin-kaggle7 - + https://github.com/WildbookOrg/wbia-tpl-kaggle7 + + https://github.com/WildMeOrg/wbia-tpl-kaggle7 -* https://github.com/WildbookOrg/wbia-plugin-2d-orientation +* https://github.com/WildMeOrg/wbia-plugin-2d-orientation - + https://github.com/WildbookOrg/wbia-tpl-2d-orientation + + https://github.com/WildMeOrg/wbia-tpl-2d-orientation -* https://github.com/WildbookOrg/wbia-plugin-lca +* https://github.com/WildMeOrg/wbia-plugin-lca - + https://github.com/WildbookOrg/wbia-tpl-lca + + https://github.com/WildMeOrg/wbia-tpl-lca Deprecated Toolkits (Deprecated) -* https://github.com/WildbookOrg/wbia-deprecate-ubelt +* https://github.com/WildMeOrg/wbia-deprecate-ubelt -* https://github.com/WildbookOrg/wbia-deprecate-dtool +* https://github.com/WildMeOrg/wbia-deprecate-dtool -* https://github.com/WildbookOrg/wbia-deprecate-guitool +* https://github.com/WildMeOrg/wbia-deprecate-guitool -* https://github.com/WildbookOrg/wbia-deprecate-plottool +* https://github.com/WildMeOrg/wbia-deprecate-plottool -* https://github.com/WildbookOrg/wbia-deprecate-detecttools +* https://github.com/WildMeOrg/wbia-deprecate-detecttools -* https://github.com/WildbookOrg/wbia-deprecate-plugin-humpbacktl +* https://github.com/WildMeOrg/wbia-deprecate-plugin-humpbacktl -* https://github.com/WildbookOrg/wbia-deprecate-tpl-lightnet +* https://github.com/WildMeOrg/wbia-deprecate-tpl-lightnet -* https://github.com/WildbookOrg/wbia-deprecate-tpl-brambox +* https://github.com/WildMeOrg/wbia-deprecate-tpl-brambox Plug-in Templates (Reference) -* https://github.com/WildbookOrg/wbia-plugin-template +* https://github.com/WildMeOrg/wbia-plugin-template -* https://github.com/WildbookOrg/wbia-plugin-id-example +* https://github.com/WildMeOrg/wbia-plugin-id-example Miscellaneous (Reference) -* https://github.com/WildbookOrg/wbia-pypkg-build +* https://github.com/WildMeOrg/wbia-pypkg-build -* https://github.com/WildbookOrg/wbia-project-website +* https://github.com/WildMeOrg/wbia-project-website -* https://github.com/WildbookOrg/wbia-aws-codedeploy +* https://github.com/WildMeOrg/wbia-aws-codedeploy Citation -------- @@ -259,8 +259,8 @@ To run doctests with `+REQUIRES(--web-tests)` do: pytest --web-tests -.. |Build| image:: https://img.shields.io/github/workflow/status/WildbookOrg/wildbook-ia/Build%20and%20upload%20to%20PyPI/master - :target: https://github.com/WildbookOrg/wildbook-ia/actions?query=branch%3Amaster+workflow%3A%22Build+and+upload+to+PyPI%22 +.. |Build| image:: https://img.shields.io/github/workflow/status/WildMeOrg/wildbook-ia/Build%20and%20upload%20to%20PyPI/master + :target: https://github.com/WildMeOrg/wildbook-ia/actions?query=branch%3Amaster+workflow%3A%22Build+and+upload+to+PyPI%22 :alt: Build and upload to PyPI (master) .. |Pypi| image:: https://img.shields.io/pypi/v/wildbook-ia.svg diff --git a/_dev/super_setup_old.py b/_dev/super_setup_old.py index 67cfb72393..c1d826578c 100755 --- a/_dev/super_setup_old.py +++ b/_dev/super_setup_old.py @@ -11,7 +11,7 @@ export CODE_DIR=~/code mkdir $CODE_DIR cd $CODE_DIR -git clone https://github.com/WildbookOrg/wbia.git +git clone https://github.com/WildMeOrg/wbia.git cd wbia python super_setup.py --bootstrap @@ -311,7 +311,7 @@ def ensure_utool(CODE_DIR, pythoncmd): WIN32 = sys.platform.startswith('win32') # UTOOL_BRANCH = ' -b ' UTOOL_BRANCH = 'next' - UTOOL_REPO = 'https://github.com/WildbookOrg/utool.git' + UTOOL_REPO = 'https://github.com/WildMeOrg/utool.git' print('WARNING: utool is not found') print('Attempting to get utool. Enter (y) to continue') @@ -370,8 +370,8 @@ def initialize_repo_managers(CODE_DIR, pythoncmd, PY2, PY3): # IBEIS project repos # ----------- # if True: - # jon_repo_base = 'https://github.com/WildbookOrg' - # jason_repo_base = 'https://github.com/WildbookOrg' + # jon_repo_base = 'https://github.com/WildMeOrg' + # jason_repo_base = 'https://github.com/WildMeOrg' # else: # jon_repo_base = 'https://github.com/wildme' # jason_repo_base = 'https://github.com/wildme' @@ -381,12 +381,12 @@ def initialize_repo_managers(CODE_DIR, pythoncmd, PY2, PY3): wbia_rman = ut.RepoManager( [ - 'https://github.com/WildbookOrg/utool.git', - # 'https://github.com/WildbookOrg/sandbox_utools.git', - 'https://github.com/WildbookOrg/vtool_ibeis.git', - 'https://github.com/WildbookOrg/dtool_ibeis.git', + 'https://github.com/WildMeOrg/utool.git', + # 'https://github.com/WildMeOrg/sandbox_utools.git', + 'https://github.com/WildMeOrg/vtool_ibeis.git', + 'https://github.com/WildMeOrg/dtool_ibeis.git', 'https://github.com/Erotemic/ubelt.git', - 'https://github.com/WildbookOrg/detecttools.git', + 'https://github.com/WildMeOrg/detecttools.git', ], CODE_DIR, label='core', @@ -399,24 +399,24 @@ def initialize_repo_managers(CODE_DIR, pythoncmd, PY2, PY3): tpl_rman.add_repo(cv_repo) if WITH_GUI: - wbia_rman.add_repos(['https://github.com/WildbookOrg/plottool_ibeis.git']) + wbia_rman.add_repos(['https://github.com/WildMeOrg/plottool_ibeis.git']) if WITH_QT: - wbia_rman.add_repos(['https://github.com/WildbookOrg/guitool_ibeis.git']) + wbia_rman.add_repos(['https://github.com/WildMeOrg/guitool_ibeis.git']) tpl_rman.add_repo(ut.Repo(modname=('PyQt4', 'PyQt5', 'PyQt'))) if WITH_CUSTOM_TPL: flann_repo = ut.Repo( - 'https://github.com/WildbookOrg/flann.git', CODE_DIR, modname='pyflann' + 'https://github.com/WildMeOrg/flann.git', CODE_DIR, modname='pyflann' ) wbia_rman.add_repo(flann_repo) - wbia_rman.add_repos(['https://github.com/WildbookOrg/hesaff.git']) + wbia_rman.add_repos(['https://github.com/WildMeOrg/hesaff.git']) if WITH_CNN: wbia_rman.add_repos( [ - 'https://github.com/WildbookOrg/wbia_cnn.git', - 'https://github.com/WildbookOrg/pydarknet.git', + 'https://github.com/WildMeOrg/wbia_cnn.git', + 'https://github.com/WildMeOrg/pydarknet.git', ] ) # NEW CNN Dependencies @@ -433,28 +433,26 @@ def initialize_repo_managers(CODE_DIR, pythoncmd, PY2, PY3): ) if WITH_FLUKEMATCH: - wbia_rman.add_repos( - ['https://github.com/WildbookOrg/ibeis-flukematch-module.git'] - ) + wbia_rman.add_repos(['https://github.com/WildMeOrg/ibeis-flukematch-module.git']) if WITH_CURVRANK: - wbia_rman.add_repos(['https://github.com/WildbookOrg/ibeis-curvrank-module.git']) + wbia_rman.add_repos(['https://github.com/WildMeOrg/ibeis-curvrank-module.git']) if WITH_PYRF: - wbia_rman.add_repos(['https://github.com/WildbookOrg/pyrf.git']) + wbia_rman.add_repos(['https://github.com/WildMeOrg/pyrf.git']) if False: # Depricated wbia_rman.add_repos( [ - # 'https://github.com/WildbookOrg/pybing.git', + # 'https://github.com/WildMeOrg/pybing.git', # 'https://github.com/aweinstock314/cyth.git', # 'https://github.com/hjweide/pygist', ] ) # Add main repo (Must be checked last due to dependency issues) - wbia_rman.add_repos(['https://github.com/WildbookOrg/wbia.git']) + wbia_rman.add_repos(['https://github.com/WildMeOrg/wbia.git']) # ----------- # Custom third party build/install scripts @@ -1005,7 +1003,7 @@ def GET_ARGFLAG(arg, *args, **kwargs): def move_wildme(wbia_rman, fmt): - wildme_user = 'WildbookOrg' + wildme_user = 'WildMeOrg' wildme_remote = 'wildme' for repo in wbia_rman.repos: diff --git a/devops/Dockerfile b/devops/Dockerfile index 8421ffc71b..f9adfab13b 100644 --- a/devops/Dockerfile +++ b/devops/Dockerfile @@ -50,7 +50,7 @@ LABEL autoheal=true ARG VERSION="3.3.0" -ARG VCS_URL="https://github.com/WildbookOrg/wildbook-ia" +ARG VCS_URL="https://github.com/WildMeOrg/wildbook-ia" ARG VCS_REF="develop" diff --git a/devops/install.ubuntu.sh b/devops/install.ubuntu.sh index ff8a6922cc..6d2bc741dc 100644 --- a/devops/install.ubuntu.sh +++ b/devops/install.ubuntu.sh @@ -181,35 +181,35 @@ pip install pygraphviz --install-option="--include-path=/usr/include/graphviz" - cp -r ${VIRTUAL_ENV}/lib/python3.7/site-packages/cv2 /tmp/cv2 cd ${CODE} -git clone --branch develop https://github.com/WildbookOrg/wildbook-ia.git -git clone --branch develop https://github.com/WildbookOrg/wbia-utool.git -git clone --branch develop https://github.com/WildbookOrg/wbia-vtool.git -git clone --branch develop https://github.com/WildbookOrg/wbia-tpl-pyhesaff.git -git clone --branch develop https://github.com/WildbookOrg/wbia-tpl-pyflann.git -git clone --branch develop https://github.com/WildbookOrg/wbia-tpl-pydarknet.git -git clone --branch develop https://github.com/WildbookOrg/wbia-tpl-pyrf.git -git clone --branch develop https://github.com/WildbookOrg/wbia-deprecate-tpl-brambox -git clone --branch develop https://github.com/WildbookOrg/wbia-deprecate-tpl-lightnet -git clone --recursive --branch develop https://github.com/WildbookOrg/wbia-plugin-cnn.git -git clone --branch develop https://github.com/WildbookOrg/wbia-plugin-flukematch.git -git clone --branch develop https://github.com/WildbookOrg/wbia-plugin-finfindr.git -git clone --branch develop https://github.com/WildbookOrg/wbia-plugin-deepsense.git -git clone --branch develop https://github.com/WildbookOrg/wbia-plugin-pie.git +git clone --branch develop https://github.com/WildMeOrg/wildbook-ia.git +git clone --branch develop https://github.com/WildMeOrg/wbia-utool.git +git clone --branch develop https://github.com/WildMeOrg/wbia-vtool.git +git clone --branch develop https://github.com/WildMeOrg/wbia-tpl-pyhesaff.git +git clone --branch develop https://github.com/WildMeOrg/wbia-tpl-pyflann.git +git clone --branch develop https://github.com/WildMeOrg/wbia-tpl-pydarknet.git +git clone --branch develop https://github.com/WildMeOrg/wbia-tpl-pyrf.git +git clone --branch develop https://github.com/WildMeOrg/wbia-deprecate-tpl-brambox +git clone --branch develop https://github.com/WildMeOrg/wbia-deprecate-tpl-lightnet +git clone --recursive --branch develop https://github.com/WildMeOrg/wbia-plugin-cnn.git +git clone --branch develop https://github.com/WildMeOrg/wbia-plugin-flukematch.git +git clone --branch develop https://github.com/WildMeOrg/wbia-plugin-finfindr.git +git clone --branch develop https://github.com/WildMeOrg/wbia-plugin-deepsense.git +git clone --branch develop https://github.com/WildMeOrg/wbia-plugin-pie.git cd ${CODE} -git clone --recursive --branch develop https://github.com/WildbookOrg/wbia-plugin-curvrank.git +git clone --recursive --branch develop https://github.com/WildMeOrg/wbia-plugin-curvrank.git cd wbia-plugin-curvrank/wbia_curvrank git fetch origin git checkout develop cd ${CODE} -git clone --recursive --branch develop https://github.com/WildbookOrg/wbia-plugin-kaggle7.git +git clone --recursive --branch develop https://github.com/WildMeOrg/wbia-plugin-kaggle7.git cd wbia-plugin-kaggle7/wbia_kaggle7 git fetch origin git checkout develop cd ${CODE} -git clone --recursive --branch develop https://github.com/WildbookOrg/wbia-plugin-lca.git +git clone --recursive --branch develop https://github.com/WildMeOrg/wbia-plugin-lca.git cd ${CODE}/wbia-utool ./run_developer_setup.sh diff --git a/devops/provision/Dockerfile b/devops/provision/Dockerfile index c64bc548d7..2799c67308 100644 --- a/devops/provision/Dockerfile +++ b/devops/provision/Dockerfile @@ -3,7 +3,7 @@ ARG WBIA_DEPENDENCIES_IMAGE=wildme/wbia-dependencies:latest FROM ${WBIA_DEPENDENCIES_IMAGE} as org.wildme.wbia.provision # Wildbook IA version -ARG VCS_URL="https://github.com/WildbookOrg/wildbook-ia" +ARG VCS_URL="https://github.com/WildMeOrg/wildbook-ia" ARG VCS_REF="develop" @@ -23,57 +23,57 @@ RUN set -ex \ # Clone WBIA toolkit repositories RUN set -ex \ && cd /wbia \ - && git clone --branch develop https://github.com/WildbookOrg/wbia-utool.git \ - && git clone --branch develop https://github.com/WildbookOrg/wbia-vtool.git + && git clone --branch develop https://github.com/WildMeOrg/wbia-utool.git \ + && git clone --branch develop https://github.com/WildMeOrg/wbia-vtool.git # Clone WBIA third-party toolkit repositories RUN set -ex \ && cd /wbia \ - && git clone --branch develop https://github.com/WildbookOrg/wbia-tpl-pyhesaff.git \ - && git clone --branch develop https://github.com/WildbookOrg/wbia-tpl-pyflann.git \ - && git clone --branch develop https://github.com/WildbookOrg/wbia-tpl-pydarknet.git \ - && git clone --branch develop https://github.com/WildbookOrg/wbia-tpl-pyrf.git \ + && git clone --branch develop https://github.com/WildMeOrg/wbia-tpl-pyhesaff.git \ + && git clone --branch develop https://github.com/WildMeOrg/wbia-tpl-pyflann.git \ + && git clone --branch develop https://github.com/WildMeOrg/wbia-tpl-pydarknet.git \ + && git clone --branch develop https://github.com/WildMeOrg/wbia-tpl-pyrf.git \ # Depricated - && git clone --branch develop https://github.com/WildbookOrg/wbia-deprecate-tpl-brambox \ - && git clone --branch develop https://github.com/WildbookOrg/wbia-deprecate-tpl-lightnet + && git clone --branch develop https://github.com/WildMeOrg/wbia-deprecate-tpl-brambox \ + && git clone --branch develop https://github.com/WildMeOrg/wbia-deprecate-tpl-lightnet # Clone first-party WBIA plug-in repositories RUN set -ex \ && cd /wbia \ - && git clone --recursive --branch develop https://github.com/WildbookOrg/wbia-plugin-cnn.git + && git clone --recursive --branch develop https://github.com/WildMeOrg/wbia-plugin-cnn.git RUN set -ex \ && cd /wbia \ - && git clone --branch develop https://github.com/WildbookOrg/wbia-plugin-flukematch.git \ - && git clone --branch develop https://github.com/WildbookOrg/wbia-plugin-finfindr.git \ - && git clone --branch develop https://github.com/WildbookOrg/wbia-plugin-deepsense.git \ - && git clone --branch develop https://github.com/WildbookOrg/wbia-plugin-pie.git \ - && git clone --branch develop https://github.com/WildbookOrg/wbia-plugin-lca.git + && git clone --branch develop https://github.com/WildMeOrg/wbia-plugin-flukematch.git \ + && git clone --branch develop https://github.com/WildMeOrg/wbia-plugin-finfindr.git \ + && git clone --branch develop https://github.com/WildMeOrg/wbia-plugin-deepsense.git \ + && git clone --branch develop https://github.com/WildMeOrg/wbia-plugin-pie.git \ + && git clone --branch develop https://github.com/WildMeOrg/wbia-plugin-lca.git RUN set -ex \ && cd /wbia \ - && git clone --branch develop https://github.com/WildbookOrg/wbia-plugin-orientation.git + && git clone --branch develop https://github.com/WildMeOrg/wbia-plugin-orientation.git -# git clone --recursive --branch develop https://github.com/WildbookOrg/wbia-plugin-2d-orientation.git +# git clone --recursive --branch develop https://github.com/WildMeOrg/wbia-plugin-2d-orientation.git # Clone third-party WBIA plug-in repositories RUN set -ex \ && cd /wbia \ - && git clone --recursive --branch develop-curvrank-v1 https://github.com/WildbookOrg/wbia-plugin-curvrank.git /wbia/wbia-plugin-curvrank-v1 \ + && git clone --recursive --branch develop-curvrank-v1 https://github.com/WildMeOrg/wbia-plugin-curvrank.git /wbia/wbia-plugin-curvrank-v1 \ && cd /wbia/wbia-plugin-curvrank-v1/wbia_curvrank \ && git fetch origin \ && git checkout develop RUN set -ex \ && cd /wbia \ - && git clone --recursive --branch develop-curvrank-v2 https://github.com/WildbookOrg/wbia-plugin-curvrank.git /wbia/wbia-plugin-curvrank-v2 \ + && git clone --recursive --branch develop-curvrank-v2 https://github.com/WildMeOrg/wbia-plugin-curvrank.git /wbia/wbia-plugin-curvrank-v2 \ && cd /wbia/wbia-plugin-curvrank-v2/wbia_curvrank_v2 \ && git fetch origin \ && git checkout develop RUN set -ex \ && cd /wbia \ - && git clone --recursive --branch develop https://github.com/WildbookOrg/wbia-plugin-kaggle7.git \ + && git clone --recursive --branch develop https://github.com/WildMeOrg/wbia-plugin-kaggle7.git \ && cd /wbia/wbia-plugin-kaggle7/wbia_kaggle7 \ && git fetch origin \ && git checkout develop diff --git a/devops/publish.sh b/devops/publish.sh index 5819d21940..5fc04b563c 100755 --- a/devops/publish.sh +++ b/devops/publish.sh @@ -22,7 +22,7 @@ REGISTRY=${REGISTRY:-} IMAGES=${@:-wbia-base wbia-dependencies wbia-provision wbia wildbook-ia} # Set the image prefix if [ -n "$REGISTRY" ]; then - IMG_PREFIX="${REGISTRY}/wildbookorg/wildbook-ia/" + IMG_PREFIX="${REGISTRY}/wildme/wildbook-ia/" else IMG_PREFIX="wildme/" fi diff --git a/docs/index.rst b/docs/index.rst index 0fc7fbed6a..e30cbf242a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,7 +12,7 @@ Wildbook's Image Analysis is colloquially known as Wildbook-IA and by developers The Wildbook-IA application is used for the storage, management and analysis of images and derived data used by computer vision algorithms. It aims to compute who an animal is, what species an animal is, and where an animal is with the ultimate goal being to ask important why biological questions. -This project is the Machine Learning (ML) / computer vision component of the `WildBook project `_. This project is an actively maintained fork of the popular IBEIS (Image Based Ecological Information System) software suite for wildlife conservation. The original IBEIS project is maintained by Jon Crall (@Erotemic) at https://github.com/Erotemic/ibeis. The IBEIS toolkit originally was a wrapper around HotSpotter, which original binaries can be downloaded from: http://cs.rpi.edu/hotspotter/ +This project is the Machine Learning (ML) / computer vision component of the `WildBook project `_. This project is an actively maintained fork of the popular IBEIS (Image Based Ecological Information System) software suite for wildlife conservation. The original IBEIS project is maintained by Jon Crall (@Erotemic) at https://github.com/Erotemic/ibeis. The IBEIS toolkit originally was a wrapper around HotSpotter, which original binaries can be downloaded from: http://cs.rpi.edu/hotspotter/ Currently the system is build around and SQLite database, a web UI, and matplotlib visualizations. Algorithms employed are: convolutional neural network detection and localization and classification, hessian-affine keypoint detection, SIFT keypoint description, LNBNN identification using approximate nearest neighbors. diff --git a/setup.py b/setup.py index d6cac9a1a5..b37eccd60b 100755 --- a/setup.py +++ b/setup.py @@ -193,7 +193,7 @@ def gen_packages_items(): 'J. Wrona', ] AUTHOR_EMAIL = 'dev@wildme.org' -URL = 'https://github.com/WildbookOrg/wildbook-ia' +URL = 'https://github.com/WildMeOrg/wildbook-ia' LICENSE = 'Apache License 2.0' DESCRIPTION = 'Wildbook IA (WBIA) - Machine learning service for the WildBook project' KEYWORDS = [ @@ -271,7 +271,7 @@ def gen_packages_items(): 'Programming Language :: Python :: 3 :: Only', ], project_urls={ # Optional - 'Bug Reports': 'https://github.com/WildbookOrg/wildbook-ia/issues', + 'Bug Reports': 'https://github.com/WildMeOrg/wildbook-ia/issues', 'Funding': 'https://www.wildme.org/donate/', 'Say Thanks!': 'https://community.wildbook.org', 'Source': URL, diff --git a/super_setup.py b/super_setup.py index 39a939efb3..d8fee14cf1 100755 --- a/super_setup.py +++ b/super_setup.py @@ -23,24 +23,24 @@ REPOS = [ # (, , ) - ('WildbookOrg/wildbook-ia', 'wildbook-ia', 'develop'), - ('WildbookOrg/wbia-utool', 'wbia-utool', 'develop'), - ('WildbookOrg/wbia-vtool', 'wbia-vtool', 'develop'), - ('WildbookOrg/wbia-tpl-pyhesaff', 'wbia-tpl-pyhesaff', 'develop'), - ('WildbookOrg/wbia-tpl-pyflann', 'wbia-tpl-pyflann', 'develop'), - ('WildbookOrg/wbia-tpl-pydarknet', 'wbia-tpl-pydarknet', 'develop'), - ('WildbookOrg/wbia-tpl-pyrf', 'wbia-tpl-pyrf', 'develop'), - ('WildbookOrg/wbia-deprecate-tpl-brambox', 'wbia-deprecate-tpl-brambox', 'develop'), - ('WildbookOrg/wbia-deprecate-tpl-lightnet', 'wbia-deprecate-tpl-lightnet', 'develop'), - ('WildbookOrg/wbia-plugin-cnn', 'wbia-plugin-cnn', 'develop'), - ('WildbookOrg/wbia-plugin-flukematch', 'wbia-plugin-flukematch', 'develop'), - ('WildbookOrg/wbia-plugin-curvrank', 'wbia-plugin-curvrank', 'develop'), - ('WildbookOrg/wbia-plugin-deepsense', 'wbia-plugin-deepsense', 'develop'), - ('WildbookOrg/wbia-plugin-finfindr', 'wbia-plugin-finfindr', 'develop'), - ('WildbookOrg/wbia-plugin-kaggle7', 'wbia-plugin-kaggle7', 'develop'), - ('WildbookOrg/wbia-plugin-pie', 'wbia-plugin-pie', 'develop'), - ('WildbookOrg/wbia-plugin-lca', 'wbia-plugin-lca', 'develop'), - # ('WildbookOrg/wbia-plugin-2d-orientation', 'wbia-plugin-2d-orientation', 'develop'), + ('WildMeOrg/wildbook-ia', 'wildbook-ia', 'develop'), + ('WildMeOrg/wbia-utool', 'wbia-utool', 'develop'), + ('WildMeOrg/wbia-vtool', 'wbia-vtool', 'develop'), + ('WildMeOrg/wbia-tpl-pyhesaff', 'wbia-tpl-pyhesaff', 'develop'), + ('WildMeOrg/wbia-tpl-pyflann', 'wbia-tpl-pyflann', 'develop'), + ('WildMeOrg/wbia-tpl-pydarknet', 'wbia-tpl-pydarknet', 'develop'), + ('WildMeOrg/wbia-tpl-pyrf', 'wbia-tpl-pyrf', 'develop'), + ('WildMeOrg/wbia-deprecate-tpl-brambox', 'wbia-deprecate-tpl-brambox', 'develop'), + ('WildMeOrg/wbia-deprecate-tpl-lightnet', 'wbia-deprecate-tpl-lightnet', 'develop'), + ('WildMeOrg/wbia-plugin-cnn', 'wbia-plugin-cnn', 'develop'), + ('WildMeOrg/wbia-plugin-flukematch', 'wbia-plugin-flukematch', 'develop'), + ('WildMeOrg/wbia-plugin-curvrank', 'wbia-plugin-curvrank', 'develop'), + ('WildMeOrg/wbia-plugin-deepsense', 'wbia-plugin-deepsense', 'develop'), + ('WildMeOrg/wbia-plugin-finfindr', 'wbia-plugin-finfindr', 'develop'), + ('WildMeOrg/wbia-plugin-kaggle7', 'wbia-plugin-kaggle7', 'develop'), + ('WildMeOrg/wbia-plugin-pie', 'wbia-plugin-pie', 'develop'), + ('WildMeOrg/wbia-plugin-lca', 'wbia-plugin-lca', 'develop'), + # ('WildMeOrg/wbia-plugin-2d-orientation', 'wbia-plugin-2d-orientation', 'develop'), ] diff --git a/wbia/__init__.py b/wbia/__init__.py index e6945edd7c..f79b02020c 100644 --- a/wbia/__init__.py +++ b/wbia/__init__.py @@ -85,6 +85,7 @@ from wbia.init import main_helpers from wbia import algo + from wbia import research from wbia import expt from wbia import templates diff --git a/wbia/algo/graph/mixin_loops.py b/wbia/algo/graph/mixin_loops.py index c98be883b6..a2755a2b43 100644 --- a/wbia/algo/graph/mixin_loops.py +++ b/wbia/algo/graph/mixin_loops.py @@ -6,6 +6,7 @@ import ubelt as ub import pandas as pd import itertools as it +import threading import wbia.constants as const from wbia.algo.graph.state import POSTV, NEGTV, INCMP, NULL from wbia.algo.graph.refresh import RefreshCriteria @@ -22,6 +23,9 @@ class InfrLoops(object): Algorithm control flow loops """ + def __init__(infr): + infr.gen_lock = threading.Lock() + def main_gen(infr, max_loops=None, use_refresh=True): """ The main outer loop. @@ -166,6 +170,8 @@ def main_gen(infr, max_loops=None, use_refresh=True): if infr.params['inference.enabled']: infr.assert_consistency_invariant() + return 'finished' + def hardcase_review_gen(infr): """ Subiterator for hardcase review @@ -480,6 +486,7 @@ def main_loop(infr, max_loops=None, use_refresh=True): or assert not any(infr.main_gen()) maybe this is fine. """ + raise RuntimeError() infr.start_id_review(max_loops=max_loops, use_refresh=use_refresh) # To automatically run through the loop just exhaust the generator result = next(infr._gen) @@ -632,14 +639,29 @@ def continue_review(infr): infr.print('continue_review', 10) if infr._gen is None: return None - try: - user_request = next(infr._gen) - except StopIteration: + + hungry, finished, attempt = True, False, 0 + while hungry: + try: + attempt += 1 + with infr.gen_lock: + user_request = next(infr._gen) + hungry = False + except StopIteration: + pass + if attempt >= 100: + finished = True + if isinstance(user_request, str) and user_request in ['finished']: + hungry = False + finished = True + + if finished: review_finished = infr.callbacks.get('review_finished', None) if review_finished is not None: review_finished() infr._gen = None user_request = None + return user_request def qt_edge_reviewer(infr, edge=None): diff --git a/wbia/algo/preproc/preproc_occurrence.py b/wbia/algo/preproc/preproc_occurrence.py index 20be05310f..d71d0f2916 100644 --- a/wbia/algo/preproc/preproc_occurrence.py +++ b/wbia/algo/preproc/preproc_occurrence.py @@ -33,7 +33,7 @@ def wbia_compute_occurrences(ibs, gid_list, config=None, verbose=None): TODO: FIXME: good example of autogen doctest return failure """ if config is None: - config = {'use_gps': False, 'seconds_thresh': 600} + config = {'use_gps': True, 'seconds_thresh': 60 * 5} # from wbia.algo import Config # config = Config.OccurrenceConfig().asdict() occur_labels, occur_gids = compute_occurrence_groups( diff --git a/wbia/control/IBEISControl.py b/wbia/control/IBEISControl.py index 25840680aa..02f0bd7612 100644 --- a/wbia/control/IBEISControl.py +++ b/wbia/control/IBEISControl.py @@ -68,6 +68,7 @@ 'wbia.other.detectgrave', 'wbia.other.detecttrain', 'wbia.init.filter_annots', + 'wbia.research.metrics', 'wbia.control.manual_featweight_funcs', 'wbia.control._autogen_party_funcs', 'wbia.control.manual_annotmatch_funcs', diff --git a/wbia/control/manual_meta_funcs.py b/wbia/control/manual_meta_funcs.py index 81cc4f1e05..db550e6e0c 100644 --- a/wbia/control/manual_meta_funcs.py +++ b/wbia/control/manual_meta_funcs.py @@ -1011,6 +1011,7 @@ def _init_config(ibs): except IOError as ex: logger.error('*** failed to load general config', exc_info=ex) general_config = {} + ut.save_cPkl(config_fpath, general_config, verbose=ut.VERBOSE) current_species = general_config.get('current_species', None) logger.info('[_init_config] general_config.current_species = %r' % (current_species,)) # diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 0ba70c107b..1f678298b3 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -29,6 +29,8 @@ from wbia.dtool.types import Integer, TYPE_TO_SQLTYPE from wbia.dtool.types import initialize_postgresql_types +import tqdm + print, rrr, profile = ut.inject2(__name__) logger = logging.getLogger('wbia') @@ -43,6 +45,8 @@ TIMEOUT = 600 # Wait for up to 600 seconds for the database to return from a locked state +BATCH_SIZE = int(1e4) + SQLColumnRichInfo = collections.namedtuple( 'SQLColumnRichInfo', ('column_id', 'name', 'type_', 'notnull', 'dflt_value', 'pk') ) @@ -1042,7 +1046,7 @@ def get_where_eq( where_colnames, unpack_scalars=True, op='AND', - BATCH_SIZE=250000, + batch_size=BATCH_SIZE, **kwargs, ): """Executes a SQL select where the given parameters match/equal @@ -1068,7 +1072,7 @@ def get_where_eq( id_iter=(p[0] for p in params_iter), id_colname=where_colnames[0], unpack_scalars=unpack_scalars, - BATCH_SIZE=BATCH_SIZE, + batch_size=batch_size, **kwargs, ) params_iter = list(params_iter) @@ -1092,7 +1096,7 @@ def get_where_eq( **kwargs, ) - params_per_batch = int(BATCH_SIZE / len(params_iter[0])) + params_per_batch = int(batch_size / len(params_iter[0])) result_map = {} stmt = sqlalchemy.select( [table.c[c] for c in tuple(where_colnames) + tuple(colnames)] @@ -1102,7 +1106,10 @@ def get_where_eq( sqlalchemy.sql.bindparam('params', expanding=True) ) ) - for batch in range(int(len(params_iter) / params_per_batch) + 1): + batch_list = list(range(int(len(params_iter) / params_per_batch) + 1)) + for batch in tqdm.tqdm( + batch_list, disable=len(batch_list) <= 1, desc='[db.get(%s)]' % (tblname,) + ): val_list = self.executeone( stmt, { @@ -1334,7 +1341,7 @@ def get( id_colname='rowid', eager=True, assume_unique=False, - BATCH_SIZE=250000, + batch_size=BATCH_SIZE, **kwargs, ): """Get rows of data by ID @@ -1412,10 +1419,14 @@ def get( id_column = table.c[id_colname] stmt = sqlalchemy.select([id_column] + [table.c[c] for c in colnames]) stmt = stmt.where(id_column.in_(bindparam('value', expanding=True))) - for batch in range(int(len(id_iter) / BATCH_SIZE) + 1): + + batch_list = list(range(int(len(id_iter) / batch_size) + 1)) + for batch in tqdm.tqdm( + batch_list, disable=len(batch_list) <= 1, desc='[db.get(%s)]' % (tblname,) + ): val_list = self.executeone( stmt, - {'value': id_iter[batch * BATCH_SIZE : (batch + 1) * BATCH_SIZE]}, + {'value': id_iter[batch * batch_size : (batch + 1) * batch_size]}, ) for val in val_list: diff --git a/wbia/other/dbinfo.py b/wbia/other/dbinfo.py index f500ddddb7..8be148de28 100644 --- a/wbia/other/dbinfo.py +++ b/wbia/other/dbinfo.py @@ -51,11 +51,12 @@ def print_qd_info(ibs, qaid_list, daid_list, verbose=False): def get_dbinfo( ibs, verbose=True, - with_imgsize=False, - with_bytes=False, - with_contrib=False, - with_agesex=False, + with_imgsize=True, + with_bytes=True, + with_contrib=True, + with_agesex=True, with_header=True, + with_reviews=True, short=False, tag='dbinfo', aid_list=None, @@ -224,16 +225,16 @@ def get_dbinfo( ibs.check_name_mapping_consistency(nx2_aids) - if False: + if True: # Occurrence Info def compute_annot_occurrence_ids(ibs, aid_list): from wbia.algo.preproc import preproc_occurrence + import utool as ut gid_list = ibs.get_annot_gids(aid_list) gid2_aids = ut.group_items(aid_list, gid_list) - config = {'seconds_thresh': 4 * 60 * 60} flat_imgsetids, flat_gids = preproc_occurrence.wbia_compute_occurrences( - ibs, gid_list, config=config, verbose=False + ibs, gid_list, verbose=False ) occurid2_gids = ut.group_items(flat_gids, flat_imgsetids) occurid2_aids = { @@ -283,19 +284,19 @@ def break_annots_into_encounters(aids): # ave_enc_time = [np.mean(times) for lbl, times in ut.group_items(posixtimes, labels).items()] # ut.square_pdist(ave_enc_time) - try: - am_rowids = ibs.get_annotmatch_rowids_between_groups([valid_aids], [valid_aids])[ - 0 - ] - aid_pairs = ibs.filter_aidpairs_by_tags(min_num=0, am_rowids=am_rowids) - undirected_tags = ibs.get_aidpair_tags( - aid_pairs.T[0], aid_pairs.T[1], directed=False - ) - tagged_pairs = list(zip(aid_pairs.tolist(), undirected_tags)) - tag_dict = ut.groupby_tags(tagged_pairs, undirected_tags) - pair_tag_info = ut.map_dict_vals(len, tag_dict) - except Exception: - pair_tag_info = {} + # try: + # am_rowids = ibs.get_annotmatch_rowids_between_groups([valid_aids], [valid_aids])[ + # 0 + # ] + # aid_pairs = ibs.filter_aidpairs_by_tags(min_num=0, am_rowids=am_rowids) + # undirected_tags = ibs.get_aidpair_tags( + # aid_pairs.T[0], aid_pairs.T[1], directed=False + # ) + # tagged_pairs = list(zip(aid_pairs.tolist(), undirected_tags)) + # tag_dict = ut.groupby_tags(tagged_pairs, undirected_tags) + # pair_tag_info = ut.map_dict_vals(len, tag_dict) + # except Exception: + # pair_tag_info = {} # logger.info(ut.repr2(pair_tag_info)) @@ -558,14 +559,108 @@ def fix_tag_list(tag_list): contributor_rowids = ibs.get_valid_contributor_rowids() num_contributors = len(contributor_rowids) + if verbose: + logger.info('Checking Review Info') + + # Get reviewer statistics + def get_review_decision_stats(ibs, rid_list): + review_decision_list = ibs.get_review_decision_str(rid_list) + review_decision_to_rids = ut.group_items(rid_list, review_decision_list) + review_decision_stats = { + key: len(val) for key, val in review_decision_to_rids.items() + } + return review_decision_stats + + def get_review_identity(rid_list): + review_identity_list = ibs.get_review_identity(rid_list) + review_identity_list = [ + value.replace('user:web', 'human:web') + .replace('web:None', 'web') + .replace('auto_clf', 'vamp') + .replace(':', '[') + + ']' + for value in review_identity_list + ] + return review_identity_list + + def get_review_identity_stats(ibs, rid_list): + review_identity_list = get_review_identity(rid_list) + review_identity_to_rids = ut.group_items(rid_list, review_identity_list) + review_identity_stats = { + key: len(val) for key, val in review_identity_to_rids.items() + } + return review_identity_to_rids, review_identity_stats + + def get_review_participation(review_aids_list, value_list): + review_participation_dict = {} + for review_aids, value in zip(review_aids_list, value_list): + for value_ in [value, 'Any']: + if value_ not in review_participation_dict: + review_participation_dict[value_] = {} + for aid in review_aids: + if aid not in review_participation_dict[value_]: + review_participation_dict[value_][aid] = 0 + review_participation_dict[value_][aid] += 1 + + for value in review_participation_dict: + values = list(review_participation_dict[value].values()) + mean = np.mean(values) + std = np.std(values) + thresh = int(np.around(mean + 2 * std)) + values = [ + '%02d+' % (thresh,) if value >= thresh else '%02d' % (value,) + for value in values + ] + review_participation_dict[value] = ut.dict_hist(values) + review_participation_dict[value]['AVG'] = '%0.1f +/- %0.1f' % ( + mean, + std, + ) + + return review_participation_dict + + valid_rids = ibs._get_all_review_rowids() + + review_decision_stats = get_review_decision_stats(ibs, valid_rids) + review_identity_to_rids, review_identity_stats = get_review_identity_stats( + ibs, valid_rids + ) + + review_identity_to_decision_stats = { + key: get_review_decision_stats(ibs, aids) + for key, aids in six.iteritems(review_identity_to_rids) + } + + review_aids_list = ibs.get_review_aid_tuple(valid_rids) + review_decision_list = ibs.get_review_decision_str(valid_rids) + review_identity_list = get_review_identity(valid_rids) + review_decision_participation_dict = get_review_participation( + review_aids_list, review_decision_list + ) + review_identity_participation_dict = get_review_participation( + review_aids_list, review_identity_list + ) + + review_tags_list = ibs.get_review_tags(valid_rids) + review_tag_list = [ + review_tag if review_tag is None else '+'.join(sorted(review_tag)) + for review_tag in review_tags_list + ] + + review_tag_to_rids = ut.group_items(valid_rids, review_tag_list) + review_tag_stats = {key: len(val) for key, val in review_tag_to_rids.items()} + + ut.embed() + # print - num_tabs = 5 + num_tabs = 30 def align2(str_): return ut.align(str_, ':', ' :') def align_dict2(dict_): - str_ = ut.repr2(dict_, si=True) + # str_ = ut.repr2(dict_, si=True) + str_ = ut.repr3(dict_, si=True) return align2(str_) header_block_lines = [('+============================')] + ( @@ -626,17 +721,6 @@ def align_dict2(dict_): else [] ) - occurrence_block_lines = ( - [ - ('--' * num_tabs), - # ('# Occurrence Per Name (Resights) = %s' % (align_dict2(resight_name_stats),)), - # ('# Annots per Encounter (Singlesights) = %s' % (align_dict2(singlesight_annot_stats),)), - ('# Pair Tag Info (annots) = %s' % (align_dict2(pair_tag_info),)), - ] - if not short - else [] - ) - annot_per_qualview_block_lines = [ None if short else '# Annots per Viewpoint = %s' % align_dict2(viewcode2_nAnnots), None if short else '# Annots per Quality = %s' % align_dict2(qualtext2_nAnnots), @@ -644,23 +728,51 @@ def align_dict2(dict_): annot_per_agesex_block_lines = ( [ - '# Annots per Age = %s' % align_dict2(agetext2_nAnnots), - '# Annots per Sex = %s' % align_dict2(sextext2_nAnnots), + ('# Annots per Age = %s' % align_dict2(agetext2_nAnnots)), + ('# Annots per Sex = %s' % align_dict2(sextext2_nAnnots)), ] if not short and with_agesex else [] ) - contributor_block_lines = ( + occurrence_block_lines = ( [ - '# Images per contributor = ' + align_dict2(contributor_tag_to_nImages), - '# Annots per contributor = ' + align_dict2(contributor_tag_to_nAnnots), - '# Quality per contributor = ' - + ut.repr2(contributor_tag_to_qualstats, sorted_=True), - '# Viewpoint per contributor = ' - + ut.repr2(contributor_tag_to_viewstats, sorted_=True), + ('--' * num_tabs), + ( + '# Occurrence Per Name (Resights) = %s' + % (align_dict2(resight_name_stats),) + ), + ( + '# Annots per Encounter (Singlesights) = %s' + % (align_dict2(singlesight_annot_stats),) + ), + # ('# Pair Tag Info (annots) = %s' % (align_dict2(pair_tag_info),)), ] - if with_contrib + if not short + else [] + ) + + reviews_block_lines = ( + [ + ('--' * num_tabs), + ('# Reviews = %d' % len(valid_rids)), + ('# Reviews per Decision = %s' % align_dict2(review_decision_stats)), + ('# Reviews per Reviewer = %s' % align_dict2(review_identity_stats)), + ( + '# Review Breakdown = %s' + % align_dict2(review_identity_to_decision_stats) + ), + ('# Reviews with Tag = %s' % align_dict2(review_tag_stats)), + ( + '# Review Participation #1 = %s' + % align_dict2(review_decision_participation_dict) + ), + ( + '# Review Participation #2 = %s' + % align_dict2(review_identity_participation_dict) + ), + ] + if with_reviews else [] ) @@ -677,6 +789,30 @@ def align_dict2(dict_): else ('Img Time Stats = %s' % (align2(unixtime_statstr),)), ] + contributor_block_lines = ( + [ + ('--' * num_tabs), + ( + '# Images per contributor = ' + + align_dict2(contributor_tag_to_nImages) + ), + ( + '# Annots per contributor = ' + + align_dict2(contributor_tag_to_nAnnots) + ), + ( + '# Quality per contributor = ' + + align_dict2(contributor_tag_to_qualstats) + ), + ( + '# Viewpoint per contributor = ' + + align_dict2(contributor_tag_to_viewstats) + ), + ] + if with_contrib + else [] + ) + info_str_lines = ( header_block_lines + bytes_block_lines @@ -684,16 +820,17 @@ def align_dict2(dict_): + name_block_lines + annot_block_lines + annot_per_basic_block_lines - + occurrence_block_lines + annot_per_qualview_block_lines + annot_per_agesex_block_lines + + occurrence_block_lines + + reviews_block_lines + img_block_lines - + contributor_block_lines + imgsize_stat_lines + + contributor_block_lines + [('L============================')] ) info_str = '\n'.join(ut.filter_Nones(info_str_lines)) - info_str2 = ut.indent(info_str, '[{tag}]'.format(tag=tag)) + info_str2 = ut.indent(info_str, '[{tag}] '.format(tag=tag)) if verbose: logger.info(info_str2) locals_ = locals() diff --git a/wbia/other/ibsfuncs.py b/wbia/other/ibsfuncs.py index ab86e92ee6..4e4a432d05 100644 --- a/wbia/other/ibsfuncs.py +++ b/wbia/other/ibsfuncs.py @@ -945,9 +945,7 @@ def check_name_mapping_consistency(ibs, nx2_aids): """ checks that all the aids grouped in a name ahave the same name """ # DEBUGGING CODE try: - from wbia import ibsfuncs - - _nids_list = ibsfuncs.unflat_map(ibs.get_annot_name_rowids, nx2_aids) + _nids_list = unflat_map(ibs.get_annot_name_rowids, nx2_aids) assert all(map(ut.allsame, _nids_list)) except Exception as ex: # THESE SHOULD BE CONSISTENT BUT THEY ARE NOT!!? diff --git a/wbia/research/__init__.py b/wbia/research/__init__.py new file mode 100644 index 0000000000..6b831212de --- /dev/null +++ b/wbia/research/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from wbia.research import metrics # NOQA + +import utool as ut + +ut.noinject(__name__, '[wbia.research.__init__]', DEBUG=False) diff --git a/wbia/research/metrics.py b/wbia/research/metrics.py new file mode 100644 index 0000000000..90f4ca5465 --- /dev/null +++ b/wbia/research/metrics.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +""" +developer convenience functions for ibs + +TODO: need to split up into sub modules: + consistency_checks + feasibility_fixes + move the export stuff to dbio + + python -m utool.util_inspect check_module_usage --pat="ibsfuncs.py" + + then there are also convineience functions that need to be ordered at least + within this file +""" +import logging +import utool as ut +from wbia.control import controller_inject +from wbia import annotmatch_funcs # NOQA +import pytz + + +PST = pytz.timezone('US/Pacific') + + +# Inject utool function +(print, rrr, profile) = ut.inject2(__name__, '[research]') +logger = logging.getLogger('wbia') + + +# Must import class before injection +CLASS_INJECT_KEY, register_ibs_method = controller_inject.make_ibs_register_decorator( + __name__ +) + + +register_api = controller_inject.get_wbia_flask_api(__name__) + + +@register_ibs_method +def research_print_metrics(ibs, tag='metrics'): + imageset_rowid_list = ibs.get_valid_imgsetids(is_special=False) + imageset_text_list = ibs.get_imageset_text(imageset_rowid_list) + + global_gid_list = [] + global_cid_list = [] + for imageset_rowid, imageset_text in zip(imageset_rowid_list, imageset_text_list): + imageset_text_ = imageset_text.strip().split(',') + if len(imageset_text_) == 3: + ggr, car, person = imageset_text_ + if ggr in ['GGR', 'GGR2']: + gid_list = ibs.get_imageset_gids(imageset_rowid) + global_gid_list += gid_list + cid = ibs.add_contributors([imageset_text])[0] + global_cid_list += [cid] * len(gid_list) + + assert len(global_gid_list) == len(set(global_gid_list)) + + ibs.set_image_contributor_rowid(global_gid_list, global_cid_list) + + ibs.print_dbinfo() diff --git a/wbia/web/routes.py b/wbia/web/routes.py index cd58fcb500..ebcfb6cb14 100644 --- a/wbia/web/routes.py +++ b/wbia/web/routes.py @@ -4665,6 +4665,7 @@ def turk_identification_graph_refer( annot_uuid_list=annot_uuid_list, hogwild_species=species, creation_imageset_rowid_list=[imgsetid], + kaia=True, ) elif option in ['rosemary']: imgsetid_ = ibs.get_imageset_imgsetids_from_text('RosemaryLoopsData') @@ -4803,6 +4804,7 @@ def turk_identification_graph( hogwild_species=None, creation_imageset_rowid_list=None, kaia=False, + census=False, **kwargs, ): """ @@ -4965,6 +4967,22 @@ def turk_identification_graph( 'redun.neg': 2, 'redun.pos': 2, } + elif kaia: + logger.info('[routes] Graph is in CA-mode') + query_config_dict = { + 'autoreview.enabled': True, + 'autoreview.prioritize_nonpos': True, + 'inference.enabled': True, + 'ranking.enabled': True, + 'ranking.ntop': 20, + 'redun.enabled': True, + 'redun.enforce_neg': True, + 'redun.enforce_pos': True, + 'redun.neg.only_auto': False, + 'redun.neg': 3, + 'redun.pos': 3, + 'algo.hardcase': True, + } else: logger.info('[routes] Graph is not in hardcase-mode') query_config_dict = {} diff --git a/wbia/web/templates/index.html b/wbia/web/templates/index.html index 1ea2493c45..a0d762bd40 100644 --- a/wbia/web/templates/index.html +++ b/wbia/web/templates/index.html @@ -20,7 +20,7 @@

Welcome to IBEIS


For more information: http://wbia.org/
- To view the code repository: https://github.com/WildbookOrg/wbia + To view the code repository: https://github.com/WildMeOrg/wildbook-ia
To view the API settings: {{ request.url }}{{ url_for('api_root').lstrip('/') }}