diff --git a/.dockerfiles/docker-entrypoint.sh b/.dockerfiles/docker-entrypoint.sh index a2d4bb2b74..e5e0454a2c 100755 --- a/.dockerfiles/docker-entrypoint.sh +++ b/.dockerfiles/docker-entrypoint.sh @@ -6,6 +6,8 @@ set -e if [ -d /code ]; then echo "*** $0 --- Uninstalling wildbook-ia" pip uninstall -y wildbook-ia + echo "*** $0 --- Uninstalling sentry_sdk (in development)" + pip uninstall -y sentry_sdk echo "*** $0 --- Installing development version of wildbook-ia at /code" pushd /code && pip install -e ".[tests,postgres]" && popd fi diff --git a/.dockerfiles/init-db.sh b/.dockerfiles/init-db.sh index ca9c97904f..fed804169c 100644 --- a/.dockerfiles/init-db.sh +++ b/.dockerfiles/init-db.sh @@ -3,7 +3,7 @@ set -e psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL - CREATE USER $DB_USER WITH PASSWORD '$DB_PASSWORD'; + CREATE USER $DB_USER WITH SUPERUSER PASSWORD '$DB_PASSWORD'; CREATE DATABASE $DB_NAME; GRANT ALL PRIVILEGES ON DATABASE $DB_NAME TO $DB_USER; EOSQL diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 704544af17..553ce7640f 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -5,6 +5,9 @@ on: push: paths: - '.github/workflows/nightly.yml' + - 'devops/**' + - 'Dockerfile' + - '.dockerfiles/*' pull_request: paths: - '.github/workflows/nightly.yml' @@ -19,6 +22,8 @@ jobs: name: DevOps nightly image build runs-on: ubuntu-latest strategy: + max-parallel: 2 + fail-fast: false matrix: images: - wildbook-ia diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index ca2bd0186d..de89a9ba08 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -9,6 +9,8 @@ jobs: name: Build on ${{ matrix.os }} runs-on: ${{ matrix.os }} strategy: + max-parallel: 2 + fail-fast: false matrix: os: [ubuntu-latest, macos-latest] python-version: [3.7] diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index e198616c1d..5ad35a0bfe 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -9,6 +9,8 @@ jobs: lint: runs-on: ubuntu-latest strategy: + max-parallel: 3 + fail-fast: false matrix: # For speed, we choose one version and that should be the lowest common denominator python-version: [3.6, 3.7, 3.8] @@ -43,14 +45,34 @@ jobs: test: runs-on: ${{ matrix.os }} strategy: + max-parallel: 6 + fail-fast: false matrix: os: [ubuntu-latest] # Disable "macos-latest" for now # For speed, we choose one version and that should be the lowest common denominator python-version: [3.6, 3.7, 3.8] + postgres-uri: ['', 'postgresql://postgres:wbia@localhost:5432/postgres'] + + services: + db: + image: postgres:10 + env: + POSTGRES_PASSWORD: wbia + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 steps: # Checkout and env setup - uses: actions/checkout@v2 + - name: Install pgloader + if: matrix.postgres-uri + run: sudo apt-get install pgloader - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: @@ -65,12 +87,14 @@ jobs: # Install and test - name: Install the project - run: pip install -e .[tests] + run: pip install -e .[tests,postgres] - name: Test with pytest run: | mkdir -p data/work python -m wbia --set-workdir data/work --preload-exit - pytest --slow + pytest --slow --web-tests --with-postgres-uri=$POSTGRES_URI + env: + POSTGRES_URI: ${{ matrix.postgres-uri }} on-failure: # This is not in the 'test' job itself because it would otherwise notify once per matrix combination. diff --git a/.gitignore b/.gitignore index efe3950d95..6e8990fb08 100644 --- a/.gitignore +++ b/.gitignore @@ -36,7 +36,7 @@ testdb*/ tmp.txt vsone.*.cPkl vsone.*.json -wbia/DEPCACHE*/ +wbia/dtool/DEPCACHE*/ # Translations *.mo @@ -67,3 +67,4 @@ branch.sync.sh # setuptools_scm version discovery writes the version to /_version.py **/_version.py +map.png diff --git a/Dockerfile b/Dockerfile index 716adaf0e5..bbeaa7a050 100644 --- a/Dockerfile +++ b/Dockerfile @@ -67,6 +67,10 @@ RUN set -x \ libxext6 \ #: opencv2 dependency libgl1 \ + #: required to stop prompting by pgloader + libssl1.0.0 \ + #: sqlite->postgres dependency + pgloader \ #: dev debug dependency #: python3-dev required to build 'annoy' python3-dev \ @@ -118,7 +122,12 @@ EXPOSE 5000 # Move to the workdir WORKDIR /data +# Set the "workdir" +RUN python3 -m wbia --set-workdir /data --preload-exit + COPY .dockerfiles/docker-entrypoint.sh /docker-entrypoint.sh +ENV WBIA_DB_DIR="/data/db" + ENTRYPOINT ["/docker-entrypoint.sh"] -CMD ["python3", "-m", "wbia.dev", "--dbdir", "/data/db", "--logdir", "/data/logs/", "--web", "--port", "5000", "--web-deterministic-ports", "--containerized", "--cpudark", "--production"] +CMD ["python3", "-m", "wbia.dev", "--dbdir", "$WBIA_DB_DIR", "--logdir", "/data/logs/", "--web", "--port", "5000", "--web-deterministic-ports", "--containerized", "--cpudark", "--production"] diff --git a/README.rst b/README.rst index a61e18a7ec..c9eb12e2c1 100644 --- a/README.rst +++ b/README.rst @@ -12,7 +12,7 @@ use in computer vision algorithms. It aims to compute who an animal is, what species an animal is, and where an animal is with the ultimate goal being to ask important why biological questions. -This project is the Machine Learning (ML) / computer vision component of the WildBook project: See https://github.com/WildbookOrg/. This project is an actively maintained fork of the popular IBEIS (Image Based Ecological Information System) software suite for wildlife conservation. The original IBEIS project is maintained by Jon Crall (@Erotemic) at https://github.com/Erotemic/ibeis. The IBEIS toolkit originally was a wrapper around HotSpotter, which original binaries can be downloaded from: http://cs.rpi.edu/hotspotter/ +This project is the Machine Learning (ML) / computer vision component of the WildBook project: See https://github.com/WildMeOrg/. This project is an actively maintained fork of the popular IBEIS (Image Based Ecological Information System) software suite for wildlife conservation. The original IBEIS project is maintained by Jon Crall (@Erotemic) at https://github.com/Erotemic/ibeis. The IBEIS toolkit originally was a wrapper around HotSpotter, which original binaries can be downloaded from: http://cs.rpi.edu/hotspotter/ Currently the system is build around and SQLite database, a web GUI, and matplotlib visualizations. Algorithms employed are: convolutional neural network @@ -53,7 +53,7 @@ We highly recommend using a Python virtual environment: https://docs.python-guid Documentation ~~~~~~~~~~~~~ -The documentation is built and available online at `wildbookorg.github.io/wildbook-ia/ `_. However, if you need to build a local copy of the source, the following instructions can be used. +The documentation is built and available online at `wildmeorg.github.io/wildbook-ia/ `_. However, if you need to build a local copy of the source, the following instructions can be used. .. code:: bash @@ -87,76 +87,76 @@ This project depends on an array of other repositories for functionality. First Party Toolkits (Required) -* https://github.com/WildbookOrg/wbia-utool +* https://github.com/WildMeOrg/wbia-utool -* https://github.com/WildbookOrg/wbia-vtool +* https://github.com/WildMeOrg/wbia-vtool First Party Dependencies for Third Party Libraries (Required) -* https://github.com/WildbookOrg/wbia-tpl-pyhesaff +* https://github.com/WildMeOrg/wbia-tpl-pyhesaff -* https://github.com/WildbookOrg/wbia-tpl-pyflann +* https://github.com/WildMeOrg/wbia-tpl-pyflann -* https://github.com/WildbookOrg/wbia-tpl-pydarknet +* https://github.com/WildMeOrg/wbia-tpl-pydarknet -* https://github.com/WildbookOrg/wbia-tpl-pyrf +* https://github.com/WildMeOrg/wbia-tpl-pyrf First Party Plug-ins (Optional) -* https://github.com/WildbookOrg/wbia-plugin-cnn +* https://github.com/WildMeOrg/wbia-plugin-cnn -* https://github.com/WildbookOrg/wbia-plugin-flukematch +* https://github.com/WildMeOrg/wbia-plugin-flukematch -* https://github.com/WildbookOrg/wbia-plugin-deepsense +* https://github.com/WildMeOrg/wbia-plugin-deepsense -* https://github.com/WildbookOrg/wbia-plugin-finfindr +* https://github.com/WildMeOrg/wbia-plugin-finfindr -* https://github.com/WildbookOrg/wbia-plugin-curvrank +* https://github.com/WildMeOrg/wbia-plugin-curvrank - + https://github.com/WildbookOrg/wbia-tpl-curvrank + + https://github.com/WildMeOrg/wbia-tpl-curvrank -* https://github.com/WildbookOrg/wbia-plugin-kaggle7 +* https://github.com/WildMeOrg/wbia-plugin-kaggle7 - + https://github.com/WildbookOrg/wbia-tpl-kaggle7 + + https://github.com/WildMeOrg/wbia-tpl-kaggle7 -* https://github.com/WildbookOrg/wbia-plugin-2d-orientation +* https://github.com/WildMeOrg/wbia-plugin-2d-orientation - + https://github.com/WildbookOrg/wbia-tpl-2d-orientation + + https://github.com/WildMeOrg/wbia-tpl-2d-orientation -* https://github.com/WildbookOrg/wbia-plugin-lca +* https://github.com/WildMeOrg/wbia-plugin-lca - + https://github.com/WildbookOrg/wbia-tpl-lca + + https://github.com/WildMeOrg/wbia-tpl-lca Deprecated Toolkits (Deprecated) -* https://github.com/WildbookOrg/wbia-deprecate-ubelt +* https://github.com/WildMeOrg/wbia-deprecate-ubelt -* https://github.com/WildbookOrg/wbia-deprecate-dtool +* https://github.com/WildMeOrg/wbia-deprecate-dtool -* https://github.com/WildbookOrg/wbia-deprecate-guitool +* https://github.com/WildMeOrg/wbia-deprecate-guitool -* https://github.com/WildbookOrg/wbia-deprecate-plottool +* https://github.com/WildMeOrg/wbia-deprecate-plottool -* https://github.com/WildbookOrg/wbia-deprecate-detecttools +* https://github.com/WildMeOrg/wbia-deprecate-detecttools -* https://github.com/WildbookOrg/wbia-deprecate-plugin-humpbacktl +* https://github.com/WildMeOrg/wbia-deprecate-plugin-humpbacktl -* https://github.com/WildbookOrg/wbia-deprecate-tpl-lightnet +* https://github.com/WildMeOrg/wbia-deprecate-tpl-lightnet -* https://github.com/WildbookOrg/wbia-deprecate-tpl-brambox +* https://github.com/WildMeOrg/wbia-deprecate-tpl-brambox Plug-in Templates (Reference) -* https://github.com/WildbookOrg/wbia-plugin-template +* https://github.com/WildMeOrg/wbia-plugin-template -* https://github.com/WildbookOrg/wbia-plugin-id-example +* https://github.com/WildMeOrg/wbia-plugin-id-example Miscellaneous (Reference) -* https://github.com/WildbookOrg/wbia-pypkg-build +* https://github.com/WildMeOrg/wbia-pypkg-build -* https://github.com/WildbookOrg/wbia-project-website +* https://github.com/WildMeOrg/wbia-project-website -* https://github.com/WildbookOrg/wbia-aws-codedeploy +* https://github.com/WildMeOrg/wbia-aws-codedeploy Citation -------- @@ -259,8 +259,8 @@ To run doctests with `+REQUIRES(--web-tests)` do: pytest --web-tests -.. |Build| image:: https://img.shields.io/github/workflow/status/WildbookOrg/wildbook-ia/Build%20and%20upload%20to%20PyPI/master - :target: https://github.com/WildbookOrg/wildbook-ia/actions?query=branch%3Amaster+workflow%3A%22Build+and+upload+to+PyPI%22 +.. |Build| image:: https://img.shields.io/github/workflow/status/WildMeOrg/wildbook-ia/Build%20and%20upload%20to%20PyPI/master + :target: https://github.com/WildMeOrg/wildbook-ia/actions?query=branch%3Amaster+workflow%3A%22Build+and+upload+to+PyPI%22 :alt: Build and upload to PyPI (master) .. |Pypi| image:: https://img.shields.io/pypi/v/wildbook-ia.svg diff --git a/_dev/super_setup_old.py b/_dev/super_setup_old.py index 67cfb72393..c1d826578c 100755 --- a/_dev/super_setup_old.py +++ b/_dev/super_setup_old.py @@ -11,7 +11,7 @@ export CODE_DIR=~/code mkdir $CODE_DIR cd $CODE_DIR -git clone https://github.com/WildbookOrg/wbia.git +git clone https://github.com/WildMeOrg/wbia.git cd wbia python super_setup.py --bootstrap @@ -311,7 +311,7 @@ def ensure_utool(CODE_DIR, pythoncmd): WIN32 = sys.platform.startswith('win32') # UTOOL_BRANCH = ' -b ' UTOOL_BRANCH = 'next' - UTOOL_REPO = 'https://github.com/WildbookOrg/utool.git' + UTOOL_REPO = 'https://github.com/WildMeOrg/utool.git' print('WARNING: utool is not found') print('Attempting to get utool. Enter (y) to continue') @@ -370,8 +370,8 @@ def initialize_repo_managers(CODE_DIR, pythoncmd, PY2, PY3): # IBEIS project repos # ----------- # if True: - # jon_repo_base = 'https://github.com/WildbookOrg' - # jason_repo_base = 'https://github.com/WildbookOrg' + # jon_repo_base = 'https://github.com/WildMeOrg' + # jason_repo_base = 'https://github.com/WildMeOrg' # else: # jon_repo_base = 'https://github.com/wildme' # jason_repo_base = 'https://github.com/wildme' @@ -381,12 +381,12 @@ def initialize_repo_managers(CODE_DIR, pythoncmd, PY2, PY3): wbia_rman = ut.RepoManager( [ - 'https://github.com/WildbookOrg/utool.git', - # 'https://github.com/WildbookOrg/sandbox_utools.git', - 'https://github.com/WildbookOrg/vtool_ibeis.git', - 'https://github.com/WildbookOrg/dtool_ibeis.git', + 'https://github.com/WildMeOrg/utool.git', + # 'https://github.com/WildMeOrg/sandbox_utools.git', + 'https://github.com/WildMeOrg/vtool_ibeis.git', + 'https://github.com/WildMeOrg/dtool_ibeis.git', 'https://github.com/Erotemic/ubelt.git', - 'https://github.com/WildbookOrg/detecttools.git', + 'https://github.com/WildMeOrg/detecttools.git', ], CODE_DIR, label='core', @@ -399,24 +399,24 @@ def initialize_repo_managers(CODE_DIR, pythoncmd, PY2, PY3): tpl_rman.add_repo(cv_repo) if WITH_GUI: - wbia_rman.add_repos(['https://github.com/WildbookOrg/plottool_ibeis.git']) + wbia_rman.add_repos(['https://github.com/WildMeOrg/plottool_ibeis.git']) if WITH_QT: - wbia_rman.add_repos(['https://github.com/WildbookOrg/guitool_ibeis.git']) + wbia_rman.add_repos(['https://github.com/WildMeOrg/guitool_ibeis.git']) tpl_rman.add_repo(ut.Repo(modname=('PyQt4', 'PyQt5', 'PyQt'))) if WITH_CUSTOM_TPL: flann_repo = ut.Repo( - 'https://github.com/WildbookOrg/flann.git', CODE_DIR, modname='pyflann' + 'https://github.com/WildMeOrg/flann.git', CODE_DIR, modname='pyflann' ) wbia_rman.add_repo(flann_repo) - wbia_rman.add_repos(['https://github.com/WildbookOrg/hesaff.git']) + wbia_rman.add_repos(['https://github.com/WildMeOrg/hesaff.git']) if WITH_CNN: wbia_rman.add_repos( [ - 'https://github.com/WildbookOrg/wbia_cnn.git', - 'https://github.com/WildbookOrg/pydarknet.git', + 'https://github.com/WildMeOrg/wbia_cnn.git', + 'https://github.com/WildMeOrg/pydarknet.git', ] ) # NEW CNN Dependencies @@ -433,28 +433,26 @@ def initialize_repo_managers(CODE_DIR, pythoncmd, PY2, PY3): ) if WITH_FLUKEMATCH: - wbia_rman.add_repos( - ['https://github.com/WildbookOrg/ibeis-flukematch-module.git'] - ) + wbia_rman.add_repos(['https://github.com/WildMeOrg/ibeis-flukematch-module.git']) if WITH_CURVRANK: - wbia_rman.add_repos(['https://github.com/WildbookOrg/ibeis-curvrank-module.git']) + wbia_rman.add_repos(['https://github.com/WildMeOrg/ibeis-curvrank-module.git']) if WITH_PYRF: - wbia_rman.add_repos(['https://github.com/WildbookOrg/pyrf.git']) + wbia_rman.add_repos(['https://github.com/WildMeOrg/pyrf.git']) if False: # Depricated wbia_rman.add_repos( [ - # 'https://github.com/WildbookOrg/pybing.git', + # 'https://github.com/WildMeOrg/pybing.git', # 'https://github.com/aweinstock314/cyth.git', # 'https://github.com/hjweide/pygist', ] ) # Add main repo (Must be checked last due to dependency issues) - wbia_rman.add_repos(['https://github.com/WildbookOrg/wbia.git']) + wbia_rman.add_repos(['https://github.com/WildMeOrg/wbia.git']) # ----------- # Custom third party build/install scripts @@ -1005,7 +1003,7 @@ def GET_ARGFLAG(arg, *args, **kwargs): def move_wildme(wbia_rman, fmt): - wildme_user = 'WildbookOrg' + wildme_user = 'WildMeOrg' wildme_remote = 'wildme' for repo in wbia_rman.repos: diff --git a/conftest.py b/conftest.py index adca5f617a..b81e9a0b49 100644 --- a/conftest.py +++ b/conftest.py @@ -17,3 +17,11 @@ def pytest_addoption(parser): "instead it will reuse the previous test run's db" ), ) + parser.addoption( + '--with-postgres-uri', + dest='postgres_uri', + help=( + 'used to enable tests to run against a Postgres database ' + '(note, the uri should use a superuser role)' + ), + ) diff --git a/devops/Dockerfile b/devops/Dockerfile index 7e87f845de..ee5c4788cb 100644 --- a/devops/Dockerfile +++ b/devops/Dockerfile @@ -12,27 +12,35 @@ RUN set -ex \ 'cd {} && cd .. && echo $(pwd) && (git stash && git pull && git stash pop || git reset --hard origin/develop)' \ && find /wbia/wildbook* -name '.git' -type d -print0 | xargs -0 -i /bin/bash -c \ 'cd {} && cd .. && echo $(pwd) && (git stash && git pull && git stash pop || git reset --hard origin/develop)' \ - && cd /wbia/wbia-plugin-curvrank/wbia_curvrank \ + && cd /wbia/wbia-plugin-curvrank-v1/wbia_curvrank \ + && git stash && git pull && git stash pop || git reset --hard origin/develop \ + && cd /wbia/wbia-plugin-curvrank-v2/wbia_curvrank_v2 \ && git stash && git pull && git stash pop || git reset --hard origin/develop \ && cd /wbia/wbia-plugin-kaggle7/wbia_kaggle7 \ && git stash && git pull && git stash pop || git reset --hard origin/develop \ + && cd /wbia/wbia-plugin-orientation/ \ + && git stash && git pull && git stash pop || git reset --hard origin/develop \ + && cd /wbia/wildbook-ia/ \ + && git checkout develop \ && find /wbia -name '.git' -type d -print0 | xargs -0 rm -rf \ && find /wbia -name '_skbuild' -type d -print0 | xargs -0 rm -rf # Run smoke tests RUN set -ex \ - && /virtualenv/env3/bin/python -c "import wbia; from wbia.__main__ import smoke_test; smoke_test()" \ - && /virtualenv/env3/bin/python -c "import wbia_cnn; from wbia_cnn.__main__ import main; main()" \ - && /virtualenv/env3/bin/python -c "import wbia_pie; from wbia_pie.__main__ import main; main()" \ - && /virtualenv/env3/bin/python -c "import wbia_flukematch; from wbia_flukematch.plugin import *" \ - && /virtualenv/env3/bin/python -c "import wbia_curvrank; from wbia_curvrank._plugin import *" \ - && /virtualenv/env3/bin/python -c "import wbia_finfindr; from wbia_finfindr._plugin import *" \ - && /virtualenv/env3/bin/python -c "import wbia_kaggle7; from wbia_kaggle7._plugin import *" \ - && /virtualenv/env3/bin/python -c "import wbia_deepsense; from wbia_deepsense._plugin import *" \ - && find /wbia/wbia* -name '*.a' -print0 | xargs -0 -i /bin/bash -c 'echo {} && ld -d {}' \ - && find /wbia/wbia* -name '*.so' -print0 | xargs -0 -i /bin/bash -c 'echo {} && ld -d {}' \ - && find /wbia/wildbook* -name '*.a' -print0 | xargs -0 -i /bin/bash -c 'echo {} && ld -d {}' \ - && find /wbia/wildbook* -name '*.so' -print0 | xargs -0 -i /bin/bash -c 'echo {} && ld -d {}' + && /virtualenv/env3/bin/python -c "import wbia; from wbia.__main__ import smoke_test; smoke_test()" \ + && /virtualenv/env3/bin/python -c "import wbia_cnn; from wbia_cnn.__main__ import main; main()" \ + && /virtualenv/env3/bin/python -c "import wbia_pie; from wbia_pie.__main__ import main; main()" \ + && /virtualenv/env3/bin/python -c "import wbia_orientation; from wbia_orientation.__main__ import main; main()" \ + && /virtualenv/env3/bin/python -c "import wbia_flukematch; from wbia_flukematch.plugin import *" \ + && /virtualenv/env3/bin/python -c "import wbia_curvrank; from wbia_curvrank._plugin import *" \ + && /virtualenv/env3/bin/python -c "import wbia_curvrank_v2; from wbia_curvrank_v2._plugin import *" \ + && /virtualenv/env3/bin/python -c "import wbia_finfindr; from wbia_finfindr._plugin import *" \ + && /virtualenv/env3/bin/python -c "import wbia_kaggle7; from wbia_kaggle7._plugin import *" \ + && /virtualenv/env3/bin/python -c "import wbia_deepsense; from wbia_deepsense._plugin import *" \ + && find /wbia/wbia* -name '*.a' -print | grep -v "cpython-37m-x86_64-linux-gnu" | xargs -i /bin/bash -c 'echo {} && ld -d {}' \ + && find /wbia/wbia* -name '*.so' -print | grep -v "cpython-37m-x86_64-linux-gnu" | xargs -i /bin/bash -c 'echo {} && ld -d {}' \ + && find /wbia/wildbook* -name '*.a' -print | grep -v "cpython-37m-x86_64-linux-gnu" | xargs -i /bin/bash -c 'echo {} && ld -d {}' \ + && find /wbia/wildbook* -name '*.so' -print | grep -v "cpython-37m-x86_64-linux-gnu" | xargs -i /bin/bash -c 'echo {} && ld -d {}' ########################################################################################## @@ -42,7 +50,7 @@ LABEL autoheal=true ARG VERSION="3.3.0" -ARG VCS_URL="https://github.com/WildbookOrg/wildbook-ia" +ARG VCS_URL="https://github.com/WildMeOrg/wildbook-ia" ARG VCS_REF="develop" diff --git a/devops/_config/setup.sh b/devops/_config/setup.sh index cf22d64331..11276bcbc5 100755 --- a/devops/_config/setup.sh +++ b/devops/_config/setup.sh @@ -44,3 +44,6 @@ chown -R ${HOST_USER}:${HOST_USER} /wbia/wbia-plugin-pie/ if [ ! -d "/data/docker" ]; then ln -s -T /data/db /data/docker fi + +# Allow Tensorflow to use GPU memory more dynamically +export TF_FORCE_GPU_ALLOW_GROWTH=true diff --git a/devops/build.sh b/devops/build.sh index 2bb012968c..43354f7e87 100755 --- a/devops/build.sh +++ b/devops/build.sh @@ -5,6 +5,8 @@ set -ex # See https://stackoverflow.com/a/246128/176882 export ROOT_LOC="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +export DOCKER_BUILDKIT=1 + # Change to the script's root directory location cd ${ROOT_LOC} diff --git a/devops/install.ubuntu.sh b/devops/install.ubuntu.sh index ff8a6922cc..6d2bc741dc 100644 --- a/devops/install.ubuntu.sh +++ b/devops/install.ubuntu.sh @@ -181,35 +181,35 @@ pip install pygraphviz --install-option="--include-path=/usr/include/graphviz" - cp -r ${VIRTUAL_ENV}/lib/python3.7/site-packages/cv2 /tmp/cv2 cd ${CODE} -git clone --branch develop https://github.com/WildbookOrg/wildbook-ia.git -git clone --branch develop https://github.com/WildbookOrg/wbia-utool.git -git clone --branch develop https://github.com/WildbookOrg/wbia-vtool.git -git clone --branch develop https://github.com/WildbookOrg/wbia-tpl-pyhesaff.git -git clone --branch develop https://github.com/WildbookOrg/wbia-tpl-pyflann.git -git clone --branch develop https://github.com/WildbookOrg/wbia-tpl-pydarknet.git -git clone --branch develop https://github.com/WildbookOrg/wbia-tpl-pyrf.git -git clone --branch develop https://github.com/WildbookOrg/wbia-deprecate-tpl-brambox -git clone --branch develop https://github.com/WildbookOrg/wbia-deprecate-tpl-lightnet -git clone --recursive --branch develop https://github.com/WildbookOrg/wbia-plugin-cnn.git -git clone --branch develop https://github.com/WildbookOrg/wbia-plugin-flukematch.git -git clone --branch develop https://github.com/WildbookOrg/wbia-plugin-finfindr.git -git clone --branch develop https://github.com/WildbookOrg/wbia-plugin-deepsense.git -git clone --branch develop https://github.com/WildbookOrg/wbia-plugin-pie.git +git clone --branch develop https://github.com/WildMeOrg/wildbook-ia.git +git clone --branch develop https://github.com/WildMeOrg/wbia-utool.git +git clone --branch develop https://github.com/WildMeOrg/wbia-vtool.git +git clone --branch develop https://github.com/WildMeOrg/wbia-tpl-pyhesaff.git +git clone --branch develop https://github.com/WildMeOrg/wbia-tpl-pyflann.git +git clone --branch develop https://github.com/WildMeOrg/wbia-tpl-pydarknet.git +git clone --branch develop https://github.com/WildMeOrg/wbia-tpl-pyrf.git +git clone --branch develop https://github.com/WildMeOrg/wbia-deprecate-tpl-brambox +git clone --branch develop https://github.com/WildMeOrg/wbia-deprecate-tpl-lightnet +git clone --recursive --branch develop https://github.com/WildMeOrg/wbia-plugin-cnn.git +git clone --branch develop https://github.com/WildMeOrg/wbia-plugin-flukematch.git +git clone --branch develop https://github.com/WildMeOrg/wbia-plugin-finfindr.git +git clone --branch develop https://github.com/WildMeOrg/wbia-plugin-deepsense.git +git clone --branch develop https://github.com/WildMeOrg/wbia-plugin-pie.git cd ${CODE} -git clone --recursive --branch develop https://github.com/WildbookOrg/wbia-plugin-curvrank.git +git clone --recursive --branch develop https://github.com/WildMeOrg/wbia-plugin-curvrank.git cd wbia-plugin-curvrank/wbia_curvrank git fetch origin git checkout develop cd ${CODE} -git clone --recursive --branch develop https://github.com/WildbookOrg/wbia-plugin-kaggle7.git +git clone --recursive --branch develop https://github.com/WildMeOrg/wbia-plugin-kaggle7.git cd wbia-plugin-kaggle7/wbia_kaggle7 git fetch origin git checkout develop cd ${CODE} -git clone --recursive --branch develop https://github.com/WildbookOrg/wbia-plugin-lca.git +git clone --recursive --branch develop https://github.com/WildMeOrg/wbia-plugin-lca.git cd ${CODE}/wbia-utool ./run_developer_setup.sh diff --git a/devops/provision/Dockerfile b/devops/provision/Dockerfile index 47fe001e53..2799c67308 100644 --- a/devops/provision/Dockerfile +++ b/devops/provision/Dockerfile @@ -3,7 +3,7 @@ ARG WBIA_DEPENDENCIES_IMAGE=wildme/wbia-dependencies:latest FROM ${WBIA_DEPENDENCIES_IMAGE} as org.wildme.wbia.provision # Wildbook IA version -ARG VCS_URL="https://github.com/WildbookOrg/wildbook-ia" +ARG VCS_URL="https://github.com/WildMeOrg/wildbook-ia" ARG VCS_REF="develop" @@ -23,46 +23,57 @@ RUN set -ex \ # Clone WBIA toolkit repositories RUN set -ex \ && cd /wbia \ - && git clone --branch develop https://github.com/WildbookOrg/wbia-utool.git \ - && git clone --branch develop https://github.com/WildbookOrg/wbia-vtool.git + && git clone --branch develop https://github.com/WildMeOrg/wbia-utool.git \ + && git clone --branch develop https://github.com/WildMeOrg/wbia-vtool.git # Clone WBIA third-party toolkit repositories RUN set -ex \ && cd /wbia \ - && git clone --branch develop https://github.com/WildbookOrg/wbia-tpl-pyhesaff.git \ - && git clone --branch develop https://github.com/WildbookOrg/wbia-tpl-pyflann.git \ - && git clone --branch develop https://github.com/WildbookOrg/wbia-tpl-pydarknet.git \ - && git clone --branch develop https://github.com/WildbookOrg/wbia-tpl-pyrf.git \ + && git clone --branch develop https://github.com/WildMeOrg/wbia-tpl-pyhesaff.git \ + && git clone --branch develop https://github.com/WildMeOrg/wbia-tpl-pyflann.git \ + && git clone --branch develop https://github.com/WildMeOrg/wbia-tpl-pydarknet.git \ + && git clone --branch develop https://github.com/WildMeOrg/wbia-tpl-pyrf.git \ # Depricated - && git clone --branch develop https://github.com/WildbookOrg/wbia-deprecate-tpl-brambox \ - && git clone --branch develop https://github.com/WildbookOrg/wbia-deprecate-tpl-lightnet + && git clone --branch develop https://github.com/WildMeOrg/wbia-deprecate-tpl-brambox \ + && git clone --branch develop https://github.com/WildMeOrg/wbia-deprecate-tpl-lightnet # Clone first-party WBIA plug-in repositories RUN set -ex \ && cd /wbia \ - && git clone --recursive --branch develop https://github.com/WildbookOrg/wbia-plugin-cnn.git + && git clone --recursive --branch develop https://github.com/WildMeOrg/wbia-plugin-cnn.git RUN set -ex \ && cd /wbia \ - && git clone --branch develop https://github.com/WildbookOrg/wbia-plugin-flukematch.git \ - && git clone --branch develop https://github.com/WildbookOrg/wbia-plugin-finfindr.git \ - && git clone --branch develop https://github.com/WildbookOrg/wbia-plugin-deepsense.git \ - && git clone --branch develop https://github.com/WildbookOrg/wbia-plugin-pie.git \ - && git clone --branch develop https://github.com/WildbookOrg/wbia-plugin-lca.git + && git clone --branch develop https://github.com/WildMeOrg/wbia-plugin-flukematch.git \ + && git clone --branch develop https://github.com/WildMeOrg/wbia-plugin-finfindr.git \ + && git clone --branch develop https://github.com/WildMeOrg/wbia-plugin-deepsense.git \ + && git clone --branch develop https://github.com/WildMeOrg/wbia-plugin-pie.git \ + && git clone --branch develop https://github.com/WildMeOrg/wbia-plugin-lca.git -# git clone --recursive --branch develop https://github.com/WildbookOrg/wbia-plugin-2d-orientation.git +RUN set -ex \ + && cd /wbia \ + && git clone --branch develop https://github.com/WildMeOrg/wbia-plugin-orientation.git + +# git clone --recursive --branch develop https://github.com/WildMeOrg/wbia-plugin-2d-orientation.git # Clone third-party WBIA plug-in repositories RUN set -ex \ && cd /wbia \ - && git clone --recursive --branch develop https://github.com/WildbookOrg/wbia-plugin-curvrank.git \ - && cd /wbia/wbia-plugin-curvrank/wbia_curvrank \ + && git clone --recursive --branch develop-curvrank-v1 https://github.com/WildMeOrg/wbia-plugin-curvrank.git /wbia/wbia-plugin-curvrank-v1 \ + && cd /wbia/wbia-plugin-curvrank-v1/wbia_curvrank \ + && git fetch origin \ + && git checkout develop + +RUN set -ex \ + && cd /wbia \ + && git clone --recursive --branch develop-curvrank-v2 https://github.com/WildMeOrg/wbia-plugin-curvrank.git /wbia/wbia-plugin-curvrank-v2 \ + && cd /wbia/wbia-plugin-curvrank-v2/wbia_curvrank_v2 \ && git fetch origin \ && git checkout develop RUN set -ex \ && cd /wbia \ - && git clone --recursive --branch develop https://github.com/WildbookOrg/wbia-plugin-kaggle7.git \ + && git clone --recursive --branch develop https://github.com/WildMeOrg/wbia-plugin-kaggle7.git \ && cd /wbia/wbia-plugin-kaggle7/wbia_kaggle7 \ && git fetch origin \ && git checkout develop @@ -138,13 +149,22 @@ RUN /bin/bash -xc '. /virtualenv/env3/bin/activate \ && cd /wbia/wbia-plugin-lca \ && pip install -e .' +RUN /bin/bash -xc '. /virtualenv/env3/bin/activate \ + && cd /wbia/wbia-plugin-orientation \ + && pip install -e .' + RUN /bin/bash -xc '. /virtualenv/env3/bin/activate \ && cd /wbia/wbia-plugin-flukematch \ && ./unix_build.sh \ && pip install -e .' RUN /bin/bash -xc '. /virtualenv/env3/bin/activate \ - && cd /wbia/wbia-plugin-curvrank \ + && cd /wbia/wbia-plugin-curvrank-v1 \ + && ./unix_build.sh \ + && pip install -e .' + +RUN /bin/bash -xc '. /virtualenv/env3/bin/activate \ + && cd /wbia/wbia-plugin-curvrank-v2 \ && ./unix_build.sh \ && pip install -e .' @@ -180,7 +200,9 @@ RUN set -ex \ && /virtualenv/env3/bin/pip install \ 'tensorflow-gpu==1.15.4' \ 'keras==2.2.5' \ - 'h5py<3.0.0' + 'h5py<3.0.0' \ + && /virtualenv/env3/bin/pip install \ + 'jedi==0.17.2' RUN set -ex \ && /virtualenv/env3/bin/pip freeze | grep wbia \ diff --git a/devops/publish.sh b/devops/publish.sh index 5819d21940..764d32f4b3 100755 --- a/devops/publish.sh +++ b/devops/publish.sh @@ -22,7 +22,7 @@ REGISTRY=${REGISTRY:-} IMAGES=${@:-wbia-base wbia-dependencies wbia-provision wbia wildbook-ia} # Set the image prefix if [ -n "$REGISTRY" ]; then - IMG_PREFIX="${REGISTRY}/wildbookorg/wildbook-ia/" + IMG_PREFIX="${REGISTRY}/wildmeorg/wildbook-ia/" else IMG_PREFIX="wildme/" fi diff --git a/docker-compose.yml b/docker-compose.yml index 5e70c50dc2..5b04dae872 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -16,8 +16,10 @@ services: env_file: ./.dockerfiles/docker-compose.env ports: - "5000:5000" + # Development mounting of the code volumes: - ./:/code + - ./.dockerfiles/docker-entrypoint.sh:/docker-entrypoint.sh pgadmin: image: dpage/pgadmin4 diff --git a/docs/index.rst b/docs/index.rst index 0fc7fbed6a..e30cbf242a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,7 +12,7 @@ Wildbook's Image Analysis is colloquially known as Wildbook-IA and by developers The Wildbook-IA application is used for the storage, management and analysis of images and derived data used by computer vision algorithms. It aims to compute who an animal is, what species an animal is, and where an animal is with the ultimate goal being to ask important why biological questions. -This project is the Machine Learning (ML) / computer vision component of the `WildBook project `_. This project is an actively maintained fork of the popular IBEIS (Image Based Ecological Information System) software suite for wildlife conservation. The original IBEIS project is maintained by Jon Crall (@Erotemic) at https://github.com/Erotemic/ibeis. The IBEIS toolkit originally was a wrapper around HotSpotter, which original binaries can be downloaded from: http://cs.rpi.edu/hotspotter/ +This project is the Machine Learning (ML) / computer vision component of the `WildBook project `_. This project is an actively maintained fork of the popular IBEIS (Image Based Ecological Information System) software suite for wildlife conservation. The original IBEIS project is maintained by Jon Crall (@Erotemic) at https://github.com/Erotemic/ibeis. The IBEIS toolkit originally was a wrapper around HotSpotter, which original binaries can be downloaded from: http://cs.rpi.edu/hotspotter/ Currently the system is build around and SQLite database, a web UI, and matplotlib visualizations. Algorithms employed are: convolutional neural network detection and localization and classification, hessian-affine keypoint detection, SIFT keypoint description, LNBNN identification using approximate nearest neighbors. diff --git a/requirements/optional.txt b/requirements/optional.txt index 9a2da3a469..4c2d02213a 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -3,6 +3,7 @@ autopep8 Cython>=0.24 + Flask-CAS flask-cors diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 25e197676f..c8953437bc 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,6 +1,6 @@ blinker boto>=2.20.1 -brambox +# wbia-brambox # already included in wbia-lightnet cachetools>=1.1.6 click colorama>=0.3.2 @@ -8,16 +8,18 @@ dateutils>=0.6.6 delorean deprecated +descartes docker flask>=0.10.1 flask-cas flask-cors gdm +geopandas + imgaug ipython>=5.0.0 -lightnet lockfile>=0.10.2 mako matplotlib>=3.3.0 @@ -42,7 +44,7 @@ pyzmq>=14.7.0 requests>=2.5.0 scikit-image>=0.12.3 -scikit-learn>=0.17.1 +scikit-learn>=0.24.0 scipy>=0.18.0 sentry-sdk>=0.10.2 @@ -55,7 +57,7 @@ simplejson>=3.6.5 sip six>=1.10.0 -sqlalchemy +sqlalchemy>=1.4.0b1 statsmodels>=0.6.1 torch @@ -65,10 +67,11 @@ tqdm ubelt >= 0.8.7 wbia-cnn>=3.0.2 +wbia-lightnet wbia-pydarknet >= 3.0.1 wbia-pyflann >= 3.1.0 wbia-pyhesaff >= 3.0.2 wbia-pyrf >= 3.0.0 -wbia-utool >= 3.3.1 +wbia-utool >= 3.3.3 wbia-vtool >= 3.2.1 diff --git a/setup.py b/setup.py index 38e7fb3874..8c5c60f4f7 100755 --- a/setup.py +++ b/setup.py @@ -193,7 +193,7 @@ def gen_packages_items(): 'J. Wrona', ] AUTHOR_EMAIL = 'dev@wildme.org' -URL = 'https://github.com/WildbookOrg/wildbook-ia' +URL = 'https://github.com/WildMeOrg/wildbook-ia' LICENSE = 'Apache License 2.0' DESCRIPTION = 'Wildbook IA (WBIA) - Machine learning service for the WildBook project' KEYWORDS = [ @@ -231,6 +231,7 @@ def gen_packages_items(): 'tests': parse_requirements('requirements/tests.txt'), 'build': parse_requirements('requirements/build.txt'), 'runtime': parse_requirements('requirements/runtime.txt'), + 'optional': parse_requirements('requirements/optional.txt'), 'postgres': parse_requirements('requirements/postgres.txt'), }, # --- VERSION --- @@ -271,7 +272,7 @@ def gen_packages_items(): 'Programming Language :: Python :: 3 :: Only', ], project_urls={ # Optional - 'Bug Reports': 'https://github.com/WildbookOrg/wildbook-ia/issues', + 'Bug Reports': 'https://github.com/WildMeOrg/wildbook-ia/issues', 'Funding': 'https://www.wildme.org/donate/', 'Say Thanks!': 'https://community.wildbook.org', 'Source': URL, @@ -279,6 +280,9 @@ def gen_packages_items(): entry_points="""\ [console_scripts] wbia-init-testdbs = wbia.cli.testdbs:main + wbia-convert-hsdb = wbia.cli.convert_hsdb:main + wbia-migrate-sqlite-to-postgres = wbia.cli.migrate_sqlite_to_postgres:main + wbia-compare-databases = wbia.cli.compare_databases:main """, ) diff --git a/super_setup.py b/super_setup.py index 39a939efb3..d8fee14cf1 100755 --- a/super_setup.py +++ b/super_setup.py @@ -23,24 +23,24 @@ REPOS = [ # (, , ) - ('WildbookOrg/wildbook-ia', 'wildbook-ia', 'develop'), - ('WildbookOrg/wbia-utool', 'wbia-utool', 'develop'), - ('WildbookOrg/wbia-vtool', 'wbia-vtool', 'develop'), - ('WildbookOrg/wbia-tpl-pyhesaff', 'wbia-tpl-pyhesaff', 'develop'), - ('WildbookOrg/wbia-tpl-pyflann', 'wbia-tpl-pyflann', 'develop'), - ('WildbookOrg/wbia-tpl-pydarknet', 'wbia-tpl-pydarknet', 'develop'), - ('WildbookOrg/wbia-tpl-pyrf', 'wbia-tpl-pyrf', 'develop'), - ('WildbookOrg/wbia-deprecate-tpl-brambox', 'wbia-deprecate-tpl-brambox', 'develop'), - ('WildbookOrg/wbia-deprecate-tpl-lightnet', 'wbia-deprecate-tpl-lightnet', 'develop'), - ('WildbookOrg/wbia-plugin-cnn', 'wbia-plugin-cnn', 'develop'), - ('WildbookOrg/wbia-plugin-flukematch', 'wbia-plugin-flukematch', 'develop'), - ('WildbookOrg/wbia-plugin-curvrank', 'wbia-plugin-curvrank', 'develop'), - ('WildbookOrg/wbia-plugin-deepsense', 'wbia-plugin-deepsense', 'develop'), - ('WildbookOrg/wbia-plugin-finfindr', 'wbia-plugin-finfindr', 'develop'), - ('WildbookOrg/wbia-plugin-kaggle7', 'wbia-plugin-kaggle7', 'develop'), - ('WildbookOrg/wbia-plugin-pie', 'wbia-plugin-pie', 'develop'), - ('WildbookOrg/wbia-plugin-lca', 'wbia-plugin-lca', 'develop'), - # ('WildbookOrg/wbia-plugin-2d-orientation', 'wbia-plugin-2d-orientation', 'develop'), + ('WildMeOrg/wildbook-ia', 'wildbook-ia', 'develop'), + ('WildMeOrg/wbia-utool', 'wbia-utool', 'develop'), + ('WildMeOrg/wbia-vtool', 'wbia-vtool', 'develop'), + ('WildMeOrg/wbia-tpl-pyhesaff', 'wbia-tpl-pyhesaff', 'develop'), + ('WildMeOrg/wbia-tpl-pyflann', 'wbia-tpl-pyflann', 'develop'), + ('WildMeOrg/wbia-tpl-pydarknet', 'wbia-tpl-pydarknet', 'develop'), + ('WildMeOrg/wbia-tpl-pyrf', 'wbia-tpl-pyrf', 'develop'), + ('WildMeOrg/wbia-deprecate-tpl-brambox', 'wbia-deprecate-tpl-brambox', 'develop'), + ('WildMeOrg/wbia-deprecate-tpl-lightnet', 'wbia-deprecate-tpl-lightnet', 'develop'), + ('WildMeOrg/wbia-plugin-cnn', 'wbia-plugin-cnn', 'develop'), + ('WildMeOrg/wbia-plugin-flukematch', 'wbia-plugin-flukematch', 'develop'), + ('WildMeOrg/wbia-plugin-curvrank', 'wbia-plugin-curvrank', 'develop'), + ('WildMeOrg/wbia-plugin-deepsense', 'wbia-plugin-deepsense', 'develop'), + ('WildMeOrg/wbia-plugin-finfindr', 'wbia-plugin-finfindr', 'develop'), + ('WildMeOrg/wbia-plugin-kaggle7', 'wbia-plugin-kaggle7', 'develop'), + ('WildMeOrg/wbia-plugin-pie', 'wbia-plugin-pie', 'develop'), + ('WildMeOrg/wbia-plugin-lca', 'wbia-plugin-lca', 'develop'), + # ('WildMeOrg/wbia-plugin-2d-orientation', 'wbia-plugin-2d-orientation', 'develop'), ] diff --git a/wbia/__init__.py b/wbia/__init__.py index d2fb0f1715..f79b02020c 100644 --- a/wbia/__init__.py +++ b/wbia/__init__.py @@ -68,7 +68,7 @@ main_loop, opendb, opendb_in_background, - opendb_bg_web, + opendb_with_web, ) from wbia.control.IBEISControl import IBEISController from wbia.algo.hots.query_request import QueryRequest @@ -85,6 +85,7 @@ from wbia.init import main_helpers from wbia import algo + from wbia import research from wbia import expt from wbia import templates diff --git a/wbia/algo/detect/__init__.py b/wbia/algo/detect/__init__.py index 5c742d556c..416d644ada 100644 --- a/wbia/algo/detect/__init__.py +++ b/wbia/algo/detect/__init__.py @@ -5,7 +5,9 @@ from wbia.algo.detect import grabmodels from wbia.algo.detect import randomforest from wbia.algo.detect import yolo +from wbia.algo.detect import assigner +# from wbia.algo.detect import train_assigner # from wbia.algo.detect import selectivesearch # from wbia.algo.detect import ssd # from wbia.algo.detect import fasterrcnn @@ -93,6 +95,7 @@ def get_reload_subs(mod): ('grabmodels', None), ('randomforest', None), ('yolo', None), + ('assigner', None), # ('selectivesearch', None), # ('ssd', None), # ('fasterrcnn', None), diff --git a/wbia/algo/detect/assigner.py b/wbia/algo/detect/assigner.py new file mode 100644 index 0000000000..2b3fced4cc --- /dev/null +++ b/wbia/algo/detect/assigner.py @@ -0,0 +1,502 @@ +# -*- coding: utf-8 -*- +import logging + +# from os.path import expanduser, join +from wbia import constants as const +from wbia.control.controller_inject import make_ibs_register_decorator +import utool as ut +import os +from collections import defaultdict + +# illustration imports +from shutil import copy +from PIL import Image, ImageDraw +import wbia.plottool as pt + + +logger = logging.getLogger('wbia') + +CLASS_INJECT_KEY, register_ibs_method = make_ibs_register_decorator(__name__) + +PARALLEL = not const.CONTAINERIZED +INPUT_SIZE = 224 + +INMEM_ASSIGNER_MODELS = {} + +SPECIES_CONFIG_MAP = { + 'wild_dog': { + 'model_file': '/tmp/balanced_wd.joblib', + 'model_url': 'https://wildbookiarepository.azureedge.net/models/assigner.wd_v0.joblib', + 'annot_feature_col': 'assigner_viewpoint_features', + }, + 'wild_dog_dark': { + 'model_file': '/tmp/balanced_wd.joblib', + 'model_url': 'https://wildbookiarepository.azureedge.net/models/assigner.wd_v0.joblib', + 'annot_feature_col': 'assigner_viewpoint_features', + }, + 'wild_dog_light': { + 'model_file': '/tmp/balanced_wd.joblib', + 'model_url': 'https://wildbookiarepository.azureedge.net/models/assigner.wd_v0.joblib', + 'annot_feature_col': 'assigner_viewpoint_features', + }, + 'wild_dog_puppy': { + 'model_file': '/tmp/balanced_wd.joblib', + 'model_url': 'https://wildbookiarepository.azureedge.net/models/assigner.wd_v0.joblib', + 'annot_feature_col': 'assigner_viewpoint_features', + }, + 'wild_dog_standard': { + 'model_file': '/tmp/balanced_wd.joblib', + 'model_url': 'https://wildbookiarepository.azureedge.net/models/assigner.wd_v0.joblib', + 'annot_feature_col': 'assigner_viewpoint_features', + }, + 'wild_dog_tan': { + 'model_file': '/tmp/balanced_wd.joblib', + 'model_url': 'https://wildbookiarepository.azureedge.net/models/assigner.wd_v0.joblib', + 'annot_feature_col': 'assigner_viewpoint_features', + }, +} + + +@register_ibs_method +def _are_part_annots(ibs, aid_list): + r""" + returns a boolean list representing if each aid in aid_list is a part annot. + This determination is made by the presence of a '+' in the species. + + Args: + ibs (IBEISController): IBEIS / WBIA controller object + aid_list (int): annot ids to split + + CommandLine: + python -m wbia.algo.detect.assigner _are_part_annots + + Example: + >>> # ENABLE_DOCTEST + >>> import utool as ut + >>> from wbia.algo.detect.assigner import * + >>> from wbia.algo.detect.train_assigner import * + >>> ibs = assigner_testdb_ibs() + >>> aids = ibs.get_valid_aids() + >>> result = ibs._are_part_annots(aids) + >>> print(result) + [False, False, True, True, False, True, False, True] + """ + species = ibs.get_annot_species(aid_list) + are_parts = ['+' in specie for specie in species] + return are_parts + + +def all_part_pairs(ibs, gid_list): + r""" + Returns all possible part,body pairs from aids in gid_list, in the format of + two parralel lists: the first being all parts, the second all bodies + + Args: + ibs (IBEISController): IBEIS / WBIA controller object + gid_list (int): gids in question + + CommandLine: + python -m wbia.algo.detect.assigner _are_part_annots + + Example: + >>> # ENABLE_DOCTEST + >>> import utool as ut + >>> from wbia.algo.detect.assigner import * + >>> from wbia.algo.detect.train_assigner import * + >>> ibs = assigner_testdb_ibs() + >>> gids = ibs.get_valid_gids() + >>> all_part_pairs = all_part_pairs(ibs, gids) + >>> parts = all_part_pairs[0] + >>> bodies = all_part_pairs[1] + >>> all_aids = ibs.get_image_aids(gids) + >>> all_aids = [aid for aids in all_aids for aid in aids] # flatten + >>> assert (set(parts) & set(bodies)) == set({}) + >>> assert (set(parts) | set(bodies)) == set(all_aids) + >>> result = all_part_pairs + >>> print(result) + ([3, 3, 4, 4, 6, 8], [1, 2, 1, 2, 5, 7]) + """ + all_aids = ibs.get_image_aids(gid_list) + all_aids_are_parts = [ibs._are_part_annots(aids) for aids in all_aids] + all_part_aids = [ + [aid for (aid, part) in zip(aids, are_parts) if part] + for (aids, are_parts) in zip(all_aids, all_aids_are_parts) + ] + all_body_aids = [ + [aid for (aid, part) in zip(aids, are_parts) if not part] + for (aids, are_parts) in zip(all_aids, all_aids_are_parts) + ] + part_body_parallel_lists = [ + _all_pairs_parallel(parts, bodies) + for parts, bodies in zip(all_part_aids, all_body_aids) + ] + all_parts = [ + aid + for part_body_parallel_list in part_body_parallel_lists + for aid in part_body_parallel_list[0] + ] + all_bodies = [ + aid + for part_body_parallel_list in part_body_parallel_lists + for aid in part_body_parallel_list[1] + ] + return all_parts, all_bodies + + +def _all_pairs_parallel(list_a, list_b): + # is tested by all_part_pairs above + pairs = [(a, b) for a in list_a for b in list_b] + pairs_a = [pair[0] for pair in pairs] + pairs_b = [pair[1] for pair in pairs] + return pairs_a, pairs_b + + +@register_ibs_method +def assign_parts(ibs, all_aids, cutoff_score=0.5): + r""" + Main assigner method; makes assignments on all_aids based on assigner scores. + + Args: + ibs (IBEISController): IBEIS / WBIA controller object + aid_list (int): aids in question + cutoff_score: the threshold for the aids' assigner scores, under which no assignments are made + + Returns: + tuple of two lists: all_assignments (a list of tuples, each tuple grouping + aids assigned to a single animal), and all_unassigned_aids, which are the aids that did not meet the cutoff_score or whose body/part + + CommandLine: + python -m wbia.algo.detect.assigner _are_part_annots + + Example: + >>> # ENABLE_DOCTEST + >>> import utool as ut + >>> from wbia.algo.detect.assigner import * + >>> from wbia.algo.detect.train_assigner import * + >>> ibs = assigner_testdb_ibs() + >>> aids = ibs.get_valid_aids() + >>> result = ibs.assign_parts(aids) + >>> assigned_pairs = result[0] + >>> unassigned_aids = result[1] + >>> assigned_aids = [item for pair in assigned_pairs for item in pair] + >>> # no overlap between assigned and unassigned aids + >>> assert (set(assigned_aids) & set(unassigned_aids) == set({})) + >>> # all aids are either assigned or unassigned + >>> assert (set(assigned_aids) | set(unassigned_aids) == set(aids)) + >>> ([(3, 1), (6, 5), (8, 7)], [2, 4]) + """ + gids = ibs.get_annot_gids(all_aids) + gid_to_aids = defaultdict(list) + for gid, aid in zip(gids, all_aids): + gid_to_aids[gid] += [aid] + + all_assignments = [] + all_unassigned_aids = [] + + for gid in gid_to_aids.keys(): + this_pairs, this_unassigned = assign_parts_one_image( + ibs, gid_to_aids[gid], cutoff_score + ) + all_assignments += this_pairs + all_unassigned_aids += this_unassigned + + return all_assignments, all_unassigned_aids + + +@register_ibs_method +def assign_parts_one_image(ibs, aid_list, cutoff_score=0.5): + r""" + Main assigner method; makes assignments on all_aids based on assigner scores. + + Args: + ibs (IBEISController): IBEIS / WBIA controller object + aid_list (int): aids in question + cutoff_score: the threshold for the aids' assigner scores, under which no assignments are made + + Returns: + tuple of two lists: all_assignments (a list of tuples, each tuple grouping + aids assigned to a single animal), and all_unassigned_aids, which are the aids that did not meet the cutoff_score or whose body/part + + CommandLine: + python -m wbia.algo.detect.assigner _are_part_annots + + Example: + >>> # ENABLE_DOCTEST + >>> import utool as ut + >>> from wbia.algo.detect.assigner import * + >>> from wbia.algo.detect.train_assigner import * + >>> ibs = assigner_testdb_ibs() + >>> gid = 1 + >>> aids = ibs.get_image_aids(gid) + >>> result = ibs.assign_parts_one_image(aids) + >>> assigned_pairs = result[0] + >>> unassigned_aids = result[1] + >>> assigned_aids = [item for pair in assigned_pairs for item in pair] + >>> # no overlap between assigned and unassigned aids + >>> assert (set(assigned_aids) & set(unassigned_aids) == set({})) + >>> # all aids are either assigned or unassigned + >>> assert (set(assigned_aids) | set(unassigned_aids) == set(aids)) + >>> ([(3, 1)], [2, 4]) + """ + all_species = ibs.get_annot_species(aid_list) + # put unsupported species into the all_unassigned_aids list + all_species_no_parts = [species.split('+')[0] for species in all_species] + assign_flag_list = [ + species in SPECIES_CONFIG_MAP.keys() for species in all_species_no_parts + ] + + unassigned_aids_noconfig = ut.filterfalse_items(aid_list, assign_flag_list) + aid_list = ut.compress(aid_list, assign_flag_list) + + are_part_aids = _are_part_annots(ibs, aid_list) + part_aids = ut.compress(aid_list, are_part_aids) + body_aids = ut.compress(aid_list, [not p for p in are_part_aids]) + + gids = ibs.get_annot_gids(list(set(part_aids)) + list(set(body_aids))) + num_images = len(set(gids)) + assert num_images <= 1, "assign_parts_one_image called on multiple images' aids" + + # parallel lists representing all possible part/body pairs + all_pairs_parallel = _all_pairs_parallel(part_aids, body_aids) + pair_parts, pair_bodies = all_pairs_parallel + + if len(pair_parts) > 0 and len(pair_bodies) > 0: + assigner_features = ibs.depc_annot.get( + 'assigner_viewpoint_features', all_pairs_parallel + ) + # send all aids to this call just so it can find the right classifier model + assigner_classifier = load_assigner_classifier(ibs, body_aids + part_aids) + + assigner_scores = assigner_classifier.predict_proba(assigner_features) + # assigner_scores is a list of [P_false, P_true] probabilities which sum to 1, so here we just pare down to the true probabilities + assigner_scores = [score[1] for score in assigner_scores] + good_pairs, unassigned_aids = _make_assignments( + pair_parts, pair_bodies, assigner_scores, cutoff_score + ) + else: + good_pairs = [] + unassigned_aids = aid_list + + unassigned_aids = unassigned_aids_noconfig + unassigned_aids + return good_pairs, unassigned_aids + + +def _make_assignments(pair_parts, pair_bodies, assigner_scores, cutoff_score=0.5): + + sorted_scored_pairs = [ + (part, body, score) + for part, body, score in sorted( + zip(pair_parts, pair_bodies, assigner_scores), + key=lambda pbscore: pbscore[2], + reverse=True, + ) + ] + + assigned_pairs = [] + assigned_parts = set() + assigned_bodies = set() + n_bodies = len(set(pair_bodies)) + n_parts = len(set(pair_parts)) + n_true_pairs = min(n_bodies, n_parts) + for part_aid, body_aid, score in sorted_scored_pairs: + assign_this_pair = ( + part_aid not in assigned_parts + and body_aid not in assigned_bodies + and score >= cutoff_score + ) + + if assign_this_pair: + assigned_pairs.append((part_aid, body_aid)) + assigned_parts.add(part_aid) + assigned_bodies.add(body_aid) + + if ( + len(assigned_parts) is n_true_pairs + or len(assigned_bodies) is n_true_pairs + or score > cutoff_score + ): + break + + unassigned_parts = set(pair_parts) - set(assigned_parts) + unassigned_bodies = set(pair_bodies) - set(assigned_bodies) + unassigned_aids = sorted(list(unassigned_parts) + list(unassigned_bodies)) + + return assigned_pairs, unassigned_aids + + +def load_assigner_classifier(ibs, aid_list, fallback_species='wild_dog'): + species_with_part = ibs.get_annot_species(aid_list[0]) + species = species_with_part.split('+')[0] + if species in INMEM_ASSIGNER_MODELS.keys(): + clf = INMEM_ASSIGNER_MODELS[species] + else: + if species not in SPECIES_CONFIG_MAP.keys(): + print( + 'WARNING: Assigner called for species %s which does not have an assigner modelfile specified. Falling back to the model for %s' + % species, + fallback_species, + ) + species = fallback_species + + model_url = SPECIES_CONFIG_MAP[species]['model_url'] + model_fpath = ut.grab_file_url(model_url) + from joblib import load + + clf = load(model_fpath) + + return clf + + +def illustrate_all_assignments( + ibs, + gid_to_assigner_results, + gid_to_ground_truth, + target_dir='/tmp/assigner-illustrations-2/', + limit=20, +): + + correct_dir = os.path.join(target_dir, 'correct/') + incorrect_dir = os.path.join(target_dir, 'incorrect/') + + for gid, assigned_aid_dict in gid_to_assigner_results.items()[:limit]: + ground_t_dict = gid_to_ground_truth[gid] + assigned_correctly = sorted(assigned_aid_dict['pairs']) == sorted( + ground_t_dict['pairs'] + ) + if assigned_correctly: + illustrate_assignments( + ibs, gid, assigned_aid_dict, None, correct_dir + ) # don't need to illustrate gtruth if it's identical to assignment + else: + illustrate_assignments( + ibs, gid, assigned_aid_dict, ground_t_dict, incorrect_dir + ) + + print('illustrated assignments and saved them in %s' % target_dir) + + +# works on a single gid's worth of gid_keyed_assigner_results output +def illustrate_assignments( + ibs, + gid, + assigned_aid_dict, + gtruth_aid_dict, + target_dir='/tmp/assigner-illustrations/', +): + impath = ibs.get_image_paths(gid) + imext = os.path.splitext(impath)[1] + new_fname = os.path.join(target_dir, '%s%s' % (gid, imext)) + os.makedirs(target_dir, exist_ok=True) + copy(impath, new_fname) + + with Image.open(new_fname) as image: + _draw_all_annots(ibs, image, assigned_aid_dict, gtruth_aid_dict) + image.save(new_fname) + + +def _draw_all_annots(ibs, image, assigned_aid_dict, gtruth_aid_dict): + n_pairs = len(assigned_aid_dict['pairs']) + # n_missing_pairs = 0 + # TODO: missing pair shit + n_unass = len(assigned_aid_dict['unassigned']) + n_groups = n_pairs + n_unass + colors = _pil_distinct_colors(n_groups) + + draw = ImageDraw.Draw(image) + for i, pair in enumerate(assigned_aid_dict['pairs']): + _draw_bbox(ibs, draw, pair[0], colors[i]) + _draw_bbox(ibs, draw, pair[1], colors[i]) + + for i, aid in enumerate(assigned_aid_dict['unassigned'], start=n_pairs): + _draw_bbox(ibs, draw, aid, colors[i]) + + +def _pil_distinct_colors(n_colors): + float_colors = pt.distinct_colors(n_colors) + int_colors = [tuple([int(256 * f) for f in color]) for color in float_colors] + return int_colors + + +def _draw_bbox(ibs, pil_draw, aid, color): + verts = ibs.get_annot_rotated_verts(aid) + pil_verts = [tuple(vertex) for vertex in verts] + pil_verts += pil_verts[:1] # for the line between the last and first vertex + pil_draw.line(pil_verts, color, width=4) + + +def gid_keyed_assigner_results(ibs, all_pairs, all_unassigned_aids): + one_from_each_pair = [p[0] for p in all_pairs] + pair_gids = ibs.get_annot_gids(one_from_each_pair) + unassigned_gids = ibs.get_annot_gids(all_unassigned_aids) + + gid_to_pairs = defaultdict(list) + for pair, gid in zip(all_pairs, pair_gids): + gid_to_pairs[gid] += [pair] + + gid_to_unassigned = defaultdict(list) + for aid, gid in zip(all_unassigned_aids, unassigned_gids): + gid_to_unassigned[gid] += [aid] + + gid_to_assigner_results = {} + for gid in set(gid_to_pairs.keys()) | set(gid_to_unassigned.keys()): + gid_to_assigner_results[gid] = { + 'pairs': gid_to_pairs[gid], + 'unassigned': gid_to_unassigned[gid], + } + + return gid_to_assigner_results + + +def gid_keyed_ground_truth(ibs, assigner_data): + test_pairs = assigner_data['test_pairs'] + test_truth = assigner_data['test_truth'] + assert len(test_pairs) == len(test_truth) + + aid_from_each_pair = [p[0] for p in test_pairs] + gids_for_pairs = ibs.get_annot_gids(aid_from_each_pair) + + gid_to_pairs = defaultdict(list) + gid_to_paired_aids = defaultdict(set) # to know which have not been in any pair + gid_to_all_aids = defaultdict(set) + for pair, is_true_pair, gid in zip(test_pairs, test_truth, gids_for_pairs): + gid_to_all_aids[gid] = gid_to_all_aids[gid] | set(pair) + if is_true_pair: + gid_to_pairs[gid] += [pair] + gid_to_paired_aids[gid] = gid_to_paired_aids[gid] | set(pair) + + gid_to_unassigned_aids = defaultdict(list) + for gid in gid_to_all_aids.keys(): + gid_to_unassigned_aids[gid] = list(gid_to_all_aids[gid] - gid_to_paired_aids[gid]) + + gid_to_assigner_results = {} + for gid in set(gid_to_pairs.keys()) | set(gid_to_unassigned_aids.keys()): + gid_to_assigner_results[gid] = { + 'pairs': gid_to_pairs[gid], + 'unassigned': gid_to_unassigned_aids[gid], + } + + return gid_to_assigner_results + + +@register_ibs_method +def assigner_testdb_ibs(): + import wbia + from wbia import sysres + + dbdir = sysres.ensure_testdb_assigner() + # dbdir = '/data/testdb_assigner' + ibs = wbia.opendb(dbdir=dbdir) + return ibs + + +if __name__ == '__main__': + r""" + CommandLine: + python -m wbia.algo.detect.assigner --allexamples + """ + import multiprocessing + + multiprocessing.freeze_support() # for win32 + import utool as ut # NOQA + + ut.doctest_funcs() diff --git a/wbia/algo/detect/densenet.py b/wbia/algo/detect/densenet.py index f49cc88663..d0942cc6c9 100644 --- a/wbia/algo/detect/densenet.py +++ b/wbia/algo/detect/densenet.py @@ -47,8 +47,7 @@ 'humpback_dorsal': 'https://wildbookiarepository.azureedge.net/models/labeler.whale_humpback.dorsal.v0.zip', 'orca_v0': 'https://wildbookiarepository.azureedge.net/models/labeler.whale_orca.v0.zip', 'fins_v0': 'https://wildbookiarepository.azureedge.net/models/labeler.fins.v0.zip', - 'fins_v1': 'https://wildbookiarepository.azureedge.net/models/labeler.fins.v1.zip', - 'fins_v1-1': 'https://wildbookiarepository.azureedge.net/models/labeler.fins.v1.1.zip', + 'fins_v1': 'https://wildbookiarepository.azureedge.net/models/labeler.fins.v1.1.zip', 'wilddog_v0': 'https://wildbookiarepository.azureedge.net/models/labeler.wild_dog.v0.zip', 'wilddog_v1': 'https://wildbookiarepository.azureedge.net/models/labeler.wild_dog.v1.zip', 'wilddog_v2': 'https://wildbookiarepository.azureedge.net/models/labeler.wild_dog.v2.zip', @@ -72,6 +71,7 @@ 'flukebook_v1': 'https://wildbookiarepository.azureedge.net/models/classifier2.flukebook.v1.zip', 'rightwhale_v5': 'https://wildbookiarepository.azureedge.net/models/labeler.rightwhale.v5.zip', 'snow_leopard_v0': 'https://wildbookiarepository.azureedge.net/models/labeler.snow_leopard.v0.zip', + 'grey_whale_v0': 'https://wildbookiarepository.azureedge.net/models/labeler.whale_grey.v0.zip', } diff --git a/wbia/algo/detect/lightnet.py b/wbia/algo/detect/lightnet.py index cc0655ede2..63ede7770a 100644 --- a/wbia/algo/detect/lightnet.py +++ b/wbia/algo/detect/lightnet.py @@ -52,8 +52,9 @@ 'humpback_dorsal': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.whale_humpback.dorsal.v0.py', 'orca_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.whale_orca.v0.py', 'fins_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.fins.v0.py', - 'fins_v1': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.fins.v1.py', - 'fins_v1-1': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.fins.v1.1.py', + 'fins_v1_fluke': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.fins.v1.py', + 'fins_v1_dorsal': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.fins.v1.1.py', + 'fins_v1': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.fins.v1.1.py', 'nassau_grouper_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.grouper_nassau.v0.py', 'spotted_dolphin_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.dolphin_spotted.v0.py', 'spotted_skunk_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.skunk_spotted.v0.py', @@ -70,6 +71,9 @@ 'candidacy': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.candidacy.py', 'ggr2': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.ggr2.py', 'snow_leopard_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.snow_leopard.v0.py', + 'megan_argentina_v1': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.megan.argentina.v1.py', + 'megan_kenya_v1': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.megan.kenya.v1.py', + 'grey_whale_v0': 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.whale_grey.v0.py', None: 'https://wildbookiarepository.azureedge.net/models/detect.lightnet.candidacy.py', 'training_kit': 'https://wildbookiarepository.azureedge.net/data/lightnet-training-kit.zip', } diff --git a/wbia/algo/detect/train_assigner.py b/wbia/algo/detect/train_assigner.py new file mode 100644 index 0000000000..a76f2fc1a4 --- /dev/null +++ b/wbia/algo/detect/train_assigner.py @@ -0,0 +1,1750 @@ +# -*- coding: utf-8 -*- +# import logging +# from os.path import expanduser, join +# from wbia import constants as const +from wbia.control.controller_inject import ( + register_preprocs, + # register_subprops, + make_ibs_register_decorator, +) + +from wbia.algo.detect.assigner import ( + gid_keyed_assigner_results, + gid_keyed_ground_truth, + illustrate_all_assignments, + all_part_pairs, +) + +import utool as ut +import numpy as np +from wbia import dtool +import random + +# import os +from collections import OrderedDict + +# from collections import defaultdict +from datetime import datetime +import time + +from math import sqrt + +from sklearn import preprocessing + +# illustration imports +# from shutil import copy +# from PIL import Image, ImageDraw +# import wbia.plottool as pt + + +# import matplotlib.pyplot as plt +# from matplotlib.colors import ListedColormap +# from sklearn.model_selection import train_test_split +# from sklearn.preprocessing import StandardScaler +from sklearn.neural_network import MLPClassifier +from sklearn.neighbors import KNeighborsClassifier +from sklearn.svm import SVC +from sklearn.gaussian_process import GaussianProcessClassifier +from sklearn.gaussian_process.kernels import RBF +from sklearn.tree import DecisionTreeClassifier +from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier +from sklearn.naive_bayes import GaussianNB +from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis +from sklearn.model_selection import GridSearchCV + + +derived_attribute = register_preprocs['annot'] + + +CLASS_INJECT_KEY, register_ibs_method = make_ibs_register_decorator(__name__) + + +CLASSIFIER_OPTIONS = [ + { + 'name': 'Nearest Neighbors', + 'clf': KNeighborsClassifier(3), + 'param_options': { + 'n_neighbors': [3, 5, 11, 19], + 'weights': ['uniform', 'distance'], + 'metric': ['euclidean', 'manhattan'], + }, + }, + { + 'name': 'Linear SVM', + 'clf': SVC(kernel='linear', C=0.025), + 'param_options': { + 'C': [1, 10, 100, 1000], + 'kernel': ['linear'], + }, + }, + { + 'name': 'RBF SVM', + 'clf': SVC(gamma=2, C=1), + 'param_options': { + 'C': [1, 10, 100, 1000], + 'gamma': [0.001, 0.0001], + 'kernel': ['rbf'], + }, + }, + { + 'name': 'Decision Tree', + 'clf': DecisionTreeClassifier(), # max_depth=5 + 'param_options': { + 'max_depth': np.arange(1, 12), + 'max_leaf_nodes': [2, 5, 10, 20, 50, 100], + }, + }, + # { + # "name": "Random Forest", + # "clf": RandomForestClassifier(), #max_depth=5, n_estimators=10, max_features=1 + # "param_options": { + # 'bootstrap': [True, False], + # 'max_depth': [10, 50, 100, None], + # 'max_features': ['auto', 'sqrt'], + # 'min_samples_leaf': [1, 2, 4], + # 'min_samples_split': [2, 5, 10], + # 'n_estimators': [200, 1000, 2000] + # } + # }, + # { + # "name": "Neural Net", + # "clf": MLPClassifier(), #alpha=1, max_iter=1000 + # "param_options": { + # 'hidden_layer_sizes': [(10,30,10),(20,)], + # 'activation': ['tanh', 'relu'], + # 'solver': ['sgd', 'adam'], + # 'alpha': [0.0001, 0.05], + # 'learning_rate': ['constant','adaptive'], + # } + # }, + { + 'name': 'AdaBoost', + 'clf': AdaBoostClassifier(), + 'param_options': { + 'n_estimators': np.arange(10, 310, 50), + 'learning_rate': [0.01, 0.05, 0.1, 1], + }, + }, + { + 'name': 'Naive Bayes', + 'clf': GaussianNB(), + 'param_options': {}, # no hyperparams to optimize + }, + { + 'name': 'QDA', + 'clf': QuadraticDiscriminantAnalysis(), + 'param_options': {'reg_param': [0.1, 0.2, 0.3, 0.4, 0.5]}, + }, +] + + +classifier_names = [ + 'Nearest Neighbors', + 'Linear SVM', + 'RBF SVM', + 'Decision Tree', + 'Random Forest', + 'Neural Net', + 'AdaBoost', + 'Naive Bayes', + 'QDA', +] + +classifiers = [ + KNeighborsClassifier(3), + SVC(kernel='linear', C=0.025), + SVC(gamma=2, C=1), + DecisionTreeClassifier(max_depth=5), + RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1), + MLPClassifier(alpha=1, max_iter=1000), + AdaBoostClassifier(), + GaussianNB(), + QuadraticDiscriminantAnalysis(), +] + + +slow_classifier_names = 'Gaussian Process' +slow_classifiers = (GaussianProcessClassifier(1.0 * RBF(1.0)),) + + +def classifier_report(clf, name, assigner_data): + print('%s CLASSIFIER REPORT ' % name) + print(' %s: calling clf.fit' % str(datetime.now())) + clf.fit(assigner_data['data'], assigner_data['target']) + print(' %s: done training, making prediction ' % str(datetime.now())) + preds = clf.predict(assigner_data['test']) + print(' %s: done with predictions, computing accuracy' % str(datetime.now())) + agree = [pred == truth for pred, truth in zip(preds, assigner_data['test_truth'])] + accuracy = agree.count(True) / len(agree) + print(' %s accuracy' % accuracy) + print() + return accuracy + + +@register_ibs_method +def compare_ass_classifiers( + ibs, depc_table_name='assigner_viewpoint_features', print_accs=False +): + + assigner_data = ibs.wd_training_data(depc_table_name) + + accuracies = OrderedDict() + for classifier in CLASSIFIER_OPTIONS: + accuracy = classifier_report(classifier['clf'], classifier['name'], assigner_data) + accuracies[classifier['name']] = accuracy + + # handy for e.g. pasting into excel + if print_accs: + just_accuracy = [accuracies[name] for name in accuracies.keys()] + print(just_accuracy) + + return accuracies + + +@register_ibs_method +def tune_ass_classifiers(ibs, depc_table_name='assigner_viewpoint_unit_features'): + + assigner_data = ibs.wd_training_data(depc_table_name) + + accuracies = OrderedDict() + best_acc = 0 + best_clf_name = '' + best_clf_params = {} + for classifier in CLASSIFIER_OPTIONS: + print('Tuning %s' % classifier['name']) + accuracy, best_params = ibs._tune_grid_search( + classifier['clf'], classifier['param_options'], assigner_data + ) + print() + accuracies[classifier['name']] = { + 'accuracy': accuracy, + 'best_params': best_params, + } + if accuracy > best_acc: + best_acc = accuracy + best_clf_name = classifier['name'] + best_clf_params = best_params + + print( + 'best performance: %s using %s with params %s' + % (best_acc, best_clf_name, best_clf_params) + ) + + return accuracies + + +@register_ibs_method +def _tune_grid_search(ibs, clf, parameters, assigner_data=None): + if assigner_data is None: + assigner_data = ibs.wd_training_data() + + X_train = assigner_data['data'] + y_train = assigner_data['target'] + X_test = assigner_data['test'] + y_test = assigner_data['test_truth'] + + tune_search = GridSearchCV( # TuneGridSearchCV( + clf, + parameters, + ) + + start = time.time() + tune_search.fit(X_train, y_train) + end = time.time() + print('Tune Fit Time: %s' % (end - start)) + pred = tune_search.predict(X_test) + accuracy = np.count_nonzero(np.array(pred) == np.array(y_test)) / len(pred) + print('Tune Accuracy: %s' % accuracy) + print('best parms : %s' % tune_search.best_params_) + + return accuracy, tune_search.best_params_ + + +@register_ibs_method +def _tune_random_search(ibs, clf, parameters, assigner_data=None): + if assigner_data is None: + assigner_data = ibs.wd_training_data() + + X_train = assigner_data['data'] + y_train = assigner_data['target'] + X_test = assigner_data['test'] + y_test = assigner_data['test_truth'] + + tune_search = GridSearchCV( + clf, + parameters, + ) + + start = time.time() + tune_search.fit(X_train, y_train) + end = time.time() + print('Tune Fit Time: %s' % (end - start)) + pred = tune_search.predict(X_test) + accuracy = np.count_nonzero(np.array(pred) == np.array(y_test)) / len(pred) + print('Tune Accuracy: %s' % accuracy) + print('best parms : %s' % tune_search.best_params_) + + return accuracy, tune_search.best_params_ + + +# for wild dog dev +@register_ibs_method +def wd_assigner_data(ibs): + return wd_training_data('part_assignment_features') + + +@register_ibs_method +def wd_normed_assigner_data(ibs): + return wd_training_data('normalized_assignment_features') + + +@register_ibs_method +def wd_training_data( + ibs, depc_table_name='assigner_viewpoint_features', balance_t_f=True +): + all_aids = ibs.get_valid_aids() + ia_classes = ibs.get_annot_species(all_aids) + part_aids = [aid for aid, ia_class in zip(all_aids, ia_classes) if '+' in ia_class] + part_gids = list(set(ibs.get_annot_gids(part_aids))) + all_pairs = all_part_pairs(ibs, part_gids) + all_feats = ibs.depc_annot.get(depc_table_name, all_pairs) + names = [ibs.get_annot_names(all_pairs[0]), ibs.get_annot_names(all_pairs[1])] + ground_truth = [n1 == n2 for (n1, n2) in zip(names[0], names[1])] + + # train_feats, test_feats = train_test_split(all_feats) + # train_truth, test_truth = train_test_split(ground_truth) + pairs_in_train = ibs.gid_train_test_split( + all_pairs[0] + ) # we could pass just the pair aids or just the body aids bc gids are the same + train_feats, test_feats = split_list(all_feats, pairs_in_train) + train_truth, test_truth = split_list(ground_truth, pairs_in_train) + + all_pairs_tuple = [(part, body) for part, body in zip(all_pairs[0], all_pairs[1])] + train_pairs, test_pairs = split_list(all_pairs_tuple, pairs_in_train) + + if balance_t_f: + train_balance_flags = balance_true_false_training_pairs(train_truth) + train_truth = ut.compress(train_truth, train_balance_flags) + train_feats = ut.compress(train_feats, train_balance_flags) + train_pairs = ut.compress(train_pairs, train_balance_flags) + + test_balance_flags = balance_true_false_training_pairs(test_truth) + test_truth = ut.compress(test_truth, test_balance_flags) + test_feats = ut.compress(test_feats, test_balance_flags) + test_pairs = ut.compress(test_pairs, test_balance_flags) + + assigner_data = { + 'data': train_feats, + 'target': train_truth, + 'test': test_feats, + 'test_truth': test_truth, + 'train_pairs': train_pairs, + 'test_pairs': test_pairs, + } + + return assigner_data + + +# returns flags so we can compress other lists +def balance_true_false_training_pairs(ground_truth, seed=777): + n_true = ground_truth.count(True) + # there's always more false samples than true when we're looking at all pairs + false_indices = [i for i, ground_t in enumerate(ground_truth) if not ground_t] + import random + + random.seed(seed) + subsampled_false_indices = random.sample(false_indices, n_true) + # for quick membership check + subsampled_false_indices = set(subsampled_false_indices) + # keep all true flags, and the subsampled false ones + keep_flags = [ + gt or (i in subsampled_false_indices) for i, gt in enumerate(ground_truth) + ] + return keep_flags + + +# def train_test_split(item_list, random_seed=777, test_size=0.1): +# import random +# import math + +# random.seed(random_seed) +# sample_size = math.floor(len(item_list) * test_size) +# all_indices = list(range(len(item_list))) +# test_indices = random.sample(all_indices, sample_size) +# test_items = [item_list[i] for i in test_indices] +# train_indices = sorted(list(set(all_indices) - set(test_indices))) +# train_items = [item_list[i] for i in train_indices] +# return train_items, test_items + + +@register_ibs_method +def gid_train_test_split(ibs, aid_list, random_seed=777, test_size=0.1): + r""" + Makes a gid-wise train-test split. This avoids potential overfitting when a network + is trained on some annots from one image and tested on others from the same image. + + Args: + ibs (IBEISController): IBEIS / WBIA controller object + aid_list (int): annot ids to split + random_seed: to make this split reproducible + test_size: portion of gids reserved for test data + + Yields: + a boolean flag_list of which aids are in the training set. Returning the flag_list + allows the user to filter multiple lists with one gid_train_test_split call + + + CommandLine: + python -m wbia.algo.detect.train_assigner gid_train_test_split + + Example: + >>> # ENABLE_DOCTEST + >>> import utool as ut + >>> from wbia.algo.detect.assigner import * + >>> from wbia.algo.detect.train_assigner import * + >>> ibs = assigner_testdb_ibs() + >>> aids = ibs.get_valid_aids() + >>> all_gids = set(ibs.get_annot_gids(aids)) + >>> test_size = 0.34 # we want floor(test_size*3) to equal 1 + >>> aid_in_train = gid_train_test_split(ibs, aids, test_size=test_size) + >>> train_aids = ut.compress(aids, aid_in_train) + >>> aid_in_test = [not train for train in aid_in_train] + >>> test_aids = ut.compress(aids, aid_in_test) + >>> train_gids = set(ibs.get_annot_gids(train_aids)) + >>> test_gids = set(ibs.get_annot_gids(test_aids)) + >>> assert len(train_gids & test_gids) is 0 + >>> assert len(train_gids) + len(test_gids) == len(all_gids) + >>> assert len(train_gids) is 2 + >>> assert len(test_gids) is 1 + >>> result = aid_in_train # note one gid has 4 aids, the other two 2 + >>> print(result) + [False, False, False, False, True, True, True, True] + """ + print('calling gid_train_test_split') + gid_list = ibs.get_annot_gids(aid_list) + gid_set = list(set(gid_list)) + import math + + random.seed(random_seed) + n_test_gids = math.floor(len(gid_set) * test_size) + test_gids = set(random.sample(gid_set, n_test_gids)) + aid_in_train = [gid not in test_gids for gid in gid_list] + return aid_in_train + + +def split_list(item_list, is_in_first_group_list): + first_group = ut.compress(item_list, is_in_first_group_list) + is_in_second_group = [not b for b in is_in_first_group_list] + second_group = ut.compress(item_list, is_in_second_group) + return first_group, second_group + + +def check_accuracy(ibs, assigner_data=None, cutoff_score=0.5, illustrate=False): + + if assigner_data is None: + assigner_data = ibs.wd_training_data() + + all_aids = [] + for pair in assigner_data['test_pairs']: + all_aids.extend(list(pair)) + all_aids = sorted(list(set(all_aids))) + + all_pairs, all_unassigned_aids = ibs.assign_parts(all_aids, cutoff_score) + + gid_to_assigner_results = gid_keyed_assigner_results( + ibs, all_pairs, all_unassigned_aids + ) + gid_to_ground_truth = gid_keyed_ground_truth(ibs, assigner_data) + + if illustrate: + illustrate_all_assignments(ibs, gid_to_assigner_results, gid_to_ground_truth) + + correct_gids = [] + incorrect_gids = [] + gids_with_false_positives = 0 + n_false_positives = 0 + gids_with_false_negatives = 0 + gids_with_false_neg_allowing_errors = [0, 0, 0] + max_allowed_errors = len(gids_with_false_neg_allowing_errors) + n_false_negatives = 0 + gids_with_both_errors = 0 + for gid in gid_to_assigner_results.keys(): + assigned_pairs = set(gid_to_assigner_results[gid]['pairs']) + ground_t_pairs = set(gid_to_ground_truth[gid]['pairs']) + false_negatives = len(ground_t_pairs - assigned_pairs) + false_positives = len(assigned_pairs - ground_t_pairs) + n_false_negatives += false_negatives + + if false_negatives > 0: + gids_with_false_negatives += 1 + if false_negatives >= 2: + false_neg_log_index = min( + false_negatives - 2, max_allowed_errors - 1 + ) # ie, if we have 2 errors, we have a false neg even allowing 1 error, in index 0 of that list + try: + gids_with_false_neg_allowing_errors[false_neg_log_index] += 1 + except Exception: + ut.embed() + + n_false_positives += false_positives + if false_positives > 0: + gids_with_false_positives += 1 + if false_negatives > 0 and false_positives > 0: + gids_with_both_errors += 1 + + pairs_equal = sorted(gid_to_assigner_results[gid]['pairs']) == sorted( + gid_to_ground_truth[gid]['pairs'] + ) + if pairs_equal: + correct_gids += [gid] + else: + incorrect_gids += [gid] + + n_gids = len(gid_to_assigner_results.keys()) + accuracy = len(correct_gids) / n_gids + incorrect_gids = n_gids - len(correct_gids) + acc_allowing_errors = [ + 1 - (nerrors / n_gids) for nerrors in gids_with_false_neg_allowing_errors + ] + print('accuracy with cutoff of %s: %s' % (cutoff_score, accuracy)) + for i, acc_allowing_error in enumerate(acc_allowing_errors): + print(' allowing %s errors, acc = %s' % (i + 1, acc_allowing_error)) + print( + ' %s false positives on %s error images' + % (n_false_positives, gids_with_false_positives) + ) + print( + ' %s false negatives on %s error images' + % (n_false_negatives, gids_with_false_negatives) + ) + print(' %s images with both errors' % (gids_with_both_errors)) + return accuracy + + +if __name__ == '__main__': + r""" + CommandLine: + python -m wbia.algo.detect.train_assigner --allexamples + """ + import multiprocessing + + multiprocessing.freeze_support() # for win32 + import utool as ut # NOQA + + ut.doctest_funcs() + + +# additional assigner features to explore +class PartAssignmentFeatureConfig(dtool.Config): + _param_info_list = [] + + +@derived_attribute( + tablename='part_assignment_features', + parents=['annotations', 'annotations'], + colnames=[ + 'p_xtl', + 'p_ytl', + 'p_w', + 'p_h', + 'b_xtl', + 'b_ytl', + 'b_w', + 'b_h', + 'int_xtl', + 'int_ytl', + 'int_w', + 'int_h', + 'intersect_area_relative_part', + 'intersect_area_relative_body', + 'part_area_relative_body', + ], + coltypes=[ + int, + int, + int, + int, + int, + int, + int, + int, + int, + int, + int, + int, + float, + float, + float, + ], + configclass=PartAssignmentFeatureConfig, + fname='part_assignment_features', + rm_extern_on_delete=True, + chunksize=256, +) +def compute_assignment_features(depc, part_aid_list, body_aid_list, config=None): + + ibs = depc.controller + + part_gids = ibs.get_annot_gids(part_aid_list) + body_gids = ibs.get_annot_gids(body_aid_list) + assert ( + part_gids == body_gids + ), 'can only compute assignment features on aids in the same image' + parts_are_parts = ibs._are_part_annots(part_aid_list) + assert all(parts_are_parts), 'all part_aids must be part annots.' + bodies_are_parts = ibs._are_part_annots(body_aid_list) + assert not any(bodies_are_parts), 'body_aids cannot be part annots' + + part_bboxes = ibs.get_annot_bboxes(part_aid_list) + body_bboxes = ibs.get_annot_bboxes(body_aid_list) + + part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] + body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] + part_area_relative_body = [ + part_area / body_area for (part_area, body_area) in zip(part_areas, body_areas) + ] + + intersect_bboxes = _bbox_intersections(part_bboxes, body_bboxes) + # note that intesect w and h could be negative if there is no intersection, in which case it is the x/y distance between the annots. + intersect_areas = [ + w * h if w > 0 and h > 0 else 0 for (_, _, w, h) in intersect_bboxes + ] + + int_area_relative_part = [ + int_area / part_area for int_area, part_area in zip(intersect_areas, part_areas) + ] + int_area_relative_body = [ + int_area / body_area for int_area, body_area in zip(intersect_areas, body_areas) + ] + + result_list = list( + zip( + part_bboxes, + body_bboxes, + intersect_bboxes, + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) + ) + + for ( + part_bbox, + body_bbox, + intersect_bbox, + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) in result_list: + yield ( + part_bbox[0], + part_bbox[1], + part_bbox[2], + part_bbox[3], + body_bbox[0], + body_bbox[1], + body_bbox[2], + body_bbox[3], + intersect_bbox[0], + intersect_bbox[1], + intersect_bbox[2], + intersect_bbox[3], + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) + + +@derived_attribute( + tablename='normalized_assignment_features', + parents=['annotations', 'annotations'], + colnames=[ + 'p_xtl', + 'p_ytl', + 'p_w', + 'p_h', + 'b_xtl', + 'b_ytl', + 'b_w', + 'b_h', + 'int_xtl', + 'int_ytl', + 'int_w', + 'int_h', + 'intersect_area_relative_part', + 'intersect_area_relative_body', + 'part_area_relative_body', + ], + coltypes=[ + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + ], + configclass=PartAssignmentFeatureConfig, + fname='normalized_assignment_features', + rm_extern_on_delete=True, + chunksize=256, +) +def normalized_assignment_features(depc, part_aid_list, body_aid_list, config=None): + + ibs = depc.controller + + part_gids = ibs.get_annot_gids(part_aid_list) + body_gids = ibs.get_annot_gids(body_aid_list) + assert ( + part_gids == body_gids + ), 'can only compute assignment features on aids in the same image' + parts_are_parts = ibs._are_part_annots(part_aid_list) + assert all(parts_are_parts), 'all part_aids must be part annots.' + bodies_are_parts = ibs._are_part_annots(body_aid_list) + assert not any(bodies_are_parts), 'body_aids cannot be part annots' + + part_bboxes = ibs.get_annot_bboxes(part_aid_list) + body_bboxes = ibs.get_annot_bboxes(body_aid_list) + im_widths = ibs.get_image_widths(part_gids) + im_heights = ibs.get_image_heights(part_gids) + part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) + body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) + + part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] + body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] + part_area_relative_body = [ + part_area / body_area for (part_area, body_area) in zip(part_areas, body_areas) + ] + + intersect_bboxes = _bbox_intersections(part_bboxes, body_bboxes) + # note that intesect w and h could be negative if there is no intersection, in which case it is the x/y distance between the annots. + intersect_areas = [ + w * h if w > 0 and h > 0 else 0 for (_, _, w, h) in intersect_bboxes + ] + + int_area_relative_part = [ + int_area / part_area for int_area, part_area in zip(intersect_areas, part_areas) + ] + int_area_relative_body = [ + int_area / body_area for int_area, body_area in zip(intersect_areas, body_areas) + ] + + result_list = list( + zip( + part_bboxes, + body_bboxes, + intersect_bboxes, + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) + ) + + for ( + part_bbox, + body_bbox, + intersect_bbox, + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) in result_list: + yield ( + part_bbox[0], + part_bbox[1], + part_bbox[2], + part_bbox[3], + body_bbox[0], + body_bbox[1], + body_bbox[2], + body_bbox[3], + intersect_bbox[0], + intersect_bbox[1], + intersect_bbox[2], + intersect_bbox[3], + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) + + +@derived_attribute( + tablename='standardized_assignment_features', + parents=['annotations', 'annotations'], + colnames=[ + 'p_xtl', + 'p_ytl', + 'p_w', + 'p_h', + 'b_xtl', + 'b_ytl', + 'b_w', + 'b_h', + 'int_xtl', + 'int_ytl', + 'int_w', + 'int_h', + 'intersect_area_relative_part', + 'intersect_area_relative_body', + 'part_area_relative_body', + ], + coltypes=[ + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + ], + configclass=PartAssignmentFeatureConfig, + fname='standardized_assignment_features', + rm_extern_on_delete=True, + chunksize=256000000, # chunk size is huge bc we need accurate means and stdevs of various traits +) +def standardized_assignment_features(depc, part_aid_list, body_aid_list, config=None): + + ibs = depc.controller + + part_gids = ibs.get_annot_gids(part_aid_list) + body_gids = ibs.get_annot_gids(body_aid_list) + assert ( + part_gids == body_gids + ), 'can only compute assignment features on aids in the same image' + parts_are_parts = ibs._are_part_annots(part_aid_list) + assert all(parts_are_parts), 'all part_aids must be part annots.' + bodies_are_parts = ibs._are_part_annots(body_aid_list) + assert not any(bodies_are_parts), 'body_aids cannot be part annots' + + part_bboxes = ibs.get_annot_bboxes(part_aid_list) + body_bboxes = ibs.get_annot_bboxes(body_aid_list) + im_widths = ibs.get_image_widths(part_gids) + im_heights = ibs.get_image_heights(part_gids) + part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) + body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) + + part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] + body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] + part_area_relative_body = [ + part_area / body_area for (part_area, body_area) in zip(part_areas, body_areas) + ] + + intersect_bboxes = _bbox_intersections(part_bboxes, body_bboxes) + # note that intesect w and h could be negative if there is no intersection, in which case it is the x/y distance between the annots. + intersect_areas = [ + w * h if w > 0 and h > 0 else 0 for (_, _, w, h) in intersect_bboxes + ] + + int_area_relative_part = [ + int_area / part_area for int_area, part_area in zip(intersect_areas, part_areas) + ] + int_area_relative_body = [ + int_area / body_area for int_area, body_area in zip(intersect_areas, body_areas) + ] + + int_area_relative_part = preprocessing.scale(int_area_relative_part) + int_area_relative_body = preprocessing.scale(int_area_relative_body) + part_area_relative_body = preprocessing.scale(part_area_relative_body) + + result_list = list( + zip( + part_bboxes, + body_bboxes, + intersect_bboxes, + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) + ) + + for ( + part_bbox, + body_bbox, + intersect_bbox, + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) in result_list: + yield ( + part_bbox[0], + part_bbox[1], + part_bbox[2], + part_bbox[3], + body_bbox[0], + body_bbox[1], + body_bbox[2], + body_bbox[3], + intersect_bbox[0], + intersect_bbox[1], + intersect_bbox[2], + intersect_bbox[3], + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) + + +# like the above but bboxes are also standardized +@derived_attribute( + tablename='mega_standardized_assignment_features', + parents=['annotations', 'annotations'], + colnames=[ + 'p_xtl', + 'p_ytl', + 'p_w', + 'p_h', + 'b_xtl', + 'b_ytl', + 'b_w', + 'b_h', + 'int_xtl', + 'int_ytl', + 'int_w', + 'int_h', + 'intersect_area_relative_part', + 'intersect_area_relative_body', + 'part_area_relative_body', + ], + coltypes=[ + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + ], + configclass=PartAssignmentFeatureConfig, + fname='mega_standardized_assignment_features', + rm_extern_on_delete=True, + chunksize=256000000, # chunk size is huge bc we need accurate means and stdevs of various traits +) +def mega_standardized_assignment_features( + depc, part_aid_list, body_aid_list, config=None +): + + ibs = depc.controller + + part_gids = ibs.get_annot_gids(part_aid_list) + body_gids = ibs.get_annot_gids(body_aid_list) + assert ( + part_gids == body_gids + ), 'can only compute assignment features on aids in the same image' + parts_are_parts = ibs._are_part_annots(part_aid_list) + assert all(parts_are_parts), 'all part_aids must be part annots.' + bodies_are_parts = ibs._are_part_annots(body_aid_list) + assert not any(bodies_are_parts), 'body_aids cannot be part annots' + + part_bboxes = ibs.get_annot_bboxes(part_aid_list) + body_bboxes = ibs.get_annot_bboxes(body_aid_list) + im_widths = ibs.get_image_widths(part_gids) + im_heights = ibs.get_image_heights(part_gids) + part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) + body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) + + part_bboxes = _standardized_bboxes(part_bboxes) + body_bboxes = _standardized_bboxes(body_bboxes) + + part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] + body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] + part_area_relative_body = [ + part_area / body_area for (part_area, body_area) in zip(part_areas, body_areas) + ] + + intersect_bboxes = _bbox_intersections(part_bboxes, body_bboxes) + # note that intesect w and h could be negative if there is no intersection, in which case it is the x/y distance between the annots. + intersect_areas = [ + w * h if w > 0 and h > 0 else 0 for (_, _, w, h) in intersect_bboxes + ] + + int_area_relative_part = [ + int_area / part_area for int_area, part_area in zip(intersect_areas, part_areas) + ] + int_area_relative_body = [ + int_area / body_area for int_area, body_area in zip(intersect_areas, body_areas) + ] + + int_area_relative_part = preprocessing.scale(int_area_relative_part) + int_area_relative_body = preprocessing.scale(int_area_relative_body) + part_area_relative_body = preprocessing.scale(part_area_relative_body) + + result_list = list( + zip( + part_bboxes, + body_bboxes, + intersect_bboxes, + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) + ) + + for ( + part_bbox, + body_bbox, + intersect_bbox, + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) in result_list: + yield ( + part_bbox[0], + part_bbox[1], + part_bbox[2], + part_bbox[3], + body_bbox[0], + body_bbox[1], + body_bbox[2], + body_bbox[3], + intersect_bbox[0], + intersect_bbox[1], + intersect_bbox[2], + intersect_bbox[3], + int_area_relative_part, + int_area_relative_body, + part_area_relative_body, + ) + + +@derived_attribute( + tablename='theta_assignment_features', + parents=['annotations', 'annotations'], + colnames=[ + 'p_v1_x', + 'p_v1_y', + 'p_v2_x', + 'p_v2_y', + 'p_v3_x', + 'p_v3_y', + 'p_v4_x', + 'p_v4_y', + 'p_center_x', + 'p_center_y', + 'b_xtl', + 'b_ytl', + 'b_xbr', + 'b_ybr', + 'b_center_x', + 'b_center_y', + 'int_area_scalar', + 'part_body_distance', + 'part_body_centroid_dist', + 'int_over_union', + 'int_over_part', + 'int_over_body', + 'part_over_body', + ], + coltypes=[ + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + ], + configclass=PartAssignmentFeatureConfig, + fname='theta_assignment_features', + rm_extern_on_delete=True, + chunksize=256, # chunk size is huge bc we need accurate means and stdevs of various traits +) +def theta_assignment_features(depc, part_aid_list, body_aid_list, config=None): + + from shapely import geometry + import math + + ibs = depc.controller + + part_gids = ibs.get_annot_gids(part_aid_list) + body_gids = ibs.get_annot_gids(body_aid_list) + assert ( + part_gids == body_gids + ), 'can only compute assignment features on aids in the same image' + parts_are_parts = ibs._are_part_annots(part_aid_list) + assert all(parts_are_parts), 'all part_aids must be part annots.' + bodies_are_parts = ibs._are_part_annots(body_aid_list) + assert not any(bodies_are_parts), 'body_aids cannot be part annots' + + im_widths = ibs.get_image_widths(part_gids) + im_heights = ibs.get_image_heights(part_gids) + + part_verts = ibs.get_annot_rotated_verts(part_aid_list) + body_verts = ibs.get_annot_rotated_verts(body_aid_list) + part_verts = _norm_vertices(part_verts, im_widths, im_heights) + body_verts = _norm_vertices(body_verts, im_widths, im_heights) + part_polys = [geometry.Polygon(vert) for vert in part_verts] + body_polys = [geometry.Polygon(vert) for vert in body_verts] + intersect_polys = [ + part.intersection(body) for part, body in zip(part_polys, body_polys) + ] + intersect_areas = [poly.area for poly in intersect_polys] + # just to make int_areas more comparable via ML methods, and since all distances < 1 + int_area_scalars = [math.sqrt(area) for area in intersect_areas] + + part_bboxes = ibs.get_annot_bboxes(part_aid_list) + body_bboxes = ibs.get_annot_bboxes(body_aid_list) + part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) + body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) + part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] + body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] + union_areas = [ + part + body - intersect + for (part, body, intersect) in zip(part_areas, body_areas, intersect_areas) + ] + int_over_unions = [ + intersect / union for (intersect, union) in zip(intersect_areas, union_areas) + ] + + part_body_distances = [ + part.distance(body) for part, body in zip(part_polys, body_polys) + ] + + part_centroids = [poly.centroid for poly in part_polys] + body_centroids = [poly.centroid for poly in body_polys] + + part_body_centroid_dists = [ + part.distance(body) for part, body in zip(part_centroids, body_centroids) + ] + + int_over_parts = [ + int_area / part_area for part_area, int_area in zip(part_areas, intersect_areas) + ] + + int_over_bodys = [ + int_area / body_area for body_area, int_area in zip(body_areas, intersect_areas) + ] + + part_over_bodys = [ + part_area / body_area for part_area, body_area in zip(part_areas, body_areas) + ] + + # note that here only parts have thetas, hence only returning body bboxes + result_list = list( + zip( + part_verts, + part_centroids, + body_bboxes, + body_centroids, + int_area_scalars, + part_body_distances, + part_body_centroid_dists, + int_over_unions, + int_over_parts, + int_over_bodys, + part_over_bodys, + ) + ) + + for ( + part_vert, + part_center, + body_bbox, + body_center, + int_area_scalar, + part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body, + ) in result_list: + yield ( + part_vert[0][0], + part_vert[0][1], + part_vert[1][0], + part_vert[1][1], + part_vert[2][0], + part_vert[2][1], + part_vert[3][0], + part_vert[3][1], + part_center.x, + part_center.y, + body_bbox[0], + body_bbox[1], + body_bbox[2], + body_bbox[3], + body_center.x, + body_center.y, + int_area_scalar, + part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body, + ) + + +@derived_attribute( + tablename='assigner_viewpoint_unit_features', + parents=['annotations', 'annotations'], + colnames=[ + 'p_v1_x', + 'p_v1_y', + 'p_v2_x', + 'p_v2_y', + 'p_v3_x', + 'p_v3_y', + 'p_v4_x', + 'p_v4_y', + 'p_center_x', + 'p_center_y', + 'b_xtl', + 'b_ytl', + 'b_xbr', + 'b_ybr', + 'b_center_x', + 'b_center_y', + 'int_area_scalar', + 'part_body_distance', + 'part_body_centroid_dist', + 'int_over_union', + 'int_over_part', + 'int_over_body', + 'part_over_body', + 'part_is_left', + 'part_is_right', + 'part_is_up', + 'part_is_down', + 'part_is_front', + 'part_is_back', + 'body_is_left', + 'body_is_right', + 'body_is_up', + 'body_is_down', + 'body_is_front', + 'body_is_back', + ], + coltypes=[ + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + ], + configclass=PartAssignmentFeatureConfig, + fname='assigner_viewpoint_unit_features', + rm_extern_on_delete=True, + chunksize=256, # chunk size is huge bc we need accurate means and stdevs of various traits +) +def assigner_viewpoint_unit_features(depc, part_aid_list, body_aid_list, config=None): + + from shapely import geometry + import math + + ibs = depc.controller + + part_gids = ibs.get_annot_gids(part_aid_list) + body_gids = ibs.get_annot_gids(body_aid_list) + assert ( + part_gids == body_gids + ), 'can only compute assignment features on aids in the same image' + parts_are_parts = ibs._are_part_annots(part_aid_list) + assert all(parts_are_parts), 'all part_aids must be part annots.' + bodies_are_parts = ibs._are_part_annots(body_aid_list) + assert not any(bodies_are_parts), 'body_aids cannot be part annots' + + im_widths = ibs.get_image_widths(part_gids) + im_heights = ibs.get_image_heights(part_gids) + + part_verts = ibs.get_annot_rotated_verts(part_aid_list) + body_verts = ibs.get_annot_rotated_verts(body_aid_list) + part_verts = _norm_vertices(part_verts, im_widths, im_heights) + body_verts = _norm_vertices(body_verts, im_widths, im_heights) + part_polys = [geometry.Polygon(vert) for vert in part_verts] + body_polys = [geometry.Polygon(vert) for vert in body_verts] + intersect_polys = [ + part.intersection(body) for part, body in zip(part_polys, body_polys) + ] + intersect_areas = [poly.area for poly in intersect_polys] + # just to make int_areas more comparable via ML methods, and since all distances < 1 + int_area_scalars = [math.sqrt(area) for area in intersect_areas] + + part_bboxes = ibs.get_annot_bboxes(part_aid_list) + body_bboxes = ibs.get_annot_bboxes(body_aid_list) + part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) + body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) + part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] + body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] + union_areas = [ + part + body - intersect + for (part, body, intersect) in zip(part_areas, body_areas, intersect_areas) + ] + int_over_unions = [ + intersect / union for (intersect, union) in zip(intersect_areas, union_areas) + ] + + part_body_distances = [ + part.distance(body) for part, body in zip(part_polys, body_polys) + ] + + part_centroids = [poly.centroid for poly in part_polys] + body_centroids = [poly.centroid for poly in body_polys] + + part_body_centroid_dists = [ + part.distance(body) for part, body in zip(part_centroids, body_centroids) + ] + + int_over_parts = [ + int_area / part_area for part_area, int_area in zip(part_areas, intersect_areas) + ] + + int_over_bodys = [ + int_area / body_area for body_area, int_area in zip(body_areas, intersect_areas) + ] + + part_over_bodys = [ + part_area / body_area for part_area, body_area in zip(part_areas, body_areas) + ] + + part_lrudfb_vects = get_annot_lrudfb_unit_vector(ibs, part_aid_list) + body_lrudfb_vects = get_annot_lrudfb_unit_vector(ibs, part_aid_list) + + # note that here only parts have thetas, hence only returning body bboxes + result_list = list( + zip( + part_verts, + part_centroids, + body_bboxes, + body_centroids, + int_area_scalars, + part_body_distances, + part_body_centroid_dists, + int_over_unions, + int_over_parts, + int_over_bodys, + part_over_bodys, + part_lrudfb_vects, + body_lrudfb_vects, + ) + ) + + for ( + part_vert, + part_center, + body_bbox, + body_center, + int_area_scalar, + part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body, + part_lrudfb_vect, + body_lrudfb_vect, + ) in result_list: + ans = ( + part_vert[0][0], + part_vert[0][1], + part_vert[1][0], + part_vert[1][1], + part_vert[2][0], + part_vert[2][1], + part_vert[3][0], + part_vert[3][1], + part_center.x, + part_center.y, + body_bbox[0], + body_bbox[1], + body_bbox[2], + body_bbox[3], + body_center.x, + body_center.y, + int_area_scalar, + part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body, + ) + ans += tuple(part_lrudfb_vect) + ans += tuple(body_lrudfb_vect) + yield ans + + +@derived_attribute( + tablename='theta_standardized_assignment_features', + parents=['annotations', 'annotations'], + colnames=[ + 'p_v1_x', + 'p_v1_y', + 'p_v2_x', + 'p_v2_y', + 'p_v3_x', + 'p_v3_y', + 'p_v4_x', + 'p_v4_y', + 'p_center_x', + 'p_center_y', + 'b_xtl', + 'b_ytl', + 'b_xbr', + 'b_ybr', + 'b_center_x', + 'b_center_y', + 'int_area_scalar', + 'part_body_distance', + 'part_body_centroid_dist', + 'int_over_union', + 'int_over_part', + 'int_over_body', + 'part_over_body', + ], + coltypes=[ + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + ], + configclass=PartAssignmentFeatureConfig, + fname='theta_standardized_assignment_features', + rm_extern_on_delete=True, + chunksize=2560000, # chunk size is huge bc we need accurate means and stdevs of various traits +) +def theta_standardized_assignment_features( + depc, part_aid_list, body_aid_list, config=None +): + + from shapely import geometry + import math + + ibs = depc.controller + + part_gids = ibs.get_annot_gids(part_aid_list) + body_gids = ibs.get_annot_gids(body_aid_list) + assert ( + part_gids == body_gids + ), 'can only compute assignment features on aids in the same image' + parts_are_parts = ibs._are_part_annots(part_aid_list) + assert all(parts_are_parts), 'all part_aids must be part annots.' + bodies_are_parts = ibs._are_part_annots(body_aid_list) + assert not any(bodies_are_parts), 'body_aids cannot be part annots' + + im_widths = ibs.get_image_widths(part_gids) + im_heights = ibs.get_image_heights(part_gids) + + part_verts = ibs.get_annot_rotated_verts(part_aid_list) + body_verts = ibs.get_annot_rotated_verts(body_aid_list) + part_verts = _norm_vertices(part_verts, im_widths, im_heights) + body_verts = _norm_vertices(body_verts, im_widths, im_heights) + part_polys = [geometry.Polygon(vert) for vert in part_verts] + body_polys = [geometry.Polygon(vert) for vert in body_verts] + intersect_polys = [ + part.intersection(body) for part, body in zip(part_polys, body_polys) + ] + intersect_areas = [poly.area for poly in intersect_polys] + # just to make int_areas more comparable via ML methods, and since all distances < 1 + int_area_scalars = [math.sqrt(area) for area in intersect_areas] + int_area_scalars = preprocessing.scale(int_area_scalars) + + part_bboxes = ibs.get_annot_bboxes(part_aid_list) + body_bboxes = ibs.get_annot_bboxes(body_aid_list) + part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) + body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) + part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] + body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] + union_areas = [ + part + body - intersect + for (part, body, intersect) in zip(part_areas, body_areas, intersect_areas) + ] + int_over_unions = [ + intersect / union for (intersect, union) in zip(intersect_areas, union_areas) + ] + int_over_unions = preprocessing.scale(int_over_unions) + + part_body_distances = [ + part.distance(body) for part, body in zip(part_polys, body_polys) + ] + part_body_distances = preprocessing.scale(part_body_distances) + + part_centroids = [poly.centroid for poly in part_polys] + body_centroids = [poly.centroid for poly in body_polys] + + part_body_centroid_dists = [ + part.distance(body) for part, body in zip(part_centroids, body_centroids) + ] + part_body_centroid_dists = preprocessing.scale(part_body_centroid_dists) + + int_over_parts = [ + int_area / part_area for part_area, int_area in zip(part_areas, intersect_areas) + ] + int_over_parts = preprocessing.scale(int_over_parts) + + int_over_bodys = [ + int_area / body_area for body_area, int_area in zip(body_areas, intersect_areas) + ] + int_over_bodys = preprocessing.scale(int_over_bodys) + + part_over_bodys = [ + part_area / body_area for part_area, body_area in zip(part_areas, body_areas) + ] + part_over_bodys = preprocessing.scale(part_over_bodys) + + # standardization + + # note that here only parts have thetas, hence only returning body bboxes + result_list = list( + zip( + part_verts, + part_centroids, + body_bboxes, + body_centroids, + int_area_scalars, + part_body_distances, + part_body_centroid_dists, + int_over_unions, + int_over_parts, + int_over_bodys, + part_over_bodys, + ) + ) + + for ( + part_vert, + part_center, + body_bbox, + body_center, + int_area_scalar, + part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body, + ) in result_list: + yield ( + part_vert[0][0], + part_vert[0][1], + part_vert[1][0], + part_vert[1][1], + part_vert[2][0], + part_vert[2][1], + part_vert[3][0], + part_vert[3][1], + part_center.x, + part_center.y, + body_bbox[0], + body_bbox[1], + body_bbox[2], + body_bbox[3], + body_center.x, + body_center.y, + int_area_scalar, + part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body, + ) + + +def get_annot_lrudfb_unit_vector(ibs, aid_list): + from wbia.core_annots import get_annot_lrudfb_bools + + bool_arrays = get_annot_lrudfb_bools(ibs, aid_list) + float_arrays = [[float(b) for b in lrudfb] for lrudfb in bool_arrays] + lrudfb_lengths = [sqrt(lrudfb.count(True)) for lrudfb in bool_arrays] + # lying just to avoid division by zero errors + lrudfb_lengths = [length if length != 0 else -1 for length in lrudfb_lengths] + unit_float_array = [ + [f / length for f in lrudfb] + for lrudfb, length in zip(float_arrays, lrudfb_lengths) + ] + + return unit_float_array + + +def _norm_bboxes(bbox_list, width_list, height_list): + normed_boxes = [ + (bbox[0] / w, bbox[1] / h, bbox[2] / w, bbox[3] / h) + for (bbox, w, h) in zip(bbox_list, width_list, height_list) + ] + return normed_boxes + + +def _norm_vertices(verts_list, width_list, height_list): + normed_verts = [ + [[x / w, y / h] for x, y in vert] + for vert, w, h in zip(verts_list, width_list, height_list) + ] + return normed_verts + + +# does this even make any sense? let's find out experimentally +def _standardized_bboxes(bbox_list): + xtls = preprocessing.scale([bbox[0] for bbox in bbox_list]) + ytls = preprocessing.scale([bbox[1] for bbox in bbox_list]) + wids = preprocessing.scale([bbox[2] for bbox in bbox_list]) + heis = preprocessing.scale([bbox[3] for bbox in bbox_list]) + standardized_bboxes = list(zip(xtls, ytls, wids, heis)) + return standardized_bboxes + + +def _bbox_intersections(bboxes_a, bboxes_b): + corner_bboxes_a = _bbox_to_corner_format(bboxes_a) + corner_bboxes_b = _bbox_to_corner_format(bboxes_b) + + intersect_xtls = [ + max(xtl_a, xtl_b) + for ((xtl_a, _, _, _), (xtl_b, _, _, _)) in zip(corner_bboxes_a, corner_bboxes_b) + ] + + intersect_ytls = [ + max(ytl_a, ytl_b) + for ((_, ytl_a, _, _), (_, ytl_b, _, _)) in zip(corner_bboxes_a, corner_bboxes_b) + ] + + intersect_xbrs = [ + min(xbr_a, xbr_b) + for ((_, _, xbr_a, _), (_, _, xbr_b, _)) in zip(corner_bboxes_a, corner_bboxes_b) + ] + + intersect_ybrs = [ + min(ybr_a, ybr_b) + for ((_, _, _, ybr_a), (_, _, _, ybr_b)) in zip(corner_bboxes_a, corner_bboxes_b) + ] + + intersect_widths = [ + int_xbr - int_xtl for int_xbr, int_xtl in zip(intersect_xbrs, intersect_xtls) + ] + + intersect_heights = [ + int_ybr - int_ytl for int_ybr, int_ytl in zip(intersect_ybrs, intersect_ytls) + ] + + intersect_bboxes = list( + zip(intersect_xtls, intersect_ytls, intersect_widths, intersect_heights) + ) + + return intersect_bboxes + + +def _all_centroids(verts_list_a, verts_list_b): + import shapely + + polys_a = [shapely.geometry.Polygon(vert) for vert in verts_list_a] + polys_b = [shapely.geometry.Polygon(vert) for vert in verts_list_b] + intersect_polys = [ + poly1.intersection(poly2) for poly1, poly2 in zip(polys_a, polys_b) + ] + + centroids_a = [poly.centroid for poly in polys_a] + centroids_b = [poly.centroid for poly in polys_b] + centroids_int = [poly.centroid for poly in intersect_polys] + + return centroids_a, centroids_b, centroids_int + + +def _theta_aware_intersect_areas(verts_list_a, verts_list_b): + import shapely + + polys_a = [shapely.geometry.Polygon(vert) for vert in verts_list_a] + polys_b = [shapely.geometry.Polygon(vert) for vert in verts_list_b] + intersect_areas = [ + poly1.intersection(poly2).area for poly1, poly2 in zip(polys_a, polys_b) + ] + return intersect_areas + + +# converts bboxes from (xtl, ytl, w, h) to (xtl, ytl, xbr, ybr) +def _bbox_to_corner_format(bboxes): + corner_bboxes = [ + (bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]) for bbox in bboxes + ] + return corner_bboxes + + +def _polygons_to_centroid_coords(polygon_list): + centroids = [poly.centroid for poly in polygon_list] + return centroids diff --git a/wbia/algo/graph/core.py b/wbia/algo/graph/core.py index fcaa87beca..e3b73e87ea 100644 --- a/wbia/algo/graph/core.py +++ b/wbia/algo/graph/core.py @@ -27,6 +27,8 @@ from wbia.algo.graph.state import SAME, DIFF, NULL import networkx as nx import logging +import threading + print, rrr, profile = ut.inject2(__name__) logger = logging.getLogger('wbia') @@ -1261,6 +1263,7 @@ def __init__(infr, ibs, aids=[], nids=None, autoinit=True, verbose=False): # A generator that maintains the state of the algorithm infr._gen = None + infr._gen_lock = threading.Lock() # Computer vision algorithms infr.ranker = None diff --git a/wbia/algo/graph/mixin_loops.py b/wbia/algo/graph/mixin_loops.py index c98be883b6..eb92b3b5b6 100644 --- a/wbia/algo/graph/mixin_loops.py +++ b/wbia/algo/graph/mixin_loops.py @@ -109,8 +109,11 @@ def main_gen(infr, max_loops=None, use_refresh=True): if infr.params['redun.enabled'] and infr.params['redun.enforce_pos']: infr.loop_phase = 'pos_redun_init' # Fix positive redundancy of anything within the loop - for _ in infr.pos_redun_gen(): - yield _ + try: + for _ in infr.pos_redun_gen(): + yield _ + except StopIteration: + pass infr.phase = 1 if infr.params['ranking.enabled']: @@ -166,6 +169,8 @@ def main_gen(infr, max_loops=None, use_refresh=True): if infr.params['inference.enabled']: infr.assert_consistency_invariant() + return 'finished' + def hardcase_review_gen(infr): """ Subiterator for hardcase review @@ -480,6 +485,7 @@ def main_loop(infr, max_loops=None, use_refresh=True): or assert not any(infr.main_gen()) maybe this is fine. """ + raise RuntimeError() infr.start_id_review(max_loops=max_loops, use_refresh=use_refresh) # To automatically run through the loop just exhaust the generator result = next(infr._gen) @@ -632,14 +638,30 @@ def continue_review(infr): infr.print('continue_review', 10) if infr._gen is None: return None - try: - user_request = next(infr._gen) - except StopIteration: + + hungry, finished, attempt = True, False, 0 + while hungry: + try: + attempt += 1 + with infr._gen_lock: + user_request = next(infr._gen) + hungry = False + except StopIteration: + pass + if attempt >= 100: + hungry = False + finished = True + if isinstance(user_request, str) and user_request in ['finished']: + hungry = False + finished = True + + if finished: review_finished = infr.callbacks.get('review_finished', None) if review_finished is not None: review_finished() infr._gen = None user_request = None + return user_request def qt_edge_reviewer(infr, edge=None): diff --git a/wbia/algo/hots/neighbor_index_cache.py b/wbia/algo/hots/neighbor_index_cache.py index 1f8d8356c4..cb1dc76bc2 100644 --- a/wbia/algo/hots/neighbor_index_cache.py +++ b/wbia/algo/hots/neighbor_index_cache.py @@ -338,14 +338,14 @@ def request_augmented_wbia_nnindexer( >>> ZEB_PLAIN = wbia.const.TEST_SPECIES.ZEB_PLAIN >>> ibs = wbia.opendb('testdb1') >>> use_memcache, max_covers, verbose = True, None, True - >>> daid_list = ibs.get_valid_aids(species=ZEB_PLAIN)[0:6] + >>> daid_list = sorted(ibs.get_valid_aids(species=ZEB_PLAIN))[0:6] >>> qreq_ = ibs.new_query_request(daid_list, daid_list) >>> qreq_.qparams.min_reindex_thresh = 1 >>> min_reindex_thresh = qreq_.qparams.min_reindex_thresh >>> # CLEAR CACHE for clean test >>> clear_uuid_cache(qreq_) >>> # LOAD 3 AIDS INTO CACHE - >>> aid_list = ibs.get_valid_aids(species=ZEB_PLAIN)[0:3] + >>> aid_list = sorted(ibs.get_valid_aids(species=ZEB_PLAIN))[0:3] >>> # Should fallback >>> nnindexer = request_augmented_wbia_nnindexer(qreq_, aid_list) >>> # assert the fallback @@ -630,7 +630,7 @@ def group_daids_by_cached_nnindexer( >>> # STEP 0: CLEAR THE CACHE >>> clear_uuid_cache(qreq_) >>> # STEP 1: ASSERT EMPTY INDEX - >>> daid_list = ibs.get_valid_aids(species=ZEB_PLAIN)[0:3] + >>> daid_list = sorted(ibs.get_valid_aids(species=ZEB_PLAIN))[0:3] >>> uncovered_aids, covered_aids_list = group_daids_by_cached_nnindexer( ... qreq_, daid_list, min_reindex_thresh, max_covers) >>> result1 = uncovered_aids, covered_aids_list diff --git a/wbia/algo/smk/smk_pipeline.py b/wbia/algo/smk/smk_pipeline.py index 3058ea7f68..2ab6d0a8f3 100644 --- a/wbia/algo/smk/smk_pipeline.py +++ b/wbia/algo/smk/smk_pipeline.py @@ -571,8 +571,7 @@ def testdata_smk(*args, **kwargs): # import sklearn.model_selection ibs, aid_list = wbia.testdata_aids(defaultdb='PZ_MTEST') nid_list = np.array(ibs.annots(aid_list).nids) - rng = ut.ensure_rng(0) - xvalkw = dict(n_splits=4, shuffle=False, random_state=rng) + xvalkw = dict(n_splits=4, shuffle=False) skf = sklearn.model_selection.StratifiedKFold(**xvalkw) train_idx, test_idx = six.next(skf.split(aid_list, nid_list)) diff --git a/wbia/algo/verif/sklearn_utils.py b/wbia/algo/verif/sklearn_utils.py index b637bf7748..94185e752e 100644 --- a/wbia/algo/verif/sklearn_utils.py +++ b/wbia/algo/verif/sklearn_utils.py @@ -512,11 +512,9 @@ def amean(x, w=None): # and BM * MK MCC? def matthews_corrcoef(y_true, y_pred, sample_weight=None): - from sklearn.metrics.classification import ( - _check_targets, - LabelEncoder, - confusion_matrix, - ) + from sklearn.preprocessing import LabelEncoder + from sklearn.metrics import confusion_matrix + from sklearn.metrics._classification import _check_targets y_type, y_true, y_pred = _check_targets(y_true, y_pred) if y_type not in {'binary', 'multiclass'}: diff --git a/wbia/annotmatch_funcs.py b/wbia/annotmatch_funcs.py index 6739bc858e..5e23533925 100644 --- a/wbia/annotmatch_funcs.py +++ b/wbia/annotmatch_funcs.py @@ -45,14 +45,15 @@ def get_annotmatch_rowids_from_aid1(ibs, aid1_list, eager=True, nInput=None): params_iter = zip(aid1_list) if True: # HACK IN INDEX - ibs.db.connection.execute( - """ - CREATE INDEX IF NOT EXISTS aid1_to_am ON {ANNOTMATCH_TABLE} ({annot_rowid1}); - """.format( - ANNOTMATCH_TABLE=ibs.const.ANNOTMATCH_TABLE, - annot_rowid1=manual_annotmatch_funcs.ANNOT_ROWID1, + with ibs.db.connect() as conn: + conn.execute( + """ + CREATE INDEX IF NOT EXISTS aid1_to_am ON {ANNOTMATCH_TABLE} ({annot_rowid1}); + """.format( + ANNOTMATCH_TABLE=ibs.const.ANNOTMATCH_TABLE, + annot_rowid1=manual_annotmatch_funcs.ANNOT_ROWID1, + ) ) - ).fetchall() where_colnames = [manual_annotmatch_funcs.ANNOT_ROWID1] annotmatch_rowid_list = ibs.db.get_where_eq( ibs.const.ANNOTMATCH_TABLE, @@ -82,14 +83,15 @@ def get_annotmatch_rowids_from_aid2( nInput = len(aid2_list) if True: # HACK IN INDEX - ibs.db.connection.execute( - """ - CREATE INDEX IF NOT EXISTS aid2_to_am ON {ANNOTMATCH_TABLE} ({annot_rowid2}); - """.format( - ANNOTMATCH_TABLE=ibs.const.ANNOTMATCH_TABLE, - annot_rowid2=manual_annotmatch_funcs.ANNOT_ROWID2, + with ibs.db.connect() as conn: + conn.execute( + """ + CREATE INDEX IF NOT EXISTS aid2_to_am ON {ANNOTMATCH_TABLE} ({annot_rowid2}); + """.format( + ANNOTMATCH_TABLE=ibs.const.ANNOTMATCH_TABLE, + annot_rowid2=manual_annotmatch_funcs.ANNOT_ROWID2, + ) ) - ).fetchall() colnames = (manual_annotmatch_funcs.ANNOTMATCH_ROWID,) # FIXME: col_rowid is not correct params_iter = zip(aid2_list) diff --git a/wbia/cli/compare_databases.py b/wbia/cli/compare_databases.py new file mode 100644 index 0000000000..d57473ddb5 --- /dev/null +++ b/wbia/cli/compare_databases.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- +import logging +import sys +import click + +from wbia.dtool.copy_sqlite_to_postgres import ( + SqliteDatabaseInfo, + PostgresDatabaseInfo, + compare_databases, + DEFAULT_CHECK_PC, + DEFAULT_CHECK_MIN, + DEFAULT_CHECK_MAX, +) + + +logger = logging.getLogger('wbia') + + +@click.command() +@click.option( + '--db-dir', + multiple=True, + help='SQLite databases location', +) +@click.option( + '--sqlite-uri', + multiple=True, + help='SQLite database URI (e.g. sqlite:////path.sqlite3)', +) +@click.option( + '--pg-uri', + multiple=True, + help='Postgres connection URI (e.g. postgresql://user:pass@host)', +) +@click.option( + '--check-pc', + type=float, + default=DEFAULT_CHECK_PC, + help=f'Percentage of table to check, default {DEFAULT_CHECK_PC} ({int(DEFAULT_CHECK_PC * 100)}% of the table)', +) +@click.option( + '--check-max', + type=int, + default=DEFAULT_CHECK_MAX, + help=f'Maximum number of rows to check, default {DEFAULT_CHECK_MAX} (0 for no limit)', +) +@click.option( + '--check-min', + type=int, + default=DEFAULT_CHECK_MIN, + help=f'Minimum number of rows to check, default {DEFAULT_CHECK_MIN}', +) +@click.option( + '-v', + '--verbose', + is_flag=True, + default=False, + help='Show debug messages', +) +def main(db_dir, sqlite_uri, pg_uri, check_pc, check_max, check_min, verbose): + if verbose: + logger.setLevel(logging.DEBUG) + else: + logger.setLevel(logging.INFO) + + logger.addHandler(logging.StreamHandler()) + + if len(db_dir) + len(sqlite_uri) + len(pg_uri) != 2: + raise click.BadParameter('exactly 2 db_dir or sqlite_uri or pg_uri must be given') + db_infos = [] + for db_dir_ in db_dir: + db_infos.append(SqliteDatabaseInfo(db_dir_)) + for sqlite_uri_ in sqlite_uri: + db_infos.append(SqliteDatabaseInfo(sqlite_uri_)) + for pg_uri_ in pg_uri: + db_infos.append(PostgresDatabaseInfo(pg_uri_)) + exact = not (sqlite_uri and pg_uri) + differences = compare_databases( + *db_infos, + exact=exact, + check_pc=check_pc, + check_max=check_max, + check_min=check_min, + ) + if differences: + click.echo(f'Databases {db_infos[0]} and {db_infos[1]} are different:') + for line in differences: + click.echo(line) + sys.exit(1) + else: + click.echo(f'Databases {db_infos[0]} and {db_infos[1]} are the same') + + +if __name__ == '__main__': + main() diff --git a/wbia/cli/convert_hsdb.py b/wbia/cli/convert_hsdb.py new file mode 100644 index 0000000000..a0f8bf7594 --- /dev/null +++ b/wbia/cli/convert_hsdb.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +"""Script to convert hotspotter database (HSDB) to a WBIA compatible database""" +import sys + +import click + +from wbia.dbio.ingest_hsdb import is_hsdb, is_succesful_convert, convert_hsdb_to_wbia + + +@click.command() +@click.option( + '--db-dir', required=True, type=click.Path(exists=True), help='database location' +) +def main(db_dir): + """Convert hotspotter database (HSDB) to a WBIA compatible database""" + click.echo(f'⏳ working on {db_dir}') + if is_hsdb(db_dir): + click.echo('✅ confirmed hotspotter database') + else: + click.echo('❌ not a hotspotter database') + sys.exit(1) + if is_succesful_convert(db_dir): + click.echo('✅ already converted hotspotter database') + sys.exit(0) + + convert_hsdb_to_wbia( + db_dir, + ensure=True, + verbose=True, + ) + + if is_succesful_convert(db_dir): + click.echo('✅ successfully converted database') + else: + click.echo('❌ unsuccessfully converted... further investigation necessary') + sys.exit(1) + sys.exit(0) + + +if __name__ == '__main__': + main() diff --git a/wbia/cli/migrate_sqlite_to_postgres.py b/wbia/cli/migrate_sqlite_to_postgres.py new file mode 100644 index 0000000000..9e43f80e2c --- /dev/null +++ b/wbia/cli/migrate_sqlite_to_postgres.py @@ -0,0 +1,143 @@ +# -*- coding: utf-8 -*- +import logging +import re +import subprocess +import sys +from pathlib import Path + +import click +import sqlalchemy + +from wbia.dtool.copy_sqlite_to_postgres import ( + copy_sqlite_to_postgres, + SqliteDatabaseInfo, + PostgresDatabaseInfo, + compare_databases, +) + + +logger = logging.getLogger('wbia') + + +@click.command() +@click.option( + '--db-dir', required=True, type=click.Path(exists=True), help='database location' +) +@click.option( + '--db-uri', + required=True, + help='Postgres connection URI (e.g. postgres://user:pass@host)', +) +@click.option( + '--force', + is_flag=True, + default=False, + help='Delete all tables in the public schema in postgres', +) +@click.option( + '-v', + '--verbose', + is_flag=True, + default=False, + help='Show debug messages', +) +def main(db_dir, db_uri, force, verbose): + """""" + # Set up logging + if verbose: + logger.setLevel(logging.DEBUG) + else: + logger.setLevel(logging.INFO) + logger.addHandler(logging.StreamHandler()) + + logger.info(f'using {db_dir} ...') + + # Create the database if it doesn't exist + engine = sqlalchemy.create_engine(db_uri) + try: + engine.connect() + except sqlalchemy.exc.OperationalError as e: + m = re.search(r'database "([^"]*)" does not exist', str(e)) + if m: + dbname = m.group(1) + engine = sqlalchemy.create_engine(db_uri.rsplit('/', 1)[0]) + logger.info(f'Creating "{dbname}"...') + engine.execution_options(isolation_level='AUTOCOMMIT').execute( + f'CREATE DATABASE {dbname}' + ) + else: + raise + finally: + engine.dispose() + + # Check that the database hasn't already been migrated. + db_infos = [ + SqliteDatabaseInfo(Path(db_dir)), + PostgresDatabaseInfo(db_uri), + ] + differences = compare_databases(*db_infos) + + if not differences: + logger.info('Database already migrated') + sys.exit(0) + + # Make sure there are no tables in the public schema in postgresql + # because we're using it as the workspace for the migration + if 'public' in db_infos[1].get_schema(): + table_names = [ + t for schema, t in db_infos[1].get_table_names() if schema == 'public' + ] + if not force: + click.echo( + f'Tables in public schema in postgres database: {", ".join(table_names)}' + ) + click.echo('Use --force to remove the tables in public schema') + sys.exit(1) + else: + click.echo(f'Dropping all tables in public schema: {", ".join(table_names)}') + for table_name in table_names: + db_infos[1].engine.execute(f'DROP TABLE {table_name} CASCADE') + + # Migrate + problems = {} + with click.progressbar(length=100000, show_eta=True) as bar: + for path, completed_future, db_size, total_size in copy_sqlite_to_postgres( + Path(db_dir), db_uri + ): + try: + completed_future.result() + except Exception as exc: + logger.info( + f'\nfailed while processing {str(path)}\n{completed_future.exception()}' + ) + problems[path] = exc + else: + logger.info(f'\nfinished processing {str(path)}') + finally: + bar.update(int(db_size / total_size * bar.length)) + + # Report problems + for path, exc in problems.items(): + logger.info('*' * 60) + logger.info(f'There was a problem migrating {str(path)}') + logger.exception(exc) + if isinstance(exc, subprocess.CalledProcessError): + logger.info('-' * 30) + logger.info(exc.stdout.decode()) + + # Verify the migration + differences = compare_databases(*db_infos) + + if differences: + logger.info(f'Databases {db_infos[0]} and {db_infos[1]} are different:') + for line in differences: + logger.info(line) + sys.exit(1) + else: + logger.info(f'Database {db_infos[0]} successfully migrated to {db_infos[1]}') + + sys.exit(0) + + +if __name__ == '__main__': + main() diff --git a/wbia/cli/testdbs.py b/wbia/cli/testdbs.py index 1575a1ae17..20dbb4cef9 100644 --- a/wbia/cli/testdbs.py +++ b/wbia/cli/testdbs.py @@ -6,7 +6,13 @@ import click from wbia.dbio import ingest_database -from wbia.init.sysres import get_workdir +from wbia.init.sysres import ( + ensure_nauts, + ensure_pz_mtest, + ensure_testdb2, + ensure_wilddogs, + get_workdir, +) @click.command() @@ -18,6 +24,10 @@ def main(force_replace): dbs = { # : 'testdb1': lambda: ingest_database.ingest_standard_database('testdb1'), + 'PZ_MTEST': ensure_pz_mtest, + 'NAUT_test': ensure_nauts, + 'wd_peter2': ensure_wilddogs, + 'testdb2': ensure_testdb2, } for db in dbs: diff --git a/wbia/conftest.py b/wbia/conftest.py index 22cd04a077..e58996648f 100644 --- a/wbia/conftest.py +++ b/wbia/conftest.py @@ -1,11 +1,10 @@ # -*- coding: utf-8 -*- # See also conftest.py documentation at https://docs.pytest.org/en/stable/fixture.html#conftest-py-sharing-fixture-functions """This module is implicitly used by ``pytest`` to load testing configuration and fixtures.""" -import os -from functools import wraps from pathlib import Path import pytest +import sqlalchemy from wbia.dbio import ingest_database from wbia.init.sysres import ( @@ -26,80 +25,100 @@ 'testdb2', 'testdb_guiall', 'wd_peter2', + 'testdb_assigner', + # Not a populated database, but used by wbia.dbio.export_subset:merge_databases + 'testdb_dst', ) -# Global marker for determining the availablity of postgres -# set by db_uri fixture and used by requires_postgresql decorator -_POSTGRES_AVAILABLE = None - - -# -# Decorators -# - +@pytest.fixture +def enable_wildbook_signal(): + """This sets the ``ENABLE_WILDBOOK_SIGNAL`` to False""" + # TODO (16-Jul-12020) Document ENABLE_WILDBOOK_SIGNAL + # ??? what is ENABLE_WILDBOOK_SIGNAL used for? + import wbia -def requires_postgresql(func): - """Test decorator to mark a test that requires postgresql + setattr(wbia, 'ENABLE_WILDBOOK_SIGNAL', False) - Usage: - @requires_postgresql - def test_postgres_thing(): - # testing logic that requires postgres... - assert True +@pytest.fixture(scope='session', autouse=True) +def postgres_base_uri(request): + """The base URI connection string to postgres. + This should contain all necessary connection information except the database name. """ - # Firstly, skip if psycopg2 is not installed - try: - import psycopg2 # noqa: - except ImportError: - pytest.mark.skip('psycopg2 is not installed')(func) + uri = request.config.getoption('postgres_uri') + if not uri: + # Not set, return None; indicates the tests are not to use postgres + return None + + # If the URI contains a database name, we need to remove it + from sqlalchemy.engine.url import make_url, URL + + url = make_url(uri) + url_kwargs = { + 'drivername': url.drivername, + 'username': url.username, + 'password': url.password, + 'host': url.host, + 'port': url.port, + # Purposely remove database and query. + # 'database': None, + # 'query': None, + } + base_uri = str(URL.create(**url_kwargs)) + return base_uri + + +class MonkeyPatchedGetWbiaDbUri: + """Creates a monkey patched version of ``wbia.init.sysres.get_wbia_db_uri`` + to set the testing URI. - @wraps(func) - def wrapper(*args, **kwargs): - # We'll only know if we can connect to postgres during execution. - if not _POSTGRES_AVAILABLE: # see db_uri fixture for value definition - pytest.skip('requires a postgresql connection URI') - return func(*args, **kwargs) + """ - return wrapper + def __init__(self, base_uri: str): + self.base_uri = base_uri + def __call__(self, db_dir: str): + """The monkeypatch of ``wbia.init.sysres.get_wbia_db_uri``""" + uri = None + # Reminder, base_uri could be None if running tests under sqlite + if self.base_uri: + db_name = self.get_db_name_from_db_dir(Path(db_dir)) + uri = self.replace_uri_database(self.base_uri, db_name) + return uri -# -# Fixtures -# + def get_db_name_from_db_dir(self, db_dir: Path): + """Discover the database name from the given ``db_dir``""" + from wbia.init.sysres import get_workdir + db_dir = db_dir.resolve() # just in case + work_dir = Path(get_workdir()).resolve() -@pytest.fixture(scope='session', autouse=True) -def db_uri(): - """The DB URI to use with the tests. - This value comes from ``WBIA_TESTING_BASE_DB_URI``. - """ - # TODO (28-Aug-12020) Should we depend on the user supplying this value? - # Perhaps not at this level? Fail if not specified? - uri = os.getenv('WBIA_TESTING_BASE_DB_URI', '') + # Can we discover the database name? + # if not db_dir.is_relative_to(workdir): # >= Python 3.9 + if not str(work_dir) in str(db_dir): + raise ValueError( + 'Strange circumstances have lead us to a place of ' + f"incongruity where the '{db_dir}' is not within '{work_dir}'" + ) - # Set postgres availablity marker - global _POSTGRES_AVAILABLE - _POSTGRES_AVAILABLE = uri.startswith('postgres') + # lowercase because database names are case insensitive + return db_dir.name.lower() - return uri + def replace_uri_database(self, uri: str, db_name: str): + """Replace the database name in the given ``uri`` with the given ``db_name``""" + from sqlalchemy.engine.url import make_url + url = make_url(uri) + url = url._replace(database=db_name) -@pytest.fixture -def enable_wildbook_signal(): - """This sets the ``ENABLE_WILDBOOK_SIGNAL`` to False""" - # TODO (16-Jul-12020) Document ENABLE_WILDBOOK_SIGNAL - # ??? what is ENABLE_WILDBOOK_SIGNAL used for? - import wbia - - setattr(wbia, 'ENABLE_WILDBOOK_SIGNAL', False) + return str(url) @pytest.fixture(scope='session', autouse=True) @pytest.mark.usefixtures('enable_wildbook_signal') -def set_up_db(request): +def set_up_db(request, postgres_base_uri): """ Sets up the testing databases. This fixture is set to run automatically any any test run of wbia. @@ -114,6 +133,20 @@ def set_up_db(request): # FIXME (16-Jul-12020) this fixture does not cleanup after itself to preserve exiting usage behavior for dbname in TEST_DBNAMES: delete_dbdir(dbname) + if postgres_base_uri: + engine = sqlalchemy.create_engine(postgres_base_uri) + engine.execution_options(isolation_level='AUTOCOMMIT').execute( + f'DROP DATABASE IF EXISTS {dbname}' + ) + engine.execution_options(isolation_level='AUTOCOMMIT').execute( + f'CREATE DATABASE {dbname}' + ) + engine.dispose() + + # Monkey patch the global URI getter + from wbia.init import sysres + + setattr(sysres, 'get_wbia_db_uri', MonkeyPatchedGetWbiaDbUri(postgres_base_uri)) # Set up DBs ingest_database.ingest_standard_database('testdb1') diff --git a/wbia/constants.py b/wbia/constants.py index 61ec041698..4fd9621504 100644 --- a/wbia/constants.py +++ b/wbia/constants.py @@ -354,6 +354,7 @@ class ZIPPED_URLS(object): # NOQA ORIENTATION = ( 'https://wildbookiarepository.azureedge.net/databases/testdb_orientation.zip' ) + ASSIGNER = 'https://wildbookiarepository.azureedge.net/databases/testdb_assigner.zip' K7_EXAMPLE = 'https://wildbookiarepository.azureedge.net/databases/testdb_kaggle7.zip' diff --git a/wbia/control/DB_SCHEMA.py b/wbia/control/DB_SCHEMA.py index 22d3a7a468..09d747bfaf 100644 --- a/wbia/control/DB_SCHEMA.py +++ b/wbia/control/DB_SCHEMA.py @@ -224,7 +224,12 @@ def update_1_0_0(db, ibs=None): ('feature_keypoints', 'NUMPY'), ('feature_sifts', 'NUMPY'), ), - superkeys=[('chip_rowid, config_rowid',)], + superkeys=[ + ( + 'chip_rowid', + 'config_rowid', + ) + ], docstr=""" Used to store individual chip features (ellipses)""", ) @@ -2160,9 +2165,10 @@ def dump_schema_sql(): from wbia import dtool as dt from wbia.control import DB_SCHEMA_CURRENT - db = dt.SQLDatabaseController.from_uri(':memory:') + db = dt.SQLDatabaseController('sqlite:///', 'dump') DB_SCHEMA_CURRENT.update_current(db) - dump_str = dumps(db.connection) + with db.connect() as conn: + dump_str = dumps(conn) logger.info(dump_str) for tablename in db.get_table_names(): diff --git a/wbia/control/IBEISControl.py b/wbia/control/IBEISControl.py index 0c8d446a78..8262ac6d59 100644 --- a/wbia/control/IBEISControl.py +++ b/wbia/control/IBEISControl.py @@ -41,18 +41,19 @@ import utool as ut import ubelt as ub from six.moves import zip -from os.path import join, split, realpath +from os.path import join, split from wbia.init import sysres -from wbia.dbio import ingest_hsdb from wbia import constants as const from wbia.control import accessor_decors, controller_inject from wbia.dtool.dump import dump +from pathlib import Path # Inject utool functions (print, rrr, profile) = ut.inject2(__name__) logger = logging.getLogger('wbia') + # Import modules which define injectable functions # tuples represent conditional imports with the flags in the first part of the @@ -68,6 +69,7 @@ 'wbia.other.detectgrave', 'wbia.other.detecttrain', 'wbia.init.filter_annots', + 'wbia.research.metrics', 'wbia.control.manual_featweight_funcs', 'wbia.control._autogen_party_funcs', 'wbia.control.manual_annotmatch_funcs', @@ -117,6 +119,11 @@ (('--no-curvrank', '--nocurvrank'), 'wbia_curvrank._plugin'), ] +if ut.get_argflag('--curvrank-v2'): + AUTOLOAD_PLUGIN_MODNAMES += [ + (('--no-curvrank-v2', '--nocurvrankv2'), 'wbia_curvrank_v2._plugin'), + ] + if ut.get_argflag('--deepsense'): AUTOLOAD_PLUGIN_MODNAMES += [ (('--no-deepsense', '--nodeepsense'), 'wbia_deepsense._plugin'), @@ -140,6 +147,12 @@ (('--no-2d-orient', '--no2dorient'), 'wbia_2d_orientation._plugin'), ] + +if ut.get_argflag('--orient'): + AUTOLOAD_PLUGIN_MODNAMES += [ + (('--no-orient', '--noorient'), 'wbia_orientation._plugin'), + ] + if ut.get_argflag('--pie'): AUTOLOAD_PLUGIN_MODNAMES += [ (('--no-pie', '--nopie'), 'wbia_pie._plugin'), @@ -244,7 +257,8 @@ def request_IBEISController( Example: >>> # ENABLE_DOCTEST >>> from wbia.control.IBEISControl import * # NOQA - >>> dbdir = 'testdb1' + >>> from wbia.init.sysres import get_workdir + >>> dbdir = '/'.join([get_workdir(), 'testdb1']) >>> ensure = True >>> wbaddr = None >>> verbose = True @@ -264,21 +278,15 @@ def request_IBEISController( if force_serial: assert ibs.force_serial, 'set use_cache=False in wbia.opendb' else: - # Convert hold hotspotter dirs if necessary - if check_hsdb and ingest_hsdb.check_unconverted_hsdb(dbdir): - ibs = ingest_hsdb.convert_hsdb_to_wbia( - dbdir, ensure=ensure, wbaddr=wbaddr, verbose=verbose - ) - else: - ibs = IBEISController( - dbdir=dbdir, - ensure=ensure, - wbaddr=wbaddr, - verbose=verbose, - force_serial=force_serial, - request_dbversion=request_dbversion, - request_stagingversion=request_stagingversion, - ) + ibs = IBEISController( + dbdir=dbdir, + ensure=ensure, + wbaddr=wbaddr, + verbose=verbose, + force_serial=force_serial, + request_dbversion=request_dbversion, + request_stagingversion=request_stagingversion, + ) __IBEIS_CONTROLLER_CACHE__[dbdir] = ibs return ibs @@ -325,7 +333,7 @@ class IBEISController(BASE_CLASS): @profile def __init__( - ibs, + self, dbdir=None, ensure=True, wbaddr=None, @@ -338,15 +346,15 @@ def __init__( # if verbose and ut.VERBOSE: logger.info('\n[ibs.__init__] new IBEISController') - ibs.dbname = None + self.dbname = None # an dict to hack in temporary state - ibs.const = const - ibs.readonly = None - ibs.depc_image = None - ibs.depc_annot = None - ibs.depc_part = None - # ibs.allow_override = 'override+warn' - ibs.allow_override = True + self.const = const + self.readonly = None + self.depc_image = None + self.depc_annot = None + self.depc_part = None + # self.allow_override = 'override+warn' + self.allow_override = True if force_serial is None: if ut.get_argflag(('--utool-force-serial', '--force-serial', '--serial')): force_serial = True @@ -354,60 +362,109 @@ def __init__( force_serial = not ut.in_main_process() # if const.CONTAINERIZED: # force_serial = True - ibs.force_serial = force_serial + self.force_serial = force_serial # observer_weakref_list keeps track of the guibacks connected to this # controller - ibs.observer_weakref_list = [] + self.observer_weakref_list = [] # not completely working decorator cache - ibs.table_cache = None - ibs._initialize_self() - ibs._init_dirs(dbdir=dbdir, ensure=ensure) + self.table_cache = None + self._initialize_self() + self._init_dirs(dbdir=dbdir, ensure=ensure) + + # Set the base URI to be used for all database connections + self.__init_base_uri() + # _send_wildbook_request will do nothing if no wildbook address is # specified - ibs._send_wildbook_request(wbaddr) - ibs._init_sql( + self._send_wildbook_request(wbaddr) + self._init_sql( request_dbversion=request_dbversion, request_stagingversion=request_stagingversion, ) - ibs._init_config() - if not ut.get_argflag('--noclean') and not ibs.readonly: - # ibs._init_burned_in_species() - ibs._clean_species() - ibs.job_manager = None + self._init_config() + if not ut.get_argflag('--noclean') and not self.readonly: + # self._init_burned_in_species() + self._clean_species() + self.job_manager = None # Hack for changing the way chips compute # by default use serial because warpAffine is weird with multiproc is_mac = 'macosx' in ut.get_plat_specifier().lower() - ibs._parallel_chips = not ibs.force_serial and not is_mac + self._parallel_chips = not self.force_serial and not is_mac - ibs.containerized = const.CONTAINERIZED - ibs.production = const.PRODUCTION + self.containerized = const.CONTAINERIZED + self.production = const.PRODUCTION - logger.info('[ibs.__init__] CONTAINERIZED: %s\n' % (ibs.containerized,)) - logger.info('[ibs.__init__] PRODUCTION: %s\n' % (ibs.production,)) + logger.info('[ibs.__init__] CONTAINERIZED: %s\n' % (self.containerized,)) + logger.info('[ibs.__init__] PRODUCTION: %s\n' % (self.production,)) # Hack to store HTTPS flag (deliver secure content in web) - ibs.https = const.HTTPS + self.https = const.HTTPS logger.info('[ibs.__init__] END new IBEISController\n') - def reset_table_cache(ibs): - ibs.table_cache = accessor_decors.init_tablecache() + def __init_base_uri(self) -> None: + """Initialize the base URI that is used for all database connections. + This sets the ``_base_uri`` attribute. + This influences the ``is_using_postgres`` property. + + One of the following conditions is met in order to set the uri value: + + - ``--db-uri`` is set to a Postgres URI on the commandline + - only db-dir is set, and thus we assume a sqlite connection - def clear_table_cache(ibs, tablename=None): + """ + self._is_using_postgres_db = False + + uri = sysres.get_wbia_db_uri(self.dbdir) + if uri: + if not uri.startswith('postgresql://'): + raise RuntimeError( + "invalid use of '--db-uri'; only supports postgres uris; " + f"uri = '{uri}'" + ) + # Capture that we are using postgres + self._is_using_postgres_db = True + else: + # Assume a sqlite database + uri = f'sqlite:///{self.get_ibsdir()}' + self._base_uri = uri + + @property + def is_using_postgres_db(self) -> bool: + """Indicates whether this controller is using postgres as the database""" + return self._is_using_postgres_db + + @property + def base_uri(self): + """Base database URI without a specific database name""" + return self._base_uri + + def make_cache_db_uri(self, name): + """Given a name of the cache produce a database connection URI""" + if self.is_using_postgres_db: + # When using postgres, the base-uri is a connection to a single database + # that is used for all database needs and scoped using namespace schemas. + return self._base_uri + return f'sqlite:///{self.get_cachedir()}/{name}.sqlite' + + def reset_table_cache(self): + self.table_cache = accessor_decors.init_tablecache() + + def clear_table_cache(self, tablename=None): logger.info('[ibs] clearing table_cache[%r]' % (tablename,)) if tablename is None: - ibs.reset_table_cache() + self.reset_table_cache() else: try: - del ibs.table_cache[tablename] + del self.table_cache[tablename] except KeyError: pass - def show_depc_graph(ibs, depc, reduced=False): + def show_depc_graph(self, depc, reduced=False): depc.show_graph(reduced=reduced) - def show_depc_image_graph(ibs, **kwargs): + def show_depc_image_graph(self, **kwargs): """ CommandLine: python -m wbia.control.IBEISControl --test-show_depc_image_graph --show @@ -422,9 +479,9 @@ def show_depc_image_graph(ibs, **kwargs): >>> ibs.show_depc_image_graph(reduced=reduced) >>> ut.show_if_requested() """ - ibs.show_depc_graph(ibs.depc_image, **kwargs) + self.show_depc_graph(self.depc_image, **kwargs) - def show_depc_annot_graph(ibs, *args, **kwargs): + def show_depc_annot_graph(self, *args, **kwargs): """ CommandLine: python -m wbia.control.IBEISControl --test-show_depc_annot_graph --show @@ -439,9 +496,9 @@ def show_depc_annot_graph(ibs, *args, **kwargs): >>> ibs.show_depc_annot_graph(reduced=reduced) >>> ut.show_if_requested() """ - ibs.show_depc_graph(ibs.depc_annot, *args, **kwargs) + self.show_depc_graph(self.depc_annot, *args, **kwargs) - def show_depc_annot_table_input(ibs, tablename, *args, **kwargs): + def show_depc_annot_table_input(self, tablename, *args, **kwargs): """ CommandLine: python -m wbia.control.IBEISControl --test-show_depc_annot_table_input --show --tablename=vsone @@ -457,30 +514,30 @@ def show_depc_annot_table_input(ibs, tablename, *args, **kwargs): >>> ibs.show_depc_annot_table_input(tablename) >>> ut.show_if_requested() """ - ibs.depc_annot[tablename].show_input_graph() + self.depc_annot[tablename].show_input_graph() - def get_cachestats_str(ibs): + def get_cachestats_str(self): """ Returns info about the underlying SQL cache memory """ total_size_str = ut.get_object_size_str( - ibs.table_cache, lbl='size(table_cache): ' + self.table_cache, lbl='size(table_cache): ' ) - total_size_str = '\nlen(table_cache) = %r' % (len(ibs.table_cache)) + total_size_str = '\nlen(table_cache) = %r' % (len(self.table_cache)) table_size_str_list = [ ut.get_object_size_str(val, lbl='size(table_cache[%s]): ' % (key,)) - for key, val in six.iteritems(ibs.table_cache) + for key, val in six.iteritems(self.table_cache) ] cachestats_str = total_size_str + ut.indentjoin(table_size_str_list, '\n * ') return cachestats_str - def print_cachestats_str(ibs): - cachestats_str = ibs.get_cachestats_str() + def print_cachestats_str(self): + cachestats_str = self.get_cachestats_str() logger.info('IBEIS Controller Cache Stats:') logger.info(cachestats_str) return cachestats_str - def _initialize_self(ibs): + def _initialize_self(self): """ Injects code from plugin modules into the controller @@ -488,17 +545,17 @@ def _initialize_self(ibs): """ if ut.VERBOSE: logger.info('[ibs] _initialize_self()') - ibs.reset_table_cache() + self.reset_table_cache() ut.util_class.inject_all_external_modules( - ibs, + self, controller_inject.CONTROLLER_CLASSNAME, - allow_override=ibs.allow_override, + allow_override=self.allow_override, ) - assert hasattr(ibs, 'get_database_species'), 'issue with ibsfuncs' - assert hasattr(ibs, 'get_annot_pair_timedelta'), 'issue with annotmatch_funcs' - ibs.register_controller() + assert hasattr(self, 'get_database_species'), 'issue with ibsfuncs' + assert hasattr(self, 'get_annot_pair_timedelta'), 'issue with annotmatch_funcs' + self.register_controller() - def _on_reload(ibs): + def _on_reload(self): """ For utools auto reload (rrr). Called before reload @@ -506,36 +563,36 @@ def _on_reload(ibs): # Reloading breaks flask, turn it off controller_inject.GLOBAL_APP_ENABLED = False # Only warn on first load. Overrideing while reloading is ok - ibs.allow_override = True - ibs.unregister_controller() + self.allow_override = True + self.unregister_controller() # Reload dependent modules ut.reload_injected_modules(controller_inject.CONTROLLER_CLASSNAME) - def load_plugin_module(ibs, module): + def load_plugin_module(self, module): ut.inject_instance( - ibs, + self, classkey=module.CLASS_INJECT_KEY, - allow_override=ibs.allow_override, + allow_override=self.allow_override, strict=False, verbose=False, ) # We should probably not implement __del__ # see: https://docs.python.org/2/reference/datamodel.html#object.__del__ - # def __del__(ibs): - # ibs.cleanup() + # def __del__(self): + # self.cleanup() # ------------ # SELF REGISTRATION # ------------ - def register_controller(ibs): + def register_controller(self): """ registers controller with global list """ - ibs_weakref = weakref.ref(ibs) + ibs_weakref = weakref.ref(self) __ALL_CONTROLLERS__.append(ibs_weakref) - def unregister_controller(ibs): - ibs_weakref = weakref.ref(ibs) + def unregister_controller(self): + ibs_weakref = weakref.ref(self) try: __ALL_CONTROLLERS__.remove(ibs_weakref) pass @@ -546,48 +603,48 @@ def unregister_controller(ibs): # OBSERVER REGISTRATION # ------------ - def cleanup(ibs): + def cleanup(self): """ call on del? """ - logger.info('[ibs.cleanup] Observers (if any) notified [controller killed]') - for observer_weakref in ibs.observer_weakref_list: + logger.info('[self.cleanup] Observers (if any) notified [controller killed]') + for observer_weakref in self.observer_weakref_list: observer_weakref().notify_controller_killed() - def register_observer(ibs, observer): + def register_observer(self, observer): logger.info('[register_observer] Observer registered: %r' % observer) observer_weakref = weakref.ref(observer) - ibs.observer_weakref_list.append(observer_weakref) + self.observer_weakref_list.append(observer_weakref) - def remove_observer(ibs, observer): + def remove_observer(self, observer): logger.info('[remove_observer] Observer removed: %r' % observer) - ibs.observer_weakref_list.remove(observer) + self.observer_weakref_list.remove(observer) - def notify_observers(ibs): + def notify_observers(self): logger.info('[notify_observers] Observers (if any) notified') - for observer_weakref in ibs.observer_weakref_list: + for observer_weakref in self.observer_weakref_list: observer_weakref().notify() # ------------ - def _init_rowid_constants(ibs): + def _init_rowid_constants(self): # ADD TO CONSTANTS # THIS IS EXPLICIT IN CONST, USE THAT VERSION INSTEAD - # ibs.UNKNOWN_LBLANNOT_ROWID = const.UNKNOWN_LBLANNOT_ROWID - # ibs.UNKNOWN_NAME_ROWID = ibs.UNKNOWN_LBLANNOT_ROWID - # ibs.UNKNOWN_SPECIES_ROWID = ibs.UNKNOWN_LBLANNOT_ROWID + # self.UNKNOWN_LBLANNOT_ROWID = const.UNKNOWN_LBLANNOT_ROWID + # self.UNKNOWN_NAME_ROWID = self.UNKNOWN_LBLANNOT_ROWID + # self.UNKNOWN_SPECIES_ROWID = self.UNKNOWN_LBLANNOT_ROWID - # ibs.MANUAL_CONFIG_SUFFIX = 'MANUAL_CONFIG' - # ibs.MANUAL_CONFIGID = ibs.add_config(ibs.MANUAL_CONFIG_SUFFIX) + # self.MANUAL_CONFIG_SUFFIX = 'MANUAL_CONFIG' + # self.MANUAL_CONFIGID = self.add_config(self.MANUAL_CONFIG_SUFFIX) # duct_tape.fix_compname_configs(ibs) # duct_tape.remove_database_slag(ibs) # duct_tape.fix_nulled_yaws(ibs) lbltype_names = const.KEY_DEFAULTS.keys() lbltype_defaults = const.KEY_DEFAULTS.values() - lbltype_ids = ibs.add_lbltype(lbltype_names, lbltype_defaults) - ibs.lbltype_ids = dict(zip(lbltype_names, lbltype_ids)) + lbltype_ids = self.add_lbltype(lbltype_names, lbltype_defaults) + self.lbltype_ids = dict(zip(lbltype_names, lbltype_ids)) @profile - def _init_sql(ibs, request_dbversion=None, request_stagingversion=None): + def _init_sql(self, request_dbversion=None, request_stagingversion=None): """ Load or create sql database """ from wbia.other import duct_tape # NOQA @@ -598,286 +655,175 @@ def _init_sql(ibs, request_dbversion=None, request_stagingversion=None): # DATABASE DURING A POST UPDATE FUNCTION ROUTINE, WHICH HAS TO BE LOADED # FIRST AND DEFINED IN ORDER TO MAKE THE SUBSEQUENT WRITE CALLS TO THE # RELEVANT CACHE DATABASE - ibs._init_depcache() - ibs._init_sqldbcore(request_dbversion=request_dbversion) - ibs._init_sqldbstaging(request_stagingversion=request_stagingversion) - # ibs.db.dump_schema() - ibs._init_rowid_constants() + self._init_depcache() + self._init_sqldbcore(request_dbversion=request_dbversion) + self._init_sqldbstaging(request_stagingversion=request_stagingversion) + # self.db.dump_schema() + self._init_rowid_constants() - def _needs_backup(ibs): + def _needs_backup(self): needs_backup = not ut.get_argflag('--nobackup') - if ibs.get_dbname() == 'PZ_MTEST': + if self.get_dbname() == 'PZ_MTEST': needs_backup = False if dtool.sql_control.READ_ONLY: needs_backup = False return needs_backup @profile - def _init_sqldbcore(ibs, request_dbversion=None): - """ - Example: - >>> # DISABLE_DOCTEST - >>> from wbia.control.IBEISControl import * # NOQA - >>> import wbia # NOQA - >>> #ibs = wbia.opendb('PZ_MTEST') - >>> #ibs = wbia.opendb('PZ_Master0') - >>> ibs = wbia.opendb('testdb1') - >>> #ibs = wbia.opendb('PZ_Master0') - - Ignore: - aid_list = ibs.get_valid_aids() - #ibs.update_annot_visual_uuids(aid_list) - vuuid_list = ibs.get_annot_visual_uuids(aid_list) - aid_list2 = ibs.get_annot_aids_from_visual_uuid(vuuid_list) - assert aid_list2 == aid_list - # v1.3.0 testdb1:264us, PZ_MTEST:3.93ms, PZ_Master0:11.6s - %timeit ibs.get_annot_aids_from_visual_uuid(vuuid_list) - # v1.3.1 testdb1:236us, PZ_MTEST:1.83ms, PZ_Master0:140ms - - ibs.print_imageset_table(exclude_columns=['imageset_uuid']) - """ - from wbia.control import _sql_helpers + def _init_sqldbcore(self, request_dbversion=None): + """Initializes the *main* database object""" + # FIXME (12-Jan-12021) Disabled automatic schema upgrade + DB_VERSION_EXPECTED = '2.0.0' + + if self.is_using_postgres_db: + uri = self._base_uri + else: + uri = f'{self.base_uri}/{self.sqldb_fname}' + fname = Path(self.sqldb_fname).stem # filename without extension + self.db = dtool.SQLDatabaseController(uri, fname) + + # BBB (12-Jan-12021) Disabled the ability to make the database read-only + self.readonly = False + + # Upgrade the database + from wbia.control._sql_helpers import ensure_correct_version from wbia.control import DB_SCHEMA - # Before load, ensure database has been backed up for the day - backup_idx = ut.get_argval('--loadbackup', type_=int, default=None) - sqldb_fpath = None - if backup_idx is not None: - backups = _sql_helpers.get_backup_fpaths(ibs) - logger.info('backups = %r' % (backups,)) - sqldb_fpath = backups[backup_idx] - logger.info('CHOSE BACKUP sqldb_fpath = %r' % (sqldb_fpath,)) - if backup_idx is None and ibs._needs_backup(): - try: - _sql_helpers.ensure_daily_database_backup( - ibs.get_ibsdir(), ibs.sqldb_fname, ibs.backupdir - ) - except IOError as ex: - ut.printex( - ex, ('Failed making daily backup. ' 'Run with --nobackup to disable') - ) - import utool + ensure_correct_version( + self, + self.db, + DB_VERSION_EXPECTED, + DB_SCHEMA, + verbose=True, + dobackup=not self.readonly, + ) - utool.embed() - raise - # IBEIS SQL State Database - if request_dbversion is None: - ibs.db_version_expected = '2.0.0' - else: - ibs.db_version_expected = request_dbversion - # TODO: add this functionality to SQLController - if backup_idx is None: - new_version, new_fname = dtool.sql_control.dev_test_new_schema_version( - ibs.get_dbname(), - ibs.get_ibsdir(), - ibs.sqldb_fname, - ibs.db_version_expected, - version_next='2.0.0', - ) - ibs.db_version_expected = new_version - ibs.sqldb_fname = new_fname - if sqldb_fpath is None: - assert backup_idx is None - sqldb_fpath = join(ibs.get_ibsdir(), ibs.sqldb_fname) - readonly = None + @profile + def _init_sqldbstaging(self, request_stagingversion=None): + """Initializes the *staging* database object""" + # FIXME (12-Jan-12021) Disabled automatic schema upgrade + DB_VERSION_EXPECTED = '1.2.0' + + if self.is_using_postgres_db: + uri = self._base_uri else: - readonly = True - db_uri = 'file://{}'.format(realpath(sqldb_fpath)) - ibs.db = dtool.SQLDatabaseController.from_uri(db_uri, readonly=readonly) - ibs.readonly = ibs.db.readonly - - if backup_idx is None: - # Ensure correct schema versions - _sql_helpers.ensure_correct_version( - ibs, - ibs.db, - ibs.db_version_expected, - DB_SCHEMA, - verbose=ut.VERBOSE, - dobackup=not ibs.readonly, - ) - # import sys - # sys.exit(1) + uri = f'{self.base_uri}/{self.sqlstaging_fname}' + fname = Path(self.sqlstaging_fname).stem # filename without extension + self.staging = dtool.SQLDatabaseController(uri, fname) - @profile - def _init_sqldbstaging(ibs, request_stagingversion=None): - """ - Example: - >>> # DISABLE_DOCTEST - >>> from wbia.control.IBEISControl import * # NOQA - >>> import wbia # NOQA - >>> #ibs = wbia.opendb('PZ_MTEST') - >>> #ibs = wbia.opendb('PZ_Master0') - >>> ibs = wbia.opendb('testdb1') - >>> #ibs = wbia.opendb('PZ_Master0') - - Ignore: - aid_list = ibs.get_valid_aids() - #ibs.update_annot_visual_uuids(aid_list) - vuuid_list = ibs.get_annot_visual_uuids(aid_list) - aid_list2 = ibs.get_annot_aids_from_visual_uuid(vuuid_list) - assert aid_list2 == aid_list - # v1.3.0 testdb1:264us, PZ_MTEST:3.93ms, PZ_Master0:11.6s - %timeit ibs.get_annot_aids_from_visual_uuid(vuuid_list) - # v1.3.1 testdb1:236us, PZ_MTEST:1.83ms, PZ_Master0:140ms - - ibs.print_imageset_table(exclude_columns=['imageset_uuid']) - """ - from wbia.control import _sql_helpers + # BBB (12-Jan-12021) Disabled the ability to make the database read-only + self.readonly = False + + # Upgrade the database + from wbia.control._sql_helpers import ensure_correct_version from wbia.control import STAGING_SCHEMA - # Before load, ensure database has been backed up for the day - backup_idx = ut.get_argval('--loadbackup-staging', type_=int, default=None) - sqlstaging_fpath = None - if backup_idx is not None: - backups = _sql_helpers.get_backup_fpaths(ibs) - logger.info('backups = %r' % (backups,)) - sqlstaging_fpath = backups[backup_idx] - logger.info('CHOSE BACKUP sqlstaging_fpath = %r' % (sqlstaging_fpath,)) - # HACK - if backup_idx is None and ibs._needs_backup(): - try: - _sql_helpers.ensure_daily_database_backup( - ibs.get_ibsdir(), ibs.sqlstaging_fname, ibs.backupdir - ) - except IOError as ex: - ut.printex( - ex, - ('Failed making daily backup. ' 'Run with --nobackup to disable'), - ) - raise - # IBEIS SQL State Database - if request_stagingversion is None: - ibs.staging_version_expected = '1.2.0' - else: - ibs.staging_version_expected = request_stagingversion - # TODO: add this functionality to SQLController - if backup_idx is None: - new_version, new_fname = dtool.sql_control.dev_test_new_schema_version( - ibs.get_dbname(), - ibs.get_ibsdir(), - ibs.sqlstaging_fname, - ibs.staging_version_expected, - version_next='1.2.0', - ) - ibs.staging_version_expected = new_version - ibs.sqlstaging_fname = new_fname - if sqlstaging_fpath is None: - assert backup_idx is None - sqlstaging_fpath = join(ibs.get_ibsdir(), ibs.sqlstaging_fname) - readonly = None - else: - readonly = True - db_uri = 'file://{}'.format(realpath(sqlstaging_fpath)) - ibs.staging = dtool.SQLDatabaseController.from_uri( - db_uri, - readonly=readonly, + ensure_correct_version( + self, + self.staging, + DB_VERSION_EXPECTED, + STAGING_SCHEMA, + verbose=True, + dobackup=not self.readonly, ) - ibs.readonly = ibs.staging.readonly - - if backup_idx is None: - # Ensure correct schema versions - _sql_helpers.ensure_correct_version( - ibs, - ibs.staging, - ibs.staging_version_expected, - STAGING_SCHEMA, - verbose=ut.VERBOSE, - ) - # import sys - # sys.exit(1) @profile - def _init_depcache(ibs): + def _init_depcache(self): # Initialize dependency cache for images image_root_getters = {} - ibs.depc_image = dtool.DependencyCache( - root_tablename=const.IMAGE_TABLE, - default_fname=const.IMAGE_TABLE + '_depcache', - cache_dpath=ibs.get_cachedir(), - controller=ibs, - get_root_uuid=ibs.get_image_uuids, + self.depc_image = dtool.DependencyCache( + self, + const.IMAGE_TABLE, + self.get_image_uuids, root_getters=image_root_getters, ) - ibs.depc_image.initialize() + self.depc_image.initialize() - """ Need to reinit this sometimes if cache is ever deleted """ + # Need to reinit this sometimes if cache is ever deleted # Initialize dependency cache for annotations annot_root_getters = { - 'name': ibs.get_annot_names, - 'species': ibs.get_annot_species, - 'yaw': ibs.get_annot_yaws, - 'viewpoint_int': ibs.get_annot_viewpoint_int, - 'viewpoint': ibs.get_annot_viewpoints, - 'bbox': ibs.get_annot_bboxes, - 'verts': ibs.get_annot_verts, - 'image_uuid': lambda aids: ibs.get_image_uuids( - ibs.get_annot_image_rowids(aids) + 'name': self.get_annot_names, + 'species': self.get_annot_species, + 'yaw': self.get_annot_yaws, + 'viewpoint_int': self.get_annot_viewpoint_int, + 'viewpoint': self.get_annot_viewpoints, + 'bbox': self.get_annot_bboxes, + 'verts': self.get_annot_verts, + 'image_uuid': lambda aids: self.get_image_uuids( + self.get_annot_image_rowids(aids) ), - 'theta': ibs.get_annot_thetas, - 'occurrence_text': ibs.get_annot_occurrence_text, + 'theta': self.get_annot_thetas, + 'occurrence_text': self.get_annot_occurrence_text, } - ibs.depc_annot = dtool.DependencyCache( - # root_tablename='annot', # const.ANNOTATION_TABLE - root_tablename=const.ANNOTATION_TABLE, - default_fname=const.ANNOTATION_TABLE + '_depcache', - cache_dpath=ibs.get_cachedir(), - controller=ibs, - get_root_uuid=ibs.get_annot_visual_uuids, + self.depc_annot = dtool.DependencyCache( + self, + const.ANNOTATION_TABLE, + self.get_annot_visual_uuids, root_getters=annot_root_getters, ) # backwards compatibility - ibs.depc = ibs.depc_annot + self.depc = self.depc_annot # TODO: root_uuids should be specified as the # base_root_uuid plus a hash of the attributes that matter for the # requested computation. - ibs.depc_annot.initialize() + self.depc_annot.initialize() # Initialize dependency cache for parts part_root_getters = {} - ibs.depc_part = dtool.DependencyCache( - root_tablename=const.PART_TABLE, - default_fname=const.PART_TABLE + '_depcache', - cache_dpath=ibs.get_cachedir(), - controller=ibs, - get_root_uuid=ibs.get_part_uuids, + self.depc_part = dtool.DependencyCache( + self, + const.PART_TABLE, + self.get_part_uuids, root_getters=part_root_getters, ) - ibs.depc_part.initialize() + self.depc_part.initialize() - def _close_depcache(ibs): - ibs.depc_image.close() - ibs.depc_image = None - ibs.depc_annot.close() - ibs.depc_annot = None - ibs.depc_part.close() - ibs.depc_part = None + def _close_depcache(self): + self.depc_image.close() + self.depc_image = None + self.depc_annot.close() + self.depc_annot = None + self.depc_part.close() + self.depc_part = None - def disconnect_sqldatabase(ibs): + def disconnect_sqldatabase(self): logger.info('disconnecting from sql database') - ibs._close_depcache() - ibs.db.close() - ibs.db = None - ibs.staging.close() - ibs.staging = None - - def clone_handle(ibs, **kwargs): - ibs2 = IBEISController(dbdir=ibs.get_dbdir(), ensure=False) + self._close_depcache() + self.db.close() + self.db = None + self.staging.close() + self.staging = None + + def clone_handle(self, **kwargs): + ibs2 = IBEISController(dbdir=self.get_dbdir(), ensure=False) if len(kwargs) > 0: ibs2.update_query_cfg(**kwargs) - # if ibs.qreq is not None: - # ibs2._prep_qreq(ibs.qreq.qaids, ibs.qreq.daids) + # if self.qreq is not None: + # ibs2._prep_qreq(self.qreq.qaids, self.qreq.daids) return ibs2 - def backup_database(ibs): + def backup_database(self): from wbia.control import _sql_helpers - _sql_helpers.database_backup(ibs.get_ibsdir(), ibs.sqldb_fname, ibs.backupdir) + _sql_helpers.database_backup(self.get_ibsdir(), self.sqldb_fname, self.backupdir) _sql_helpers.database_backup( - ibs.get_ibsdir(), ibs.sqlstaging_fname, ibs.backupdir + self.get_ibsdir(), self.sqlstaging_fname, self.backupdir ) - def _send_wildbook_request(ibs, wbaddr, payload=None): + def daily_backup_database(self): + from wbia.control import _sql_helpers + + _sql_helpers.database_backup( + self.get_ibsdir(), self.sqldb_fname, self.backupdir, False + ) + _sql_helpers.database_backup( + self.get_ibsdir(), + self.sqlstaging_fname, + self.backupdir, + False, + ) + + def _send_wildbook_request(self, wbaddr, payload=None): import requests if wbaddr is None: @@ -898,7 +844,7 @@ def _send_wildbook_request(ibs, wbaddr, payload=None): return response def _init_dirs( - ibs, dbdir=None, dbname='testdb_1', workdir='~/wbia_workdir', ensure=True + self, dbdir=None, dbname='testdb_1', workdir='~/wbia_workdir', ensure=True ): """ Define ibs directories @@ -907,67 +853,67 @@ def _init_dirs( REL_PATHS = const.REL_PATHS if not ut.QUIET: - logger.info('[ibs._init_dirs] ibs.dbdir = %r' % dbdir) + logger.info('[self._init_dirs] self.dbdir = %r' % dbdir) if dbdir is not None: workdir, dbname = split(dbdir) - ibs.workdir = ut.truepath(workdir) - ibs.dbname = dbname - ibs.sqldb_fname = PATH_NAMES.sqldb - ibs.sqlstaging_fname = PATH_NAMES.sqlstaging + self.workdir = ut.truepath(workdir) + self.dbname = dbname + self.sqldb_fname = PATH_NAMES.sqldb + self.sqlstaging_fname = PATH_NAMES.sqlstaging # Make sure you are not nesting databases assert PATH_NAMES._ibsdb != ut.dirsplit( - ibs.workdir + self.workdir ), 'cannot work in _ibsdb internals' assert PATH_NAMES._ibsdb != dbname, 'cannot create db in _ibsdb internals' - ibs.dbdir = join(ibs.workdir, ibs.dbname) + self.dbdir = join(self.workdir, self.dbname) # All internal paths live in /_ibsdb # TODO: constantify these # so non controller objects (like in score normalization) have access # to these - ibs._ibsdb = join(ibs.dbdir, REL_PATHS._ibsdb) - ibs.trashdir = join(ibs.dbdir, REL_PATHS.trashdir) - ibs.cachedir = join(ibs.dbdir, REL_PATHS.cache) - ibs.backupdir = join(ibs.dbdir, REL_PATHS.backups) - ibs.logsdir = join(ibs.dbdir, REL_PATHS.logs) - ibs.chipdir = join(ibs.dbdir, REL_PATHS.chips) - ibs.imgdir = join(ibs.dbdir, REL_PATHS.images) - ibs.uploadsdir = join(ibs.dbdir, REL_PATHS.uploads) + self._ibsdb = join(self.dbdir, REL_PATHS._ibsdb) + self.trashdir = join(self.dbdir, REL_PATHS.trashdir) + self.cachedir = join(self.dbdir, REL_PATHS.cache) + self.backupdir = join(self.dbdir, REL_PATHS.backups) + self.logsdir = join(self.dbdir, REL_PATHS.logs) + self.chipdir = join(self.dbdir, REL_PATHS.chips) + self.imgdir = join(self.dbdir, REL_PATHS.images) + self.uploadsdir = join(self.dbdir, REL_PATHS.uploads) # All computed dirs live in /_ibsdb/_wbia_cache - ibs.thumb_dpath = join(ibs.dbdir, REL_PATHS.thumbs) - ibs.flanndir = join(ibs.dbdir, REL_PATHS.flann) - ibs.qresdir = join(ibs.dbdir, REL_PATHS.qres) - ibs.bigcachedir = join(ibs.dbdir, REL_PATHS.bigcache) - ibs.distinctdir = join(ibs.dbdir, REL_PATHS.distinctdir) + self.thumb_dpath = join(self.dbdir, REL_PATHS.thumbs) + self.flanndir = join(self.dbdir, REL_PATHS.flann) + self.qresdir = join(self.dbdir, REL_PATHS.qres) + self.bigcachedir = join(self.dbdir, REL_PATHS.bigcache) + self.distinctdir = join(self.dbdir, REL_PATHS.distinctdir) if ensure: - ibs.ensure_directories() + self.ensure_directories() assert dbdir is not None, 'must specify database directory' - def ensure_directories(ibs): + def ensure_directories(self): """ Makes sure the core directores for the controller exist """ _verbose = ut.VERBOSE - ut.ensuredir(ibs._ibsdb) - ut.ensuredir(ibs.cachedir, verbose=_verbose) - ut.ensuredir(ibs.backupdir, verbose=_verbose) - ut.ensuredir(ibs.logsdir, verbose=_verbose) - ut.ensuredir(ibs.workdir, verbose=_verbose) - ut.ensuredir(ibs.imgdir, verbose=_verbose) - ut.ensuredir(ibs.chipdir, verbose=_verbose) - ut.ensuredir(ibs.flanndir, verbose=_verbose) - ut.ensuredir(ibs.qresdir, verbose=_verbose) - ut.ensuredir(ibs.bigcachedir, verbose=_verbose) - ut.ensuredir(ibs.thumb_dpath, verbose=_verbose) - ut.ensuredir(ibs.distinctdir, verbose=_verbose) - ibs.get_smart_patrol_dir() + ut.ensuredir(self._ibsdb) + ut.ensuredir(self.cachedir, verbose=_verbose) + ut.ensuredir(self.backupdir, verbose=_verbose) + ut.ensuredir(self.logsdir, verbose=_verbose) + ut.ensuredir(self.workdir, verbose=_verbose) + ut.ensuredir(self.imgdir, verbose=_verbose) + ut.ensuredir(self.chipdir, verbose=_verbose) + ut.ensuredir(self.flanndir, verbose=_verbose) + ut.ensuredir(self.qresdir, verbose=_verbose) + ut.ensuredir(self.bigcachedir, verbose=_verbose) + ut.ensuredir(self.thumb_dpath, verbose=_verbose) + ut.ensuredir(self.distinctdir, verbose=_verbose) + self.get_smart_patrol_dir() # -------------- # --- DIRS ---- # -------------- @register_api('/api/core/db/name/', methods=['GET']) - def get_dbname(ibs): + def get_dbname(self): """ Returns: list_ (list): database name @@ -976,14 +922,14 @@ def get_dbname(ibs): Method: GET URL: /api/core/db/name/ """ - return ibs.dbname + return self.dbname - def get_db_name(ibs): - """ Alias for ibs.get_dbname(). """ - return ibs.get_dbname() + def get_db_name(self): + """ Alias for self.get_dbname(). """ + return self.get_dbname() @register_api(CORE_DB_UUID_INIT_API_RULE, methods=['GET']) - def get_db_init_uuid(ibs): + def get_db_init_uuid(self): """ Returns: UUID: The SQLDatabaseController's initialization UUID @@ -992,125 +938,125 @@ def get_db_init_uuid(ibs): Method: GET URL: /api/core/db/uuid/init/ """ - return ibs.db.get_db_init_uuid() + return self.db.get_db_init_uuid() - def get_logdir_local(ibs): - return ibs.logsdir + def get_logdir_local(self): + return self.logsdir - def get_logdir_global(ibs, local=False): + def get_logdir_global(self, local=False): if const.CONTAINERIZED: - return ibs.get_logdir_local() + return self.get_logdir_local() else: return ut.get_logging_dir(appname='wbia') - def get_dbdir(ibs): + def get_dbdir(self): """ database dir with ibs internal directory """ - return ibs.dbdir + return self.dbdir - def get_db_core_path(ibs): - return ibs.db.uri + def get_db_core_path(self): + return self.db.uri - def get_db_staging_path(ibs): - return ibs.staging.uri + def get_db_staging_path(self): + return self.staging.uri - def get_db_cache_path(ibs): - return ibs.dbcache.uri + def get_db_cache_path(self): + return self.dbcache.uri - def get_shelves_path(ibs): + def get_shelves_path(self): engine_slot = const.ENGINE_SLOT engine_slot = str(engine_slot).lower() if engine_slot in ['none', 'null', '1', 'default']: engine_shelve_dir = 'engine_shelves' else: engine_shelve_dir = 'engine_shelves_%s' % (engine_slot,) - return join(ibs.get_cachedir(), engine_shelve_dir) + return join(self.get_cachedir(), engine_shelve_dir) - def get_trashdir(ibs): - return ibs.trashdir + def get_trashdir(self): + return self.trashdir - def get_ibsdir(ibs): + def get_ibsdir(self): """ ibs internal directory """ - return ibs._ibsdb + return self._ibsdb - def get_chipdir(ibs): - return ibs.chipdir + def get_chipdir(self): + return self.chipdir - def get_probchip_dir(ibs): - return join(ibs.get_cachedir(), 'prob_chips') + def get_probchip_dir(self): + return join(self.get_cachedir(), 'prob_chips') - def get_fig_dir(ibs): + def get_fig_dir(self): """ ibs internal directory """ - return join(ibs._ibsdb, 'figures') + return join(self._ibsdb, 'figures') - def get_imgdir(ibs): + def get_imgdir(self): """ ibs internal directory """ - return ibs.imgdir + return self.imgdir - def get_uploadsdir(ibs): + def get_uploadsdir(self): """ ibs internal directory """ - return ibs.uploadsdir + return self.uploadsdir - def get_thumbdir(ibs): + def get_thumbdir(self): """ database directory where thumbnails are cached """ - return ibs.thumb_dpath + return self.thumb_dpath - def get_workdir(ibs): + def get_workdir(self): """ directory where databases are saved to """ - return ibs.workdir + return self.workdir - def get_cachedir(ibs): + def get_cachedir(self): """ database directory of all cached files """ - return ibs.cachedir + return self.cachedir - def get_match_thumbdir(ibs): - match_thumb_dir = ut.unixjoin(ibs.get_cachedir(), 'match_thumbs') + def get_match_thumbdir(self): + match_thumb_dir = ut.unixjoin(self.get_cachedir(), 'match_thumbs') ut.ensuredir(match_thumb_dir) return match_thumb_dir - def get_wbia_resource_dir(ibs): + def get_wbia_resource_dir(self): """ returns the global resource dir in .config or AppData or whatever """ resource_dir = sysres.get_wbia_resource_dir() return resource_dir - def get_detect_modeldir(ibs): + def get_detect_modeldir(self): return join(sysres.get_wbia_resource_dir(), 'detectmodels') - def get_detectimg_cachedir(ibs): + def get_detectimg_cachedir(self): """ Returns: detectimgdir (str): database directory of image resized for detections """ - return join(ibs.cachedir, const.PATH_NAMES.detectimg) + return join(self.cachedir, const.PATH_NAMES.detectimg) - def get_flann_cachedir(ibs): + def get_flann_cachedir(self): """ Returns: flanndir (str): database directory where the FLANN KD-Tree is stored """ - return ibs.flanndir + return self.flanndir - def get_qres_cachedir(ibs): + def get_qres_cachedir(self): """ Returns: qresdir (str): database directory where query results are stored """ - return ibs.qresdir + return self.qresdir - def get_neighbor_cachedir(ibs): - neighbor_cachedir = ut.unixjoin(ibs.get_cachedir(), 'neighborcache2') + def get_neighbor_cachedir(self): + neighbor_cachedir = ut.unixjoin(self.get_cachedir(), 'neighborcache2') return neighbor_cachedir - def get_big_cachedir(ibs): + def get_big_cachedir(self): """ Returns: bigcachedir (str): database directory where aggregate results are stored """ - return ibs.bigcachedir + return self.bigcachedir - def get_smart_patrol_dir(ibs, ensure=True): + def get_smart_patrol_dir(self, ensure=True): """ Args: ensure (bool): @@ -1133,7 +1079,7 @@ def get_smart_patrol_dir(ibs, ensure=True): >>> # verify results >>> ut.assertpath(smart_patrol_dpath, verbose=True) """ - smart_patrol_dpath = join(ibs.dbdir, const.PATH_NAMES.smartpatrol) + smart_patrol_dpath = join(self.dbdir, const.PATH_NAMES.smartpatrol) if ensure: ut.ensuredir(smart_patrol_dpath) return smart_patrol_dpath @@ -1143,51 +1089,46 @@ def get_smart_patrol_dir(ibs, ensure=True): # ------------------ @register_api('/log/current/', methods=['GET']) - def get_current_log_text(ibs): + def get_current_log_text(self): r""" - CommandLine: - python -m wbia.control.IBEISControl --exec-get_current_log_text - python -m wbia.control.IBEISControl --exec-get_current_log_text --domain http://52.33.105.88 Example: >>> # xdoctest: +REQUIRES(--web-tests) - >>> from wbia.control.IBEISControl import * # NOQA >>> import wbia - >>> import wbia.web - >>> with wbia.opendb_bg_web('testdb1', start_job_queue=False, managed=True) as web_ibs: - ... resp = web_ibs.send_wbia_request('/log/current/', 'get') - >>> print('\n-------Logs ----: \n' ) - >>> print(resp) - >>> print('\nL____ END LOGS ___\n') + >>> with wbia.opendb_with_web('testdb1') as (ibs, client): + ... resp = client.get('/log/current/') + >>> resp.json + {'status': {'success': True, 'code': 200, 'message': '', 'cache': -1}, 'response': None} + """ text = ut.get_current_log_text() return text @register_api('/api/core/db/info/', methods=['GET']) - def get_dbinfo(ibs): + def get_dbinfo(self): from wbia.other import dbinfo - locals_ = dbinfo.get_dbinfo(ibs) + locals_ = dbinfo.get_dbinfo(self) return locals_['info_str'] - # return ut.repr2(dbinfo.get_dbinfo(ibs), nl=1)['infostr'] + # return ut.repr2(dbinfo.get_dbinfo(self), nl=1)['infostr'] # -------------- # --- MISC ---- # -------------- - def copy_database(ibs, dest_dbdir): + def copy_database(self, dest_dbdir): # TODO: rectify with rsync, script, and merge script. from wbia.init import sysres - sysres.copy_wbiadb(ibs.get_dbdir(), dest_dbdir) + sysres.copy_wbiadb(self.get_dbdir(), dest_dbdir) - def dump_database_csv(ibs): - dump_dir = join(ibs.get_dbdir(), 'CSV_DUMP') - ibs.db.dump_tables_to_csv(dump_dir=dump_dir) + def dump_database_csv(self): + dump_dir = join(self.get_dbdir(), 'CSV_DUMP') + self.db.dump_tables_to_csv(dump_dir=dump_dir) with open(join(dump_dir, '_ibsdb.dump'), 'w') as fp: - dump(ibs.db.connection, fp) + dump(self.db.connection, fp) - def get_database_icon(ibs, max_dsize=(None, 192), aid=None): + def get_database_icon(self, max_dsize=(None, 192), aid=None): r""" Args: max_dsize (tuple): (default = (None, 192)) @@ -1204,58 +1145,58 @@ def get_database_icon(ibs, max_dsize=(None, 192), aid=None): >>> from wbia.control.IBEISControl import * # NOQA >>> import wbia >>> ibs = wbia.opendb(defaultdb='testdb1') - >>> icon = ibs.get_database_icon() + >>> icon = self.get_database_icon() >>> ut.quit_if_noshow() >>> import wbia.plottool as pt >>> pt.imshow(icon) >>> ut.show_if_requested() """ - # if ibs.get_dbname() == 'Oxford': + # if self.get_dbname() == 'Oxford': # pass # else: import vtool as vt - if hasattr(ibs, 'force_icon_aid'): - aid = ibs.force_icon_aid + if hasattr(self, 'force_icon_aid'): + aid = self.force_icon_aid if aid is None: - species = ibs.get_primary_database_species() + species = self.get_primary_database_species() # Use a url to get the icon url = { - ibs.const.TEST_SPECIES.GIR_MASAI: 'http://i.imgur.com/tGDVaKC.png', - ibs.const.TEST_SPECIES.ZEB_PLAIN: 'http://i.imgur.com/2Ge1PRg.png', - ibs.const.TEST_SPECIES.ZEB_GREVY: 'http://i.imgur.com/PaUT45f.png', + self.const.TEST_SPECIES.GIR_MASAI: 'http://i.imgur.com/tGDVaKC.png', + self.const.TEST_SPECIES.ZEB_PLAIN: 'http://i.imgur.com/2Ge1PRg.png', + self.const.TEST_SPECIES.ZEB_GREVY: 'http://i.imgur.com/PaUT45f.png', }.get(species, None) if url is not None: icon = vt.imread(ut.grab_file_url(url), orient='auto') else: # HACK: (this should probably be a db setting) # use an specific aid to get the icon - aid = {'Oxford': 73, 'seaturtles': 37}.get(ibs.get_dbname(), None) + aid = {'Oxford': 73, 'seaturtles': 37}.get(self.get_dbname(), None) if aid is None: # otherwise just grab a random aid - aid = ibs.get_valid_aids()[0] + aid = self.get_valid_aids()[0] if aid is not None: - icon = ibs.get_annot_chips(aid) + icon = self.get_annot_chips(aid) icon = vt.resize_to_maxdims(icon, max_dsize) return icon - def _custom_ibsstr(ibs): + def _custom_ibsstr(self): # typestr = ut.type_str(type(ibs)).split('.')[-1] - typestr = ibs.__class__.__name__ - dbname = ibs.get_dbname() + typestr = self.__class__.__name__ + dbname = self.get_dbname() # hash_str = hex(id(ibs)) # ibsstr = '<%s(%s) at %s>' % (typestr, dbname, hash_str, ) - hash_str = ibs.get_db_init_uuid() + hash_str = self.get_db_init_uuid() ibsstr = '<%s(%s) with UUID %s>' % (typestr, dbname, hash_str) return ibsstr - def __str__(ibs): - return ibs._custom_ibsstr() + def __str__(self): + return self._custom_ibsstr() - def __repr__(ibs): - return ibs._custom_ibsstr() + def __repr__(self): + return self._custom_ibsstr() - def __getstate__(ibs): + def __getstate__(self): """ Example: >>> # ENABLE_DOCTEST @@ -1267,12 +1208,12 @@ def __getstate__(ibs): """ # Hack to allow for wbia objects to be pickled state = { - 'dbdir': ibs.get_dbdir(), + 'dbdir': self.get_dbdir(), 'machine_name': ut.get_computer_name(), } return state - def __setstate__(ibs, state): + def __setstate__(self, state): # Hack to allow for wbia objects to be pickled import wbia @@ -1288,20 +1229,20 @@ def __setstate__(ibs, state): if not iswarning: raise ibs2 = wbia.opendb(dbdir=dbdir, web=False) - ibs.__dict__.update(**ibs2.__dict__) + self.__dict__.update(**ibs2.__dict__) - def predict_ws_injury_interim_svm(ibs, aids): + def predict_ws_injury_interim_svm(self, aids): from wbia.scripts import classify_shark - return classify_shark.predict_ws_injury_interim_svm(ibs, aids) + return classify_shark.predict_ws_injury_interim_svm(self, aids) def get_web_port_via_scan( - ibs, url_base='127.0.0.1', port_base=5000, scan_limit=100, verbose=True + self, url_base='127.0.0.1', port_base=5000, scan_limit=100, verbose=True ): import requests api_rule = CORE_DB_UUID_INIT_API_RULE - target_uuid = ibs.get_db_init_uuid() + target_uuid = self.get_db_init_uuid() for candidate_port in range(port_base, port_base + scan_limit + 1): candidate_url = 'http://%s:%s%s' % (url_base, candidate_port, api_rule) try: diff --git a/wbia/control/STAGING_SCHEMA.py b/wbia/control/STAGING_SCHEMA.py index b165166f53..7cbd69ce69 100644 --- a/wbia/control/STAGING_SCHEMA.py +++ b/wbia/control/STAGING_SCHEMA.py @@ -40,21 +40,27 @@ @profile def update_1_0_0(db, ibs=None): + columns = [ + ('review_rowid', 'INTEGER PRIMARY KEY'), + ('annot_1_rowid', 'INTEGER NOT NULL'), + ('annot_2_rowid', 'INTEGER NOT NULL'), + ('review_count', 'INTEGER NOT NULL'), + ('review_decision', 'INTEGER NOT NULL'), + ( + 'review_time_posix', + """INTEGER DEFAULT (CAST(STRFTIME('%s', 'NOW', 'UTC') AS INTEGER))""", + ), # this should probably be UCT + ('review_identity', 'TEXT'), + ('review_tags', 'TEXT'), + ] + if db._engine.dialect.name == 'postgresql': + columns[5] = ( + 'review_time_posix', + "INTEGER DEFAULT (CAST(EXTRACT(EPOCH FROM NOW() AT TIME ZONE 'UTC') AS INTEGER))", + ) db.add_table( const.REVIEW_TABLE, - ( - ('review_rowid', 'INTEGER PRIMARY KEY'), - ('annot_1_rowid', 'INTEGER NOT NULL'), - ('annot_2_rowid', 'INTEGER NOT NULL'), - ('review_count', 'INTEGER NOT NULL'), - ('review_decision', 'INTEGER NOT NULL'), - ( - 'review_time_posix', - """INTEGER DEFAULT (CAST(STRFTIME('%s', 'NOW', 'UTC') AS INTEGER))""", - ), # this should probably be UCT - ('review_identity', 'TEXT'), - ('review_tags', 'TEXT'), - ), + columns, superkeys=[('annot_1_rowid', 'annot_2_rowid', 'review_count')], docstr=""" Used to store completed user review states of two matched annotations @@ -104,20 +110,26 @@ def update_1_0_3(db, ibs=None): def update_1_1_0(db, ibs=None): + columns = [ + ('test_rowid', 'INTEGER PRIMARY KEY'), + ('test_uuid', 'UUID'), + ('test_user_identity', 'TEXT'), + ('test_challenge_json', 'TEXT'), + ('test_response_json', 'TEXT'), + ('test_result', 'INTEGER'), + ( + 'test_time_posix', + """INTEGER DEFAULT (CAST(STRFTIME('%s', 'NOW', 'UTC') AS INTEGER))""", + ), # this should probably be UCT + ] + if db._engine.dialect.name == 'postgresql': + columns[6] = ( + 'test_time_posix', + "INTEGER DEFAULT (CAST(EXTRACT(EPOCH FROM NOW() AT TIME ZONE 'UTC') AS INTEGER))", + ) db.add_table( const.TEST_TABLE, - ( - ('test_rowid', 'INTEGER PRIMARY KEY'), - ('test_uuid', 'UUID'), - ('test_user_identity', 'TEXT'), - ('test_challenge_json', 'TEXT'), - ('test_response_json', 'TEXT'), - ('test_result', 'INTEGER'), - ( - 'test_time_posix', - """INTEGER DEFAULT (CAST(STRFTIME('%s', 'NOW', 'UTC') AS INTEGER))""", - ), # this should probably be UCT - ), + columns, superkeys=[('test_uuid',)], docstr=""" Used to store tests given to the user, their responses, and their results diff --git a/wbia/control/_sql_helpers.py b/wbia/control/_sql_helpers.py index db19f22ecb..b0fa506d12 100644 --- a/wbia/control/_sql_helpers.py +++ b/wbia/control/_sql_helpers.py @@ -47,8 +47,8 @@ def _devcheck_backups(): sorted(ut.glob(join(dbdir, '_wbia_backups'), '*staging_back*.sqlite3')) fpaths = sorted(ut.glob(join(dbdir, '_wbia_backups'), '*database_back*.sqlite3')) for fpath in fpaths: - db_uri = 'file://{}'.format(realpath(fpath)) - db = dt.SQLDatabaseController.from_uri(db_uri) + db_uri = 'sqlite:///{}'.format(realpath(fpath)) + db = dt.SQLDatabaseController(db_uri, 'PZ_Master1') logger.info('fpath = %r' % (fpath,)) num_edges = len(db.executeone('SELECT rowid from annotmatch')) logger.info('num_edges = %r' % (num_edges,)) @@ -185,8 +185,8 @@ def copy_database(src_fpath, dst_fpath): # blocked lock for all processes potentially writing to the database timeout = 12 * 60 * 60 # Allow a lock of up to 12 hours for a database backup routine if not src_fpath.startswith('file:'): - src_fpath = 'file://{}'.format(realpath(src_fpath)) - db = dtool.SQLDatabaseController.from_uri(src_fpath, timeout=timeout) + src_fpath = 'sqlite:///{}'.format(realpath(src_fpath)) + db = dtool.SQLDatabaseController(src_fpath, 'copy', timeout=timeout) db.backup(dst_fpath) @@ -359,6 +359,10 @@ def update_schema_version( clearbackup = False FIXME: AN SQL HELPER FUNCTION SHOULD BE AGNOSTIC TO CONTROLER OBJECTS """ + if db._engine.dialect.name != 'sqlite': + # Backup is based on copying files so if we're not using sqlite, skip + # backup + dobackup = False def _check_superkeys(): all_tablename_list = db.get_table_names() @@ -379,7 +383,7 @@ def _check_superkeys(): ), 'ERROR UPDATING DATABASE, SUPERKEYS of %s DROPPED!' % (tablename,) logger.info('[_SQL] update_schema_version') - db_fpath = db.uri.replace('file://', '') + db_fpath = db.uri.replace('sqlite://', '') if dobackup: db_dpath, db_fname = split(db_fpath) db_fname_noext, ext = splitext(db_fname) @@ -426,10 +430,13 @@ def _check_superkeys(): pre, update, post = db_versions[next_version] if pre is not None: pre(db, ibs=ibs) + db.invalidate_tables_cache() if update is not None: update(db, ibs=ibs) + db.invalidate_tables_cache() if post is not None: post(db, ibs=ibs) + db.invalidate_tables_cache() _check_superkeys() except Exception as ex: if dobackup: @@ -540,7 +547,7 @@ def get_nth_test_schema_version(schema_spec, n=-1): cachedir = ut.ensure_app_resource_dir('wbia_test') db_fname = 'test_%s.sqlite3' % dbname ut.delete(join(cachedir, db_fname)) - db_uri = 'file://{}'.format(realpath(join(cachedir, db_fname))) - db = SQLDatabaseController.from_uri(db_uri) + db_uri = 'sqlite:///{}'.format(realpath(join(cachedir, db_fname))) + db = SQLDatabaseController(db_uri, dbname) ensure_correct_version(None, db, version_expected, schema_spec, dobackup=False) return db diff --git a/wbia/control/controller_inject.py b/wbia/control/controller_inject.py index 068b9bd7ae..58cb7539a8 100644 --- a/wbia/control/controller_inject.py +++ b/wbia/control/controller_inject.py @@ -489,35 +489,20 @@ def translate_wbia_webcall(func, *args, **kwargs): Returns: tuple: (output, True, 200, None, jQuery_callback) - CommandLine: - python -m wbia.control.controller_inject --exec-translate_wbia_webcall - python -m wbia.control.controller_inject --exec-translate_wbia_webcall --domain http://52.33.105.88 - Example: >>> # xdoctest: +REQUIRES(--web-tests) >>> from wbia.control.controller_inject import * # NOQA >>> import wbia - >>> import time - >>> import wbia.web - >>> with wbia.opendb_bg_web('testdb1', start_job_queue=False, managed=True) as web_ibs: - ... aids = web_ibs.send_wbia_request('/api/annot/', 'get') - ... uuid_list = web_ibs.send_wbia_request('/api/annot/uuids/', aid_list=aids, json=False) - ... failrsp = web_ibs.send_wbia_request('/api/annot/uuids/', json=False) - ... failrsp2 = web_ibs.send_wbia_request('/api/query/chips/simple_dict//', 'get', qaid_list=[0], daid_list=[0], json=False) - ... log_text = web_ibs.send_wbia_request('/api/query/chips/simple_dict/', 'get', qaid_list=[0], daid_list=[0], json=False) - >>> time.sleep(.1) - >>> print('\n---\nuuid_list = %r' % (uuid_list,)) - >>> print('\n---\nfailrsp =\n%s' % (failrsp,)) - >>> print('\n---\nfailrsp2 =\n%s' % (failrsp2,)) + >>> with wbia.opendb_with_web('testdb1') as (ibs, client): + ... aids = client.get('/api/annot/').json + ... failrsp = client.post('/api/annot/uuids/') + ... failrsp2 = client.get('/api/query/chips/simple_dict//', data={'qaid_list': [0], 'daid_list': [0]}) + ... log_text = client.get('/api/query/chips/simple_dict/', data={'qaid_list': [0], 'daid_list': [0]}) + >>> print('\n---\nfailrsp =\n%s' % (failrsp.data,)) + >>> print('\n---\nfailrsp2 =\n%s' % (failrsp2.data,)) >>> print('Finished test') + Finished test - Ignore: - app = get_flask_app() - with app.app_context(): - #ibs = wbia.opendb('testdb1') - func = ibs.get_annot_uuids - args = tuple() - kwargs = dict() """ assert len(args) == 0, 'There should not be any args=%r' % (args,) @@ -590,7 +575,7 @@ def translate_wbia_webcall(func, *args, **kwargs): output = func(**kwargs) except TypeError: try: - output = func(ibs=ibs, **kwargs) + output = func(ibs, **kwargs) except WebException: raise except Exception as ex2: # NOQA diff --git a/wbia/control/manual_annot_funcs.py b/wbia/control/manual_annot_funcs.py index d58b963c78..3b7aac13ef 100644 --- a/wbia/control/manual_annot_funcs.py +++ b/wbia/control/manual_annot_funcs.py @@ -838,6 +838,8 @@ def filter_annotation_set( is_canonical=None, min_timedelta=None, ): + if not aid_list: # no need to filter if empty + return aid_list # -- valid aid filtering -- # filter by is_exemplar if is_exemplar is True: @@ -2327,6 +2329,14 @@ def set_annot_viewpoints( message = 'Could not purge CurvRankDorsal cache for viewpoint' # raise RuntimeError(message) logger.info(message) + try: + ibs.wbia_plugin_curvrank_v2_delete_cache_optimized( + update_aid_list, 'CurvRankTwoDorsal' + ) + except Exception: + message = 'Could not purge CurvRankTwoDorsal cache for viewpoint' + # raise RuntimeError(message) + logger.info(message) try: ibs.wbia_plugin_curvrank_delete_cache_optimized( update_aid_list, 'CurvRankFinfindrHybridDorsal' @@ -2482,11 +2492,12 @@ def get_annot_part_rowids(ibs, aid_list, is_staged=False): """ # FIXME: This index should when the database is defined. # Ensure that an index exists on the image column of the annotation table - ibs.db.connection.execute( - """ - CREATE INDEX IF NOT EXISTS aid_to_part_rowids ON parts (annot_rowid); - """ - ).fetchall() + with ibs.db.connect() as conn: + conn.execute( + """ + CREATE INDEX IF NOT EXISTS aid_to_part_rowids ON parts (annot_rowid); + """ + ) # The index maxes the following query very efficient part_rowids_list = ibs.db.get( ibs.const.PART_TABLE, diff --git a/wbia/control/manual_gsgrelate_funcs.py b/wbia/control/manual_gsgrelate_funcs.py index 9310832a6a..daeb66ec18 100644 --- a/wbia/control/manual_gsgrelate_funcs.py +++ b/wbia/control/manual_gsgrelate_funcs.py @@ -55,13 +55,12 @@ def get_image_gsgrids(ibs, gid_list): list_ (list): a list of imageset-image-relationship rowids for each imageid""" # TODO: Group type params_iter = ((gid,) for gid in gid_list) - where_clause = 'image_rowid=?' # list of relationships for each image - gsgrids_list = ibs.db.get_where( + gsgrids_list = ibs.db.get_where_eq( const.GSG_RELATION_TABLE, ('gsgr_rowid',), params_iter, - where_clause, + ('image_rowid',), unpack_scalars=False, ) return gsgrids_list @@ -173,10 +172,13 @@ def get_gsgr_rowid_from_superkey(ibs, gid_list, imgsetid_list): Returns: gsgrid_list (list): eg-relate-ids from info constrained to be unique (imgsetid, gid)""" colnames = ('image_rowid',) + where_colnames = ('image_rowid', 'imageset_rowid') params_iter = zip(gid_list, imgsetid_list) - where_clause = 'image_rowid=? AND imageset_rowid=?' - gsgrid_list = ibs.db.get_where( - const.GSG_RELATION_TABLE, colnames, params_iter, where_clause + gsgrid_list = ibs.db.get_where_eq( + const.GSG_RELATION_TABLE, + colnames, + params_iter, + where_colnames, ) return gsgrid_list diff --git a/wbia/control/manual_image_funcs.py b/wbia/control/manual_image_funcs.py index e655d49ea5..12afb194a1 100644 --- a/wbia/control/manual_image_funcs.py +++ b/wbia/control/manual_image_funcs.py @@ -112,7 +112,13 @@ def _get_all_image_rowids(ibs): @accessor_decors.ider @register_api('/api/image/', methods=['GET']) def get_valid_gids( - ibs, imgsetid=None, require_unixtime=False, require_gps=None, reviewed=None, **kwargs + ibs, + imgsetid=None, + imgsetid_list=(), + require_unixtime=False, + require_gps=None, + reviewed=None, + **kwargs ): r""" Args: @@ -147,8 +153,10 @@ def get_valid_gids( >>> print(result) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] """ - if imgsetid is None: + if imgsetid is None and not imgsetid_list: gid_list = ibs._get_all_gids() + elif imgsetid_list: + gid_list = ibs.get_imageset_gids(imgsetid_list) else: assert not ut.isiterable(imgsetid) gid_list = ibs.get_imageset_gids(imgsetid) @@ -397,15 +405,11 @@ def add_images( ) # - # Execute SQL Add - from distutils.version import LooseVersion - - if LooseVersion(ibs.db.get_db_version()) >= LooseVersion('1.3.4'): - colnames = IMAGE_COLNAMES + ('image_original_path', 'image_location_code') - params_list = [ - tuple(params) + (gpath, location_for_names) if params is not None else None - for params, gpath in zip(params_list, gpath_list) - ] + colnames = IMAGE_COLNAMES + ('image_original_path', 'image_location_code') + params_list = [ + tuple(params) + (gpath, location_for_names) if params is not None else None + for params, gpath in zip(params_list, gpath_list) + ] all_gid_list = ibs.db.add_cleanly( const.IMAGE_TABLE, colnames, params_list, ibs.get_image_gids_from_uuid @@ -2172,13 +2176,14 @@ def get_image_imgsetids(ibs, gid_list): if NEW_INDEX_HACK: # FIXME: This index should when the database is defined. # Ensure that an index exists on the image column of the annotation table - ibs.db.connection.execute( - """ - CREATE INDEX IF NOT EXISTS gs_to_gids ON {GSG_RELATION_TABLE} ({IMAGE_ROWID}); - """.format( - GSG_RELATION_TABLE=const.GSG_RELATION_TABLE, IMAGE_ROWID=IMAGE_ROWID + with ibs.db.connect() as conn: + conn.execute( + """ + CREATE INDEX IF NOT EXISTS gs_to_gids ON {GSG_RELATION_TABLE} ({IMAGE_ROWID}); + """.format( + GSG_RELATION_TABLE=const.GSG_RELATION_TABLE, IMAGE_ROWID=IMAGE_ROWID + ) ) - ).fetchall() colnames = ('imageset_rowid',) imgsetids_list = ibs.db.get( const.GSG_RELATION_TABLE, @@ -2276,11 +2281,12 @@ def get_image_aids(ibs, gid_list, is_staged=False, __check_staged__=True): # FIXME: This index should when the database is defined. # Ensure that an index exists on the image column of the annotation table - ibs.db.connection.execute( - """ - CREATE INDEX IF NOT EXISTS gid_to_aids ON annotations (image_rowid); - """ - ).fetchall() + with ibs.db.connect() as conn: + conn.execute( + """ + CREATE INDEX IF NOT EXISTS gid_to_aids ON annotations (image_rowid); + """ + ) # The index maxes the following query very efficient if __check_staged__: @@ -2320,7 +2326,8 @@ def get_image_aids(ibs, gid_list, is_staged=False, __check_staged__=True): """.format( input_str=input_str, ANNOTATION_TABLE=const.ANNOTATION_TABLE ) - pair_list = ibs.db.connection.execute(opstr).fetchall() + with ibs.db.connect() as conn: + pair_list = conn.execute(opstr).fetchall() aidscol = np.array(ut.get_list_column(pair_list, 0)) gidscol = np.array(ut.get_list_column(pair_list, 1)) unique_gids, groupx = vt.group_indices(gidscol) diff --git a/wbia/control/manual_imageset_funcs.py b/wbia/control/manual_imageset_funcs.py index 48c498128d..fb534d6ee5 100644 --- a/wbia/control/manual_imageset_funcs.py +++ b/wbia/control/manual_imageset_funcs.py @@ -543,13 +543,14 @@ def get_imageset_gids(ibs, imgsetid_list): if NEW_INDEX_HACK: # FIXME: This index should when the database is defined. # Ensure that an index exists on the image column of the annotation table - ibs.db.connection.execute( - """ - CREATE INDEX IF NOT EXISTS gids_to_gs ON {GSG_RELATION_TABLE} (imageset_rowid); - """.format( - GSG_RELATION_TABLE=const.GSG_RELATION_TABLE + with ibs.db.connect() as conn: + conn.execute( + """ + CREATE INDEX IF NOT EXISTS gids_to_gs ON {GSG_RELATION_TABLE} (imageset_rowid); + """.format( + GSG_RELATION_TABLE=const.GSG_RELATION_TABLE + ) ) - ).fetchall() gids_list = ibs.db.get( const.GSG_RELATION_TABLE, ('image_rowid',), @@ -594,37 +595,34 @@ def get_imageset_gsgrids(ibs, imgsetid_list=None, gid_list=None): if imgsetid_list is not None and gid_list is None: # TODO: Group type params_iter = ((imgsetid,) for imgsetid in imgsetid_list) - where_clause = 'imageset_rowid=?' # list of relationships for each imageset - gsgrids_list = ibs.db.get_where( + gsgrids_list = ibs.db.get_where_eq( const.GSG_RELATION_TABLE, ('gsgr_rowid',), params_iter, - where_clause, + ('imageset_rowid',), unpack_scalars=False, ) elif gid_list is not None and imgsetid_list is None: # TODO: Group type params_iter = ((gid,) for gid in gid_list) - where_clause = 'image_rowid=?' # list of relationships for each imageset - gsgrids_list = ibs.db.get_where( + gsgrids_list = ibs.db.get_where_eq( const.GSG_RELATION_TABLE, ('gsgr_rowid',), params_iter, - where_clause, + ('image_rowid',), unpack_scalars=False, ) else: # TODO: Group type params_iter = ((imgsetid, gid) for imgsetid, gid in zip(imgsetid_list, gid_list)) - where_clause = 'imageset_rowid=? AND image_rowid=?' # list of relationships for each imageset - gsgrids_list = ibs.db.get_where( + gsgrids_list = ibs.db.get_where_eq( const.GSG_RELATION_TABLE, ('gsgr_rowid',), params_iter, - where_clause, + ('imageset_rowid', 'image_rowid'), unpack_scalars=False, ) return gsgrids_list diff --git a/wbia/control/manual_lblannot_funcs.py b/wbia/control/manual_lblannot_funcs.py index b44a7b195b..d93b24da9e 100644 --- a/wbia/control/manual_lblannot_funcs.py +++ b/wbia/control/manual_lblannot_funcs.py @@ -100,9 +100,8 @@ def get_lblannot_rowid_from_superkey(ibs, lbltype_rowid_list, value_list): """ colnames = ('lblannot_rowid',) params_iter = zip(lbltype_rowid_list, value_list) - where_clause = 'lbltype_rowid=? AND lblannot_value=?' - lblannot_rowid_list = ibs.db.get_where( - const.LBLANNOT_TABLE, colnames, params_iter, where_clause + lblannot_rowid_list = ibs.db.get_where_eq( + const.LBLANNOT_TABLE, colnames, params_iter, ('lbltype_rowid', 'lblannot_value') ) # BIG HACK FOR ENFORCING UNKNOWN LBLANNOTS HAVE ROWID 0 lblannot_rowid_list = [ @@ -169,13 +168,12 @@ def get_alr_annot_rowids_from_lblannot_rowid(ibs, lblannot_rowid_list): # FIXME: SLOW # if verbose: # logger.info(ut.get_caller_name(N=list(range(0, 20)))) - where_clause = 'lblannot_rowid=?' params_iter = [(lblannot_rowid,) for lblannot_rowid in lblannot_rowid_list] - aids_list = ibs.db.get_where( + aids_list = ibs.db.get_where_eq( const.AL_RELATION_TABLE, ('annot_rowid',), params_iter, - where_clause, + ('lblannot_rowid',), unpack_scalars=False, ) return aids_list @@ -223,9 +221,8 @@ def get_alrid_from_superkey(ibs, aid_list, lblannot_rowid_list): """ colnames = ('annot_rowid',) params_iter = zip(aid_list, lblannot_rowid_list) - where_clause = 'annot_rowid=? AND lblannot_rowid=?' - alrid_list = ibs.db.get_where( - const.AL_RELATION_TABLE, colnames, params_iter, where_clause + alrid_list = ibs.db.get_where_eq( + const.AL_RELATION_TABLE, colnames, params_iter, ('annot_rowid', 'lblannot_rowid') ) return alrid_list diff --git a/wbia/control/manual_lblimage_funcs.py b/wbia/control/manual_lblimage_funcs.py index 2fb40c89a1..fc5847e770 100644 --- a/wbia/control/manual_lblimage_funcs.py +++ b/wbia/control/manual_lblimage_funcs.py @@ -101,9 +101,8 @@ def get_lblimage_rowid_from_superkey(ibs, lbltype_rowid_list, value_list): """ colnames = ('lblimage_rowid',) params_iter = zip(lbltype_rowid_list, value_list) - where_clause = 'lbltype_rowid=? AND lblimage_value=?' - lblimage_rowid_list = ibs.db.get_where( - const.LBLIMAGE_TABLE, colnames, params_iter, where_clause + lblimage_rowid_list = ibs.db.get_where_eq( + const.LBLIMAGE_TABLE, colnames, params_iter, ('lbltype_rowid', 'lblimage_value') ) return lblimage_rowid_list @@ -172,13 +171,12 @@ def get_lblimage_gids(ibs, lblimage_rowid_list): # FIXME: SLOW # if verbose: # logger.info(ut.get_caller_name(N=list(range(0, 20)))) - where_clause = 'lblimage_rowid=?' params_iter = [(lblimage_rowid,) for lblimage_rowid in lblimage_rowid_list] - gids_list = ibs.db.get_where( + gids_list = ibs.db.get_where_eq( const.GL_RELATION_TABLE, ('image_rowid',), params_iter, - where_clause, + ('lblimage_rowid',), unpack_scalars=False, ) return gids_list @@ -226,9 +224,8 @@ def get_glrid_from_superkey(ibs, gid_list, lblimage_rowid_list): """ colnames = ('image_rowid',) params_iter = zip(gid_list, lblimage_rowid_list) - where_clause = 'image_rowid=? AND lblimage_rowid=?' - glrid_list = ibs.db.get_where( - const.GL_RELATION_TABLE, colnames, params_iter, where_clause + glrid_list = ibs.db.get_where_eq( + const.GL_RELATION_TABLE, colnames, params_iter, ('image_rowid', 'lblimage_rowid') ) return glrid_list @@ -242,12 +239,11 @@ def get_image_glrids(ibs, gid_list): be only of a specific lbltype/category/type """ params_iter = ((gid,) for gid in gid_list) - where_clause = 'image_rowid=?' - glrids_list = ibs.db.get_where( + glrids_list = ibs.db.get_where_eq( const.GL_RELATION_TABLE, ('glr_rowid',), params_iter, - where_clause=where_clause, + ('image_rowid',), unpack_scalars=False, ) return glrids_list diff --git a/wbia/control/manual_meta_funcs.py b/wbia/control/manual_meta_funcs.py index 5d3bb4ba33..db550e6e0c 100644 --- a/wbia/control/manual_meta_funcs.py +++ b/wbia/control/manual_meta_funcs.py @@ -900,13 +900,12 @@ def get_metadata_value(ibs, metadata_key_list, db): URL: /api/metadata/value/ """ params_iter = ((metadata_key,) for metadata_key in metadata_key_list) - where_clause = 'metadata_key=?' # list of relationships for each image - metadata_value_list = db.get_where( + metadata_value_list = db.get_where_eq( const.METADATA_TABLE, ('metadata_value',), params_iter, - where_clause, + ('metadata_key',), unpack_scalars=True, ) return metadata_value_list @@ -924,13 +923,12 @@ def get_metadata_rowid_from_metadata_key(ibs, metadata_key_list, db): """ db = db[0] # Unwrap tuple, required by @accessor_decors.getter_1to1 decorator params_iter = ((metadata_key,) for metadata_key in metadata_key_list) - where_clause = 'metadata_key=?' # list of relationships for each image - metadata_rowid_list = db.get_where( + metadata_rowid_list = db.get_where_eq( const.METADATA_TABLE, ('metadata_rowid',), params_iter, - where_clause, + ('metadata_key',), unpack_scalars=True, ) return metadata_rowid_list @@ -1011,14 +1009,11 @@ def _init_config(ibs): try: general_config = ut.load_cPkl(config_fpath, verbose=ut.VERBOSE) except IOError as ex: - if ut.VERBOSE: - ut.printex(ex, 'failed to genral load config', iswarning=True) + logger.error('*** failed to load general config', exc_info=ex) general_config = {} + ut.save_cPkl(config_fpath, general_config, verbose=ut.VERBOSE) current_species = general_config.get('current_species', None) - if ut.VERBOSE and ut.NOT_QUIET: - logger.info( - '[_init_config] general_config.current_species = %r' % (current_species,) - ) + logger.info('[_init_config] general_config.current_species = %r' % (current_species,)) # ##### # species_list = ibs.get_database_species() diff --git a/wbia/control/manual_name_funcs.py b/wbia/control/manual_name_funcs.py index cee28dbca6..2d2e56d3d8 100644 --- a/wbia/control/manual_name_funcs.py +++ b/wbia/control/manual_name_funcs.py @@ -489,11 +489,12 @@ def get_name_aids(ibs, nid_list, enable_unknown_fix=True, is_staged=False): # FIXME: This index should when the database is defined. # Ensure that an index exists on the image column of the annotation table # logger.info(len(nid_list_)) - ibs.db.connection.execute( - """ - CREATE INDEX IF NOT EXISTS nid_to_aids ON annotations (name_rowid); - """ - ).fetchall() + with ibs.db.connect() as conn: + conn.execute( + """ + CREATE INDEX IF NOT EXISTS nid_to_aids ON annotations (name_rowid); + """ + ) aids_list = ibs.db.get( const.ANNOTATION_TABLE, (ANNOT_ROWID,), @@ -516,7 +517,8 @@ def get_name_aids(ibs, nid_list, enable_unknown_fix=True, is_staged=False): """.format( input_str=input_str, ANNOTATION_TABLE=const.ANNOTATION_TABLE ) - pair_list = ibs.db.connection.execute(opstr).fetchall() + with ibs.db.connect() as conn: + pair_list = conn.execute(opstr).fetchall() aidscol = np.array(ut.get_list_column(pair_list, 0)) nidscol = np.array(ut.get_list_column(pair_list, 1)) unique_nids, groupx = vt.group_indices(nidscol) @@ -617,7 +619,7 @@ def get_name_exemplar_aids(ibs, nid_list): >>> aid_list = ibs.get_valid_aids() >>> nid_list = ibs.get_annot_name_rowids(aid_list) >>> exemplar_aids_list = ibs.get_name_exemplar_aids(nid_list) - >>> result = exemplar_aids_list + >>> result = [sorted(i) for i in exemplar_aids_list] >>> print(result) [[], [2, 3], [2, 3], [], [5, 6], [5, 6], [7], [8], [], [10], [], [12], [13]] """ @@ -659,7 +661,7 @@ def get_name_gids(ibs, nid_list): >>> ibs = wbia.opendb('testdb1') >>> nid_list = ibs._get_all_known_name_rowids() >>> gids_list = ibs.get_name_gids(nid_list) - >>> result = gids_list + >>> result = [sorted(gids) for gids in gids_list] >>> print(result) [[2, 3], [5, 6], [7], [8], [10], [12], [13]] """ @@ -1042,9 +1044,8 @@ def get_name_rowids_from_text(ibs, name_text_list, ensure=True): >>> result += str(ibs._get_all_known_name_rowids()) >>> print('----') >>> ibs.print_name_table() + >>> assert result == f'{name_rowid_list}\n[1, 2, 3, 4, 5, 6, 7]' >>> print(result) - [8, 9, 0, 10, 11, 0] - [1, 2, 3, 4, 5, 6, 7] """ if ensure: name_rowid_list = ibs.add_names(name_text_list) diff --git a/wbia/control/manual_part_funcs.py b/wbia/control/manual_part_funcs.py index ae75761133..2f9635f710 100644 --- a/wbia/control/manual_part_funcs.py +++ b/wbia/control/manual_part_funcs.py @@ -29,7 +29,7 @@ PART_NOTE = 'part_note' PART_NUM_VERTS = 'part_num_verts' PART_ROWID = 'part_rowid' -# PART_TAG_TEXT = 'part_tag_text' +PART_TAG_TEXT = 'part_tag_text' PART_THETA = 'part_theta' PART_VERTS = 'part_verts' PART_UUID = 'part_uuid' @@ -107,6 +107,8 @@ def filter_part_set( viewpoint='no-filter', minqual=None, ): + if not part_rowid_list: # no need to filter if empty + return part_rowid_list # -- valid part_rowid filtering -- # filter by is_staged @@ -1086,21 +1088,26 @@ def set_part_viewpoints(ibs, part_rowid_list, viewpoint_list): ibs.db.set(const.PART_TABLE, ('part_viewpoint',), val_iter, id_iter) -# @register_ibs_method -# @accessor_decors.setter -# def set_part_tag_text(ibs, part_rowid_list, part_tags_list, duplicate_behavior='error'): -# r""" part_tags_list -> part.part_tags[part_rowid_list] +@register_ibs_method +@accessor_decors.setter +def set_part_tag_text(ibs, part_rowid_list, part_tags_list, duplicate_behavior='error'): + r"""part_tags_list -> part.part_tags[part_rowid_list] -# Args: -# part_rowid_list -# part_tags_list + Args: + part_rowid_list + part_tags_list -# """ -# #logger.info('[ibs] set_part_tag_text of part_rowid_list=%r to tags=%r' % (part_rowid_list, part_tags_list)) -# id_iter = part_rowid_list -# colnames = (PART_TAG_TEXT,) -# ibs.db.set(const.PART_TABLE, colnames, part_tags_list, -# id_iter, duplicate_behavior=duplicate_behavior) + """ + # logger.info('[ibs] set_part_tag_text of part_rowid_list=%r to tags=%r' % (part_rowid_list, part_tags_list)) + id_iter = part_rowid_list + colnames = (PART_TAG_TEXT,) + ibs.db.set( + const.PART_TABLE, + colnames, + part_tags_list, + id_iter, + duplicate_behavior=duplicate_behavior, + ) @register_ibs_method @@ -1483,6 +1490,55 @@ def set_part_contour(ibs, part_rowid_list, contour_dict_list): ibs.db.set(const.PART_TABLE, ('part_contour_json',), val_list, id_iter) +# setting up the Wild Dog data for assigner training +# def get_corresponding_aids(ibs, part_rowid_list, from_aids=None): +# if from_aids is None: +# from_aids = ibs.get_valid_aids() + + +# def get_corresponding_aids_slow(ibs, part_rowid_list, from_aids): +# part_bboxes = ibs.get_part_bboxes(part_rowid_list) +# annot_bboxes = ibs.get_annot_bboxes(from_aids) +# annot_gids = ibs.get_annot_gids(from_aids) +# from collections import defaultdict + +# bbox_gid_to_aids = defaultdict(int) +# for aid, gid, bbox in zip(from_aids, annot_gids, annot_bboxes): +# bbox_gid_to_aids[(bbox[0], bbox[1], bbox[2], bbox[3], gid)] = aid +# part_gids = ibs.get_part_image_rowids(parts) +# part_rowid_to_aid = { +# part_id: bbox_gid_to_aids[(bbox[0], bbox[1], bbox[2], bbox[3], gid)] +# for part_id, gid, bbox in zip(part_rowid_list, part_gids, part_bboxes) +# } + +# part_aids = [part_rowid_to_aid[partid] for partid in parts] +# part_parent_aids = ibs.get_part_aids(part_rowid_list) + +# # parents might be non-unique so we gotta make a unique name for each parent +# parent_aid_to_part_rowids = defaultdict(list) +# for part_rowid, parent_aid in zip(part_rowid_list, part_parent_aids): +# parent_aid_to_part_rowids[parent_aid] += [part_rowid] + +# part_annot_names = [ +# ','.join(str(p) for p in parent_aid_to_part_rowids[parent_aid]) +# for parent_aid in part_parent_aids +# ] + +# # now assign names so we can associate the part annots with the non-part annots +# new_part_names = ['part-%s' % part_rowid for part_rowid in part_rowid_list] + + +# def sort_parts_by_tags(ibs, part_rowid_list): +# tags = ibs.get_part_tag_text(part_rowid_list) +# from collections import defaultdict + +# tag_to_rowids = defaultdict(list) +# for tag, part_rowid in zip(tags, part_rowid_list): +# tag_to_rowids[tag] += [part_rowid] +# parts_by_tags = [tag_to_rowdids[tag] for tag in tag_to_rowdids.keys()] +# return parts_by_tags + + # ========== # Testdata # ========== diff --git a/wbia/control/manual_review_funcs.py b/wbia/control/manual_review_funcs.py index db1ebdc00e..f1ad370b1b 100644 --- a/wbia/control/manual_review_funcs.py +++ b/wbia/control/manual_review_funcs.py @@ -57,24 +57,25 @@ def hack_create_aidpair_index(ibs): CREATE INDEX IF NOT EXISTS {index_name} ON {table} ({index_cols}); """ ) - sqlcmd = sqlfmt.format( - index_name='aidpair_to_rowid', - table=ibs.const.REVIEW_TABLE, - index_cols=','.join([REVIEW_AID1, REVIEW_AID2]), - ) - ibs.staging.connection.execute(sqlcmd).fetchall() - sqlcmd = sqlfmt.format( - index_name='aid1_to_rowids', - table=ibs.const.REVIEW_TABLE, - index_cols=','.join([REVIEW_AID1]), - ) - ibs.staging.connection.execute(sqlcmd).fetchall() - sqlcmd = sqlfmt.format( - index_name='aid2_to_rowids', - table=ibs.const.REVIEW_TABLE, - index_cols=','.join([REVIEW_AID2]), - ) - ibs.staging.connection.execute(sqlcmd).fetchall() + with ibs.staging.connect() as conn: + sqlcmd = sqlfmt.format( + index_name='aidpair_to_rowid', + table=ibs.const.REVIEW_TABLE, + index_cols=','.join([REVIEW_AID1, REVIEW_AID2]), + ) + conn.execute(sqlcmd) + sqlcmd = sqlfmt.format( + index_name='aid1_to_rowids', + table=ibs.const.REVIEW_TABLE, + index_cols=','.join([REVIEW_AID1]), + ) + conn.execute(sqlcmd) + sqlcmd = sqlfmt.format( + index_name='aid2_to_rowids', + table=ibs.const.REVIEW_TABLE, + index_cols=','.join([REVIEW_AID2]), + ) + conn.execute(sqlcmd) @register_ibs_method @@ -566,9 +567,8 @@ def get_review_decisions_from_only(ibs, aid_list, eager=True, nInput=None): REVIEW_EVIDENCE_DECISION, ) params_iter = [(aid,) for aid in aid_list] - where_clause = '%s=?' % (REVIEW_AID1) - review_tuple_decisions_list = ibs.staging.get_where( - const.REVIEW_TABLE, colnames, params_iter, where_clause, unpack_scalars=False + review_tuple_decisions_list = ibs.staging.get_where_eq( + const.REVIEW_TABLE, colnames, params_iter, (REVIEW_AID1,), unpack_scalars=False ) return review_tuple_decisions_list @@ -586,9 +586,8 @@ def get_review_rowids_from_only(ibs, aid_list, eager=True, nInput=None): """ colnames = (REVIEW_ROWID,) params_iter = [(aid,) for aid in aid_list] - where_clause = '%s=?' % (REVIEW_AID1) - review_rowids = ibs.staging.get_where( - const.REVIEW_TABLE, colnames, params_iter, where_clause, unpack_scalars=False + review_rowids = ibs.staging.get_where_eq( + const.REVIEW_TABLE, colnames, params_iter, (REVIEW_AID1,), unpack_scalars=False ) return review_rowids @@ -612,12 +611,11 @@ def get_review_rowids_from_single(ibs, aid_list, eager=True, nInput=None): def get_review_rowids_from_aid1(ibs, aid_list, eager=True, nInput=None): colnames = (REVIEW_ROWID,) params_iter = [(aid,) for aid in aid_list] - where_clause = '%s=?' % (REVIEW_AID1,) - review_rowids = ibs.staging.get_where( + review_rowids = ibs.staging.get_where_eq( const.REVIEW_TABLE, colnames, params_iter, - where_clause=where_clause, + (REVIEW_AID1,), unpack_scalars=False, ) return review_rowids @@ -627,12 +625,11 @@ def get_review_rowids_from_aid1(ibs, aid_list, eager=True, nInput=None): def get_review_rowids_from_aid2(ibs, aid_list, eager=True, nInput=None): colnames = (REVIEW_ROWID,) params_iter = [(aid,) for aid in aid_list] - where_clause = '%s=?' % (REVIEW_AID2,) - review_rowids = ibs.staging.get_where( + review_rowids = ibs.staging.get_where_eq( const.REVIEW_TABLE, colnames, params_iter, - where_clause=where_clause, + (REVIEW_AID2,), unpack_scalars=False, ) return review_rowids diff --git a/wbia/core_annots.py b/wbia/core_annots.py index f9497cb4ac..ff2fa52b0c 100644 --- a/wbia/core_annots.py +++ b/wbia/core_annots.py @@ -2173,7 +2173,11 @@ def compute_aoi2(depc, aid_list, config=None): class OrienterConfig(dtool.Config): _param_info_list = [ - ut.ParamInfo('orienter_algo', 'deepsense', valid_values=['deepsense']), + ut.ParamInfo( + 'orienter_algo', + 'plugin:orientation', + valid_values=['deepsense', 'plugin:orientation'], + ), ut.ParamInfo('orienter_weight_filepath', None), ] _sub_config_list = [ChipConfig] @@ -2201,9 +2205,10 @@ def compute_orients_annotations(depc, aid_list, config=None): (float, str): tup CommandLine: - python -m wbia.core_annots --exec-compute_orients_annotations --deepsense + pytest wbia/core_annots.py::compute_orients_annotations:0 + python -m xdoctest /Users/jason.parham/code/wildbook-ia/wbia/core_annots.py compute_orients_annotations:1 --orient - Example: + Doctest: >>> # DISABLE_DOCTEST >>> from wbia.core_images import * # NOQA >>> import wbia @@ -2221,7 +2226,37 @@ def compute_orients_annotations(depc, aid_list, config=None): >>> theta_list = ut.take_column(result_list, 4) >>> bbox_list = list(zip(xtl_list, ytl_list, w_list, h_list)) >>> ibs.set_annot_bboxes(aid_list, bbox_list, theta_list=theta_list) + >>> print(result_list) + + Doctest: + >>> # DISABLE_DOCTEST + >>> import wbia + >>> import random + >>> import utool as ut + >>> from wbia.init import sysres + >>> import numpy as np + >>> dbdir = sysres.ensure_testdb_orientation() + >>> ibs = wbia.opendb(dbdir=dbdir) + >>> aid_list = ibs.get_valid_aids() + >>> note_list = ibs.get_annot_notes(aid_list) + >>> species_list = ibs.get_annot_species(aid_list) + >>> flag_list = [ + >>> note == 'random-01' and species == 'right_whale_head' + >>> for note, species in zip(note_list, species_list) + >>> ] + >>> aid_list = ut.compress(aid_list, flag_list) + >>> aid_list = aid_list[:10] + >>> depc = ibs.depc_annot + >>> config = {'orienter_algo': 'plugin:orientation'} + >>> # depc.delete_property('orienter', aid_list) >>> result_list = depc.get_property('orienter', aid_list, None, config=config) + >>> xtl_list = list(map(int, map(np.around, ut.take_column(result_list, 0)))) + >>> ytl_list = list(map(int, map(np.around, ut.take_column(result_list, 1)))) + >>> w_list = list(map(int, map(np.around, ut.take_column(result_list, 2)))) + >>> h_list = list(map(int, map(np.around, ut.take_column(result_list, 3)))) + >>> theta_list = ut.take_column(result_list, 4) + >>> bbox_list = list(zip(xtl_list, ytl_list, w_list, h_list)) + >>> # ibs.set_annot_bboxes(aid_list, bbox_list, theta_list=theta_list) >>> print(result_list) """ logger.info('[ibs] Process Annotation Labels') @@ -2265,6 +2300,169 @@ def compute_orients_annotations(depc, aid_list, config=None): result_gen.append(result) except Exception: raise RuntimeError('Deepsense orienter not working!') + elif config['orienter_algo'] in ['plugin:orientation']: + logger.info('[ibs] orienting using Orientation Plug-in') + try: + from wbia_orientation import _plugin # NOQA + from wbia_orientation.utils.data_manipulation import get_object_aligned_box + import vtool as vt + + species_list = ibs.get_annot_species(aid_list) + + species_dict = {} + for aid, species in zip(aid_list, species_list): + if species not in species_dict: + species_dict[species] = [] + species_dict[species].append(aid) + + results_dict = {} + species_key_list = sorted(species_dict.keys()) + for species in species_key_list: + species_tag = _plugin.SPECIES_MODEL_TAG_MAPPING.get(species, species) + message = 'Orientation plug-in does not support species_tag = %r' % ( + species_tag, + ) + assert species_tag in _plugin.MODEL_URLS, message + assert species_tag in _plugin.CONFIGS, message + aid_list_ = sorted(species_dict[species]) + print( + 'Computing %d orientations for species = %r' + % ( + len(aid_list_), + species, + ) + ) + + output_list, theta_list = _plugin.wbia_plugin_detect_oriented_box( + ibs, aid_list_, species_tag, plot_samples=False + ) + + for aid_, predicted_output, predicted_theta in zip( + aid_list_, output_list, theta_list + ): + xc, yc, xt, yt, w = predicted_output + predicted_verts = get_object_aligned_box(xc, yc, xt, yt, w) + predicted_verts = np.around(np.array(predicted_verts)).astype( + np.int64 + ) + predicted_verts = tuple(map(tuple, predicted_verts.tolist())) + + calculated_theta = np.arctan2(yt - yc, xt - xc) + np.deg2rad(90) + predicted_rot = vt.rotation_around_mat3x3( + calculated_theta * -1.0, xc, yc + ) + predicted_aligned_verts = vt.transform_points_with_homography( + predicted_rot, np.array(predicted_verts).T + ).T + predicted_aligned_verts = np.around(predicted_aligned_verts).astype( + np.int64 + ) + predicted_aligned_verts = tuple( + map(tuple, predicted_aligned_verts.tolist()) + ) + + predicted_bbox = vt.bboxes_from_vert_list([predicted_aligned_verts])[ + 0 + ] + ( + predicted_xtl, + predicted_ytl, + predicted_w, + predicted_h, + ) = predicted_bbox + + result = ( + predicted_xtl, + predicted_ytl, + predicted_w, + predicted_h, + calculated_theta, + # predicted_theta, + ) + results_dict[aid_] = result + + if False: + from itertools import combinations + + predicted_bbox_verts = vt.verts_list_from_bboxes_list( + [predicted_bbox] + )[0] + predicted_bbox_rot = vt.rotation_around_bbox_mat3x3( + calculated_theta, predicted_bbox + ) + predicted_bbox_rotated_verts = ( + vt.transform_points_with_homography( + predicted_bbox_rot, np.array(predicted_bbox_verts).T + ).T + ) + predicted_bbox_rotated_verts = np.around( + predicted_bbox_rotated_verts + ).astype(np.int64) + predicted_bbox_rotated_verts = tuple( + map(tuple, predicted_bbox_rotated_verts.tolist()) + ) + + gid_ = ibs.get_annot_gids(aid_) + image = ibs.get_images(gid_) + + original_bbox = ibs.get_annot_bboxes(aid_) + original_theta = ibs.get_annot_thetas(aid_) + original_verts = vt.verts_list_from_bboxes_list([original_bbox])[ + 0 + ] + original_rot = vt.rotation_around_bbox_mat3x3( + original_theta, original_bbox + ) + rotated_verts = vt.transform_points_with_homography( + original_rot, np.array(original_verts).T + ).T + rotated_verts = np.around(rotated_verts).astype(np.int64) + rotated_verts = tuple(map(tuple, rotated_verts.tolist())) + + color = (255, 0, 0) + for vert in original_verts: + cv2.circle(image, vert, 20, color, -1) + + for vert1, vert2 in combinations(original_verts, 2): + cv2.line(image, vert1, vert2, color, 5) + + color = (0, 0, 255) + for vert in rotated_verts: + cv2.circle(image, vert, 20, color, -1) + + for vert1, vert2 in combinations(rotated_verts, 2): + cv2.line(image, vert1, vert2, color, 5) + + color = (0, 255, 0) + for vert in predicted_verts: + cv2.circle(image, vert, 20, color, -1) + + for vert1, vert2 in combinations(predicted_verts, 2): + cv2.line(image, vert1, vert2, color, 5) + + color = (255, 255, 0) + for vert in predicted_aligned_verts: + cv2.circle(image, vert, 20, color, -1) + + for vert1, vert2 in combinations(predicted_aligned_verts, 2): + cv2.line(image, vert1, vert2, color, 5) + + color = (0, 255, 255) + for vert in predicted_bbox_rotated_verts: + cv2.circle(image, vert, 10, color, -1) + + for vert1, vert2 in combinations(predicted_bbox_rotated_verts, 2): + cv2.line(image, vert1, vert2, color, 1) + + cv2.imwrite('/tmp/image.%d.png' % (aid_), image) + + result_gen = [] + for aid in aid_list: + result = results_dict[aid] + result_gen.append(result) + + except Exception: + raise RuntimeError('Orientation plug-in not working!') else: raise ValueError( 'specified orienter algo is not supported in config = %r' % (config,) @@ -2273,3 +2471,270 @@ def compute_orients_annotations(depc, aid_list, config=None): # yield detections for result in result_gen: yield result + + +# for assigning part-annots to body-annots of the same individual: +class PartAssignmentFeatureConfig(dtool.Config): + _param_info_list = [] + + +# just like theta_assignement_features above but with a one-hot encoding of viewpoints +# viewpoints are a boolean value for each viewpoint. will possibly need to modify this for other species +@derived_attribute( + tablename='assigner_viewpoint_features', + parents=['annotations', 'annotations'], + colnames=[ + 'p_v1_x', + 'p_v1_y', + 'p_v2_x', + 'p_v2_y', + 'p_v3_x', + 'p_v3_y', + 'p_v4_x', + 'p_v4_y', + 'p_center_x', + 'p_center_y', + 'b_xtl', + 'b_ytl', + 'b_xbr', + 'b_ybr', + 'b_center_x', + 'b_center_y', + 'int_area_scalar', + 'part_body_distance', + 'part_body_centroid_dist', + 'int_over_union', + 'int_over_part', + 'int_over_body', + 'part_over_body', + 'part_is_left', + 'part_is_right', + 'part_is_up', + 'part_is_down', + 'part_is_front', + 'part_is_back', + 'body_is_left', + 'body_is_right', + 'body_is_up', + 'body_is_down', + 'body_is_front', + 'body_is_back', + ], + coltypes=[ + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + float, + bool, + bool, + bool, + bool, + bool, + bool, + bool, + bool, + bool, + bool, + bool, + bool, + ], + configclass=PartAssignmentFeatureConfig, + fname='assigner_viewpoint_features', + rm_extern_on_delete=True, + chunksize=256, # chunk size is huge bc we need accurate means and stdevs of various traits +) +def assigner_viewpoint_features(depc, part_aid_list, body_aid_list, config=None): + + from shapely import geometry + import math + + ibs = depc.controller + + part_gids = ibs.get_annot_gids(part_aid_list) + body_gids = ibs.get_annot_gids(body_aid_list) + assert ( + part_gids == body_gids + ), 'can only compute assignment features on aids in the same image' + parts_are_parts = ibs._are_part_annots(part_aid_list) + assert all(parts_are_parts), 'all part_aids must be part annots.' + bodies_are_parts = ibs._are_part_annots(body_aid_list) + assert not any(bodies_are_parts), 'body_aids cannot be part annots' + + im_widths = ibs.get_image_widths(part_gids) + im_heights = ibs.get_image_heights(part_gids) + + part_verts = ibs.get_annot_rotated_verts(part_aid_list) + body_verts = ibs.get_annot_rotated_verts(body_aid_list) + part_verts = _norm_vertices(part_verts, im_widths, im_heights) + body_verts = _norm_vertices(body_verts, im_widths, im_heights) + part_polys = [geometry.Polygon(vert) for vert in part_verts] + body_polys = [geometry.Polygon(vert) for vert in body_verts] + intersect_polys = [ + part.intersection(body) for part, body in zip(part_polys, body_polys) + ] + intersect_areas = [poly.area for poly in intersect_polys] + # just to make int_areas more comparable via ML methods, and since all distances < 1 + int_area_scalars = [math.sqrt(area) for area in intersect_areas] + + part_bboxes = ibs.get_annot_bboxes(part_aid_list) + body_bboxes = ibs.get_annot_bboxes(body_aid_list) + part_bboxes = _norm_bboxes(part_bboxes, im_widths, im_heights) + body_bboxes = _norm_bboxes(body_bboxes, im_widths, im_heights) + part_areas = [bbox[2] * bbox[3] for bbox in part_bboxes] + body_areas = [bbox[2] * bbox[3] for bbox in body_bboxes] + union_areas = [ + part + body - intersect + for (part, body, intersect) in zip(part_areas, body_areas, intersect_areas) + ] + int_over_unions = [ + intersect / union for (intersect, union) in zip(intersect_areas, union_areas) + ] + + part_body_distances = [ + part.distance(body) for part, body in zip(part_polys, body_polys) + ] + + part_centroids = [poly.centroid for poly in part_polys] + body_centroids = [poly.centroid for poly in body_polys] + + part_body_centroid_dists = [ + part.distance(body) for part, body in zip(part_centroids, body_centroids) + ] + + int_over_parts = [ + int_area / part_area for part_area, int_area in zip(part_areas, intersect_areas) + ] + + int_over_bodys = [ + int_area / body_area for body_area, int_area in zip(body_areas, intersect_areas) + ] + + part_over_bodys = [ + part_area / body_area for part_area, body_area in zip(part_areas, body_areas) + ] + + part_lrudfb_bools = get_annot_lrudfb_bools(ibs, part_aid_list) + body_lrudfb_bools = get_annot_lrudfb_bools(ibs, part_aid_list) + + # note that here only parts have thetas, hence only returning body bboxes + result_list = list( + zip( + part_verts, + part_centroids, + body_bboxes, + body_centroids, + int_area_scalars, + part_body_distances, + part_body_centroid_dists, + int_over_unions, + int_over_parts, + int_over_bodys, + part_over_bodys, + part_lrudfb_bools, + body_lrudfb_bools, + ) + ) + + for ( + part_vert, + part_center, + body_bbox, + body_center, + int_area_scalar, + part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body, + part_lrudfb_bool, + body_lrudfb_bool, + ) in result_list: + ans = ( + part_vert[0][0], + part_vert[0][1], + part_vert[1][0], + part_vert[1][1], + part_vert[2][0], + part_vert[2][1], + part_vert[3][0], + part_vert[3][1], + part_center.x, + part_center.y, + body_bbox[0], + body_bbox[1], + body_bbox[2], + body_bbox[3], + body_center.x, + body_center.y, + int_area_scalar, + part_body_distance, + part_body_centroid_dist, + int_over_union, + int_over_part, + int_over_body, + part_over_body, + ) + ans += tuple(part_lrudfb_bool) + ans += tuple(body_lrudfb_bool) + yield ans + + +# left, right, up, down, front, back booleans, useful for assigner classification and other cases where we might want viewpoint as an input for an ML model +def get_annot_lrudfb_bools(ibs, aid_list): + views = ibs.get_annot_viewpoints(aid_list) + bool_arrays = [ + [ + 'left' in view, + 'right' in view, + 'up' in view, + 'down' in view, + 'front' in view, + 'back' in view, + ] + if view is not None + else [False] * 6 + for view in views + ] + return bool_arrays + + +def _norm_bboxes(bbox_list, width_list, height_list): + normed_boxes = [ + (bbox[0] / w, bbox[1] / h, bbox[2] / w, bbox[3] / h) + for (bbox, w, h) in zip(bbox_list, width_list, height_list) + ] + return normed_boxes + + +def _norm_vertices(verts_list, width_list, height_list): + normed_verts = [ + [[x / w, y / h] for x, y in vert] + for vert, w, h in zip(verts_list, width_list, height_list) + ] + return normed_verts + + +if __name__ == '__main__': + import xdoctest as xdoc + + xdoc.doctest_module(__file__) diff --git a/wbia/core_parts.py b/wbia/core_parts.py index b5ec1183f0..55e6c48515 100644 --- a/wbia/core_parts.py +++ b/wbia/core_parts.py @@ -6,9 +6,13 @@ import logging import utool as ut import numpy as np + +# from wbia import dtool from wbia.control.controller_inject import register_preprocs, register_subprops from wbia import core_annots +# from wbia.constants import ANNOTATION_TABLE, PART_TABLE + (print, rrr, profile) = ut.inject2(__name__) logger = logging.getLogger('wbia') diff --git a/wbia/dbio/export_subset.py b/wbia/dbio/export_subset.py index 902a0f6ab2..3e168015a0 100644 --- a/wbia/dbio/export_subset.py +++ b/wbia/dbio/export_subset.py @@ -69,6 +69,7 @@ def merge_databases(ibs_src, ibs_dst, rowid_subsets=None, localize_images=True): Example: >>> # ENABLE_DOCTEST >>> from wbia.dbio.export_subset import * # NOQA + >>> from wbia.init.sysres import get_workdir >>> import wbia >>> db1 = ut.get_argval('--db1', str, default=None) >>> db2 = ut.get_argval('--db2', str, default=None) @@ -77,18 +78,18 @@ def merge_databases(ibs_src, ibs_dst, rowid_subsets=None, localize_images=True): >>> delete_ibsdir = False >>> # Check for test mode instead of script mode >>> if db1 is None and db2 is None and dbdir1 is None and dbdir2 is None: - ... db1 = 'testdb1' - ... dbdir2 = 'testdb_dst' + ... dbdir1 = '/'.join([get_workdir(), 'testdb1']) + ... dbdir2 = '/'.join([get_workdir(), 'testdb_dst']) ... delete_ibsdir = True >>> # Open the source and destination database >>> assert db1 is not None or dbdir1 is not None >>> assert db2 is not None or dbdir2 is not None >>> ibs_src = wbia.opendb(db=db1, dbdir=dbdir1) >>> ibs_dst = wbia.opendb(db=db2, dbdir=dbdir2, allow_newdir=True, - >>> delete_ibsdir=delete_ibsdir) + ... delete_ibsdir=delete_ibsdir) >>> merge_databases(ibs_src, ibs_dst) >>> check_merge(ibs_src, ibs_dst) - >>> ibs_dst.print_dbinfo() + >>> # ibs_dst.print_dbinfo() """ # TODO: ensure images are localized # otherwise this wont work diff --git a/wbia/dtool/__init__.py b/wbia/dtool/__init__.py index dcaec258f3..76706f7d05 100644 --- a/wbia/dtool/__init__.py +++ b/wbia/dtool/__init__.py @@ -6,7 +6,9 @@ # See `_integrate_sqlite3` module for details. import sqlite3 -from wbia.dtool import _integrate_sqlite3 as lite +# BBB (7-Sept-12020) +import sqlite3 as lite + from wbia.dtool import base from wbia.dtool import sql_control from wbia.dtool import depcache_control @@ -24,3 +26,4 @@ from wbia.dtool.base import * # NOQA from wbia.dtool.sql_control import SQLDatabaseController from wbia.dtool.types import TYPE_TO_SQLTYPE +import wbia.dtool.events diff --git a/wbia/dtool/_integrate_sqlite3.py b/wbia/dtool/_integrate_sqlite3.py deleted file mode 100644 index a260370824..0000000000 --- a/wbia/dtool/_integrate_sqlite3.py +++ /dev/null @@ -1,115 +0,0 @@ -# -*- coding: utf-8 -*- -"""Integrates numpy types into sqlite3""" -import io -import uuid -from sqlite3 import register_adapter, register_converter - -import numpy as np -import utool as ut - - -__all__ = () - - -def _read_numpy_from_sqlite3(blob): - # INVESTIGATE: Is memory freed up correctly here? - out = io.BytesIO(blob) - out.seek(0) - # return np.load(out) - # Is this better? - arr = np.load(out) - out.close() - return arr - - -def _read_bool(b): - return None if b is None else bool(b) - - -def _write_bool(b): - return b - - -def _write_numpy_to_sqlite3(arr): - out = io.BytesIO() - np.save(out, arr) - out.seek(0) - return memoryview(out.read()) - - -def _read_uuid_from_sqlite3(blob): - try: - return uuid.UUID(bytes_le=blob) - except ValueError as ex: - ut.printex(ex, keys=['blob']) - raise - - -def _read_dict_from_sqlite3(blob): - return ut.from_json(blob) - # return uuid.UUID(bytes_le=blob) - - -def _write_dict_to_sqlite3(dict_): - return ut.to_json(dict_) - - -def _write_uuid_to_sqlite3(uuid_): - return memoryview(uuid_.bytes_le) - - -def register_numpy_dtypes(): - py_int_type = int - for dtype in ( - np.int8, - np.int16, - np.int32, - np.int64, - np.uint8, - np.uint16, - np.uint32, - np.uint64, - ): - register_adapter(dtype, py_int_type) - register_adapter(np.float32, float) - register_adapter(np.float64, float) - - -def register_numpy(): - """ - Tell SQL how to deal with numpy arrays - Utility function allowing numpy arrays to be stored as raw blob data - """ - register_converter('NUMPY', _read_numpy_from_sqlite3) - register_converter('NDARRAY', _read_numpy_from_sqlite3) - register_adapter(np.ndarray, _write_numpy_to_sqlite3) - - -def register_uuid(): - """ Utility function allowing uuids to be stored in sqlite """ - register_converter('UUID', _read_uuid_from_sqlite3) - register_adapter(uuid.UUID, _write_uuid_to_sqlite3) - - -def register_dict(): - register_converter('DICT', _read_dict_from_sqlite3) - register_adapter(dict, _write_dict_to_sqlite3) - - -def register_list(): - register_converter('LIST', ut.from_json) - register_adapter(list, ut.to_json) - - -# def register_bool(): -# # FIXME: ensure this works -# register_converter('BOOL', _read_bool) -# register_adapter(bool, _write_bool) - - -register_numpy_dtypes() -register_numpy() -register_uuid() -register_dict() -register_list() -# register_bool() # TODO diff --git a/wbia/dtool/copy_sqlite_to_postgres.py b/wbia/dtool/copy_sqlite_to_postgres.py new file mode 100644 index 0000000000..fb082859cc --- /dev/null +++ b/wbia/dtool/copy_sqlite_to_postgres.py @@ -0,0 +1,589 @@ +# -*- coding: utf-8 -*- +""" +Copy sqlite database into a postgresql database using pgloader (from +apt-get) +""" +import logging +import re +import shutil +import subprocess +import tempfile +import typing +from concurrent.futures import as_completed, Future, ProcessPoolExecutor +from functools import wraps +from pathlib import Path + +import numpy as np +import sqlalchemy + +from wbia.dtool.sql_control import create_engine + + +logger = logging.getLogger('wbia') + + +MAIN_DB_FILENAME = '_ibeis_database.sqlite3' +STAGING_DB_FILENAME = '_ibeis_staging.sqlite3' +CACHE_DIRECTORY_NAME = '_ibeis_cache' +DEFAULT_CHECK_PC = 0.1 +DEFAULT_CHECK_MAX = 100 +DEFAULT_CHECK_MIN = 10 + + +class AlreadyMigratedError(Exception): + """Raised when the database has already been migrated""" + + +class SqliteDatabaseInfo: + def __init__(self, db_dir_or_db_uri): + self.engines = {} + self.metadata = {} + self.db_dir = None + if str(db_dir_or_db_uri).startswith('sqlite:///'): + self.db_uri = db_dir_or_db_uri + schema = get_schema_name_from_uri(self.db_uri) + engine = sqlalchemy.create_engine(self.db_uri) + self.engines[schema] = engine + self.metadata[schema] = sqlalchemy.MetaData(bind=engine) + else: + self.db_dir = Path(db_dir_or_db_uri) + for db_path, _ in get_sqlite_db_paths(self.db_dir): + db_uri = f'sqlite:///{db_path}' + schema = get_schema_name_from_uri(db_uri) + engine = sqlalchemy.create_engine(db_uri) + self.engines[schema] = engine + self.metadata[schema] = sqlalchemy.MetaData(bind=engine) + + def __str__(self): + if self.db_dir: + return f'' + return f'' + + def get_schema(self): + return sorted(self.engines.keys()) + + def get_table_names(self, normalized=False): + for schema in sorted(self.metadata.keys()): + metadata = self.metadata[schema] + metadata.reflect() + for table in sorted(metadata.tables, key=lambda a: a.lower()): + if normalized: + yield schema.lower(), table.lower() + else: + yield schema, table + + def get_total_rows(self, schema, table_name): + result = self.engines[schema].execute(f'SELECT count(*) FROM {table_name}') + return result.fetchone()[0] + + def get_random_rows( + self, + schema, + table_name, + percentage=DEFAULT_CHECK_PC, + max_rows=DEFAULT_CHECK_MAX, + min_rows=DEFAULT_CHECK_MIN, + ): + table = self.metadata[schema].tables[table_name] + stmt = sqlalchemy.select([sqlalchemy.column('rowid'), table]) + if percentage < 1 or max_rows > 0: + total_rows = self.get_total_rows(schema, table_name) + rows_to_check = max(int(total_rows * percentage), min_rows) + if max_rows > 0: + rows_to_check = min(rows_to_check, max_rows) + stmt = stmt.order_by(sqlalchemy.text('random()')).limit(rows_to_check) + return self.engines[schema].execute(stmt) + + def get_rows(self, schema, table_name, rowids): + table = self.metadata[schema].tables[table_name] + rowid = sqlalchemy.column('rowid') + stmt = sqlalchemy.select([rowid, table]) + if rowids: + stmt = stmt.where(rowid.in_(rowids)) + result = self.engines[schema].execute(stmt) + return result.fetchall() + + def get_column(self, schema, table_name, column_name='rowid'): + table = self.metadata[schema].tables[table_name] + if column_name == 'rowid': + column = sqlalchemy.column('rowid') + else: + column = table.c[column_name] + stmt = sqlalchemy.select(columns=[column], from_obj=table).order_by(column) + return self.engines[schema].execute(stmt) + + +class PostgresDatabaseInfo: + def __init__(self, db_uri): + self.db_uri = db_uri + self.engine = sqlalchemy.create_engine(db_uri) + self.metadata = sqlalchemy.MetaData(bind=self.engine) + + def __str__(self): + return f'' + + def get_schema(self): + schemas = self.engine.execute( + 'SELECT DISTINCT table_schema FROM information_schema.tables' + ) + return sorted( + [s[0] for s in schemas if s[0] not in ('pg_catalog', 'information_schema')] + ) + + def get_table_names(self, normalized=False): + for schema in self.get_schema(): + self.metadata.reflect(schema=schema) + for table in sorted(self.metadata.tables, key=lambda a: a.lower()): + if normalized: + yield tuple(table.lower().split('.', 1)) + else: + yield tuple(table.split('.', 1)) + + def get_total_rows(self, schema, table_name): + result = self.engine.execute(f'SELECT count(*) FROM {schema}.{table_name}') + return result.fetchone()[0] + + def get_random_rows( + self, + schema, + table_name, + percentage=DEFAULT_CHECK_PC, + max_rows=DEFAULT_CHECK_MAX, + min_rows=DEFAULT_CHECK_MIN, + ): + table = self.metadata.tables[f'{schema}.{table_name}'] + stmt = sqlalchemy.select([table]) + if percentage < 1 or max_rows > 0: + total_rows = self.get_total_rows(schema, table_name) + rows_to_check = max(int(total_rows * percentage), min_rows) + if max_rows > 0: + rows_to_check = min(rows_to_check, max_rows) + stmt = stmt.order_by(sqlalchemy.text('random()')).limit(rows_to_check) + return self.engine.execute(stmt) + + def get_rows(self, schema, table_name, rowids): + table = self.metadata.tables[f'{schema}.{table_name}'] + stmt = sqlalchemy.select([table]) + if rowids: + stmt = stmt.where(table.c['rowid'].in_(rowids)) + result = self.engine.execute(stmt) + return result.fetchall() + + def get_column(self, schema, table_name, column_name='rowid'): + table = self.metadata.tables[f'{schema}.{table_name}'] + column = table.c[column_name] + stmt = sqlalchemy.select(column).order_by(column) + return self.engine.execute(stmt) + + +def rows_equal(row1, row2): + """Check the rows' values for equality""" + for e1, e2 in zip(row1, row2): + if not _complex_type_equality_check(e1, e2): + return False + return True + + +def _complex_type_equality_check(v1, v2): + """Returns True on the equality of ``v1`` and ``v2``; otherwise False""" + if type(v1) != type(v2): + return False + if isinstance(v1, float): + if abs(v1 - v2) >= 0.0000000001: + return False + elif isinstance(v1, np.ndarray): + if (v1 != v2).any(): # see ndarray.any() for details + return False + elif v1 != v2: + return False + return True + + +def compare_databases( + db_info1, + db_info2, + exact=True, + check_pc=DEFAULT_CHECK_PC, + check_min=DEFAULT_CHECK_MIN, + check_max=DEFAULT_CHECK_MAX, +): + messages = [] + + # Compare schema + schema1 = set(db_info1.get_schema()) + schema2 = set(db_info2.get_schema()) + if len(schema2) < len(schema1): + # Make sure db_info1 is the one with the smaller set + db_info2, db_info1 = db_info1, db_info2 + schema2, schema1 = schema1, schema2 + + if exact and schema1 != schema2: + messages.append(f'Schema difference: {schema1} != {schema2}') + return messages + if not exact and not schema1: + messages.append(f'Schema in {db_info1} is empty') + return messages + if not exact and schema1.difference(schema2): + messages.append(f'Schema difference: {schema1} not in {schema2}') + return messages + + # Compare tables + tables1 = list(db_info1.get_table_names(normalized=True)) + tables2 = list(db_info2.get_table_names(normalized=True)) + if not exact: + tables2 = [(schema, table) for schema, table in tables2 if schema in schema1] + if tables1 != tables2: + messages.append( + 'Table names difference: ' + f'Only in {db_info1}={set(tables1).difference(tables2)} ' + f'Only in {db_info2}={set(tables2).difference(tables1)}' + ) + return messages + + # Compare number of rows + tables1 = list(db_info1.get_table_names()) + tables2 = list(db_info2.get_table_names()) + normalized_schema1 = [s.lower() for s in schema1] + if not exact: + tables2 = [ + (schema, table) for schema, table in tables2 if schema in normalized_schema1 + ] + table_total = {} + for (schema1, table1), (schema2, table2) in zip(tables1, tables2): + total1 = db_info1.get_total_rows(schema1, table1) + table_total[f'{schema1}.{table1}'] = total1 + total2 = db_info2.get_total_rows(schema2, table2) + if total1 != total2: + messages.append( + f'Total number of rows in "{schema2}.{table2}" difference: {total1} != {total2}' + ) + + # Compare rowid + for (schema1, table1), (schema2, table2) in zip(tables1, tables2): + rowid1 = list(db_info1.get_column(schema1, table1)) + rowid2 = list(db_info2.get_column(schema2, table2)) + if rowid1 != rowid2: + messages.append( + f'Row ids in "{schema2}.{table2}" difference: ' + f'Only in {db_info1}={set(rowid1).difference(rowid2)} ' + f'Only in {db_info2}={set(rowid2).difference(rowid1)}' + ) + return messages + + # Compare data + for (schema1, table1), (schema2, table2) in zip(tables1, tables2): + rows1 = list( + db_info1.get_random_rows( + schema1, + table1, + percentage=check_pc, + min_rows=check_min, + max_rows=check_max, + ) + ) + if table_total[f'{schema1}.{table1}'] == len(rows1): + rowids = None + else: + rowids = [row[0] for row in rows1] + rows2 = {row[0]: row for row in db_info2.get_rows(schema2, table2, rowids)} + for row in rows1: + rowid = row[0] + row2 = rows2[rowid] + if not rows_equal(row, row2): + messages.append( + f'Table "{schema2}.{table2}" data difference: {row} != {row2}' + ) + logger.debug(f'Compared {len(rows1)} rows in {schema2}.{table2}') + + return messages + + +def get_sqlite_db_paths(db_dir: Path): + """Generates a sequence of sqlite database file paths and sizes. + The sequence is sorted by database size, smallest first. + + """ + base_loc = (db_dir / '_ibsdb').resolve() + main_db = base_loc / MAIN_DB_FILENAME + staging_db = base_loc / STAGING_DB_FILENAME + cache_directory = base_loc / CACHE_DIRECTORY_NAME + paths = [] + + # churn over the cache databases + for f in cache_directory.glob('*.sqlite'): + if 'backup' in f.name: + continue + p = f.resolve() + paths.append((p, p.stat().st_size)) + + if staging_db.exists(): + # doesn't exist in test databases + paths.append((staging_db, staging_db.stat().st_size)) + paths.append((main_db, main_db.stat().st_size)) + + # Sort databases by file size, smallest first + paths.sort(key=lambda a: a[1]) + return paths + + +def get_schema_name_from_uri(uri: str): + """Derives the schema name from a sqlite URI (e.g. sqlite:///foo/bar/baz.sqlite)""" + db_path = Path(uri[len('sqlite:///') :]) + name = db_path.stem # filename without extension + + # special names + if name == '_ibeis_staging': + name = 'staging' + elif name == '_ibeis_database': + name = 'main' + + return name + + +def add_rowids(engine): + connection = engine.connect() + create_table_stmts = connection.execute( + """\ + SELECT name, sql FROM sqlite_master + WHERE name NOT LIKE 'sqlite_%' AND type = 'table' + """ + ).fetchall() + for table, stmt in create_table_stmts: + # Create a new table with suffix "_with_rowid" + new_table = f'{table}_with_rowid' + stmt = re.sub( + r'CREATE TABLE [^ ]* \(', + f'CREATE TABLE {new_table} (rowid INTEGER NOT NULL UNIQUE, ', + stmt, + ) + # Change "REAL" type to "DOUBLE" because "REAL" in postgresql + # only can only store 6 digits and so we'd lose precision + stmt = re.sub('REAL', 'DOUBLE', stmt) + connection.execute(stmt) + connection.execute(f'INSERT INTO {new_table} SELECT rowid, * FROM {table}') + connection.execute(f'DROP TABLE {table}') + connection.execute(f'ALTER TABLE {new_table} RENAME TO {table}') + + +def before_pgloader(engine, schema): + connection = engine.connect() + connection.execute(f'CREATE SCHEMA IF NOT EXISTS {schema}') + connection.execute(f"SET SCHEMA '{schema}'") + + for domain, base_type in ( + ('dict', 'json'), + ('list', 'json'), + ('ndarray', 'bytea'), + ('numpy', 'bytea'), + ): + try: + connection.execute(f'CREATE DOMAIN {domain} AS {base_type}') + except sqlalchemy.exc.ProgrammingError: + # sqlalchemy.exc.ProgrammingError: + # (psycopg2.errors.DuplicateObject) type "dict" already + # exists + pass + + +PGLOADER_CONFIG_TEMPLATE = """\ +LOAD DATABASE + FROM {sqlite_uri} + INTO {postgres_uri} + + WITH include drop, + create tables, + create indexes, + reset no sequences + + SET work_mem to '16MB', + maintenance_work_mem to '512 MB', + search_path to '{schema_name}' + + CAST type uuid to uuid using wbia-uuid-bytes-to-uuid, + type ndarray to ndarray using byte-vector-to-bytea, + type numpy to numpy using byte-vector-to-bytea; +""" +# Copied from the built-in sql-server-uniqueidentifier-to-uuid +# transform in pgloader 3.6.2 +# Prior to 3.6.2, the transform was for uuid in big-endian order +UUID_LOADER_LISP = """\ +(in-package :pgloader.transforms) + +(defmacro arr-to-bytes-rev (from to array) + `(loop for i from ,to downto ,from + with res = 0 + do (setf (ldb (byte 8 (* 8 (- i,from))) res) (aref ,array i)) + finally (return res))) + +(defun wbia-uuid-bytes-to-uuid (id) + (declare (type (or null (array (unsigned-byte 8) (16))) id)) + (when id + (let ((uuid + (make-instance 'uuid:uuid + :time-low (arr-to-bytes-rev 0 3 id) + :time-mid (arr-to-bytes-rev 4 5 id) + :time-high (arr-to-bytes-rev 6 7 id) + :clock-seq-var (aref id 8) + :clock-seq-low (aref id 9) + :node (uuid::arr-to-bytes 10 15 id)))) + (princ-to-string uuid)))) +""" + + +def run_pgloader(sqlite_uri: str, postgres_uri: str) -> subprocess.CompletedProcess: + """Configure and run ``pgloader``. + If there is a problem this will raise a ``CalledProcessError`` + from ``Process.check_returncode``. + + """ + schema_name = get_schema_name_from_uri(sqlite_uri) + + # Do all this within a self-cleaning temporary directory + with tempfile.TemporaryDirectory() as tempdir: + td = Path(tempdir) + pgloader_config = td / 'wbia.load' + with pgloader_config.open('w') as fb: + fb.write(PGLOADER_CONFIG_TEMPLATE.format(**locals())) + + wbia_uuid_loader = td / 'wbia_uuid_loader.lisp' + with wbia_uuid_loader.open('w') as fb: + fb.write(UUID_LOADER_LISP) + + proc = subprocess.run( + ['pgloader', '--load-lisp-file', str(wbia_uuid_loader), str(pgloader_config)], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + proc.check_returncode() # raises subprocess.CalledProcessError + return proc + + +def after_pgloader(sqlite_engine, pg_engine, schema): + # Some "NOT NULL" weren't migrated by pgloader for some reason + connection = pg_engine.connect() + connection.execute(f"SET SCHEMA '{schema}'") + sqlite_metadata = sqlalchemy.MetaData(bind=sqlite_engine) + sqlite_metadata.reflect() + for table in sqlite_metadata.tables.values(): + for column in table.c: + if not column.primary_key and not column.nullable: + connection.execute( + f'ALTER TABLE {table.name} ALTER COLUMN {column.name} SET NOT NULL' + ) + + table_pkeys = connection.execute( + f"""\ + SELECT table_name, column_name + FROM information_schema.table_constraints + NATURAL JOIN information_schema.constraint_column_usage + WHERE table_schema = '{schema}' + AND constraint_type = 'PRIMARY KEY'""" + ).fetchall() + exclude_sequences = set() + for (table_name, pkey) in table_pkeys: + # Create sequences for rowid fields + for column_name in ('rowid', pkey): + seq_name = f'{table_name}_{column_name}_seq' + exclude_sequences.add(seq_name) + connection.execute(f'CREATE SEQUENCE {seq_name}') + connection.execute( + f"SELECT setval('{seq_name}', (SELECT max({column_name}) FROM {table_name}))" + ) + connection.execute( + f"ALTER TABLE {table_name} ALTER COLUMN {column_name} SET DEFAULT nextval('{seq_name}')" + ) + connection.execute( + f'ALTER SEQUENCE {seq_name} OWNED BY {table_name}.{column_name}' + ) + + # Reset all sequences except the ones we just created (doing it here + # instead of in pgloader because it causes a fatal error in pgloader + # when pgloader runs in parallel: + # + # Asynchronous notification "seqs" (payload: "0") received from + # server process with PID 28472.) + sequences = connection.execute( + f"""\ + SELECT table_name, column_name, column_default + FROM information_schema.columns + WHERE table_schema = '{schema}' + AND column_default LIKE 'nextval%%'""" + ).fetchall() + for table_name, column_name, column_default in sequences: + seq_name = re.sub(r"nextval\('([^']*)'.*", r'\1', column_default) + if seq_name not in exclude_sequences: + connection.execute( + f"SELECT setval('{seq_name}', (SELECT max({column_name}) FROM {table_name}))" + ) + + +def drop_schema(engine, schema_name): + connection = engine.connect() + connection.execute(f'DROP SCHEMA {schema_name} CASCADE') + + +def _use_copy_of_sqlite_database(f): + """Makes a copy of the sqlite database given as the first argument in URI form""" + + @wraps(f) + def wrapper(*args): + uri = args[0] + db = Path(uri.split(':')[-1]).resolve() + with tempfile.TemporaryDirectory() as tempdir: + temp_db = Path(tempdir) / db.name + shutil.copy(db, temp_db) + new_uri = f'sqlite:///{str(temp_db)}' + return f(new_uri, *args[1:]) + + return wrapper + + +@_use_copy_of_sqlite_database +def migrate(sqlite_uri: str, postgres_uri: str): + logger.info(f'\nworking on {sqlite_uri} ...') + schema_name = get_schema_name_from_uri(sqlite_uri) + sl_info = SqliteDatabaseInfo(sqlite_uri) + pg_info = PostgresDatabaseInfo(postgres_uri) + sl_engine = create_engine(sqlite_uri) + pg_engine = create_engine(postgres_uri) + + if not compare_databases(sl_info, pg_info, exact=False): + logger.info(f'{sl_info} already migrated to {pg_info}') + # TODO: migrate missing bits here + raise AlreadyMigratedError() + + if schema_name in pg_info.get_schema(): + logger.warning(f'Dropping schema "{schema_name}"') + drop_schema(pg_engine, schema_name) + + # Add sqlite built-in rowid column to tables + add_rowids(sl_engine) + before_pgloader(pg_engine, schema_name) + run_pgloader(sqlite_uri, postgres_uri) + after_pgloader(sl_engine, pg_engine, schema_name) + + +def copy_sqlite_to_postgres( + db_dir: Path, + postgres_uri: str, + num_procs: int = 6, +) -> typing.Generator[typing.Tuple[Path, Future], None, None]: + """Copies all the sqlite databases into a single postgres database + + Args: + db_dir: the colloquial dbdir (i.e. directory containing '_ibsdb', 'smart_patrol', etc.) + postgres_uri: a postgres connection uri without the database name + num_procs: number of concurrent processes to use + + """ + executor = ProcessPoolExecutor(max_workers=num_procs) + sqlite_dbs = dict(get_sqlite_db_paths(db_dir)) + total_size = sum(sqlite_dbs.values()) + migration_futures_to_paths = { + executor.submit(migrate, f'sqlite://{str(p)}', postgres_uri): p + for p in sqlite_dbs + } + for future in as_completed(migration_futures_to_paths): + path = migration_futures_to_paths[future] + db_size = sqlite_dbs[path] + yield (path, future, db_size, total_size) diff --git a/wbia/dtool/depcache_control.py b/wbia/dtool/depcache_control.py index 5a1fb1fb7e..7e6b8bfde4 100644 --- a/wbia/dtool/depcache_control.py +++ b/wbia/dtool/depcache_control.py @@ -3,7 +3,6 @@ implicit version of dependency cache from wbia/templates/template_generator """ import logging -import os.path import utool as ut import numpy as np @@ -84,15 +83,66 @@ def _wrapper(func): return _depcdecors -class _CoreDependencyCache(object): - """ - Core worker functions for the depcache - Inherited by a calss with some "nice extras - """ +class DependencyCache: + def __init__( + self, + controller, + name, + get_root_uuid, + table_name=None, + root_getters=None, + use_globals=True, + ): + """ + Args: + controller (IBEISController): main controller + name (str): name of this controller instance, which is used in naming the data storage + table_name (str): (optional) if not the same as the 'name' + get_root_uuid: ??? + root_getters: ??? + use_globals (bool): ??? (default: True) + + """ + if table_name is None: + table_name = name + + self.name = name + # Parent (ibs) controller + self.controller = controller + # Internal dictionary of dependant tables + self.cachetable_dict = {} + self.configclass_dict = {} + self.requestclass_dict = {} + self.resultclass_dict = {} + + self.root_getters = root_getters + # Root of all dependencies + self.root_tablename = table_name + # XXX Directory all cachefiles are stored in + self.cache_dpath = self.controller.get_cachedir() + + # Mapping of database connections by name + # - names populated by _register_prop + # - values populated by initialize + self._db_by_name = {} + + # Function to map a root rowid to an object + self._use_globals = use_globals + + # FIXME (20-Oct-12020) remove filesystem name + self.default_fname = f'{table_name}_cache' + + self.get_root_uuid = get_root_uuid + self.delete_exclude_tables = {} + # BBB (25-Sept-12020) `_debug` remains around to be backwards compatible + self._debug = False + + def __repr__(self): + return f'' @profile def _register_prop( - depc, + self, tablename, parents=None, colnames=None, @@ -110,15 +160,14 @@ def _register_prop( SEE: dtool.REG_PREPROC_DOC """ - if depc._debug: - logger.info('[depc] Registering tablename=%r' % (tablename,)) - logger.info('[depc] * preproc_func=%r' % (preproc_func,)) + logger.debug('[depc] Registering tablename=%r' % (tablename,)) + logger.debug('[depc] * preproc_func=%r' % (preproc_func,)) # ---------- # Sanitize inputs if isinstance(tablename, six.string_types): tablename = six.text_type(tablename) if parents is None: - parents = [depc.root] + parents = [self.root] if colnames is None: colnames = 'data' if coltypes is None: @@ -136,7 +185,9 @@ def _register_prop( raise ValueError('must specify coltypes of %s' % (tablename,)) coltypes = [np.ndarray] * len(colnames) if fname is None: - fname = depc.default_fname + # FIXME (20-Oct-12020) Base class doesn't define this property + # and thus expects the subclass to know it should be assigned. + fname = self.default_fname if configclass is None: # Make a default config with no parameters configclass = {} @@ -147,25 +198,25 @@ def _register_prop( # ---------- # Register a new table and configuration if requestclass is not None: - depc.requestclass_dict[tablename] = requestclass - depc.fname_to_db[fname] = None - table = depcache_table.DependencyCacheTable( - depc=depc, + self.requestclass_dict[tablename] = requestclass + self._db_by_name.setdefault(fname, None) + table = depcache_table.DependencyCacheTable.from_name( + fname, + tablename, + self, parent_tablenames=parents, - tablename=tablename, data_colnames=colnames, data_coltypes=coltypes, preproc_func=preproc_func, - fname=fname, default_to_unpack=default_to_unpack, **kwargs, ) - depc.cachetable_dict[tablename] = table - depc.configclass_dict[tablename] = configclass + self.cachetable_dict[tablename] = table + self.configclass_dict[tablename] = configclass return table @ut.apply_docstr(REG_PREPROC_DOC) - def register_preproc(depc, *args, **kwargs): + def register_preproc(self, *args, **kwargs): """ Decorator for registration of cachables """ @@ -173,114 +224,92 @@ def register_preproc(depc, *args, **kwargs): def register_preproc_wrapper(func): check_register(args, kwargs) kwargs['preproc_func'] = func - depc._register_prop(*args, **kwargs) + self._register_prop(*args, **kwargs) return func return register_preproc_wrapper - def _register_subprop(depc, tablename, propname=None, preproc_func=None): + def _register_subprop(self, tablename, propname=None, preproc_func=None): """ subproperties are always recomputeed on the fly """ - table = depc.cachetable_dict[tablename] + table = self.cachetable_dict[tablename] table.subproperties[propname] = preproc_func - def close(depc): - """ - Close all managed SQL databases - """ - for fname, db in depc.fname_to_db.items(): - db.close() + def get_db_by_name(self, name): + """Get the database (i.e. SQLController) for the given database name""" + # FIXME (20-Oct-12020) Currently handled via a mapping of 'fname' + # to database controller objects. + return self._db_by_name[name] + + def close(self): + """Close all managed SQL databases""" + for db_inst in self._db_by_name.values(): + db_inst.close() @profile - def initialize(depc, _debug=None): + def initialize(self, _debug=None): """ Creates all registered tables """ - logger.info( - '[depc] Initialize %s depcache in %r' % (depc.root.upper(), depc.cache_dpath) - ) - _debug = depc._debug if _debug is None else _debug - if depc._use_globals: - reg_preproc = PREPROC_REGISTER[depc.root] - reg_subprop = SUBPROP_REGISTER[depc.root] - if ut.VERBOSE: - logger.info( - '[depc.init] Registering %d global preproc funcs' % len(reg_preproc) - ) + if self._use_globals: + reg_preproc = PREPROC_REGISTER[self.root] + reg_subprop = SUBPROP_REGISTER[self.root] + logger.info( + '[depc.init] Registering %d global preproc funcs' % len(reg_preproc) + ) for args_, _kwargs in reg_preproc: - depc._register_prop(*args_, **_kwargs) - if ut.VERBOSE: - logger.info( - '[depc.init] Registering %d global subprops ' % len(reg_subprop) - ) + self._register_prop(*args_, **_kwargs) + logger.info('[depc.init] Registering %d global subprops ' % len(reg_subprop)) for args_, _kwargs in reg_subprop: - depc._register_subprop(*args_, **_kwargs) - - ut.ensuredir(depc.cache_dpath) - - # Memory filestore - # if False: - # # http://docs.pyfilesystem.org/en/latest/getting_started.html - # pip install fs - - for fname in depc.fname_to_db.keys(): - if fname == ':memory:': - fpath = fname - else: - fname_ = ut.ensure_ext(fname, '.sqlite') - from os.path import dirname - - prefix_dpath = dirname(fname_) - if prefix_dpath: - ut.ensuredir(ut.unixjoin(depc.cache_dpath, prefix_dpath)) - fpath = ut.unixjoin(depc.cache_dpath, fname_) - # if ut.get_argflag('--clear-all-depcache'): - # ut.delete(fpath) - db_uri = 'file://{}'.format(os.path.realpath(fpath)) - db = sql_control.SQLDatabaseController.from_uri(db_uri) + self._register_subprop(*args_, **_kwargs) + + for name in self._db_by_name.keys(): + # FIXME (20-Oct-12020) 'smk/smk_agg_rvecs' is known to have issues. + # Either fix the name or find a better normalizer/slugifier. + normalized_name = name.replace('/', '__') + uri = self.controller.make_cache_db_uri(normalized_name) + db = sql_control.SQLDatabaseController(uri, normalized_name) + # ??? This seems out of place. Shouldn't this be within the depcachetable instance? depcache_table.ensure_config_table(db) - depc.fname_to_db[fname] = db - if ut.VERBOSE: - logger.info('[depc] Finished initialization') + self._db_by_name[name] = db - for table in depc.cachetable_dict.values(): - table.initialize(_debug=_debug) + for table in self.cachetable_dict.values(): + table.initialize() # HACKS: # Define injected functions for autocomplete convinience class InjectedDepc(object): pass - depc.d = InjectedDepc() - depc.w = InjectedDepc() - d = depc.d - w = depc.w + # ??? What's the significance of 'd' and 'w'? Do the mean anything? + self.d = InjectedDepc() + self.w = InjectedDepc() inject_patterns = [ - ('get_{tablename}_rowids', depc.get_rowids), - ('get_{tablename}_config_history', depc.get_config_history), + ('get_{tablename}_rowids', self.get_rowids), + ('get_{tablename}_config_history', self.get_config_history), ] - for table in depc.cachetable_dict.values(): + for table in self.cachetable_dict.values(): wobj = InjectedDepc() # Set nested version - setattr(w, table.tablename, wobj) + setattr(self.w, table.tablename, wobj) for dfmtstr, func in inject_patterns: funcname = ut.get_funcname(func) attrname = dfmtstr.format(tablename=table.tablename) - get_rowids = ut.partial(func, table.tablename) + partial_func = ut.partial(func, table.tablename) # Set flat version - setattr(d, attrname, get_rowids) + setattr(self.d, attrname, partial_func) setattr(wobj, funcname, func) dfmtstr = 'get_{tablename}_{colname}' for colname in table.data_colnames: - get_prop = ut.partial(depc.get, table.tablename, colnames=colname) + get_prop = ut.partial(self.get, table.tablename, colnames=colname) attrname = dfmtstr.format(tablename=table.tablename, colname=colname) # Set flat version - setattr(d, attrname, get_prop) + setattr(self.d, attrname, get_prop) setattr(wobj, 'get_' + colname, get_prop) # ----------------------------- # GRAPH INSPECTION - def get_dependencies(depc, tablename): + def get_dependencies(self, tablename): """ gets level dependences from root to tablename @@ -303,11 +332,11 @@ def get_dependencies(depc, tablename): ] """ try: - assert tablename in depc.cachetable_dict, 'tablename=%r does not exist' % ( + assert tablename in self.cachetable_dict, 'tablename=%r does not exist' % ( tablename, ) - root = depc.root_tablename - children_, parents_ = list(zip(*depc.get_edges())) + root = self.root_tablename + children_, parents_ = list(zip(*self.get_edges())) child_to_parents = ut.group_items(children_, parents_) if ut.VERYVERBOSE: logger.info('root = %r' % (root,)) @@ -338,7 +367,7 @@ def get_dependencies(depc, tablename): return dependency_levels - def _ensure_config(depc, tablekey, config, _debug=False): + def _ensure_config(self, tablekey, config, _debug=False): """ Creates a full table configuration with all defaults using config @@ -346,8 +375,8 @@ def _ensure_config(depc, tablekey, config, _debug=False): tablekey (str): name of the table to grab config from config (dict): may be overspecified or underspecfied """ - configclass = depc.configclass_dict.get(tablekey, None) - # requestclass = depc.requestclass_dict.get(tablekey, None) + configclass = self.configclass_dict.get(tablekey, None) + # requestclass = self.requestclass_dict.get(tablekey, None) if configclass is None: config_ = config else: @@ -365,31 +394,29 @@ def _ensure_config(depc, tablekey, config, _debug=False): if config_ is None: # Preferable way to get configs with explicit # configs - if _debug: - logger.info(' **config = %r' % (config,)) + logger.debug(' **config = %r' % (config,)) config_ = configclass(**config) - if _debug: - logger.info(' config_ = %r' % (config_,)) + logger.debug(' config_ = %r' % (config_,)) return config_ - def get_config_trail(depc, tablename, config): - graph = depc.make_graph(implicit=True) - tablename_list = ut.nx_all_nodes_between(graph, depc.root, tablename) + def get_config_trail(self, tablename, config): + graph = self.make_graph(implicit=True) + tablename_list = ut.nx_all_nodes_between(graph, self.root, tablename) tablename_list = ut.nx_topsort_nodes(graph, tablename_list) config_trail = [] for tablekey in tablename_list: - if tablekey in depc.configclass_dict: - config_ = depc._ensure_config(tablekey, config) + if tablekey in self.configclass_dict: + config_ = self._ensure_config(tablekey, config) config_trail.append(config_) return config_trail - def get_config_trail_str(depc, tablename, config): - config_trail = depc.get_config_trail(tablename, config) + def get_config_trail_str(self, tablename, config): + config_trail = self.get_config_trail(tablename, config) trail_cfgstr = '_'.join([x.get_cfgstr() for x in config_trail]) return trail_cfgstr def _get_parent_input( - depc, + self, tablename, root_rowids, config, @@ -401,8 +428,8 @@ def _get_parent_input( nInput=None, ): # Get ancestor rowids that are descendants of root - table = depc[tablename] - rowid_dict = depc.get_all_descendant_rowids( + table = self[tablename] + rowid_dict = self.get_all_descendant_rowids( tablename, root_rowids, config=config, @@ -411,16 +438,15 @@ def _get_parent_input( nInput=nInput, recompute=recompute, recompute_all=recompute_all, - _debug=ut.countdown_flag(_debug), levels_up=1, ) - parent_rowids = depc._get_parent_rowids(table, rowid_dict) + parent_rowids = self._get_parent_rowids(table, rowid_dict) return parent_rowids # ----------------------------- # STATE GETTERS - def rectify_input_tuple(depc, exi_inputs, input_tuple): + def rectify_input_tuple(self, exi_inputs, input_tuple): """ Standardizes inputs allowed for convinience into the expected input for get_parent_rowids. @@ -461,7 +487,7 @@ def rectify_input_tuple(depc, exi_inputs, input_tuple): rectified_input.append(x) return rectified_input - def get_parent_rowids(depc, target_tablename, input_tuple, config=None, **kwargs): + def get_parent_rowids(self, target_tablename, input_tuple, config=None, **kwargs): """ Returns the parent rowids needed to get / compute a property of tablename @@ -494,154 +520,140 @@ def get_parent_rowids(depc, target_tablename, input_tuple, config=None, **kwargs """ _kwargs = kwargs.copy() _recompute = _kwargs.pop('recompute_all', False) - _debug = _kwargs.get('_debug', False) _hack_rootmost = _kwargs.pop('_hack_rootmost', False) - _debug = depc._debug if _debug is None else _debug if config is None: config = {} - with ut.Indenter('[GetParentID-%s]' % (target_tablename,), enabled=_debug): - if _debug: - logger.info(ut.color_text('Enter get_parent_rowids', 'blue')) - logger.info(' * target_tablename = %r' % (target_tablename,)) - logger.info(' * input_tuple=%s' % (ut.trunc_repr(input_tuple),)) - logger.info(' * config = %r' % (config,)) - target_table = depc[target_tablename] - - # TODO: Expand to the appropriate given inputs - if _hack_rootmost: - # Hack: if true, we are given inputs in rootmost form - exi_inputs = target_table.rootmost_inputs - else: - # otherwise we are given inputs in totalroot form - exi_inputs = target_table.rootmost_inputs.total_expand() - if _debug: - logger.info(' * exi_inputs=%s' % (exi_inputs,)) - - rectified_input = depc.rectify_input_tuple(exi_inputs, input_tuple) - - rowid_dict = {} - for rmi, rowids in zip(exi_inputs.rmi_list, rectified_input): - rowid_dict[rmi] = rowids - - compute_edges = exi_inputs.flat_compute_rmi_edges() - if _debug: - logger.info(' * rectified_input=%s' % ut.trunc_repr(rectified_input)) - logger.info(' * compute_edges=%s' % ut.repr2(compute_edges, nl=2)) - - for count, (input_nodes, output_node) in enumerate(compute_edges, start=1): - if _debug: - ut.cprint( - ' * COMPUTING %d/%d EDGE %r -- %r' - % (count, len(compute_edges), input_nodes, output_node), - 'blue', - ) - tablekey = output_node.tablename - table = depc[tablekey] - input_nodes_ = input_nodes - if _debug: - logger.info( - 'table.parent_id_tablenames = %r' % (table.parent_id_tablenames,) - ) - logger.info('input_nodes_ = %r' % (input_nodes_,)) - input_multi_flags = [ - node.ismulti and node in exi_inputs.rmi_list for node in input_nodes_ - ] + logger.debug('Enter get_parent_rowids') + logger.debug(' * target_tablename = %r' % (target_tablename,)) + logger.debug(' * input_tuple=%s' % (ut.trunc_repr(input_tuple),)) + logger.debug(' * config = %r' % (config,)) + target_table = self[target_tablename] - # Args currently go in like this: - # args = [..., (pid_{i,1}, pid_{i,2}, ..., pid_{i,M}), ...] - # They get converted into - # argsT = [... (pid_{1,j}, ... pid_{N,j}) ...] - # i = row, j = col - sig_multi_flags = table.get_parent_col_attr('ismulti') - parent_rowidsT = ut.take(rowid_dict, input_nodes_) - parent_rowids_ = [] - # TODO: will need to figure out which columns to zip and which - # columns to product (ie take product over ones that have 1 - # item, and zip ones that have equal amount of items) - for flag1, flag2, rowidsT in zip( - sig_multi_flags, input_multi_flags, parent_rowidsT - ): - if flag1 and flag2: - parent_rowids_.append(rowidsT) - elif flag1 and not flag2: - parent_rowids_.append([rowidsT]) - elif not flag1 and flag2: - assert len(rowidsT) == 1 - parent_rowids_.append(rowidsT[0]) - else: - parent_rowids_.append(rowidsT) - # Assume that we are either given corresponding lists or single values - # that must be broadcast. - rowlens = list(map(len, parent_rowids_)) - maxlen = max(rowlens) - parent_rowids2_ = [ - r * maxlen if len(r) == 1 else r for r in parent_rowids_ - ] - _parent_rowids = list(zip(*parent_rowids2_)) - # _parent_rowids = list(ut.product(*parent_rowids_)) - - if _debug: - logger.info( - 'parent_rowids_ = %s' - % ( - ut.repr4( - [ut.trunc_repr(ids_) for ids_ in parent_rowids_], - strvals=True, - ) - ) + # TODO: Expand to the appropriate given inputs + if _hack_rootmost: + # Hack: if true, we are given inputs in rootmost form + exi_inputs = target_table.rootmost_inputs + else: + # otherwise we are given inputs in totalroot form + exi_inputs = target_table.rootmost_inputs.total_expand() + logger.debug(' * exi_inputs=%s' % (exi_inputs,)) + + rectified_input = self.rectify_input_tuple(exi_inputs, input_tuple) + + rowid_dict = {} + for rmi, rowids in zip(exi_inputs.rmi_list, rectified_input): + rowid_dict[rmi] = rowids + + compute_edges = exi_inputs.flat_compute_rmi_edges() + logger.debug(' * rectified_input=%s' % ut.trunc_repr(rectified_input)) + logger.debug(' * compute_edges=%s' % ut.repr2(compute_edges, nl=2)) + + for count, (input_nodes, output_node) in enumerate(compute_edges, start=1): + logger.debug( + ' * COMPUTING %d/%d EDGE %r -- %r' + % (count, len(compute_edges), input_nodes, output_node), + ) + tablekey = output_node.tablename + table = self[tablekey] + input_nodes_ = input_nodes + logger.debug( + 'table.parent_id_tablenames = %r' % (table.parent_id_tablenames,) + ) + logger.debug('input_nodes_ = %r' % (input_nodes_,)) + input_multi_flags = [ + node.ismulti and node in exi_inputs.rmi_list for node in input_nodes_ + ] + + # Args currently go in like this: + # args = [..., (pid_{i,1}, pid_{i,2}, ..., pid_{i,M}), ...] + # They get converted into + # argsT = [... (pid_{1,j}, ... pid_{N,j}) ...] + # i = row, j = col + sig_multi_flags = table.get_parent_col_attr('ismulti') + parent_rowidsT = ut.take(rowid_dict, input_nodes_) + parent_rowids_ = [] + # TODO: will need to figure out which columns to zip and which + # columns to product (ie take product over ones that have 1 + # item, and zip ones that have equal amount of items) + for flag1, flag2, rowidsT in zip( + sig_multi_flags, input_multi_flags, parent_rowidsT + ): + if flag1 and flag2: + parent_rowids_.append(rowidsT) + elif flag1 and not flag2: + parent_rowids_.append([rowidsT]) + elif not flag1 and flag2: + assert len(rowidsT) == 1 + parent_rowids_.append(rowidsT[0]) + else: + parent_rowids_.append(rowidsT) + # Assume that we are either given corresponding lists or single values + # that must be broadcast. + rowlens = list(map(len, parent_rowids_)) + maxlen = max(rowlens) + parent_rowids2_ = [r * maxlen if len(r) == 1 else r for r in parent_rowids_] + _parent_rowids = list(zip(*parent_rowids2_)) + # _parent_rowids = list(ut.product(*parent_rowids_)) + + logger.debug( + 'parent_rowids_ = %s' + % ( + ut.repr4( + [ut.trunc_repr(ids_) for ids_ in parent_rowids_], + strvals=True, ) - logger.info( - 'parent_rowids2_ = %s' - % ( - ut.repr4( - [ut.trunc_repr(ids_) for ids_ in parent_rowids2_], - strvals=True, - ) - ) + ) + ) + logger.debug( + 'parent_rowids2_ = %s' + % ( + ut.repr4( + [ut.trunc_repr(ids_) for ids_ in parent_rowids2_], + strvals=True, ) - logger.info( - '_parent_rowids = %s' - % ( - ut.truncate_str( - ut.repr4( - [ut.trunc_repr(ids_) for ids_ in _parent_rowids], - strvals=True, - ) - ) + ) + ) + logger.debug( + '_parent_rowids = %s' + % ( + ut.truncate_str( + ut.repr4( + [ut.trunc_repr(ids_) for ids_ in _parent_rowids], + strvals=True, ) ) + ) + ) - if _debug: - ut.cprint('-------------', 'blue') - if output_node.tablename != target_tablename: - # Get table configuration - config_ = depc._ensure_config(tablekey, config, _debug) + if output_node.tablename != target_tablename: + # Get table configuration + config_ = self._ensure_config(tablekey, config) - output_rowids = table.get_rowid( - _parent_rowids, config=config_, recompute=_recompute, **_kwargs - ) - rowid_dict[output_node] = output_rowids - # table.get_model_inputs(table.get_model_uuid(output_rowids)[0]) - else: - # We are only computing up to the parents of the table here. - parent_rowids = _parent_rowids - break - # rowids = rowid_dict[output_node] - return parent_rowids + output_rowids = table.get_rowid( + _parent_rowids, config=config_, recompute=_recompute, **_kwargs + ) + rowid_dict[output_node] = output_rowids + # table.get_model_inputs(table.get_model_uuid(output_rowids)[0]) + else: + # We are only computing up to the parents of the table here. + parent_rowids = _parent_rowids + break + # rowids = rowid_dict[output_node] + return parent_rowids - def check_rowids(depc, tablename, input_tuple, config={}): + def check_rowids(self, tablename, input_tuple, config={}): """ Returns a list of flags where True means the row has been computed and False means that it needs to be computed. """ - existing_rowids = depc.get_rowids( + existing_rowids = self.get_rowids( tablename, input_tuple, config=config, ensure=False ) flags = ut.flag_not_None_items(existing_rowids) return flags - def get_rowids(depc, tablename, input_tuple, **rowid_kw): + def get_rowids(self, tablename, input_tuple, **rowid_kw): """ Used to get tablename rowids. Ensures rows exist unless ensure=False. rowids uniquely specify parent inputs and a configuration. @@ -660,12 +672,11 @@ def get_rowids(depc, tablename, input_tuple, **rowid_kw): >>> root_rowids = [1, 2, 3] >>> root_rowids2 = [(4, 5, 6, 7)] >>> root_rowids3 = root_rowids2 - >>> _debug = True >>> tablename = 'smk_match' >>> input_tuple = (root_rowids, root_rowids2, root_rowids3) >>> target_table = depc[tablename] >>> inputs = target_table.rootmost_inputs.total_expand() - >>> depc.get_rowids(tablename, input_tuple, _debug=_debug) + >>> depc.get_rowids(tablename, input_tuple) >>> depc.print_all_tables() Example: @@ -696,16 +707,14 @@ def get_rowids(depc, tablename, input_tuple, **rowid_kw): >>> assert recomp_rowids == initial_rowids, 'rowids should not change due to recompute' """ target_tablename = tablename - _debug = rowid_kw.get('_debug', False) - _debug = depc._debug if _debug is None else _debug _kwargs = rowid_kw.copy() config = _kwargs.pop('config', {}) _hack_rootmost = _kwargs.pop('_hack_rootmost', False) _recompute_all = _kwargs.pop('recompute_all', False) recompute = _kwargs.pop('recompute', _recompute_all) - table = depc[target_tablename] + table = self[target_tablename] - parent_rowids = depc.get_parent_rowids( + parent_rowids = self.get_parent_rowids( target_tablename, input_tuple, config=config, @@ -713,16 +722,15 @@ def get_rowids(depc, tablename, input_tuple, **rowid_kw): **_kwargs, ) - with ut.Indenter('[GetRowId-%s]' % (target_tablename,), enabled=_debug): - config_ = depc._ensure_config(target_tablename, config, _debug) - rowids = table.get_rowid( - parent_rowids, config=config_, recompute=recompute, **_kwargs - ) + config_ = self._ensure_config(target_tablename, config) + rowids = table.get_rowid( + parent_rowids, config=config_, recompute=recompute, **_kwargs + ) return rowids @ut.accepts_scalar_input2(argx_list=[1]) def get( - depc, + self, tablename, root_rowids, colnames=None, @@ -768,7 +776,6 @@ def get( >>> depc = testdata_depc3(True) >>> exec(ut.execstr_funckw(depc.get), globals()) >>> aids = [1, 2, 3] - >>> _debug = True >>> tablename = 'labeler' >>> root_rowids = aids >>> prop_list = depc.get( @@ -785,7 +792,6 @@ def get( >>> depc = testdata_depc3(True) >>> exec(ut.execstr_funckw(depc.get), globals()) >>> aids = [1, 2, 3] - >>> _debug = True >>> tablename = 'smk_match' >>> tablename = 'vocab' >>> table = depc[tablename] @@ -804,7 +810,6 @@ def get( >>> depc = testdata_depc3(True) >>> exec(ut.execstr_funckw(depc.get), globals()) >>> aids = [1, 2, 3] - >>> _debug = True >>> depc = testdata_depc() >>> tablename = 'chip' >>> table = depc[tablename] @@ -818,90 +823,79 @@ def get( >>> prop_list3 = depc.get(tablename, root_rowids) >>> assert np.all(prop_list1[0][1] == prop_list3[0][1]), 'computed same info' """ - if tablename == depc.root_tablename: - return depc.root_getters[colnames](root_rowids) + if tablename == self.root_tablename: + return self.root_getters[colnames](root_rowids) # pass - _debug = depc._debug if _debug is None else _debug - with ut.Indenter('[GetProp-%s]' % (tablename,), enabled=_debug): - if _debug: - logger.info(' * tablename=%s' % (tablename)) - logger.info(' * root_rowids=%s' % (ut.trunc_repr(root_rowids))) - logger.info(' * colnames = %r' % (colnames,)) - logger.info(' * config = %r' % (config,)) - - if hack_paths and not ensure and not read_extern: - # HACK: should be able to not compute rows to get certain properties - from os.path import join - - # recompute_ = recompute or recompute_all - parent_rowids = depc.get_parent_rowids( - tablename, - root_rowids, - config=config, - ensure=True, - _debug=None, - recompute_all=False, - eager=True, - nInput=None, - ) - config_ = depc._ensure_config(tablename, config) - if _debug: - logger.info(' * (ensured) config_ = %r' % (config_,)) - table = depc[tablename] - extern_dpath = table.extern_dpath - ut.ensuredir(extern_dpath, verbose=False or table.depc._debug) - fname_list = table.get_extern_fnames( - parent_rowids, config=config_, extern_col_index=0 - ) - fpath_list = [join(extern_dpath, fname) for fname in fname_list] - return fpath_list + logger.debug(' * tablename=%s' % (tablename)) + logger.debug(' * root_rowids=%s' % (ut.trunc_repr(root_rowids))) + logger.debug(' * colnames = %r' % (colnames,)) + logger.debug(' * config = %r' % (config,)) - if nInput is None and ut.is_listlike(root_rowids): - nInput = len(root_rowids) + if hack_paths and not ensure and not read_extern: + # HACK: should be able to not compute rows to get certain properties + from os.path import join - rowid_kw = dict( + # recompute_ = recompute or recompute_all + parent_rowids = self.get_parent_rowids( + tablename, + root_rowids, config=config, - nInput=nInput, - eager=eager, - ensure=ensure, - recompute=recompute, - recompute_all=recompute_all, - _debug=_debug, + ensure=True, + recompute_all=False, + eager=True, + nInput=None, ) - - rowdata_kw = dict( - read_extern=read_extern, - _debug=_debug, - num_retries=num_retries, - eager=eager, - ensure=ensure, - nInput=nInput, + config_ = self._ensure_config(tablename, config) + logger.debug(' * (ensured) config_ = %r' % (config_,)) + table = self[tablename] + extern_dpath = table.extern_dpath + ut.ensuredir(extern_dpath) + fname_list = table.get_extern_fnames( + parent_rowids, config=config_, extern_col_index=0 ) + fpath_list = [join(extern_dpath, fname) for fname in fname_list] + return fpath_list - input_tuple = root_rowids + if nInput is None and ut.is_listlike(root_rowids): + nInput = len(root_rowids) - for trynum in range(num_retries + 1): - try: - table = depc[tablename] - # Vectorized get of properties - tbl_rowids = depc.get_rowids(tablename, input_tuple, **rowid_kw) - if _debug: - logger.info( - '[depc.get] tbl_rowids = %s' % (ut.trunc_repr(tbl_rowids),) - ) - prop_list = table.get_row_data(tbl_rowids, colnames, **rowdata_kw) - except depcache_table.ExternalStorageException: - logger.info('!!* Hit ExternalStorageException') - if trynum == num_retries: - raise - else: - break - if _debug: - logger.info('* return prop_list=%s' % (ut.trunc_repr(prop_list),)) + rowid_kw = dict( + config=config, + nInput=nInput, + eager=eager, + ensure=ensure, + recompute=recompute, + recompute_all=recompute_all, + ) + + rowdata_kw = dict( + read_extern=read_extern, + num_retries=num_retries, + eager=eager, + ensure=ensure, + nInput=nInput, + ) + + input_tuple = root_rowids + + for trynum in range(num_retries + 1): + try: + table = self[tablename] + # Vectorized get of properties + tbl_rowids = self.get_rowids(tablename, input_tuple, **rowid_kw) + logger.debug('[depc.get] tbl_rowids = %s' % (ut.trunc_repr(tbl_rowids),)) + prop_list = table.get_row_data(tbl_rowids, colnames, **rowdata_kw) + except depcache_table.ExternalStorageException: + logger.info('!!* Hit ExternalStorageException') + if trynum == num_retries: + raise + else: + break + logger.debug('* return prop_list=%s' % (ut.trunc_repr(prop_list),)) return prop_list def get_native( - depc, tablename, tbl_rowids, colnames=None, _debug=None, read_extern=True + self, tablename, tbl_rowids, colnames=None, _debug=None, read_extern=True ): """ Gets data using internal ids, which is faster if you have them. @@ -939,47 +933,42 @@ def get_native( >>> print('chips = %r' % (chips,)) """ tbl_rowids = list(tbl_rowids) - _debug = depc._debug if _debug is None else _debug - with ut.Indenter('[GetNative %s]' % (tablename,), enabled=_debug): - if _debug: - logger.info(' * tablename = %r' % (tablename,)) - logger.info(' * colnames = %r' % (colnames,)) - logger.info(' * tbl_rowids=%s' % (ut.trunc_repr(tbl_rowids))) - table = depc[tablename] - # import utool - # with utool.embed_on_exception_context: - # try: - prop_list = table.get_row_data( - tbl_rowids, colnames, _debug=_debug, read_extern=read_extern - ) - # except depcache_table.ExternalStorageException: - # # This code is a bit rendant and would probably live better elsewhere - # # Also need to fix issues if more than one column specified - # extern_uris = table.get_row_data( - # tbl_rowids, colnames, _debug=_debug, read_extern=False, - # delete_on_fail=True, ensure=False) - # from os.path import exists - # error_flags = [exists(uri) for uri in extern_uris] - # redo_rowids = ut.compress(tbl_rowids, ut.not_list(error_flags)) - # parent_rowids = table.get_parent_rowids(redo_rowids) - # # config_rowids = table.get_row_cfgid(redo_rowids) - # configs = table.get_row_configs(redo_rowids) - # assert ut.allsame(list(map(id, configs))), 'more than one config not yet supported' - # config = configs[0] - # table.get_rowid(parent_rowids, recompute=True, config=config) - - # # TRY ONE MORE TIME - # prop_list = table.get_row_data(tbl_rowids, colnames, _debug=_debug, - # read_extern=read_extern, - # delete_on_fail=False) + logger.debug(' * tablename = %r' % (tablename,)) + logger.debug(' * colnames = %r' % (colnames,)) + logger.debug(' * tbl_rowids=%s' % (ut.trunc_repr(tbl_rowids))) + table = self[tablename] + # import utool + # with utool.embed_on_exception_context: + # try: + prop_list = table.get_row_data(tbl_rowids, colnames, read_extern=read_extern) + # except depcache_table.ExternalStorageException: + # # This code is a bit rendant and would probably live better elsewhere + # # Also need to fix issues if more than one column specified + # extern_uris = table.get_row_data( + # tbl_rowids, colnames, read_extern=False, + # delete_on_fail=True, ensure=False) + # from os.path import exists + # error_flags = [exists(uri) for uri in extern_uris] + # redo_rowids = ut.compress(tbl_rowids, ut.not_list(error_flags)) + # parent_rowids = table.get_parent_rowids(redo_rowids) + # # config_rowids = table.get_row_cfgid(redo_rowids) + # configs = table.get_row_configs(redo_rowids) + # assert ut.allsame(list(map(id, configs))), 'more than one config not yet supported' + # config = configs[0] + # table.get_rowid(parent_rowids, recompute=True, config=config) + + # # TRY ONE MORE TIME + # prop_list = table.get_row_data(tbl_rowids, colnames, + # read_extern=read_extern, + # delete_on_fail=False) return prop_list - def get_config_history(depc, tablename, root_rowids, config=None): + def get_config_history(self, tablename, root_rowids, config=None): # Vectorized get of properties - tbl_rowids = depc.get_rowids(tablename, root_rowids, config=config) - return depc[tablename].get_config_history(tbl_rowids) + tbl_rowids = self.get_rowids(tablename, root_rowids, config=config) + return self[tablename].get_config_history(tbl_rowids) - def get_root_rowids(depc, tablename, native_rowids): + def get_root_rowids(self, tablename, native_rowids): r""" Args: tablename (str): @@ -1006,55 +995,58 @@ def get_root_rowids(depc, tablename, native_rowids): >>> assert ancestor_rowids1 == root_rowids, 'should have same root' >>> assert ancestor_rowids2 == root_rowids, 'should have same root' """ - return depc.get_ancestor_rowids(tablename, native_rowids, depc.root) + return self.get_ancestor_rowids(tablename, native_rowids, self.root) - def get_ancestor_rowids(depc, tablename, native_rowids, ancestor_tablename=None): + def get_ancestor_rowids(self, tablename, native_rowids, ancestor_tablename=None): """ ancestor_tablename = depc.root; native_rowids = cid_list; tablename = const.CHIP_TABLE """ if ancestor_tablename is None: - ancestor_tablename = depc.root - table = depc[tablename] + ancestor_tablename = self.root + table = self[tablename] ancestor_rowids = table.get_ancestor_rowids(native_rowids, ancestor_tablename) return ancestor_rowids - def new_request(depc, tablename, qaids, daids, cfgdict=None): + def new_request(self, tablename, qaids, daids, cfgdict=None): """ creates a request for data that can be executed later """ logger.info('[depc] NEW %s request' % (tablename,)) - requestclass = depc.requestclass_dict[tablename] - request = requestclass.new(depc, qaids, daids, cfgdict, tablename=tablename) + requestclass = self.requestclass_dict[tablename] + request = requestclass.new(self, qaids, daids, cfgdict, tablename=tablename) return request # ----------------------------- # STATE MODIFIERS - def delete_property(depc, tablename, root_rowids, config=None, _debug=False): + def delete_property(self, tablename, root_rowids, config=None, _debug=False): """ Deletes the rowids of `tablename` that correspond to `root_rowids` using `config`. FIXME: make this work for all configs """ - rowid_list = depc.get_rowids( - tablename, root_rowids, config=config, ensure=False, _debug=_debug + rowid_list = self.get_rowids( + tablename, + root_rowids, + config=config, + ensure=False, ) - table = depc[tablename] + table = self[tablename] num_deleted = table.delete_rows(rowid_list) return num_deleted - def delete_property_all(depc, tablename, root_rowids, _debug=False): + def delete_property_all(self, tablename, root_rowids, _debug=False): """ Deletes the rowids of `tablename` that correspond to `root_rowids` using `config`. FIXME: make this work for all configs """ - table = depc[tablename] + table = self[tablename] all_rowid_list = table._get_all_rowids() if len(all_rowid_list) == 0: return 0 - ancestor_rowid_list = depc.get_ancestor_rowids(tablename, all_rowid_list) + ancestor_rowid_list = self.get_ancestor_rowids(tablename, all_rowid_list) rowid_list = [] root_rowids_set = set(root_rowids) @@ -1065,90 +1057,41 @@ def delete_property_all(depc, tablename, root_rowids, _debug=False): num_deleted = table.delete_rows(rowid_list) return num_deleted - -@six.add_metaclass(ut.ReloadingMetaclass) -class DependencyCache(_CoreDependencyCache, ut.NiceRepr): - """ - Currently, to use this class a user must: - * on root modification, call depc.on_root_modified - * use decorators to register relevant functions - """ - - def __init__( - depc, - root_tablename=None, - cache_dpath='./DEPCACHE', - controller=None, - default_fname=None, - # root_asobject=None, - get_root_uuid=None, - root_getters=None, - use_globals=True, - ): - if default_fname is None: - default_fname = root_tablename + '_primary_cache' - # default_fname = ':memory:' - depc.root_getters = root_getters - # Root of all dependencies - depc.root_tablename = root_tablename - # Directory all cachefiles are stored in - depc.cache_dpath = ut.truepath(cache_dpath) - # Parent (ibs) controller - depc.controller = controller - # Internal dictionary of dependant tables - depc.cachetable_dict = {} - depc.configclass_dict = {} - depc.requestclass_dict = {} - depc.resultclass_dict = {} - # Mapping of different files properties are stored in - depc.fname_to_db = {} - # Function to map a root rowid to an object - # depc._root_asobject = root_asobject - depc._use_globals = use_globals - depc.default_fname = default_fname - if get_root_uuid is None: - logger.info('WARNING NEED UUID FUNCTION') - # HACK - get_root_uuid = ut.identity - depc.get_root_uuid = get_root_uuid - depc.delete_exclude_tables = {} - depc._debug = ut.get_argflag(('--debug-depcache', '--debug-depc')) - - def get_tablenames(depc): - return list(depc.cachetable_dict.keys()) + def get_tablenames(self): + return list(self.cachetable_dict.keys()) @property - def tables(depc): - return list(depc.cachetable_dict.values()) + def tables(self): + return list(self.cachetable_dict.values()) @property - def tablenames(depc): - return depc.get_tablenames() + def tablenames(self): + return self.get_tablenames() - def print_schemas(depc): - for fname, db in depc.fname_to_db.items(): - logger.info('fname = %r' % (fname,)) + def print_schemas(self): + for name, db in self._db_by_name.items(): + logger.info('name = %r' % (name,)) db.print_schema() - # def print_table_csv(depc, tablename): - # depc[tablename] + # def print_table_csv(self, tablename): + # self[tablename] - def print_table(depc, tablename): - depc[tablename].print_table() + def print_table(self, tablename): + self[tablename].print_table() - def print_all_tables(depc): - for tablename, table in depc.cachetable_dict.items(): + def print_all_tables(self): + for tablename, table in self.cachetable_dict.items(): table.print_table() # db = table.db # db.print_table_csv(tablename) - def print_config_tables(depc): - for fname in depc.fname_to_db: + def print_config_tables(self): + for name in self._db_by_name: logger.info('---') - logger.info('db_fname = %r' % (fname,)) - depc.fname_to_db[fname].print_table_csv('config') + logger.info('db_name = %r' % (name,)) + self._db_by_name[name].print_table_csv('config') - def get_edges(depc, data=False): + def get_edges(self, data=False): """ edges for networkx structure """ @@ -1188,26 +1131,26 @@ def get_edgedata(tablekey, parentkey, parent_data): edges = [ (parentkey, tablekey, get_edgedata(tablekey, parentkey, parent_data)) - for tablekey, table in depc.cachetable_dict.items() + for tablekey, table in self.cachetable_dict.items() for parentkey, parent_data in table.parents(data=True) ] else: edges = [ (parentkey, tablekey) - for tablekey, table in depc.cachetable_dict.items() + for tablekey, table in self.cachetable_dict.items() for parentkey in table.parents(data=False) ] return edges - def get_implicit_edges(depc, data=False): + def get_implicit_edges(self, data=False): """ Edges defined by subconfigurations """ # add implicit edges implicit_edges = [] # Map config classes to tablenames - _inverted_ccdict = ut.invert_dict(depc.configclass_dict) - for tablename2, configclass in depc.configclass_dict.items(): + _inverted_ccdict = ut.invert_dict(self.configclass_dict) + for tablename2, configclass in self.configclass_dict.items(): cfg = configclass() subconfigs = cfg.get_sub_config_list() if subconfigs is not None and len(subconfigs) > 0: @@ -1219,7 +1162,7 @@ def get_implicit_edges(depc, data=False): return implicit_edges @ut.memoize - def make_graph(depc, **kwargs): + def make_graph(self, **kwargs): """ Constructs a networkx representation of the dependency graph @@ -1275,13 +1218,13 @@ def make_graph(depc, **kwargs): # graph = nx.DiGraph() graph = nx.MultiDiGraph() - nodes = list(depc.cachetable_dict.keys()) - edges = depc.get_edges(data=True) + nodes = list(self.cachetable_dict.keys()) + edges = self.get_edges(data=True) graph.add_nodes_from(nodes) graph.add_edges_from(edges) if kwargs.get('implicit', True): - implicit_edges = depc.get_implicit_edges(data=True) + implicit_edges = self.get_implicit_edges(data=True) graph.add_edges_from(implicit_edges) shape_dict = { @@ -1304,8 +1247,8 @@ def make_graph(depc, **kwargs): } def _node_attrs(dict_): - props = {k: dict_['node'] for k, v in depc.cachetable_dict.items()} - props[depc.root] = dict_['root'] + props = {k: dict_['node'] for k, v in self.cachetable_dict.items()} + props[self.root] = dict_['root'] return props nx.set_node_attributes(graph, name='color', values=_node_attrs(color_dict)) @@ -1430,40 +1373,40 @@ def _node_attrs(dict_): return graph @property - def graph(depc): - return depc.make_graph() + def graph(self): + return self.make_graph() @property - def explicit_graph(depc): - return depc.make_graph(implicit=False) + def explicit_graph(self): + return self.make_graph(implicit=False) @property - def reduced_graph(depc): - return depc.make_graph(reduced=True) + def reduced_graph(self): + return self.make_graph(reduced=True) - def show_graph(depc, reduced=False, **kwargs): + def show_graph(self, reduced=False, **kwargs): """ Helper "fluff" function """ import wbia.plottool as pt - graph = depc.make_graph(reduced=reduced) + graph = self.make_graph(reduced=reduced) if ut.is_developer(): ut.ensureqt() kwargs['layout'] = 'agraph' pt.show_nx(graph, **kwargs) - def __nice__(depc): - infostr_ = 'nTables=%d' % len(depc.cachetable_dict) - return '(%s) %s' % (depc.root_tablename, infostr_) + def __nice__(self): + infostr_ = 'nTables=%d' % len(self.cachetable_dict) + return '(%s) %s' % (self.root_tablename, infostr_) - def __getitem__(depc, tablekey): - return depc.cachetable_dict[tablekey] + def __getitem__(self, tablekey): + return self.cachetable_dict[tablekey] @property - def root(depc): - return depc.root_tablename + def root(self): + return self.root_tablename def delete_root( - depc, + self, root_rowids, delete_extern=None, _debug=False, @@ -1486,60 +1429,60 @@ def delete_root( >>> depc = testdata_depc() >>> exec(ut.execstr_funckw(depc.delete_root), globals()) >>> root_rowids = [1] - >>> depc.delete_root(root_rowids, _debug=0) + >>> depc.delete_root(root_rowids) >>> depc.get('fgweight', [1]) - >>> depc.delete_root(root_rowids, _debug=0) + >>> depc.delete_root(root_rowids) """ - # graph = depc.make_graph(implicit=False) + # graph = self.make_graph(implicit=False) # hack # check to make sure child does not have another parent - rowid_dict = depc.get_allconfig_descendant_rowids( + rowid_dict = self.get_allconfig_descendant_rowids( root_rowids, table_config_filter ) - # children = [child for child in graph.succ[depc.root_tablename] + # children = [child for child in graph.succ[self.root_tablename] # if sum([len(e) for e in graph.pred[child].values()]) == 1] - # depc.delete_property(tablename, root_rowids, _debug=_debug) + # self.delete_property(tablename, root_rowids) num_deleted = 0 for tablename, table_rowids in rowid_dict.items(): - if tablename == depc.root: + if tablename == self.root: continue # Specific prop exclusion - delete_exclude_table_set_prop = depc.delete_exclude_tables.get(prop, []) - delete_exclude_table_set_all = depc.delete_exclude_tables.get(None, []) + delete_exclude_table_set_prop = self.delete_exclude_tables.get(prop, []) + delete_exclude_table_set_all = self.delete_exclude_tables.get(None, []) if ( tablename in delete_exclude_table_set_prop or tablename in delete_exclude_table_set_all ): continue - table = depc[tablename] + table = self[tablename] num_deleted += table.delete_rows(table_rowids, delete_extern=delete_extern) return num_deleted - def register_delete_table_exclusion(depc, tablename, prop): - if prop not in depc.delete_exclude_tables: - depc.delete_exclude_tables[prop] = set([]) - depc.delete_exclude_tables[prop].add(tablename) - args = (ut.repr3(depc.delete_exclude_tables),) + def register_delete_table_exclusion(self, tablename, prop): + if prop not in self.delete_exclude_tables: + self.delete_exclude_tables[prop] = set([]) + self.delete_exclude_tables[prop].add(tablename) + args = (ut.repr3(self.delete_exclude_tables),) logger.info('[depc] Updated delete tables: %s' % args) - def get_allconfig_descendant_rowids(depc, root_rowids, table_config_filter=None): + def get_allconfig_descendant_rowids(self, root_rowids, table_config_filter=None): import networkx as nx - # list(nx.topological_sort(nx.bfs_tree(graph, depc.root))) - # decendants = nx.descendants(graph, depc.root) + # list(nx.topological_sort(nx.bfs_tree(graph, self.root))) + # decendants = nx.descendants(graph, self.root) # raise NotImplementedError() - graph = depc.make_graph(implicit=True) - root = depc.root + graph = self.make_graph(implicit=True) + root = self.root rowid_dict = {} rowid_dict[root] = root_rowids # Find all rowids that inherit from the specific root rowids - sinks = list(ut.nx_sink_nodes(nx.bfs_tree(graph, depc.root))) + sinks = list(ut.nx_sink_nodes(nx.bfs_tree(graph, self.root))) for target_tablename in sinks: path = nx.shortest_path(graph, root, target_tablename) for parent, child in ut.itertwo(path): - child_table = depc[child] + child_table = self[child] relevant_col_attrs = [ attrs for attrs in child_table.parent_col_attrs @@ -1590,7 +1533,7 @@ def get_allconfig_descendant_rowids(depc, root_rowids, table_config_filter=None) ) return rowid_dict - def notify_root_changed(depc, root_rowids, prop, force_delete=False): + def notify_root_changed(self, root_rowids, prop, force_delete=False): """ this is where we are notified that a "registered" root property has changed. @@ -1600,18 +1543,18 @@ def notify_root_changed(depc, root_rowids, prop, force_delete=False): % (prop, len(root_rowids)) ) # for key in tables_depending_on(prop) - # depc.delete_property(key, root_rowids) + # self.delete_property(key, root_rowids) # TODO: check which properties were invalidated by this prop # TODO; remove invalidated properties if force_delete: - depc.delete_root(root_rowids, prop=prop) + self.delete_root(root_rowids, prop=prop) - def clear_all(depc): - logger.info('Clearning all cached data in %r' % (depc,)) - for table in depc.cachetable_dict.values(): + def clear_all(self): + logger.info('Clearning all cached data in %r' % (self,)) + for table in self.cachetable_dict.values(): table.clear_table() - def make_root_info_uuid(depc, root_rowids, info_props): + def make_root_info_uuid(self, root_rowids, info_props): """ Creates a uuid that depends on certain properties of the root object. This is used for implicit cache invalidation because, if those @@ -1626,24 +1569,24 @@ def make_root_info_uuid(depc, root_rowids, info_props): >>> info_props = ['image_uuid', 'verts', 'theta'] >>> info_props = ['image_uuid', 'verts', 'theta', 'name', 'species', 'yaw'] """ - getters = ut.dict_take(depc.root_getters, info_props) + getters = ut.dict_take(self.root_getters, info_props) infotup_list = zip(*[getter(root_rowids) for getter in getters]) info_uuid_list = [ut.augment_uuid(*tup) for tup in infotup_list] return info_uuid_list - def get_uuids(depc, tablename, root_rowids, config=None): + def get_uuids(self, tablename, root_rowids, config=None): """ # TODO: Make uuids for dependant object based on root uuid and path of # construction. """ - if tablename == depc.root: - uuid_list = depc.get_root_uuid(root_rowids) + if tablename == self.root: + uuid_list = self.get_root_uuid(root_rowids) return uuid_list - get_native_property = _CoreDependencyCache.get_native - get_property = _CoreDependencyCache.get + get_native_property = get_native + get_property = get - def stacked_config(depc, source, dest, config): + def stacked_config(self, source, dest, config): r""" CommandLine: python -m dtool.depcache_control stacked_config --show @@ -1664,15 +1607,15 @@ def stacked_config(depc, source, dest, config): if config is None: config = {} if source is None: - source = depc.root - graph = depc.make_graph(implicit=True) + source = self.root + graph = self.make_graph(implicit=True) requires_tables = ut.setdiff( ut.nx_all_nodes_between(graph, source, dest), [source] ) - # requires_tables = ut.setdiff(ut.nx_all_nodes_between(depc.graph, 'annotations', 'featweight'), ['annotations']) - requires_tables = ut.nx_topsort_nodes(depc.graph, requires_tables) + # requires_tables = ut.setdiff(ut.nx_all_nodes_between(self.graph, 'annotations', 'featweight'), ['annotations']) + requires_tables = ut.nx_topsort_nodes(self.graph, requires_tables) requires_configs = [ - depc.configclass_dict[tblname](**config) for tblname in requires_tables + self.configclass_dict[tblname](**config) for tblname in requires_tables ] # cfgstr_list = [cfg.get_cfgstr() for cfg in requires_configs] stacked_config = base.StackedConfig(requires_configs) diff --git a/wbia/dtool/depcache_table.py b/wbia/dtool/depcache_table.py index 55680120f4..b0934e1b28 100644 --- a/wbia/dtool/depcache_table.py +++ b/wbia/dtool/depcache_table.py @@ -32,7 +32,7 @@ from six.moves import zip, range from wbia.dtool import sqlite3 as lite -from wbia.dtool.sql_control import SQLDatabaseController +from wbia.dtool.sql_control import SQLDatabaseController, compare_coldef_lists from wbia.dtool.types import TYPE_TO_SQLTYPE import time @@ -59,6 +59,26 @@ GRACE_PERIOD = ut.get_argval('--grace', type_=int, default=0) +class TableOutOfSyncError(Exception): + """Raised when the code's table definition doesn't match the defition in the database""" + + def __init__(self, db, tablename, extended_msg): + db_name = db._engine.url.database + + if getattr(db, 'schema', None): + under_schema = f"under schema '{db.schema}' " + else: + # Not a table under a schema + under_schema = '' + msg = ( + f"database '{db_name}' " + + under_schema + + f"with table '{tablename}' does not match the code definition; " + f"it's likely the database needs upgraded; {extended_msg}" + ) + super().__init__(msg) + + class ExternType(ub.NiceRepr): """ Type to denote an external resource not saved in an SQL table @@ -159,19 +179,23 @@ def ensure_config_table(db): else: current_state = db.get_table_autogen_dict(CONFIG_TABLE) new_state = config_addtable_kw - if current_state['coldef_list'] != new_state['coldef_list']: - if predrop_grace_period(CONFIG_TABLE): - db.drop_all_tables() - db.add_table(**new_state) - else: - raise NotImplementedError('Need to be able to modify tables') + results = compare_coldef_lists( + current_state['coldef_list'], new_state['coldef_list'] + ) + if results: + current_coldef, new_coldef = results + raise TableOutOfSyncError( + db, + CONFIG_TABLE, + f'Current schema: {current_coldef} Expected schema: {new_coldef}', + ) @ut.reloadable_class class _TableConfigHelper(object): """ helper for configuration table """ - def get_parent_rowids(table, rowid_list): + def get_parent_rowids(self, rowid_list): """ Args: rowid_list (list): native table rowids @@ -185,12 +209,12 @@ def get_parent_rowids(table, rowid_list): >>> # Then add two items to this table, and for each item >>> # Find their parent inputs """ - parent_rowids = table.get_internal_columns( - rowid_list, table.parent_id_colnames, unpack_scalars=True, keepwrap=True + parent_rowids = self.get_internal_columns( + rowid_list, self.parent_id_colnames, unpack_scalars=True, keepwrap=True ) return parent_rowids - def get_parent_rowargs(table, rowid_list): + def get_parent_rowargs(self, rowid_list): """ Args: rowid_list (list): native table rowids @@ -204,18 +228,18 @@ def get_parent_rowargs(table, rowid_list): >>> # Then add two items to this table, and for each item >>> # Find their parent inputs """ - parent_rowids = table.get_parent_rowids(rowid_list) - parent_ismulti = table.get_parent_col_attr('ismulti') + parent_rowids = self.get_parent_rowids(rowid_list) + parent_ismulti = self.get_parent_col_attr('ismulti') if any(parent_ismulti): # If any of the parent columns are multi-indexes, then lookup the # mapping from the aggregated uuid to the expanded rowid set. parent_args = [] - model_uuids = table.get_model_uuid(rowid_list) + model_uuids = self.get_model_uuid(rowid_list) for rowid, uuid, p_id_list in zip(rowid_list, model_uuids, parent_rowids): - input_info = table.get_model_inputs(uuid) + input_info = self.get_model_inputs(uuid) fixed_args = [] for p_name, p_id, flag in zip( - table.parent_id_colnames, p_id_list, parent_ismulti + self.parent_id_colnames, p_id_list, parent_ismulti ): if flag: new_p_id = input_info[p_name + '_model_input'] @@ -231,7 +255,7 @@ def get_parent_rowargs(table, rowid_list): parent_args = parent_rowids return parent_args - def get_row_parent_rowid_map(table, rowid_list): + def get_row_parent_rowid_map(self, rowid_list): """ >>> from wbia.dtool.depcache_table import * # NOQA @@ -239,13 +263,13 @@ def get_row_parent_rowid_map(table, rowid_list): key = parent_rowid_dict.keys()[0] val = parent_rowid_dict.values()[0] """ - parent_rowids = table.get_parent_rowids(rowid_list) + parent_rowids = self.get_parent_rowids(rowid_list) parent_rowid_dict = dict( - zip(table.parent_id_tablenames, ut.list_transpose(parent_rowids)) + zip(self.parent_id_tablenames, ut.list_transpose(parent_rowids)) ) return parent_rowid_dict - def get_config_history(table, rowid_list, assume_unique=True): + def get_config_history(self, rowid_list, assume_unique=True): """ Returns the list of config objects for all properties in the dependency history of this object. Multi-edges are handled. Set assume_unique to @@ -260,23 +284,23 @@ def get_config_history(table, rowid_list, assume_unique=True): """ if assume_unique: rowid_list = rowid_list[0:1] - tbl_cfgids = table.get_row_cfgid(rowid_list) + tbl_cfgids = self.get_row_cfgid(rowid_list) cfgid2_rowids = ut.group_items(rowid_list, tbl_cfgids) unique_cfgids = cfgid2_rowids.keys() unique_cfgids = ut.filter_Nones(unique_cfgids) if len(unique_cfgids) == 0: return None - unique_configs = table.get_config_from_rowid(unique_cfgids) + unique_configs = self.get_config_from_rowid(unique_cfgids) - # parent_rowids = table.get_parent_rowids(rowid_list) - parent_rowargs = table.get_parent_rowargs(rowid_list) + # parent_rowids = self.get_parent_rowids(rowid_list) + parent_rowargs = self.get_parent_rowargs(rowid_list) ret_list = [unique_configs] - depc = table.depc + depc = self.depc rowargsT = ut.listT(parent_rowargs) - parent_ismulti = table.get_parent_col_attr('ismulti') + parent_ismulti = self.get_parent_col_attr('ismulti') for tblname, ismulti, ids in zip( - table.parent_id_tablenames, parent_ismulti, rowargsT + self.parent_id_tablenames, parent_ismulti, rowargsT ): if tblname == depc.root: continue @@ -288,16 +312,16 @@ def get_config_history(table, rowid_list, assume_unique=True): ret_list.extend(ancestor_configs) return ret_list - def __remove_old_configs(table): + def __remove_old_configs(self): """ table = ibs.depc['pairwise_match'] """ # developing - # c = table.db.get_table_as_pandas('config') - # t = table.db.get_table_as_pandas(table.tablename) + # c = self.db.get_table_as_pandas('config') + # t = self.db.get_table_as_pandas(self.tablename) - # config_rowids = table.db.get_all_rowids(CONFIG_TABLE) - # cfgdict_list = table.db.get( + # config_rowids = self.db.get_all_rowids(CONFIG_TABLE) + # cfgdict_list = self.db.get( # CONFIG_TABLE, colnames=(CONFIG_DICT,), id_iter=config_rowids, # id_colname=CONFIG_ROWID) # bad_rowids = [] @@ -310,10 +334,10 @@ def __remove_old_configs(table): SELECT rowid, {} from {} """ ).format(CONFIG_DICT, CONFIG_TABLE) - table.db.cur.execute(command) + self.db.cur.execute(command) bad_rowids = [] - for rowid, cfgdict in table.db.cur.fetchall(): + for rowid, cfgdict in self.db.cur.fetchall(): # MAKE GENERAL CONDITION if cfgdict['version'] < 7: bad_rowids.append(rowid) @@ -324,17 +348,17 @@ def __remove_old_configs(table): SELECT rowid from {tablename} WHERE config_rowid IN {bad_rowids} """ - ).format(tablename=table.tablename, bad_rowids=in_str) + ).format(tablename=self.tablename, bad_rowids=in_str) # logger.info(command) - table.db.cur.execute(command) - rowids = ut.flatten(table.db.cur.fetchall()) - table.delete_rows(rowids, dry=True, verbose=True, delete_extern=True) + self.db.cur.execute(command) + rowids = ut.flatten(self.db.cur.fetchall()) + self.delete_rows(rowids, dry=True, verbose=True, delete_extern=True) - def get_ancestor_rowids(table, rowid_list, target_table): - parent_rowids = table.get_parent_rowids(rowid_list) - depc = table.depc + def get_ancestor_rowids(self, rowid_list, target_table): + parent_rowids = self.get_parent_rowids(rowid_list) + depc = self.depc for tblname, ids in zip( - table.parent_id_tablenames, ut.list_transpose(parent_rowids) + self.parent_id_tablenames, ut.list_transpose(parent_rowids) ): if tblname == target_table: return ids @@ -344,14 +368,14 @@ def get_ancestor_rowids(table, rowid_list, target_table): return ancestor_ids return None # Base case - def get_row_cfgid(table, rowid_list): + def get_row_cfgid(self, rowid_list): """ >>> from wbia.dtool.depcache_table import * # NOQA """ - config_rowids = table.get_internal_columns(rowid_list, (CONFIG_ROWID,)) + config_rowids = self.get_internal_columns(rowid_list, (CONFIG_ROWID,)) return config_rowids - def get_row_configs(table, rowid_list): + def get_row_configs(self, rowid_list): """ Example: >>> # ENABLE_DOCTEST @@ -363,21 +387,21 @@ def get_row_configs(table, rowid_list): >>> rowid_list = depc.get_rowids('chip', [1, 2], config={}) >>> configs = table.get_row_configs(rowid_list) """ - config_rowids = table.get_row_cfgid(rowid_list) + config_rowids = self.get_row_cfgid(rowid_list) # Only look up the configs that are needed unique_config_rowids, groupxs = ut.group_indices(config_rowids) - unique_configs = table.get_config_from_rowid(unique_config_rowids) + unique_configs = self.get_config_from_rowid(unique_config_rowids) configs = ut.ungroup_unique(unique_configs, groupxs, maxval=len(rowid_list) - 1) return configs - def get_row_cfghashid(table, rowid_list): - config_rowids = table.get_row_cfgid(rowid_list) - config_hashids = table.get_config_hashid(config_rowids) + def get_row_cfghashid(self, rowid_list): + config_rowids = self.get_row_cfgid(rowid_list) + config_hashids = self.get_config_hashid(config_rowids) return config_hashids - def get_row_cfgstr(table, rowid_list): - config_rowids = table.get_row_cfgid(rowid_list) - cfgstr_list = table.db.get( + def get_row_cfgstr(self, rowid_list): + config_rowids = self.get_row_cfgid(rowid_list) + cfgstr_list = self.db.get( CONFIG_TABLE, colnames=(CONFIG_STRID,), id_iter=config_rowids, @@ -385,15 +409,15 @@ def get_row_cfgstr(table, rowid_list): ) return cfgstr_list - def get_config_rowid(table, config=None, _debug=None): + def get_config_rowid(self, config=None, _debug=None): if isinstance(config, int): config_rowid = config else: - config_rowid = table.add_config(config, _debug) + config_rowid = self.add_config(config) return config_rowid - def get_config_hashid(table, config_rowid_list): - hashid_list = table.db.get( + def get_config_hashid(self, config_rowid_list): + hashid_list = self.db.get( CONFIG_TABLE, colnames=(CONFIG_HASHID,), id_iter=config_rowid_list, @@ -401,8 +425,8 @@ def get_config_hashid(table, config_rowid_list): ) return hashid_list - def get_config_rowid_from_hashid(table, config_hashid_list): - config_rowid_list = table.db.get( + def get_config_rowid_from_hashid(self, config_hashid_list): + config_rowid_list = self.db.get( CONFIG_TABLE, colnames=(CONFIG_ROWID,), id_iter=config_hashid_list, @@ -410,43 +434,39 @@ def get_config_rowid_from_hashid(table, config_hashid_list): ) return config_rowid_list - def get_config_from_rowid(table, config_rowids): - cfgdict_list = table.db.get( + def get_config_from_rowid(self, config_rowids): + cfgdict_list = self.db.get( CONFIG_TABLE, colnames=(CONFIG_DICT,), id_iter=config_rowids, id_colname=CONFIG_ROWID, ) return [ - None if dict_ is None else table.configclass(**dict_) - for dict_ in cfgdict_list + None if dict_ is None else self.configclass(**dict_) for dict_ in cfgdict_list ] # @profile - def add_config(table, config, _debug=None): + def add_config(self, config, _debug=None): try: # assume config is AlgoRequest or TableConfig config_strid = config.get_cfgstr() except AttributeError: config_strid = ut.to_json(config) config_hashid = ut.hashstr27(config_strid) - if table.depc._debug or _debug: - logger.info('config_strid = %r' % (config_strid,)) - logger.info('config_hashid = %r' % (config_hashid,)) - get_rowid_from_superkey = table.get_config_rowid_from_hashid + logger.debug('config_strid = %r' % (config_strid,)) + logger.debug('config_hashid = %r' % (config_hashid,)) + get_rowid_from_superkey = self.get_config_rowid_from_hashid colnames = (CONFIG_HASHID, CONFIG_TABLENAME, CONFIG_STRID, CONFIG_DICT) if hasattr(config, 'config'): # Hack for requests config = config.config cfgdict = config.__getstate__() - param_list = [(config_hashid, table.tablename, config_strid, cfgdict)] - config_rowid_list = table.db.add_cleanly( + param_list = [(config_hashid, self.tablename, config_strid, cfgdict)] + config_rowid_list = self.db.add_cleanly( CONFIG_TABLE, colnames, param_list, get_rowid_from_superkey ) config_rowid = config_rowid_list[0] - if table.depc._debug: - logger.info('config_rowid_list = %r' % (config_rowid_list,)) - # logger.info('config_rowid = %r' % (config_rowid,)) + logger.debug('config_rowid_list = %r' % (config_rowid_list,)) return config_rowid @@ -456,11 +476,11 @@ class _TableDebugHelper(object): Contains printing and debug things """ - def print_sql_info(table): - add_op = table.db._make_add_table_sqlstr(sep='\n ', **table._get_addtable_kw()) + def print_sql_info(self): + add_op = self.db._make_add_table_sqlstr(sep='\n ', **self._get_addtable_kw()) ut.cprint(add_op, 'sql') - def print_internal_info(table, all_attrs=False): + def print_internal_info(self, all_attrs=False): """ CommandLine: python -m dtool.depcache_table --exec-print_internal_info @@ -475,77 +495,73 @@ def print_internal_info(table, all_attrs=False): >>> table.print_internal_info() """ logger.info('----') - logger.info(table) + logger.info(self) # Print the other inferred attrs logger.info( - 'table.parent_col_attrs = %s' % (ut.repr3(table.parent_col_attrs, nl=2),) + 'self.parent_col_attrs = %s' % (ut.repr3(self.parent_col_attrs, nl=2),) ) - logger.info('table.data_col_attrs = %s' % (ut.repr3(table.data_col_attrs, nl=2),)) + logger.info('self.data_col_attrs = %s' % (ut.repr3(self.data_col_attrs, nl=2),)) # Print the inferred allcol attrs ut.cprint( - 'table.internal_col_attrs = %s' - % (ut.repr3(table.internal_col_attrs, nl=1, sorted_=False)), + 'self.internal_col_attrs = %s' + % (ut.repr3(self.internal_col_attrs, nl=1, sorted_=False)), 'python', ) - add_table_kw = table._get_addtable_kw() - logger.info('table.add_table_kw = %s' % (ut.repr2(add_table_kw, nl=2),)) - table.print_sql_info() + add_table_kw = self._get_addtable_kw() + logger.info('self.add_table_kw = %s' % (ut.repr2(add_table_kw, nl=2),)) + self.print_sql_info() if all_attrs: # Print all attributes - for a in ut.get_instance_attrnames( - table, with_properties=True, default=False - ): - logger.info(' table.%s = %r' % (a, getattr(table, a))) + for a in ut.get_instance_attrnames(self, with_properties=True, default=False): + logger.info(' self.%s = %r' % (a, getattr(self, a))) - def print_table(table): - table.db.print_table_csv(table.tablename) - # if table.ismulti: - # table.print_model_manifests() + def print_table(self): + self.db.print_table_csv(self.tablename) + # if self.ismulti: + # self.print_model_manifests() - def print_info(table, with_colattrs=True, with_graphattrs=True): + def print_info(self, with_colattrs=True, with_graphattrs=True): """ debug function """ logger.info('TABLE ATTRIBUTES') - logger.info('table.tablename = %r' % (table.tablename,)) - logger.info('table.isinteractive = %r' % (table.isinteractive,)) - logger.info('table.default_onthefly = %r' % (table.default_onthefly,)) - logger.info('table.rm_extern_on_delete = %r' % (table.rm_extern_on_delete,)) - logger.info('table.chunksize = %r' % (table.chunksize,)) - logger.info('table.fname = %r' % (table.fname,)) - logger.info('table.docstr = %r' % (table.docstr,)) - logger.info('table.data_colnames = %r' % (table.data_colnames,)) - logger.info('table.data_coltypes = %r' % (table.data_coltypes,)) + logger.info('self.tablename = %r' % (self.tablename,)) + logger.info('self.rm_extern_on_delete = %r' % (self.rm_extern_on_delete,)) + logger.info('self.chunksize = %r' % (self.chunksize,)) + logger.info('self.fname = %r' % (self.fname,)) + logger.info('self.docstr = %r' % (self.docstr,)) + logger.info('self.data_colnames = %r' % (self.data_colnames,)) + logger.info('self.data_coltypes = %r' % (self.data_coltypes,)) if with_graphattrs: logger.info('TABLE GRAPH ATTRIBUTES') - logger.info('table.children = %r' % (table.children,)) - logger.info('table.parent = %r' % (table.parent,)) - logger.info('table.configclass = %r' % (table.configclass,)) - logger.info('table.requestclass = %r' % (table.requestclass,)) + logger.info('self.children = %r' % (self.children,)) + logger.info('self.parent = %r' % (self.parent,)) + logger.info('self.configclass = %r' % (self.configclass,)) + logger.info('self.requestclass = %r' % (self.requestclass,)) if with_colattrs: nl = 1 logger.info('TABEL COLUMN ATTRIBUTES') logger.info( - 'table.data_col_attrs = %s' % (ut.repr3(table.data_col_attrs, nl=nl),) + 'self.data_col_attrs = %s' % (ut.repr3(self.data_col_attrs, nl=nl),) ) logger.info( - 'table.parent_col_attrs = %s' % (ut.repr3(table.parent_col_attrs, nl=nl),) + 'self.parent_col_attrs = %s' % (ut.repr3(self.parent_col_attrs, nl=nl),) ) logger.info( - 'table.internal_data_col_attrs = %s' - % (ut.repr3(table.internal_data_col_attrs, nl=nl),) + 'self.internal_data_col_attrs = %s' + % (ut.repr3(self.internal_data_col_attrs, nl=nl),) ) logger.info( - 'table.internal_parent_col_attrs = %s' - % (ut.repr3(table.internal_parent_col_attrs, nl=nl),) + 'self.internal_parent_col_attrs = %s' + % (ut.repr3(self.internal_parent_col_attrs, nl=nl),) ) logger.info( - 'table.internal_col_attrs = %s' - % (ut.repr3(table.internal_col_attrs, nl=nl),) + 'self.internal_col_attrs = %s' + % (ut.repr3(self.internal_col_attrs, nl=nl),) ) - def print_schemadef(table): - logger.info('\n'.join(table.db.get_table_autogen_str(table.tablename))) + def print_schemadef(self): + logger.info('\n'.join(self.db.get_table_autogen_str(self.tablename))) - def print_configs(table): + def print_configs(self): """ CommandLine: python -m dtool.depcache_table --exec-print_configs @@ -565,30 +581,30 @@ def print_configs(table): >>> rowids = depc.get_rowids('spam', [1, 2]) >>> table.print_configs() """ - text = table.db.get_table_csv(CONFIG_TABLE) + text = self.db.get_table_csv(CONFIG_TABLE) logger.info(text) - def print_csv(table, truncate=True): - logger.info(table.db.get_table_csv(table.tablename, truncate=truncate)) + def print_csv(self, truncate=True): + logger.info(self.db.get_table_csv(self.tablename, truncate=truncate)) - def print_model_manifests(table): + def print_model_manifests(self): logger.info('manifests') - rowids = table._get_all_rowids() - uuids = table.get_model_uuid(rowids) + rowids = self._get_all_rowids() + uuids = self.get_model_uuid(rowids) for rowid, uuid in zip(rowids, uuids): logger.info('rowid = %r' % (rowid,)) - logger.info(ut.repr3(table.get_model_inputs(uuid), nl=1)) + logger.info(ut.repr3(self.get_model_inputs(uuid), nl=1)) - def _assert_self(table): - assert len(table.data_colnames) == len( - table.data_coltypes + def _assert_self(self): + assert len(self.data_colnames) == len( + self.data_coltypes ), 'specify same number of colnames and coltypes' - if table.preproc_func is not None: + if self.preproc_func is not None: # Check that preproc_func has a valid signature # ie (depc, parent_ids, config) - argspec = ut.get_func_argspec(table.preproc_func) + argspec = ut.get_func_argspec(self.preproc_func) args = argspec.args - if argspec.varargs and argspec.keywords: + if argspec.varargs and (hasattr(argspec, 'keywords') and argspec.keywords): assert len(args) == 1, 'varargs and kwargs must have one arg for depcache' else: if len(args) < 3: @@ -597,22 +613,20 @@ def _assert_self(table): 'preproc_func=%r for table=%s must have a ' 'depcache arg, at least one parent rowid arg, ' 'and a config arg' - ) % (table.preproc_func, table.tablename) + ) % (self.preproc_func, self.tablename) raise AssertionError(msg) rowid_args = args[1:-1] - if len(rowid_args) != len(table.parents()): - logger.info('table.preproc_func = %r' % (table.preproc_func,)) + if len(rowid_args) != len(self.parents()): + logger.info('self.preproc_func = %r' % (self.preproc_func,)) logger.info('args = %r' % (args,)) logger.info('rowid_args = %r' % (rowid_args,)) msg = ( 'preproc function for table=%s must have as many ' 'rowids %d args as parent %d' - ) % (table.tablename, len(rowid_args), len(table.parents())) + ) % (self.tablename, len(rowid_args), len(self.parents())) raise AssertionError(msg) extern_class_colattrs = [ - colattr - for colattr in table.data_col_attrs - if colattr.get('is_external_class') + colattr for colattr in self.data_col_attrs if colattr.get('is_external_class') ] for colattr in extern_class_colattrs: cls = colattr['coltype'] @@ -637,7 +651,7 @@ class _TableInternalSetup(ub.NiceRepr): """ helper that sets up column information """ @profile - def _infer_datacol(table): + def _infer_datacol(self): """ Constructs the columns needed to represent relationship to data @@ -665,7 +679,7 @@ def _infer_datacol(table): data_col_attrs = [] # Parse column datatypes - _iter = enumerate(zip(table.data_colnames, table.data_coltypes)) + _iter = enumerate(zip(self.data_colnames, self.data_coltypes)) for data_colx, (colname, coltype) in _iter: colattr = ut.odict() # Check column input subtypes @@ -732,7 +746,7 @@ def _infer_datacol(table): assert hasattr(coltype, '__getstate__') and hasattr( coltype, '__setstate__' ), ('External classes must have __getstate__ and ' '__setstate__ methods') - read_func, write_func = make_extern_io_funcs(table, coltype) + read_func, write_func = make_extern_io_funcs(self, coltype) sqltype = TYPE_TO_SQLTYPE[str] intern_colname = colname + EXTERN_SUFFIX # raise AssertionError('external class columns') @@ -747,7 +761,7 @@ def _infer_datacol(table): return data_col_attrs @profile - def _infer_parentcol(table): + def _infer_parentcol(self): """ construct columns to represent relationship to parent @@ -779,7 +793,7 @@ def _infer_parentcol(table): >>> depc.d.get_indexer_data([ >>> uuid.UUID('a01eda32-e4e0-b139-3274-e91d1b3e9ecf')]) """ - parent_tablenames = table.parent_tablenames + parent_tablenames = self.parent_tablenames parent_col_attrs = [] # Handle dependencies when a parent are pairwise between tables @@ -866,7 +880,7 @@ def _infer_parentcol(table): return parent_col_attrs @profile - def _infer_allcol(table): + def _infer_allcol(self): r""" Combine information from parentcol and datacol Build column definitions that will directly define SQL columns @@ -876,7 +890,7 @@ def _infer_allcol(table): # Append primary column colattr = ut.odict( [ - ('intern_colname', table.rowid_colname), + ('intern_colname', self.rowid_colname), ('sqltype', 'INTEGER PRIMARY KEY'), ('isprimary', True), ] @@ -886,7 +900,7 @@ def _infer_allcol(table): # Append parent columns ismulti = False - for parent_colattr in table.parent_col_attrs: + for parent_colattr in self.parent_col_attrs: colattr = ut.odict() colattr['intern_colname'] = parent_colattr['intern_colname'] colattr['parent_table'] = parent_colattr['parent_table'] @@ -916,26 +930,26 @@ def _infer_allcol(table): internal_col_attrs.append(colattr) # Append quick access column - # return any(table.get_parent_col_attr('ismulti')) - # if table.ismulti: + # return any(self.get_parent_col_attr('ismulti')) + # if self.ismulti: if ismulti: # Append model uuid column colattr = ut.odict() - colattr['intern_colname'] = table.model_uuid_colname + colattr['intern_colname'] = self.model_uuid_colname colattr['sqltype'] = 'UUID NOT NULL' colattr['intern_colx'] = len(internal_col_attrs) internal_col_attrs.append(colattr) # Append model uuid column colattr = ut.odict() - colattr['intern_colname'] = table.is_augmented_colname + colattr['intern_colname'] = self.is_augmented_colname colattr['sqltype'] = 'INTEGER DEFAULT 0' colattr['intern_colx'] = len(internal_col_attrs) internal_col_attrs.append(colattr) if False: # TODO: eventually enable - if table.taggable: + if self.taggable: colattr = ut.odict() colattr['intern_colname'] = 'model_tag' colattr['sqltype'] = 'TEXT' @@ -947,7 +961,7 @@ def _infer_allcol(table): pass # Append data columns - for data_colattr in table.data_col_attrs: + for data_colattr in self.data_col_attrs: colname = data_colattr['colname'] if data_colattr.get('isnested', False): for nestcol in data_colattr['nestattrs']: @@ -974,7 +988,7 @@ def _infer_allcol(table): internal_col_attrs.append(colattr) # Append extra columns - for parent_colattr in table.parent_col_attrs: + for parent_colattr in self.parent_col_attrs: for extra_colattr in parent_colattr.get('extra_cols', []): colattr = ut.odict() colattr['intern_colname'] = extra_colattr['intern_colname'] @@ -989,176 +1003,175 @@ def _infer_allcol(table): class _TableGeneralHelper(ub.NiceRepr): """ helper """ - def __nice__(table): - num_parents = len(table.parent_tablenames) - num_cols = len(table.data_colnames) + def __nice__(self): + num_parents = len(self.parent_tablenames) + num_cols = len(self.data_colnames) return '(%s) nP=%d%s nC=%d' % ( - table.tablename, + self.tablename, num_parents, - '*' if False and table.ismulti else '', + '*' if False and self.ismulti else '', num_cols, ) # @property - # def _table_colnames(table): + # def _table_colnames(self): # return @property - def extern_dpath(table): - cache_dpath = table.depc.cache_dpath - extern_dname = 'extern_' + table.tablename + def extern_dpath(self): + cache_dpath = self.depc.cache_dpath + extern_dname = 'extern_' + self.tablename extern_dpath = join(cache_dpath, extern_dname) return extern_dpath @property - def dpath(table): + def dpath(self): # assert table.ismulti, 'only valid for models' - dname = table.tablename + '_storage' - dpath = join(table.depc.cache_dpath, dname) + dname = self.tablename + '_storage' + dpath = join(self.depc.cache_dpath, dname) # ut.ensuredir(dpath) return dpath - # def dpath(table): + # def dpath(self): # from os.path import dirname - # dpath = dirname(table.db.fpath) + # dpath = dirname(self.db.fpath) # return dpath @property @ut.memoize - def ismulti(table): + def ismulti(self): # TODO: or has multi parent - return any(table.get_parent_col_attr('ismulti')) + return any(self.get_parent_col_attr('ismulti')) @property - def configclass(table): - return table.depc.configclass_dict[table.tablename] + def configclass(self): + return self.depc.configclass_dict[self.tablename] @property - def requestclass(table): - return table.depc.requestclass_dict.get(table.tablename, None) + def requestclass(self): + return self.depc.requestclass_dict.get(self.tablename, None) - def new_request(table, qaids, daids, cfgdict=None): - request = table.depc.new_request(table.tablename, qaids, daids, cfgdict=cfgdict) + def new_request(self, qaids, daids, cfgdict=None): + request = self.depc.new_request(self.tablename, qaids, daids, cfgdict=cfgdict) return request # --- Standard Properties @property - def internal_data_col_attrs(table): - flags = table.get_intern_col_attr('isdata') - return ut.compress(table.internal_col_attrs, flags) + def internal_data_col_attrs(self): + flags = self.get_intern_col_attr('isdata') + return ut.compress(self.internal_col_attrs, flags) @property - def internal_parent_col_attrs(table): - flags = table.get_intern_col_attr('isparent') - return ut.compress(table.internal_col_attrs, flags) + def internal_parent_col_attrs(self): + flags = self.get_intern_col_attr('isparent') + return ut.compress(self.internal_col_attrs, flags) # --- / Standard Properties @ut.memoize - def get_parent_col_attr(table, key): - return ut.dict_take_column(table.parent_col_attrs, key) + def get_parent_col_attr(self, key): + return ut.dict_take_column(self.parent_col_attrs, key) @ut.memoize - def get_intern_data_col_attr(table, key): - return ut.dict_take_column(table.internal_data_col_attrs, key) + def get_intern_data_col_attr(self, key): + return ut.dict_take_column(self.internal_data_col_attrs, key) @ut.memoize - def get_intern_parent_col_attr(table, key): - return ut.dict_take_column(table.internal_parent_col_attrs, key) + def get_intern_parent_col_attr(self, key): + return ut.dict_take_column(self.internal_parent_col_attrs, key) @ut.memoize - def get_intern_col_attr(table, key): - return ut.dict_take_column(table.internal_col_attrs, key) + def get_intern_col_attr(self, key): + return ut.dict_take_column(self.internal_col_attrs, key) @ut.memoize - def get_data_col_attr(table, key): - return ut.dict_take_column(table.data_col_attrs, key) + def get_data_col_attr(self, key): + return ut.dict_take_column(self.data_col_attrs, key) @property @ut.memoize - def parent_id_tablenames(table): + def parent_id_tablenames(self): tablenames = tuple( - [parent_colattr['parent_table'] for parent_colattr in table.parent_col_attrs] + [parent_colattr['parent_table'] for parent_colattr in self.parent_col_attrs] ) return tablenames @property @ut.memoize - def parent_id_prefix(table): + def parent_id_prefix(self): prefixes = tuple( - [parent_colattr['prefix'] for parent_colattr in table.parent_col_attrs] + [parent_colattr['prefix'] for parent_colattr in self.parent_col_attrs] ) return prefixes @property - def extern_columns(table): - colnames = table.get_data_col_attr('colname') - flags = table.get_data_col_attr('is_extern') + def extern_columns(self): + colnames = self.get_data_col_attr('colname') + flags = self.get_data_col_attr('is_extern') return ut.compress(colnames, flags) @property - def rowid_colname(table): + def rowid_colname(self): """ rowid of this table used by other dependant tables """ - return table.tablename + '_rowid' + return self.tablename + '_rowid' @property - def superkey_colnames(table): - return table.parent_id_colnames + (CONFIG_ROWID,) + def superkey_colnames(self): + return self.parent_id_colnames + (CONFIG_ROWID,) @property - def model_uuid_colname(table): + def model_uuid_colname(self): return 'model_uuid' @property - def is_augmented_colname(table): + def is_augmented_colname(self): return 'augment_bit' @property - def parent_id_colnames(table): - return tuple([colattr['intern_colname'] for colattr in table.parent_col_attrs]) + def parent_id_colnames(self): + return tuple([colattr['intern_colname'] for colattr in self.parent_col_attrs]) - def get_rowids_from_root(table, root_rowids, config=None): - return table.depc.get_rowids(table.tablename, root_rowids, config=config) + def get_rowids_from_root(self, root_rowids, config=None): + return self.depc.get_rowids(self.tablename, root_rowids, config=config) @property @ut.memoize - def parent(table): + def parent(self): return ut.odict( [ (parent_colattr['parent_table'], parent_colattr) - for parent_colattr in table.parent_col_attrs + for parent_colattr in self.parent_col_attrs ] ) # return tuple([parent_colattr['parent_table'] - # for parent_colattr in table.parent_col_attrs]) + # for parent_colattr in self.parent_col_attrs]) @ut.memoize - def parents(table, data=None): + def parents(self, data=None): if data: return [ (parent_colattr['parent_table'], parent_colattr) - for parent_colattr in table.parent_col_attrs + for parent_colattr in self.parent_col_attrs ] else: return [ - parent_colattr['parent_table'] - for parent_colattr in table.parent_col_attrs + parent_colattr['parent_table'] for parent_colattr in self.parent_col_attrs ] @property - def children(table): - graph = table.depc.explicit_graph - children_tablenames = list(nx.neighbors(graph, table.tablename)) + def children(self): + graph = self.depc.explicit_graph + children_tablenames = list(nx.neighbors(graph, self.tablename)) return children_tablenames @property - def ancestors(table): - graph = table.depc.explicit_graph - children_tablenames = list(nx.ancestors(graph, table.tablename)) + def ancestors(self): + graph = self.depc.explicit_graph + children_tablenames = list(nx.ancestors(graph, self.tablename)) return children_tablenames - def show_dep_subgraph(table, inter=None): + def show_dep_subgraph(self, inter=None): from wbia.plottool.interactions import ExpandableInteraction autostart = inter is None @@ -1166,8 +1179,8 @@ def show_dep_subgraph(table, inter=None): inter = ExpandableInteraction(nCols=2) import wbia.plottool as pt - graph = table.depc.explicit_graph - nodes = ut.nx_all_nodes_between(graph, None, table.tablename) + graph = self.depc.explicit_graph + nodes = ut.nx_all_nodes_between(graph, None, self.tablename) G = graph.subgraph(nodes) plot_kw = {'fontname': 'Ubuntu'} @@ -1175,14 +1188,14 @@ def show_dep_subgraph(table, inter=None): ut.partial( pt.show_nx, G, - title='Dependency Subgraph (%s)' % (table.tablename), + title='Dependency Subgraph (%s)' % (self.tablename), **plot_kw, ) ) if autostart: inter.start() - def show_input_graph(table, inter=None): + def show_input_graph(self, inter=None): """ CommandLine: python -m dtool.depcache_table show_input_graph --show @@ -1204,8 +1217,8 @@ def show_input_graph(table, inter=None): autostart = inter is None if inter is None: inter = ExpandableInteraction(nCols=2) - table.show_dep_subgraph(inter) - inputs = table.rootmost_inputs + self.show_dep_subgraph(inter) + inputs = self.rootmost_inputs inter = inputs.show_exi_graph(inter) if autostart: inter.start() @@ -1213,7 +1226,7 @@ def show_input_graph(table, inter=None): @property @ut.memoize - def expanded_input_graph(table): + def expanded_input_graph(self): """ CommandLine: python -m dtool.depcache_table --exec-expanded_input_graph --show --table=neighbs @@ -1241,13 +1254,13 @@ def expanded_input_graph(table): """ from wbia.dtool import input_helpers - graph = table.depc.explicit_graph.copy() - target = table.tablename + graph = self.depc.explicit_graph.copy() + target = self.tablename exi_graph = input_helpers.make_expanded_input_graph(graph, target) return exi_graph @property - def rootmost_inputs(table): + def rootmost_inputs(self): """ CommandLine: python -m dtool.depcache_table rootmost_inputs --show @@ -1269,24 +1282,24 @@ def rootmost_inputs(table): """ from wbia.dtool import input_helpers - exi_graph = table.expanded_input_graph - rootmost_inputs = input_helpers.get_rootmost_inputs(exi_graph, table) + exi_graph = self.expanded_input_graph + rootmost_inputs = input_helpers.get_rootmost_inputs(exi_graph, self) return rootmost_inputs @ut.memoize - def requestable_col_attrs(table): + def requestable_col_attrs(self): """ Maps names of requestable columns to indicies of internal columns """ requestable_col_attrs = {} - for colattr in table.internal_data_col_attrs: + for colattr in self.internal_data_col_attrs: rattr = {} colname = colattr['intern_colname'] rattr['intern_colx'] = colattr['intern_colx'] rattr['intern_colname'] = colattr['intern_colname'] requestable_col_attrs[colname] = rattr - for colattr in table.data_col_attrs: + for colattr in self.data_col_attrs: rattr = {} if colattr.get('isnested'): nest_internal_names = ut.take_column(colattr['nestattrs'], 'flat_colname') @@ -1309,11 +1322,11 @@ def requestable_col_attrs(table): return requestable_col_attrs @ut.memoize - def computable_colnames(table): + def computable_colnames(self): # These are the colnames that we expect to be computed - intern_colnames = ut.take_column(table.internal_col_attrs, 'intern_colname') + intern_colnames = ut.take_column(self.internal_col_attrs, 'intern_colname') insertable_flags = [ - not colattr.get('isprimary') for colattr in table.internal_col_attrs + not colattr.get('isprimary') for colattr in self.internal_col_attrs ] colnames = tuple(ut.compress(intern_colnames, insertable_flags)) return colnames @@ -1325,7 +1338,7 @@ class _TableComputeHelper(object): # @profile def prepare_storage( - table, dirty_parent_ids, proptup_gen, dirty_preproc_args, config_rowid, config + self, dirty_parent_ids, proptup_gen, dirty_preproc_args, config_rowid, config ): """ Converts output from ``preproc_func`` to data that can be stored in SQL @@ -1342,11 +1355,11 @@ def prepare_storage( >>> tablename = 'labeler' >>> tablename = 'indexer' >>> config = {tablename + '_param': None, 'foo': 'bar'} - >>> data = depc.get('labeler', [1, 2, 3], 'data', _debug=0) - >>> data = depc.get('labeler', [1, 2, 3], 'data', config=config, _debug=0) - >>> data = depc.get('indexer', [[1, 2, 3]], 'data', _debug=0) - >>> data = depc.get('indexer', [[1, 2, 3]], 'data', config=config, _debug=0) - >>> rowids = depc.get_rowids('indexer', [[1, 2, 3]], config=config, _debug=0) + >>> data = depc.get('labeler', [1, 2, 3], 'data') + >>> data = depc.get('labeler', [1, 2, 3], 'data', config=config) + >>> data = depc.get('indexer', [[1, 2, 3]], 'data') + >>> data = depc.get('indexer', [[1, 2, 3]], 'data', config=config) + >>> rowids = depc.get_rowids('indexer', [[1, 2, 3]], config=config) >>> table = depc[tablename] >>> model_uuid_list = table.get_internal_columns(rowids, ('model_uuid',)) >>> model_uuid = model_uuid_list[0] @@ -1358,19 +1371,19 @@ def prepare_storage( >>> table.print_model_manifests() >>> #ut.vd(depc.cache_dpath) """ - if table.default_to_unpack: + if self.default_to_unpack: # Hack for tables explicilty specified with a single column proptup_gen = (None if data is None else (data,) for data in proptup_gen) # Flatten nested columns - if any(table.get_data_col_attr('isnested')): - proptup_gen = table._prepare_storage_nested(proptup_gen) + if any(self.get_data_col_attr('isnested')): + proptup_gen = self._prepare_storage_nested(proptup_gen) # Write external columns - if any(table.get_data_col_attr('write_func')): - proptup_gen = table._prepare_storage_extern( + if any(self.get_data_col_attr('write_func')): + proptup_gen = self._prepare_storage_extern( dirty_parent_ids, config_rowid, config, proptup_gen ) - if table.ismulti: - manifest_dpath = table.dpath + if self.ismulti: + manifest_dpath = self.dpath ut.ensuredir(manifest_dpath) # Concatenate data with internal rowids / config-id for ids_, data_cols, args_ in zip( @@ -1380,19 +1393,19 @@ def prepare_storage( if data_cols is None: yield None else: - multi_parent_flags = table.get_parent_col_attr('ismulti') - parent_colnames = table.get_parent_col_attr('intern_colname') + multi_parent_flags = self.get_parent_col_attr('ismulti') + parent_colnames = self.get_parent_col_attr('intern_colname') multi_id_names = ut.compress(parent_colnames, multi_parent_flags) multi_ids = ut.compress(ids_, multi_parent_flags) multi_args = ut.compress(args_, multi_parent_flags) - if table.ismulti: + if self.ismulti: multi_setsizes = [] manifest_data = {} for multi_id, arg_, name in zip( multi_ids, multi_args, multi_id_names ): - assert table.ismulti, 'only valid for models' + assert self.ismulti, 'only valid for models' # TODO: need to get back to root ids manifest_data.update( **{ @@ -1410,7 +1423,7 @@ def prepare_storage( manifest_data['model_uuid'] = model_uuid manifest_data['augmented'] = False - manifest_fpath = table.get_model_manifest_fpath(model_uuid) + manifest_fpath = self.get_model_manifest_fpath(model_uuid) ut.save_json(manifest_fpath, manifest_data, pretty=1) # TODO: hash all input UUIDs and the full config together @@ -1425,7 +1438,7 @@ def prepare_storage( # fname in zip(multi_args, # multi_fpaths)])) row_tup = ( - ids_ + tuple(ids_) + (config_rowid,) + quick_access_tup + data_cols @@ -1439,37 +1452,37 @@ def prepare_storage( ) raise - def get_model_manifest_fname(table, model_uuid): + def get_model_manifest_fname(self, model_uuid): manifest_fname = 'input_manifest_%s.json' % (model_uuid,) return manifest_fname - def get_model_manifest_fpath(table, model_uuid): - manifest_fname = table.get_model_manifest_fname(model_uuid) - manifest_fpath = join(table.dpath, manifest_fname) + def get_model_manifest_fpath(self, model_uuid): + manifest_fname = self.get_model_manifest_fname(model_uuid) + manifest_fpath = join(self.dpath, manifest_fname) return manifest_fpath - def get_model_inputs(table, model_uuid): + def get_model_inputs(self, model_uuid): """ Ignore: >>> table.get_model_uuid([2]) [UUID('5b66772c-e654-dd9a-c9de-0ccc1bb6861c')] """ - assert table.ismulti, 'must be a model' - manifest_fpath = table.get_model_manifest_fpath(model_uuid) + assert self.ismulti, 'must be a model' + manifest_fpath = self.get_model_manifest_fpath(model_uuid) manifest_data = ut.load_json(manifest_fpath) return manifest_data - def get_model_uuid(table, rowids): + def get_model_uuid(self, rowids): """ Ignore: >>> table.get_model_uuid([2]) [UUID('5b66772c-e654-dd9a-c9de-0ccc1bb6861c')] """ - assert table.ismulti, 'must be a model' - model_uuid_list = table.get_internal_columns(rowids, ('model_uuid',)) + assert self.ismulti, 'must be a model' + model_uuid_list = self.get_internal_columns(rowids, ('model_uuid',)) return model_uuid_list - def get_model_rowids(table, model_uuid_list): + def get_model_rowids(self, model_uuid_list): """ Get the rowid of a model given its uuid @@ -1478,12 +1491,12 @@ def get_model_rowids(table, model_uuid_list): >>> table.get_model_rowids([uuid.UUID('5b66772c-e654-dd9a-c9de-0ccc1bb6861c')]) [2] """ - assert table.ismulti, 'must be a model' - colnames = (table.rowid_colname,) - andwhere_colnames = (table.model_uuid_colname,) + assert self.ismulti, 'must be a model' + colnames = (self.rowid_colname,) + andwhere_colnames = (self.model_uuid_colname,) params_iter = list(zip(model_uuid_list)) - rowid_list = table.db.get_where_eq( - table.tablename, + rowid_list = self.db.get_where_eq( + self.tablename, colnames, params_iter, andwhere_colnames, @@ -1493,13 +1506,13 @@ def get_model_rowids(table, model_uuid_list): return rowid_list @profile - def _prepare_storage_nested(table, proptup_gen): + def _prepare_storage_nested(self, proptup_gen): """ Hack for when a sql schema has tuples defined in it. Accepts nested tuples and flattens them to fit into the sql tables """ - nCols = len(table.data_colnames) - idxs1 = ut.where(table.get_data_col_attr('isnested')) + nCols = len(self.data_colnames) + idxs1 = ut.where(self.get_data_col_attr('isnested')) idxs2 = ut.index_complement(idxs1, nCols) for data in proptup_gen: if data is None: @@ -1518,12 +1531,12 @@ def _prepare_storage_nested(table, proptup_gen): # @profile def _prepare_storage_extern( - table, dirty_parent_ids, config_rowid, config, proptup_gen + self, dirty_parent_ids, config_rowid, config, proptup_gen ): """ Writes external data to disk if write function is specified. """ - internal_data_col_attrs = table.internal_data_col_attrs + internal_data_col_attrs = self.internal_data_col_attrs writable_flags = ut.dict_take_column(internal_data_col_attrs, 'write_func', False) extern_colattrs = ut.compress(internal_data_col_attrs, writable_flags) # extern_colnames = ut.dict_take_column(extern_colattrs, 'colname') @@ -1535,7 +1548,7 @@ def _prepare_storage_extern( extern_fnames_list = list( zip( *[ - table._get_extern_fnames( + self._get_extern_fnames( dirty_parent_ids, config_rowid, config, extern_colattr ) for extern_colattr in extern_colattrs @@ -1543,8 +1556,8 @@ def _prepare_storage_extern( ) ) # get extern cache directory and fpaths - extern_dpath = table.extern_dpath - ut.ensuredir(extern_dpath, verbose=False or table.depc._debug) + extern_dpath = self.extern_dpath + ut.ensuredir(extern_dpath) # extern_fpaths_list = [ # [join(extern_dpath, fname) for fname in fnames] # for fnames in extern_fnames_list @@ -1577,7 +1590,7 @@ def _prepare_storage_extern( data_new = tuple(ut.ungroup(grouped_items, groupxs, nCols - 1)) yield data_new - def get_extern_fnames(table, parent_rowids, config, extern_col_index=0): + def get_extern_fnames(self, parent_rowids, config, extern_col_index=0): """ convinience function around get_extern_fnames @@ -1597,25 +1610,25 @@ def get_extern_fnames(table, parent_rowids, config, extern_col_index=0): >>> fname_list = table.get_extern_fnames(parent_rowids, config) >>> print('fname_list = %r' % (fname_list,)) """ - config_rowid = table.get_config_rowid(config) + config_rowid = self.get_config_rowid(config) # depc.get_rowids(tablename, root_rowids, config) - internal_data_col_attrs = table.internal_data_col_attrs + internal_data_col_attrs = self.internal_data_col_attrs writable_flags = ut.dict_take_column(internal_data_col_attrs, 'write_func', False) extern_colattrs = ut.compress(internal_data_col_attrs, writable_flags) extern_colattr = extern_colattrs[extern_col_index] - fname_list = table._get_extern_fnames( + fname_list = self._get_extern_fnames( parent_rowids, config_rowid, config, extern_colattr ) # if False: - # root_rowids = table.depc.get_root_rowids(table.tablename, rowid_list) + # root_rowids = self.depc.get_root_rowids(self.tablename, rowid_list) # info_props = ['image_uuid', 'verts', 'theta'] - # table.depc.make_root_info_uuid(root_rowids, info_props) + # self.depc.make_root_info_uuid(root_rowids, info_props) return fname_list def _get_extern_fnames( - table, parent_rowids, config_rowid, config, extern_colattr=None + self, parent_rowids, config_rowid, config, extern_colattr=None ): """ TODO: @@ -1627,17 +1640,17 @@ def _get_extern_fnames( Args: parent_rowids (list of tuples) - list of tuples of rowids """ - config_hashid = table.get_config_hashid([config_rowid])[0] - prefix = table.tablename + config_hashid = self.get_config_hashid([config_rowid])[0] + prefix = self.tablename prefix += '_' + extern_colattr['colname'] - colattrs = table.data_col_attrs[extern_colattr['data_colx']] + colattrs = self.data_col_attrs[extern_colattr['data_colx']] # if colname is not None: # prefix += '_' + colname # TODO: Put relevant root properties into the hash of the filename # (like bbox, parent image. basically the general vuuid and suuid. fmtstr = '{prefix}_id={rowids}_{config_hashid}{ext}' # HACK: check if the config specifies the extension type - # extkey = table.extern_ext_config_keys.get(colname, 'ext') + # extkey = self.extern_ext_config_keys.get(colname, 'ext') if 'extern_ext' in colattrs: ext = colattrs['extern_ext'] else: @@ -1655,7 +1668,7 @@ def _get_extern_fnames( return fname_list def _compute_dirty_rows( - table, dirty_parent_ids, dirty_preproc_args, config_rowid, config, verbose=True + self, dirty_parent_ids, dirty_preproc_args, config_rowid, config, verbose=True ): """ dirty_preproc_args = preproc_args @@ -1674,15 +1687,15 @@ def _compute_dirty_rows( config_ = config.config if hasattr(config, 'config') else config # call registered worker function - if table.vectorized: + if self.vectorized: # Function is written in a way that only accepts multiple inputs at # once and generates output - proptup_gen = table.preproc_func(table.depc, *argsT, config=config_) + proptup_gen = self.preproc_func(self.depc, *argsT, config=config_) else: # Function is written in a way that only accepts a single row of # input at a time proptup_gen = ( - table.preproc_func(table.depc, *argrow, config=config_) + self.preproc_func(self.depc, *argrow, config=config_) for argrow in zip(*argsT) ) @@ -1697,7 +1710,7 @@ def _compute_dirty_rows( nInput, ) # Append rowids and rectify nested and external columns - dirty_params_iter = table.prepare_storage( + dirty_params_iter = self.prepare_storage( dirty_parent_ids, proptup_gen, dirty_preproc_args, config_rowid, config_ ) if DEBUG_LIST_MODE: @@ -1707,7 +1720,7 @@ def _compute_dirty_rows( return dirty_params_iter def _chunk_compute_dirty_rows( - table, dirty_parent_ids, dirty_preproc_args, config_rowid, config, verbose=True + self, dirty_parent_ids, dirty_preproc_args, config_rowid, config, verbose=True ): """ Executes registered functions, does external storage and yeilds results @@ -1722,19 +1735,18 @@ def _chunk_compute_dirty_rows( >>> from wbia.dtool.example_depcache2 import * # NOQA >>> depc = testdata_depc3(in_memory=False) >>> depc.clear_all() - >>> data = depc.get('labeler', [1, 2, 3], 'data', _debug=True) - >>> data = depc.get('indexer', [[1, 2, 3]], 'data', _debug=True) + >>> data = depc.get('labeler', [1, 2, 3], 'data') + >>> data = depc.get('indexer', [[1, 2, 3]], 'data') >>> depc.print_all_tables() """ nInput = len(dirty_parent_ids) - chunksize = nInput if table.chunksize is None else table.chunksize + chunksize = nInput if self.chunksize is None else self.chunksize - if verbose: - logger.info( - '[deptbl.compute] nInput={}, chunksize={}, tbl={}'.format( - nInput, table.chunksize, table.tablename - ) + logger.info( + '[deptbl.compute] nInput={}, chunksize={}, tbl={}'.format( + nInput, self.chunksize, self.tablename ) + ) # Report computation progress dirty_iter = list(zip(dirty_parent_ids, dirty_preproc_args)) @@ -1742,22 +1754,10 @@ def _chunk_compute_dirty_rows( dirty_iter, chunksize, nInput, - lbl='[deptbl.compute] add %s chunk' % (table.tablename), + lbl='[deptbl.compute] add %s chunk' % (self.tablename), ) # These are the colnames that we expect to be computed - colnames = table.computable_colnames() - # def unfinished_features(): - # if table._asobject: - # # Convinience - # argsT = [table.depc.get_obj(parent, rowids) - # for parent, rowids in zip(table.parents(), - # dirty_parent_ids_chunk)] - # onthefly = None - # if table.default_onthefly or onthefly: - # assert not table.ismulti, ('cannot onthefly multi tables') - # proptup_gen = [tuple([None] * len(table.data_col_attrs)) - # for _ in range(len(dirty_parent_ids_chunk))] - # pass + colnames = self.computable_colnames() # CALL EXTERNAL PREPROCESSING / GENERATION FUNCTION try: # prog_iter = list(prog_iter) @@ -1767,7 +1767,7 @@ def _chunk_compute_dirty_rows( return dirty_parent_ids_chunk, dirty_preproc_args_chunk = zip(*dirty_chunk) - dirty_params_iter = table._compute_dirty_rows( + dirty_params_iter = self._compute_dirty_rows( dirty_parent_ids_chunk, dirty_preproc_args_chunk, config_rowid, @@ -1789,12 +1789,12 @@ def _chunk_compute_dirty_rows( 'error in add_rowids', keys=[ 'table', - 'table.parents()', + 'self.parents()', 'config', 'argsT', 'config_rowid', 'dirty_parent_ids', - 'table.preproc_func', + 'self.preproc_func', ], tb=True, ) @@ -1849,7 +1849,7 @@ class DependencyCacheTable( @profile def __init__( - table, + self, depc=None, parent_tablenames=None, tablename=None, @@ -1860,9 +1860,9 @@ def __init__( fname=None, asobject=False, chunksize=None, - isinteractive=False, + isinteractive=False, # no-op default_to_unpack=False, - default_onthefly=False, + default_onthefly=False, # no-op rm_extern_on_delete=False, vectorized=True, taggable=False, @@ -1871,93 +1871,167 @@ def __init__( recieves kwargs from depc._register_prop """ try: - table.db = None + self.db = None except Exception: # HACK: jedi type hinting. Need to have non-obvious condition - table.db = SQLDatabaseController() - table.fpath_to_db = {} + self.db = SQLDatabaseController() assert ( re.search('[0-9]', tablename) is None ), 'tablename=%r cannot contain numbers' % (tablename,) # parent depcache - table.depc = depc + self.depc = depc # Definitions - table.tablename = tablename - table.docstr = docstr - table.parent_tablenames = parent_tablenames - table.data_colnames = tuple(data_colnames) - table.data_coltypes = data_coltypes - table.preproc_func = preproc_func - table.fname = fname + self.tablename = tablename + self.docstr = docstr + self.parent_tablenames = parent_tablenames + self.data_colnames = tuple(data_colnames) + self.data_coltypes = data_coltypes + self.preproc_func = preproc_func + self._db_name = fname # Behavior - table.on_delete = None - table.default_to_unpack = default_to_unpack - table.vectorized = vectorized - table.taggable = taggable + self.on_delete = None + self.default_to_unpack = default_to_unpack + self.vectorized = vectorized + self.taggable = taggable - # table.store_modification_time = True + # self.store_modification_time = True # Use the filesystem to accomplish this - # table.store_access_time = True - # table.store_create_time = True - # table.store_delete_time = True - - table.chunksize = chunksize - # Developmental properties - table.subproperties = {} - table.isinteractive = isinteractive - table._asobject = asobject - table.default_onthefly = default_onthefly + # self.store_access_time = True + # self.store_create_time = True + # self.store_delete_time = True + + self.chunksize = chunksize # SQL Internals - table.sqldb_fpath = None - table.rm_extern_on_delete = rm_extern_on_delete + self.sqldb_fpath = None + self.rm_extern_on_delete = rm_extern_on_delete # Update internals - table.parent_col_attrs = table._infer_parentcol() - table.data_col_attrs = table._infer_datacol() - table.internal_col_attrs = table._infer_allcol() + self.parent_col_attrs = self._infer_parentcol() + self.data_col_attrs = self._infer_datacol() + self.internal_col_attrs = self._infer_allcol() # Check for errors if ut.SUPER_STRICT: - table._assert_self() + self._assert_self() - table._hack_chunk_cache = None + # ??? Clearly a hack, but to what end? + self._hack_chunk_cache = None - # @profile - def initialize(table, _debug=None): + @classmethod + def from_name( + cls, + db_name, + table_name, + depcache_controller, + parent_tablenames=None, + data_colnames=None, + data_coltypes=None, + preproc_func=None, + docstr='no docstr', + asobject=False, + chunksize=None, + default_to_unpack=False, + rm_extern_on_delete=False, + vectorized=True, + taggable=False, + ): + """Build the instance based on a database and table name.""" + self = cls.__new__(cls) + + self._db_name = db_name + + # Set the table name + if re.search('[0-9]', table_name): + raise ValueError(f"tablename = '{table_name}' cannot contain numbers") + self.tablename = table_name + + # Set the parent depcache controller + self.depc = depcache_controller + + # Definitions + self.docstr = docstr + self.parent_tablenames = parent_tablenames + self.data_colnames = tuple(data_colnames) + self.data_coltypes = data_coltypes + self.preproc_func = preproc_func + #: Optional specification of the amount of blobs to modify in one SQL operation + self.chunksize = chunksize + + # FIXME (20-Oct-12020) This definition of behavior by external means is a scope issue + # Another object should not be directly manipulating this object. + #: functions defined and populated through DependencyCache._register_subprop + self.subproperties = {} + + # Behavior + self.on_delete = None + self.default_to_unpack = default_to_unpack + self.vectorized = vectorized + self.taggable = taggable + #: Flag to enable the deletion of external files on associated SQL row deletion. + self.rm_extern_on_delete = rm_extern_on_delete + + # XXX (20-Oct-12020) It's not clear if these attributes are absolutely necessary. + # Update internals + self.parent_col_attrs = self._infer_parentcol() + self.data_col_attrs = self._infer_datacol() + self.internal_col_attrs = self._infer_allcol() + # /XXX + + # Check for errors + # FIXME (20-Oct-12020) This seems like a bad idea... + # Why would you sometimes want to check and not at other times? + if ut.SUPER_STRICT: + self._assert_self() + + # ??? Clearly a hack, but to what end? + self._hack_chunk_cache = None + + return self + + @property + def fname(self): + """Backwards compatible name of the database this Table belongs to""" + # BBB (20-Oct-12020) 'fname' is the legacy name for the database name + return self._db_name + + def initialize(self, _debug=None): """ Ensures the SQL schema for this cache table """ - table.db = table.depc.fname_to_db[table.fname] - # logger.info('Checking sql for table=%r' % (table.tablename,)) - if not table.db.has_table(table.tablename): - if _debug or ut.VERBOSE: - logger.info('Initializing table=%r' % (table.tablename,)) - new_state = table._get_addtable_kw() - table.db.add_table(**new_state) + self.db = self.depc.get_db_by_name(self._db_name) + # logger.info('Checking sql for table=%r' % (self.tablename,)) + if not self.db.has_table(self.tablename): + logger.debug('Initializing table=%r' % (self.tablename,)) + new_state = self._get_addtable_kw() + self.db.add_table(**new_state) else: # TODO: Check for table modifications - new_state = table._get_addtable_kw() + new_state = self._get_addtable_kw() try: - current_state = table.db.get_table_autogen_dict(table.tablename) + current_state = self.db.get_table_autogen_dict(self.tablename) except Exception as ex: strict = True ut.printex( ex, - 'TABLE %s IS CORRUPTED' % (table.tablename,), + 'TABLE %s IS CORRUPTED' % (self.tablename,), iswarning=not strict, ) if strict: raise - table.clear_table() - current_state = table.db.get_table_autogen_dict(table.tablename) + self.clear_table() + current_state = self.db.get_table_autogen_dict(self.tablename) - if current_state['coldef_list'] != new_state['coldef_list']: - logger.info('WARNING TABLE IS MODIFIED') - if predrop_grace_period(table.tablename): - table.clear_table() - else: - raise NotImplementedError('Need to be able to modify tables') + results = compare_coldef_lists( + current_state['coldef_list'], new_state['coldef_list'] + ) + if results: + current_coldef, new_coldef = results + raise TableOutOfSyncError( + self.db, + self.tablename, + f'Current schema: {current_coldef} Expected schema: {new_coldef}', + ) - def _get_addtable_kw(table): + def _get_addtable_kw(self): """ Information that defines the SQL table @@ -1980,16 +2054,16 @@ def _get_addtable_kw(table): """ coldef_list = [ (colattr['intern_colname'], colattr['sqltype']) - for colattr in table.internal_col_attrs + for colattr in self.internal_col_attrs ] - superkeys = [table.superkey_colnames] + superkeys = [self.superkey_colnames] add_table_kw = ut.odict( [ - ('tablename', table.tablename), + ('tablename', self.tablename), ('coldef_list', coldef_list), - ('docstr', table.docstr), + ('docstr', self.docstr), ('superkeys', superkeys), - ('dependson', table.parents()), + ('dependson', self.parents()), ] ) return add_table_kw @@ -1998,16 +2072,16 @@ def _get_addtable_kw(table): # --- GETTERS NATIVE --- # ---------------------- - def _get_all_rowids(table): - return table.db.get_all_rowids(table.tablename) + def _get_all_rowids(self): + return self.db.get_all_rowids(self.tablename) @property - def number_of_rows(table): - return table.db.get_row_count(table.tablename) + def number_of_rows(self): + return self.db.get_row_count(self.tablename) # @profile def ensure_rows( - table, + self, parent_ids_, preproc_args, config=None, @@ -2028,31 +2102,28 @@ def ensure_rows( >>> table = depc['vsone'] >>> exec(ut.execstr_funckw(table.get_rowid), globals()) >>> config = table.configclass() - >>> _debug = 5 >>> verbose = True >>> # test duplicate inputs are detected and accounted for >>> parent_rowids = [(i, i) for i in list(range(100))] * 100 >>> rectify_tup = table._rectify_ids(parent_rowids) >>> (parent_ids_, preproc_args, idxs1, idxs2) = rectify_tup - >>> rowids = table.ensure_rows(parent_ids_, preproc_args, config=config, _debug=_debug) + >>> rowids = table.ensure_rows(parent_ids_, preproc_args, config=config) >>> result = ('rowids = %r' % (rowids,)) >>> print(result) """ try: - _debug = table.depc._debug if _debug is None else _debug # Get requested configuration id - config_rowid = table.get_config_rowid(config) + config_rowid = self.get_config_rowid(config) # Check which rows are already computed - initial_rowid_list = table._get_rowid(parent_ids_, config=config) + initial_rowid_list = self._get_rowid(parent_ids_, config=config) initial_rowid_list = list(initial_rowid_list) - if table.depc._debug: - logger.info( - '[deptbl.ensure] initial_rowid_list = %s' - % (ut.trunc_repr(initial_rowid_list),) - ) - logger.info('[deptbl.ensure] config_rowid = %r' % (config_rowid,)) + logger.debug( + '[deptbl.ensure] initial_rowid_list = %s' + % (ut.trunc_repr(initial_rowid_list),) + ) + logger.debug('[deptbl.ensure] config_rowid = %r' % (config_rowid,)) # Get corresponding "dirty" parent rowids isdirty_list = ut.flag_None_items(initial_rowid_list) @@ -2060,92 +2131,84 @@ def ensure_rows( num_total = len(parent_ids_) if num_dirty > 0: - with ut.Indenter('[ADD]', enabled=_debug): - if verbose or _debug: - logger.info( - 'Add %d / %d new rows to %r' - % (num_dirty, num_total, table.tablename) - ) - logger.info( - '[deptbl.add] * config_rowid = {}, config={}'.format( - config_rowid, str(config) + logger.debug( + 'Add %d / %d new rows to %r' % (num_dirty, num_total, self.tablename) + ) + logger.debug( + '[deptbl.add] * config_rowid = {}, config={}'.format( + config_rowid, str(config) + ) + ) + + dirty_parent_ids_ = ut.compress(parent_ids_, isdirty_list) + dirty_preproc_args_ = ut.compress(preproc_args, isdirty_list) + + # Process only unique items + unique_flags = ut.flag_unique_items(dirty_parent_ids_) + dirty_parent_ids = ut.compress(dirty_parent_ids_, unique_flags) + dirty_preproc_args = ut.compress(dirty_preproc_args_, unique_flags) + + # Break iterator into chunks + if False and verbose: + # check parent configs we are working with + for x, parname in enumerate(self.parents()): + if parname == self.depc.root: + continue + parent_table = self.depc[parname] + ut.take_column(parent_ids_, x) + rowid_list = ut.take_column(parent_ids_, x) + try: + parent_history = parent_table.get_config_history(rowid_list) + logger.info('parent_history = %r' % (parent_history,)) + except KeyError: + logger.info( + '[depcache_table] WARNING: config history is having troubles... says Jon' ) - ) - dirty_parent_ids_ = ut.compress(parent_ids_, isdirty_list) - dirty_preproc_args_ = ut.compress(preproc_args, isdirty_list) - - # Process only unique items - unique_flags = ut.flag_unique_items(dirty_parent_ids_) - dirty_parent_ids = ut.compress(dirty_parent_ids_, unique_flags) - dirty_preproc_args = ut.compress(dirty_preproc_args_, unique_flags) - - # Break iterator into chunks - if False and verbose: - # check parent configs we are working with - for x, parname in enumerate(table.parents()): - if parname == table.depc.root: - continue - parent_table = table.depc[parname] - ut.take_column(parent_ids_, x) - rowid_list = ut.take_column(parent_ids_, x) - try: - parent_history = parent_table.get_config_history( - rowid_list - ) - logger.info('parent_history = %r' % (parent_history,)) - except KeyError: - logger.info( - '[depcache_table] WARNING: config history is having troubles... says Jon' - ) - - # Gives the function a hacky cache to use between chunks - table._hack_chunk_cache = {} - gen = table._chunk_compute_dirty_rows( - dirty_parent_ids, dirty_preproc_args, config_rowid, config + # Gives the function a hacky cache to use between chunks + self._hack_chunk_cache = {} + gen = self._chunk_compute_dirty_rows( + dirty_parent_ids, dirty_preproc_args, config_rowid, config + ) + """ + colnames, dirty_params_iter, nChunkInput = next(gen) + """ + for colnames, dirty_params_iter, nChunkInput in gen: + self.db._add( + self.tablename, + colnames, + dirty_params_iter, + nInput=nChunkInput, ) - """ - colnames, dirty_params_iter, nChunkInput = next(gen) - """ - for colnames, dirty_params_iter, nChunkInput in gen: - table.db._add( - table.tablename, - colnames, - dirty_params_iter, - nInput=nChunkInput, - ) - # Remove cache when main add is done - table._hack_chunk_cache = None - if verbose or _debug: - logger.info('[deptbl.add] finished add') - # - # The requested data is clean and must now exist in the parent - # database, do a lookup to ensure the correct order. - rowid_list = table._get_rowid(parent_ids_, config=config) + # Remove cache when main add is done + self._hack_chunk_cache = None + logger.debug('[deptbl.add] finished add') + # + # The requested data is clean and must now exist in the parent + # database, do a lookup to ensure the correct order. + rowid_list = self._get_rowid(parent_ids_, config=config) else: rowid_list = initial_rowid_list - if _debug: - logger.info('[deptbl.add] rowid_list = %s' % ut.trunc_repr(rowid_list)) + logger.debug('[deptbl.add] rowid_list = %s' % ut.trunc_repr(rowid_list)) except lite.IntegrityError: if retry <= 0: raise logger.error( 'DEPC ENSURE_ROWS FOR TABLE %r FAILED DUE TO INTEGRITY ERROR (RETRY %d)!' - % (table, retry) + % (self, retry) ) retry_delay = random.uniform(retry_delay_min, retry_delay_max) logger.error('\t WAITING %0.02f SECONDS THEN RETRYING' % (retry_delay,)) time.sleep(retry_delay) retry_ = retry - 1 - rowid_list = table.ensure_rows( + rowid_list = self.ensure_rows( parent_ids_, preproc_args, config=config, verbose=verbose, - _debug=_debug, retry=retry_, retry_delay_min=retry_delay_min, retry_delay_max=retry_delay_max, @@ -2153,7 +2216,7 @@ def ensure_rows( return rowid_list - def _rectify_ids(table, parent_rowids): + def _rectify_ids(self, parent_rowids): r""" Filters any rows containing None ids and transforms many-to-one sets of rowids into hashable UUIDS. @@ -2199,9 +2262,9 @@ def _rectify_ids(table, parent_rowids): valid_parent_ids_ = ut.take(parent_rowids, idxs1) preproc_args = valid_parent_ids_ - if table.ismulti: + if self.ismulti: # Convert any parent-id containing multiple values into a hash of uuids - multi_parent_flags = table.get_parent_col_attr('ismulti') + multi_parent_flags = self.get_parent_col_attr('ismulti') num_parents = len(multi_parent_flags) multi_parent_colxs = ut.where(multi_parent_flags) normal_colxs = ut.index_complement(multi_parent_colxs, num_parents) @@ -2213,9 +2276,9 @@ def _rectify_ids(table, parent_rowids): ] # TODO: give each table a uuid getter function that derives from # get_root_uuids - multicol_tables = ut.take(table.parents(), multi_parent_colxs) + multicol_tables = ut.take(self.parents(), multi_parent_colxs) parent_uuid_getters = [ - table.depc.get_root_uuid if col == table.depc.root else ut.identity + self.depc.get_root_uuid if col == self.depc.root else ut.identity for col in multicol_tables ] @@ -2247,7 +2310,7 @@ def _rectify_ids(table, parent_rowids): rectify_tup = parent_ids_, preproc_args, idxs1, idxs2 return rectify_tup - def _unrectify_ids(table, rowid_list_, parent_rowids, idxs1, idxs2): + def _unrectify_ids(self, rowid_list_, parent_rowids, idxs1, idxs2): """ Ensures that output is the same length as input. Inserts necessary Nones where the original input was also None. @@ -2257,7 +2320,7 @@ def _unrectify_ids(table, rowid_list_, parent_rowids, idxs1, idxs2): return rowid_list def get_rowid( - table, + self, parent_rowids, config=None, ensure=True, @@ -2279,7 +2342,7 @@ def get_rowid( eager (bool): (default = True) nInput (int): (default = None) recompute (bool): (default = False) - _debug (None): (default = None) + _debug (None): (default = None) deprecated; no-op Returns: list: rowid_list @@ -2295,105 +2358,106 @@ def get_rowid( >>> table = depc['labeler'] >>> exec(ut.execstr_funckw(table.get_rowid), globals()) >>> config = table.configclass() - >>> _debug = True >>> parent_rowids = list(zip([1, None, None, 2])) - >>> rowids = table.get_rowid(parent_rowids, config=config, _debug=_debug) + >>> rowids = table.get_rowid(parent_rowids, config=config) >>> result = ('rowids = %r' % (rowids,)) >>> print(result) rowids = [1, None, None, 2] """ - _debug = table.depc._debug if _debug is None else _debug - if _debug: - logger.info( - '[deptbl.get_rowid] Get %s rowids via %d parent superkeys' - % (table.tablename, len(parent_rowids)) - ) - if _debug > 1: - logger.info('[deptbl.get_rowid] config = %r' % (config,)) - logger.info('[deptbl.get_rowid] ensure = %r' % (ensure,)) + logger.debug( + '[deptbl.get_rowid] Get %s rowids via %d parent superkeys' + % (self.tablename, len(parent_rowids)) + ) + logger.debug('[deptbl.get_rowid] config = %r' % (config,)) + logger.debug('[deptbl.get_rowid] ensure = %r' % (ensure,)) # Ensure inputs are in the correct format / remove Nones # Collapse multi-inputs into a UUID hash - rectify_tup = table._rectify_ids(parent_rowids) + rectify_tup = self._rectify_ids(parent_rowids) (parent_ids_, preproc_args, idxs1, idxs2) = rectify_tup # Do the getting / adding work if recompute: logger.info('REQUESTED RECOMPUTE') # get existing rowids, delete them, recompute the request - rowid_list_ = table._get_rowid( - parent_ids_, config=config, eager=True, nInput=None, _debug=_debug + rowid_list_ = self._get_rowid( + parent_ids_, + config=config, + eager=True, + nInput=None, ) rowid_list_ = list(rowid_list_) needs_recompute_rowids = ut.filter_Nones(rowid_list_) try: - table._recompute_and_store(needs_recompute_rowids) + self._recompute_and_store(needs_recompute_rowids) except Exception: # If the config changes, there is nothing we can do. # We have to delete the rows. - table.delete_rows(rowid_list_) + self.delete_rows(rowid_list_) if ensure or recompute: # Compute properties if they do not exist for try_num in range(num_retries): try: - rowid_list_ = table.ensure_rows( - parent_ids_, preproc_args, config=config, _debug=_debug + rowid_list_ = self.ensure_rows( + parent_ids_, + preproc_args, + config=config, ) except ExternalStorageException: if try_num == num_retries - 1: raise else: - rowid_list_ = table._get_rowid( - parent_ids_, config=config, eager=eager, nInput=nInput, _debug=_debug + rowid_list_ = self._get_rowid( + parent_ids_, + config=config, + eager=eager, + nInput=nInput, ) # Map outputs to correspond with inputs - rowid_list = table._unrectify_ids(rowid_list_, parent_rowids, idxs1, idxs2) + rowid_list = self._unrectify_ids(rowid_list_, parent_rowids, idxs1, idxs2) return rowid_list # @profile - def _get_rowid(table, parent_ids_, config=None, eager=True, nInput=None, _debug=None): + def _get_rowid(self, parent_ids_, config=None, eager=True, nInput=None): """ Returns rowids using parent superkeys. Does not add non-existing properties. """ - colnames = (table.rowid_colname,) - config_rowid = table.get_config_rowid(config=config) - _debug = table.depc._debug if _debug is None else _debug - if _debug: - logger.info('_get_rowid') - logger.info('_get_rowid table.tablename = %r ' % (table.tablename,)) - logger.info('_get_rowid parent_ids_ = %s' % (ut.trunc_repr(parent_ids_))) - logger.info('_get_rowid config = %s' % (config)) - logger.info('_get_rowid table.rowid_colname = %s' % (table.rowid_colname)) - logger.info('_get_rowid config_rowid = %s' % (config_rowid)) - andwhere_colnames = table.superkey_colnames + colnames = (self.rowid_colname,) + config_rowid = self.get_config_rowid(config=config) + logger.debug('_get_rowid') + logger.debug('_get_rowid self.tablename = %r ' % (self.tablename,)) + logger.debug('_get_rowid parent_ids_ = %s' % (ut.trunc_repr(parent_ids_))) + logger.debug('_get_rowid config = %s' % (config)) + logger.debug('_get_rowid self.rowid_colname = %s' % (self.rowid_colname)) + logger.debug('_get_rowid config_rowid = %s' % (config_rowid)) + andwhere_colnames = self.superkey_colnames params_iter = (ids_ + (config_rowid,) for ids_ in parent_ids_) # TODO: make sure things that call this can accept a generator # Then remove this next line params_iter = list(params_iter) # logger.info('**params_iter = %r' % (params_iter,)) - rowid_list = table.db.get_where_eq( - table.tablename, + rowid_list = self.db.get_where_eq( + self.tablename, colnames, params_iter, andwhere_colnames, eager=eager, nInput=nInput, ) - if _debug: - logger.info('_get_rowid rowid_list = %s' % (ut.trunc_repr(rowid_list))) + logger.debug('_get_rowid rowid_list = %s' % (ut.trunc_repr(rowid_list))) return rowid_list - def clear_table(table): + def clear_table(self): """ Deletes all data in this table """ # TODO: need to clear one-to-one dependencies as well - logger.info('Clearing data in %r' % (table,)) - table.db.drop_table(table.tablename) - table.db.add_table(**table._get_addtable_kw()) + logger.info('Clearing data in %r' % (self,)) + self.db.drop_table(self.tablename) + self.db.add_table(**self._get_addtable_kw()) # @profile - def delete_rows(table, rowid_list, delete_extern=None, dry=False, verbose=None): + def delete_rows(self, rowid_list, delete_extern=None, dry=False, verbose=None): """ CommandLine: python -m dtool.depcache_table --exec-delete_rows @@ -2429,30 +2493,30 @@ def delete_rows(table, rowid_list, delete_extern=None, dry=False, verbose=None): """ # import networkx as nx # from wbia.dtool.algo.preproc import preproc_feat - if table.on_delete is not None and not dry: - table.on_delete() + if self.on_delete is not None and not dry: + self.on_delete() if delete_extern is None: - delete_extern = table.rm_extern_on_delete + delete_extern = self.rm_extern_on_delete if verbose is None: verbose = False if ut.NOT_QUIET: if ut.VERBOSE: logger.info( 'Requested delete of %d rows from %s' - % (len(rowid_list), table.tablename) + % (len(rowid_list), self.tablename) ) if dry: logger.info('Dry run') # logger.info('delete_extern = %r' % (delete_extern,)) - depc = table.depc + depc = self.depc # TODO: # REMOVE EXTERNAL FILES - internal_colnames = table.get_intern_data_col_attr('intern_colname') - is_extern = table.get_intern_data_col_attr('is_external_pointer') + internal_colnames = self.get_intern_data_col_attr('intern_colname') + is_extern = self.get_intern_data_col_attr('is_external_pointer') extern_colnames = tuple(ut.compress(internal_colnames, is_extern)) if len(extern_colnames) > 0: - uris = table.get_internal_columns( + uris = self.get_internal_columns( rowid_list, extern_colnames, unpack_scalars=False, @@ -2464,7 +2528,7 @@ def delete_rows(table, rowid_list, delete_extern=None, dry=False, verbose=None): if not isinstance(uri, tuple): uri = [uri] for uri_ in uri: - absuris.append(join(table.extern_dpath, uri_)) + absuris.append(join(self.extern_dpath, uri_)) fpaths = [fpath for fpath in absuris if exists(fpath)] if delete_extern: if ut.VERBOSE or len(fpaths) > 0: @@ -2497,17 +2561,17 @@ def get_child_partial_rowids(child_table, rowid_list, parent_colnames): return child_rowids if ut.VERBOSE: - if table.children: - logger.info('Deleting from %r children' % (len(table.children),)) + if self.children: + logger.info('Deleting from %r children' % (len(self.children),)) else: logger.info('Table is a leaf node') - for child in table.children: - child_table = table.depc[child] + for child in self.children: + child_table = self.depc[child] if not child_table.ismulti: # Hack, wont work for vsone / multisets parent_colnames = ( - child_table.parent[table.tablename]['intern_colname'], + child_table.parent[self.tablename]['intern_colname'], ) child_rowids = get_child_partial_rowids( child_table, rowid_list, parent_colnames @@ -2519,24 +2583,24 @@ def get_child_partial_rowids(child_table, rowid_list, parent_colnames): if ut.VERBOSE or len(non_none_rowids) > 0: logger.info( 'Deleting %d non-None rows from %s' - % (len(non_none_rowids), table.tablename) + % (len(non_none_rowids), self.tablename) ) logger.info('...done!') # Finalize: Delete rows from this table if not dry: - table.db.delete_rowids(table.tablename, rowid_list) + self.db.delete_rowids(self.tablename, rowid_list) num_deleted = len(ut.filter_Nones(rowid_list)) else: num_deleted = 0 return num_deleted - def _resolve_requested_columns(table, requested_colnames): + def _resolve_requested_columns(self, requested_colnames): ######## # Map requested colnames flat to internal colnames ######## # Get requested column information - requestable_col_attrs = table.requestable_col_attrs() + requestable_col_attrs = self.requestable_col_attrs() requested_colattrs = ut.take(requestable_col_attrs, requested_colnames) # Make column indicies iterable for grouping intern_colxs = [ @@ -2550,7 +2614,7 @@ def _resolve_requested_columns(table, requested_colnames): extern_colattrs = ut.compress(requested_colattrs, isextern_flags) extern_resolve_colxs = ut.compress(nested_offsets_start, isextern_flags) extern_read_funcs = ut.take_column(extern_colattrs, 'read_func') - intern_colnames_ = ut.take_column(table.internal_col_attrs, 'intern_colname') + intern_colnames_ = ut.take_column(self.internal_col_attrs, 'intern_colname') intern_colnames = ut.unflat_take(intern_colnames_, intern_colxs) # TODO: this can be cleaned up @@ -2564,7 +2628,7 @@ def _resolve_requested_columns(table, requested_colnames): # @profile def get_row_data( - table, + self, tbl_rowids, colnames=None, _debug=None, @@ -2633,18 +2697,16 @@ def get_row_data( >>> data = table.get_row_data(tbl_rowids, 'chip', read_extern=False, ensure=False) >>> data = table.get_row_data(tbl_rowids, 'chip', read_extern=False, ensure=True) """ - _debug = table.depc._debug if _debug is None else _debug - if _debug: - logger.info( - ('Get col of tablename=%r, colnames=%r with ' 'tbl_rowids=%s') - % (table.tablename, colnames, ut.trunc_repr(tbl_rowids)) - ) + logger.debug( + ('Get col of tablename=%r, colnames=%r with ' 'tbl_rowids=%s') + % (self.tablename, colnames, ut.trunc_repr(tbl_rowids)) + ) #### # Resolve requested column names if unpack_columns is None: - unpack_columns = table.default_to_unpack + unpack_columns = self.default_to_unpack if colnames is None: - requested_colnames = table.data_colnames + requested_colnames = self.data_colnames elif isinstance(colnames, six.string_types): # Unpack columns if only a single column is requested. requested_colnames = (colnames,) @@ -2652,16 +2714,13 @@ def get_row_data( else: requested_colnames = colnames - if _debug: - logger.info('requested_colnames = %r' % (requested_colnames,)) - tup = table._resolve_requested_columns(requested_colnames) + logger.debug('requested_colnames = %r' % (requested_colnames,)) + tup = self._resolve_requested_columns(requested_colnames) nesting_xs, extern_resolve_tups, flat_intern_colnames = tup - if _debug: - logger.info( - '[deptbl.get_row_data] flat_intern_colnames = %r' - % (flat_intern_colnames,) - ) + logger.debug( + '[deptbl.get_row_data] flat_intern_colnames = %r' % (flat_intern_colnames,) + ) nonNone_flags = ut.flag_not_None_items(tbl_rowids) nonNone_tbl_rowids = ut.compress(tbl_rowids, nonNone_flags) @@ -2669,18 +2728,12 @@ def get_row_data( idxs2 = ut.index_complement(idxs1, len(tbl_rowids)) - #### - # Read data stored in SQL - # FIXME: understand unpack_scalars and keepwrap - # if table.default_onthefly: - # table._onthefly_dataget - # else: if nInput is None and ut.is_listlike(nonNone_tbl_rowids): nInput = len(nonNone_tbl_rowids) generator_version = not eager - raw_prop_list = table.get_internal_columns( + raw_prop_list = self.get_internal_columns( nonNone_tbl_rowids, flat_intern_colnames, eager=eager, @@ -2715,7 +2768,7 @@ def tuptake(list_, index_list): if generator_version: def _generator_resolve_all(): - extern_dpath = table.extern_dpath + extern_dpath = self.extern_dpath for rawprop in raw_prop_list: if rawprop is None: raise Exception( @@ -2751,7 +2804,7 @@ def _generator_resolve_all(): for try_num in range(num_retries + 1): tries_left = num_retries - try_num try: - prop_listT = table._resolve_any_external_data( + prop_listT = self._resolve_any_external_data( nonNone_tbl_rowids, raw_prop_list, extern_resolve_tups, @@ -2759,7 +2812,6 @@ def _generator_resolve_all(): read_extern, delete_on_fail, tries_left, - _debug, ) except ExternalStorageException: if tries_left == 0: @@ -2790,7 +2842,7 @@ def _generator_resolve_all(): return prop_list def _resolve_any_external_data( - table, + self, nonNone_tbl_rowids, raw_prop_list, extern_resolve_tups, @@ -2798,11 +2850,10 @@ def _resolve_any_external_data( read_extern, delete_on_fail, tries_left, - _debug, ): #### # Read data specified by any external columns - extern_dpath = table.extern_dpath + extern_dpath = self.extern_dpath try: prop_listT = list(zip(*raw_prop_list)) except TypeError as ex: @@ -2810,8 +2861,7 @@ def _resolve_any_external_data( raise for extern_colx, read_func in extern_resolve_tups: - if _debug: - logger.info('[deptbl.get_row_data] read_func = %r' % (read_func,)) + logger.debug('[deptbl.get_row_data] read_func = %r' % (read_func,)) data_list = [] failed_list = [] for uri in prop_listT[extern_colx]: @@ -2853,8 +2903,8 @@ def _resolve_any_external_data( ) failed_rowids = ut.compress(nonNone_tbl_rowids, failed_list) if delete_on_fail: - table._recompute_external_storage(failed_rowids) - # table.delete_rows(failed_rowids, delete_extern=None) + self._recompute_external_storage(failed_rowids) + # self.delete_rows(failed_rowids, delete_extern=None) raise ExternalStorageException( 'Some cached filenames failed to read. ' 'Need to recompute %d/%d rows' % (sum(failed_list), len(failed_list)) @@ -2863,7 +2913,7 @@ def _resolve_any_external_data( prop_listT[extern_colx] = data_list return prop_listT - def _recompute_external_storage(table, tbl_rowids): + def _recompute_external_storage(self, tbl_rowids): """ Recomputes the external file stored for this row. This DOES NOT modify the depcache internals. @@ -2871,26 +2921,26 @@ def _recompute_external_storage(table, tbl_rowids): logger.info('Recomputing external data (_recompute_external_storage)') # TODO: need to rectify parent ids? - parent_rowids = table.get_parent_rowids(tbl_rowids) - parent_rowargs = table.get_parent_rowargs(tbl_rowids) + parent_rowids = self.get_parent_rowids(tbl_rowids) + parent_rowargs = self.get_parent_rowargs(tbl_rowids) - # configs = table.get_row_configs(tbl_rowids) + # configs = self.get_row_configs(tbl_rowids) # assert ut.allsame(list(map(id, configs))), 'more than one config not yet supported' # TODO; groupby config - config_rowids = table.get_row_cfgid(tbl_rowids) + config_rowids = self.get_row_cfgid(tbl_rowids) unique_cfgids, groupxs = ut.group_indices(config_rowids) for xs, cfgid in zip(groupxs, unique_cfgids): parent_ids = ut.take(parent_rowids, xs) parent_args = ut.take(parent_rowargs, xs) - config = table.get_config_from_rowid([cfgid])[0] - dirty_params_iter = table._compute_dirty_rows( + config = self.get_config_from_rowid([cfgid])[0] + dirty_params_iter = self._compute_dirty_rows( parent_ids, parent_args, config_rowid=cfgid, config=config ) # Evaulate just to ensure storage ut.evaluate_generator(dirty_params_iter) - def _recompute_and_store(table, tbl_rowids, config=None): + def _recompute_and_store(self, tbl_rowids, config=None): """ Recomputes all data stored for this row. This DOES modify the depcache internals. @@ -2898,43 +2948,42 @@ def _recompute_and_store(table, tbl_rowids, config=None): logger.info('Recomputing external data (_recompute_and_store)') if len(tbl_rowids) == 0: return - parent_rowids = table.get_parent_rowids(tbl_rowids) - parent_rowargs = table.get_parent_rowargs(tbl_rowids) - # configs = table.get_row_configs(tbl_rowids) + parent_rowids = self.get_parent_rowids(tbl_rowids) + parent_rowargs = self.get_parent_rowargs(tbl_rowids) + # configs = self.get_row_configs(tbl_rowids) # assert ut.allsame(list(map(id, configs))), 'more than one config not yet supported' # TODO; groupby config if config is None: - config_rowids = table.get_row_cfgid(tbl_rowids) + config_rowids = self.get_row_cfgid(tbl_rowids) unique_cfgids, groupxs = ut.group_indices(config_rowids) else: # This is incredibly hacky. pass - colnames = table.computable_colnames() + colnames = self.computable_colnames() for xs, cfgid in zip(groupxs, unique_cfgids): parent_ids = ut.take(parent_rowids, xs) parent_args = ut.take(parent_rowargs, xs) rowids = ut.take(tbl_rowids, xs) - config = table.get_config_from_rowid([cfgid])[0] - dirty_params_iter = table._compute_dirty_rows( + config = self.get_config_from_rowid([cfgid])[0] + dirty_params_iter = self._compute_dirty_rows( parent_ids, parent_args, config_rowid=cfgid, config=config ) # Evaulate to external and internal storage - table.db.set(table.tablename, colnames, dirty_params_iter, rowids) + self.db.set(self.tablename, colnames, dirty_params_iter, rowids) - # _onthefly_dataget # togroup_args = [parent_rowids] # grouped_parent_ids = ut.apply_grouping(parent_rowids, groupxs) # unique_args_list = [unique_configs] # raw_prop_lists = [] - # # func = ut.partial(table.preproc_func, table.depc) + # # func = ut.partial(self.preproc_func, self.depc) # def groupmap_func(group_args, unique_args): # config_ = unique_args[0] # argsT = group_args - # propgen = table.preproc_func(table.depc, *argsT, config=config_) + # propgen = self.preproc_func(self.depc, *argsT, config=config_) # return list(propgen) # def grouped_map(groupmap_func, groupxs, togroup_args, unique_args_list): @@ -2954,7 +3003,7 @@ def _recompute_and_store(table, tbl_rowids, config=None): # @profile def get_internal_columns( - table, + self, tbl_rowids, colnames=None, eager=True, @@ -2967,11 +3016,11 @@ def get_internal_columns( Access data in this table using the table PRIMARY KEY rowids (not depc PRIMARY ids) """ - prop_list = table.db.get( - table.tablename, + prop_list = self.db.get( + self.tablename, colnames, tbl_rowids, - id_colname=table.rowid_colname, + id_colname=self.rowid_colname, eager=eager, nInput=nInput, unpack_scalars=unpack_scalars, @@ -2980,7 +3029,7 @@ def get_internal_columns( ) return prop_list - def export_rows(table, rowid, target): + def export_rows(self, rowid, target): """ The goal of this is to export taggable data that can be used independantly of its dependant features. @@ -3015,17 +3064,17 @@ def export_rows(table, rowid, target): rowid = 1 """ raise NotImplementedError('unfinished') - colnames = tuple(table.db.get_column_names(table.tablename)) - colvals = table.db.get(table.tablename, colnames, [rowid])[0] # NOQA + colnames = tuple(self.db.get_column_names(self.tablename)) + colvals = self.db.get(self.tablename, colnames, [rowid])[0] # NOQA - uuid = table.get_model_uuid([rowid])[0] - manifest_data = table.get_model_inputs(uuid) # NOQA + uuid = self.get_model_uuid([rowid])[0] + manifest_data = self.get_model_inputs(uuid) # NOQA - config_history = table.get_config_history([rowid]) # NOQA + config_history = self.get_config_history([rowid]) # NOQA - table.parent_col_attrs = table._infer_parentcol() - table.data_col_attrs - table.internal_col_attrs + self.parent_col_attrs = self._infer_parentcol() + self.data_col_attrs + self.internal_col_attrs - table.db.cur.execute('SELECT * FROM {tablename} WHERE rowid=?') + self.db.cur.execute('SELECT * FROM {tablename} WHERE rowid=?') pass diff --git a/wbia/dtool/events.py b/wbia/dtool/events.py index 06e0447fc3..8eba20e80f 100644 --- a/wbia/dtool/events.py +++ b/wbia/dtool/events.py @@ -7,6 +7,7 @@ """ from sqlalchemy import event from sqlalchemy.schema import Table +from sqlalchemy.sql import text from .types import SQL_TYPE_TO_SA_TYPE @@ -22,9 +23,33 @@ def _discovery_table_columns(inspector, table_name): #: column-id, name, data-type, nullable, default-value, is-primary-key info_rows = result.fetchall() names_to_types = {info[1]: info[2] for info in info_rows} + elif dialect == 'postgresql': + result = conn.execute( + text( + """SELECT + row_number() over () - 1, + column_name, + coalesce(domain_name, data_type), + CASE WHEN is_nullable = 'YES' THEN 0 ELSE 1 END, + column_default, + column_name = ( + SELECT column_name + FROM information_schema.table_constraints + NATURAL JOIN information_schema.constraint_column_usage + WHERE table_name = :table_name + AND constraint_type = 'PRIMARY KEY' + LIMIT 1 + ) AS pk + FROM information_schema.columns + WHERE table_name = :table_name""" + ), + table_name=table_name, + ) + info_rows = result.fetchall() + names_to_types = {info[1]: info[2] for info in info_rows} else: raise RuntimeError( - "Unknown dialect ('{dialect}'), can't introspect column information." + f"Unknown dialect ('{dialect}'), can't introspect column information." ) return names_to_types diff --git a/wbia/dtool/example_depcache.py b/wbia/dtool/example_depcache.py index 3fe744e5c6..4da19e3af1 100644 --- a/wbia/dtool/example_depcache.py +++ b/wbia/dtool/example_depcache.py @@ -4,15 +4,21 @@ python -m dtool.example_depcache --exec-dummy_example_depcacahe --show python -m dtool.depcache_control --exec-make_graph --show """ +from pathlib import Path +from os.path import join + import utool as ut import numpy as np import uuid -from os.path import join, dirname from six.moves import zip + from wbia.dtool import depcache_control from wbia import dtool +HERE = Path(__file__).parent.resolve() + + if False: # Example of global registration DUMMY_ROOT_TABLENAME = 'dummy_annot' @@ -34,6 +40,20 @@ def dummy_global_preproc_func(depc, parent_rowids, config=None): yield 'dummy' +class DummyController: + """Just enough (IBEIS) controller to make the dependency cache examples work""" + + def __init__(self, cache_dpath): + self.cache_dpath = Path(cache_dpath) + self.cache_dpath.mkdir(exist_ok=True) + + def make_cache_db_uri(self, name): + return f'sqlite:///{self.cache_dpath}/{name}.sqlite' + + def get_cachedir(self): + return self.cache_dpath + + class DummyKptsConfig(dtool.Config): def get_param_info_list(self): return [ @@ -220,16 +240,17 @@ def testdata_depc(fname=None): def get_root_uuid(aid_list): return ut.lmap(ut.hashable_to_uuid, aid_list) - # put the test cache in the dtool repo - dtool_repo = dirname(ut.get_module_dir(dtool)) - cache_dpath = join(dtool_repo, 'DEPCACHE') + if not fname: + fname = dummy_root + cache_dpath = HERE / 'DEPCACHE' + controller = DummyController(cache_dpath) depc = dtool.DependencyCache( - root_tablename=dummy_root, - default_fname=fname, - cache_dpath=cache_dpath, - get_root_uuid=get_root_uuid, - # root_asobject=root_asobject, + controller, + fname, + get_root_uuid, + table_name=dummy_root, + root_getters=None, use_globals=False, ) @@ -725,8 +746,8 @@ def dummy_example_depcacahe(): req.execute() # ut.InstanceList( - db = list(depc.fname_to_db.values())[0] - # db_list = ut.InstanceList(depc.fname_to_db.values()) + db = list(depc._db_by_name.values())[0] + # db_list = ut.InstanceList(depc._db_by_name.values()) # db_list.print_table_csv('config', exclude_columns='config_strid') print('config table') diff --git a/wbia/dtool/example_depcache2.py b/wbia/dtool/example_depcache2.py index 27aa5af07f..1fb243b036 100644 --- a/wbia/dtool/example_depcache2.py +++ b/wbia/dtool/example_depcache2.py @@ -1,10 +1,36 @@ # -*- coding: utf-8 -*- -import utool as ut +from pathlib import Path -# import numpy as np -from os.path import join, dirname +import utool as ut from six.moves import zip +from wbia.dtool.depcache_control import DependencyCache +from wbia.dtool.example_depcache import DummyController + + +HERE = Path(__file__).parent.resolve() + + +def _depc_factory(name, cache_dir): + """DependencyCache factory for the examples + + Args: + name (str): name of the cache (e.g. 'annot') + cache_dir (str): name of the cache directory + + """ + cache_dpath = HERE / cache_dir + controller = DummyController(cache_dpath) + depc = DependencyCache( + controller, + name, + ut.identity, + table_name=None, + root_getters=None, + use_globals=False, + ) + return depc + def depc_34_helper(depc): def register_dummy_config(tablename, parents, **kwargs): @@ -85,23 +111,7 @@ def testdata_depc3(in_memory=True): >>> #depc['viewpoint_classification'].show_input_graph() >>> ut.show_if_requested() """ - from wbia import dtool - - # put the test cache in the dtool repo - dtool_repo = dirname(ut.get_module_dir(dtool)) - cache_dpath = join(dtool_repo, 'DEPCACHE3') - - # FIXME: this only puts the sql files in memory - default_fname = ':memory:' if in_memory else None - - root = 'annot' - depc = dtool.DependencyCache( - root_tablename=root, - get_root_uuid=ut.identity, - default_fname=default_fname, - cache_dpath=cache_dpath, - use_globals=False, - ) + depc = _depc_factory('annot', 'DEPCACHE3') # ---------- # dummy_cols = dict(colnames=['data'], coltypes=[np.ndarray]) @@ -155,23 +165,7 @@ def testdata_depc4(in_memory=True): >>> #depc['viewpoint_classification'].show_input_graph() >>> ut.show_if_requested() """ - from wbia import dtool - - # put the test cache in the dtool repo - dtool_repo = dirname(ut.get_module_dir(dtool)) - cache_dpath = join(dtool_repo, 'DEPCACHE3') - - # FIXME: this only puts the sql files in memory - default_fname = ':memory:' if in_memory else None - - root = 'annot' - depc = dtool.DependencyCache( - root_tablename=root, - get_root_uuid=ut.identity, - default_fname=default_fname, - cache_dpath=cache_dpath, - use_globals=False, - ) + depc = _depc_factory('annot', 'DEPCACHE4') # ---------- # dummy_cols = dict(colnames=['data'], coltypes=[np.ndarray]) @@ -201,21 +195,8 @@ def testdata_depc4(in_memory=True): def testdata_custom_annot_depc(dummy_dependencies, in_memory=True): - from wbia import dtool - - # put the test cache in the dtool repo - dtool_repo = dirname(ut.get_module_dir(dtool)) - cache_dpath = join(dtool_repo, 'DEPCACHE5') - # FIXME: this only puts the sql files in memory - default_fname = ':memory:' if in_memory else None - root = 'annot' - depc = dtool.DependencyCache( - root_tablename=root, - get_root_uuid=ut.identity, - default_fname=default_fname, - cache_dpath=cache_dpath, - use_globals=False, - ) + depc = _depc_factory('annot', 'DEPCACHE5') + # ---------- register_dummy_config = depc_34_helper(depc) diff --git a/wbia/dtool/sql_control.py b/wbia/dtool/sql_control.py index 0d6597b6b2..3e78482649 100644 --- a/wbia/dtool/sql_control.py +++ b/wbia/dtool/sql_control.py @@ -5,23 +5,31 @@ TODO; need to use some sort of sticky bit so sql files are created with reasonable permissions. """ +import functools import logging import collections import os import parse import re -import threading import uuid from collections.abc import Mapping, MutableMapping -from functools import partial +from contextlib import contextmanager from os.path import join, exists import six +import sqlalchemy import utool as ut from deprecated import deprecated +from sqlalchemy.engine import LegacyRow +from sqlalchemy.schema import Table +from sqlalchemy.sql import bindparam, text, ClauseElement -from wbia.dtool import sqlite3 as lite +from wbia.dtool import lite from wbia.dtool.dump import dumps +from wbia.dtool.types import Integer, TYPE_TO_SQLTYPE +from wbia.dtool.types import initialize_postgresql_types + +import tqdm print, rrr, profile = ut.inject2(__name__) @@ -37,6 +45,8 @@ TIMEOUT = 600 # Wait for up to 600 seconds for the database to return from a locked state +BATCH_SIZE = int(1e4) + SQLColumnRichInfo = collections.namedtuple( 'SQLColumnRichInfo', ('column_id', 'name', 'type_', 'notnull', 'dflt_value', 'pk') ) @@ -63,13 +73,68 @@ METADATA_TABLE_COLUMN_NAMES = list(METADATA_TABLE_COLUMNS.keys()) -def _unpacker(results_): +def create_engine(uri, POSTGRESQL_POOL_SIZE=20, ENGINES={}): + pid = os.getpid() + if ENGINES.get('pid') != pid: + # ENGINES contains engines from the parent process that the + # child process can't use + ENGINES.clear() + ENGINES['pid'] = pid + kw = { + # The echo flag is a shortcut to set up SQLAlchemy logging + 'echo': False, + } + if uri.startswith('sqlite:') and ':memory:' in uri: + # Don't share engines for in memory sqlite databases + return sqlalchemy.create_engine(uri, **kw) + if uri not in ENGINES: + if uri.startswith('postgresql:'): + # pool_size is not available for sqlite + kw['pool_size'] = POSTGRESQL_POOL_SIZE + ENGINES[uri] = sqlalchemy.create_engine(uri, **kw) + return ENGINES[uri] + + +def compare_coldef_lists(coldef_list1, coldef_list2): + def normalize(coldef_list): + for name, coldef in coldef_list: + # Remove "rowid" which is added to postgresql tables + if name != 'rowid': + coldef_ = coldef.lower() + # Remove "default nextval" for postgresql auto-increment fields + # as sqlite doesn't need it + coldef_ = re.sub(r' default \(nextval\(.*', '', coldef_) + # Consider bigint and integer the same + if 'bigint' in coldef_: + coldef_ = re.sub(r"'([^']*)'::bigint", r'\1', coldef_) + coldef_ = re.sub(r'\bbigint\b', 'integer', coldef_) + # Consider double precision and real the same + if 'double precision' in coldef_: + coldef_ = re.sub(r'\bdouble precision\b', 'real', coldef_) + yield name.lower(), coldef_ + + coldef_list1 = list(normalize(coldef_list1)) + coldef_list2 = list(normalize(coldef_list2)) + + if len(coldef_list1) != len(coldef_list2): + return coldef_list1, coldef_list2 + for i in range(len(coldef_list1)): + name1, coldef1 = coldef_list1[i] + name2, coldef2 = coldef_list2[i] + if name1 != name2: + return coldef_list1, coldef_list2 + if coldef1 != coldef2: + return coldef_list1, coldef_list2 + return + + +def _unpacker(results): """ HELPER: Unpacks results if unpack_scalars is True. """ - if len(results_) == 0: + if not results: # Check for None or empty list results = None else: - assert len(results_) <= 1, 'throwing away results! { %r }' % (results_,) - results = results_[0] + assert len(results) <= 1, 'throwing away results! { %r }' % (results,) + results = results[0] return results @@ -79,228 +144,6 @@ def tuplize(list_): return tup_list -def flattenize(list_): - """ - maps flatten to a tuplized list - - Weird function. DEPRICATE - - Example: - >>> # DISABLE_DOCTEST - >>> list_ = [[1, 2, 3], [2, 3, [4, 2, 1]], [3, 2], [[1, 2], [3, 4]]] - >>> import utool - >>> from itertools import zip - >>> val_list1 = [(1, 2), (2, 4), (5, 3)] - >>> id_list1 = [(1,), (2,), (3,)] - >>> out_list1 = utool.flattenize(zip(val_list1, id_list1)) - - >>> val_list2 = [1, 4, 5] - >>> id_list2 = [(1,), (2,), (3,)] - >>> out_list2 = utool.flattenize(zip(val_list2, id_list2)) - - >>> val_list3 = [1, 4, 5] - >>> id_list3 = [1, 2, 3] - >>> out_list3 = utool.flattenize(zip(val_list3, id_list3)) - - out_list4 = list(zip(val_list3, id_list3)) - %timeit utool.flattenize(zip(val_list1, id_list1)) - %timeit utool.flattenize(zip(val_list2, id_list2)) - %timeit utool.flattenize(zip(val_list3, id_list3)) - %timeit list(zip(val_list3, id_list3)) - - 100000 loops, best of 3: 14 us per loop - 100000 loops, best of 3: 16.5 us per loop - 100000 loops, best of 3: 18 us per loop - 1000000 loops, best of 3: 1.18 us per loop - """ - tuplized_iter = map(tuplize, list_) - flatenized_list = list(map(ut.flatten, tuplized_iter)) - return flatenized_list - - -# ======================= -# SQL Context Class -# ======================= - - -class SQLExecutionContext(object): - """ - Context manager for transactional database calls - - FIXME: hash out details. I don't think anybody who programmed this - knows what is going on here. So much for fine grained control. - - Referencs: - http://stackoverflow.com/questions/9573768/understand-sqlite-multi-module-envs - - """ - - def __init__( - context, - db, - operation, - nInput=None, - auto_commit=True, - start_transaction=False, - keepwrap=False, - verbose=VERBOSE_SQL, - tablename=None, - ): - context.tablename = None - context.auto_commit = auto_commit - context.db = db - context.operation = operation - context.nInput = nInput - context.start_transaction = start_transaction - context.operation_type = get_operation_type(operation) - context.verbose = verbose - context.is_insert = context.operation_type.startswith('INSERT') - context.keepwrap = keepwrap - context.cur = None - context.connection = None - - def __enter__(context): - """ Checks to see if the operating will change the database """ - # ut.printif(lambda: '[sql] Callers: ' + ut.get_caller_name(range(3, 6)), DEBUG) - if context.nInput is not None: - context.operation_lbl = '[sql] execute nInput=%d optype=%s: ' % ( - context.nInput, - context.operation_type, - ) - else: - context.operation_lbl = '[sql] executeone optype=%s: ' % ( - context.operation_type - ) - # Start SQL Transaction - - context.connection = context.db.connection - try: - context.cur = context.connection.cursor() # HACK in a new cursor - except lite.ProgrammingError: - # Get connection for new thread - context.connection = context.db.thread_connection() - context.cur = context.connection.cursor() - - # context.cur = context.db.cur # OR USE DB CURSOR?? - if context.start_transaction: - # context.cur.execute('BEGIN', ()) - try: - context.cur.execute('BEGIN') - except lite.OperationalError: - context.connection.rollback() - context.cur.execute('BEGIN') - if context.verbose or VERBOSE_SQL: - logger.info(context.operation_lbl) - if context.verbose: - logger.info('[sql] operation=\n' + context.operation) - # Comment out timeing code - # if __debug__: - # if NOT_QUIET and (VERBOSE_SQL or context.verbose): - # context.tt = ut.tic(context.operation_lbl) - return context - - # @profile - def execute_and_generate_results(context, params): - """ helper for context statment """ - try: - context.cur.execute(context.operation, params) - except lite.Error as ex: - logger.info('Reporting SQLite Error') - logger.info('params = ' + ut.repr2(params, truncate=not ut.VERBOSE)) - ut.printex(ex, 'sql.Error', keys=['params']) - if ( - hasattr(ex, 'message') - and ex.message.find('probably unsupported type') > -1 - ): - logger.info( - 'ERR REPORT: given param types = ' + ut.repr2(ut.lmap(type, params)) - ) - if context.tablename is None: - if context.operation_type.startswith('SELECT'): - tablename = ut.str_between( - context.operation, 'FROM', 'WHERE' - ).strip() - else: - tablename = context.operation_type.split(' ')[-1] - else: - tablename = context.tablename - try: - coldef_list = context.db.get_coldef_list(tablename) - logger.info( - 'ERR REPORT: expected types = %s' % (ut.repr4(coldef_list),) - ) - except Exception: - pass - raise - return context._results_gen() - - # @profile - def _results_gen(context): - """HELPER - Returns as many results as there are. - Careful. Overwrites the results once you call it. - Basically: Dont call this twice. - """ - if context.is_insert: - # The sqlite3_last_insert_rowid(D) interface returns the - # rowid of the most recent successful INSERT - # into a rowid table in D - context.cur.execute('SELECT last_insert_rowid()', ()) - # Wraping fetchone in a generator for some pretty tight calls. - while True: - result = context.cur.fetchone() - if not result: - return - if context.keepwrap: - # Results are always returned wraped in a tuple - yield result - else: - # Here unpacking is conditional - # FIXME: can this if be removed? - yield result[0] if len(result) == 1 else result - - def __exit__(context, type_, value, trace): - """ Finalization of an SQLController call """ - if trace is not None: - # An SQLError is a serious offence. - logger.info('[sql] FATAL ERROR IN QUERY CONTEXT') - logger.info('[sql] operation=\n' + context.operation) - logger.info('[sql] Error in context manager!: ' + str(value)) - # return a falsey value on error - return False - else: - # Commit the transaction - if context.auto_commit: - context.connection.commit() - else: - logger.info('no commit %r' % context.operation_lbl) - - -def get_operation_type(operation): - """ - Parses the operation_type from an SQL operation - """ - operation = ' '.join(operation.split('\n')).strip() - operation_type = operation.split(' ')[0].strip() - if operation_type.startswith('SELECT'): - operation_args = ut.str_between(operation, operation_type, 'FROM').strip() - elif operation_type.startswith('INSERT'): - operation_args = ut.str_between(operation, operation_type, '(').strip() - elif operation_type.startswith('DROP'): - operation_args = '' - elif operation_type.startswith('ALTER'): - operation_args = '' - elif operation_type.startswith('UPDATE'): - operation_args = ut.str_between(operation, operation_type, 'FROM').strip() - elif operation_type.startswith('DELETE'): - operation_args = ut.str_between(operation, operation_type, 'FROM').strip() - elif operation_type.startswith('CREATE'): - operation_args = ut.str_between(operation, operation_type, '(').strip() - else: - operation_args = None - operation_type += ' ' + operation_args.replace('\n', ' ') - return operation_type.upper() - - def sanitize_sql(db, tablename_, columns=None): """ Sanatizes an sql tablename and column. Use sparingly """ tablename = re.sub('[^a-zA-Z_0-9]', '', tablename_) @@ -336,47 +179,6 @@ def _sanitize_sql_helper(column): return tablename, columns -def dev_test_new_schema_version( - dbname, sqldb_dpath, sqldb_fname, version_current, version_next=None -): - """ - HACK - - hacky function to ensure that only developer sees the development schema - and only on test databases - """ - TESTING_NEW_SQL_VERSION = version_current != version_next - if TESTING_NEW_SQL_VERSION: - logger.info('[sql] ATTEMPTING TO TEST NEW SQLDB VERSION') - devdb_list = [ - 'PZ_MTEST', - 'testdb1', - 'testdb2', - 'testdb_dst2', - 'emptydatabase', - ] - testing_newschmea = ut.is_developer() and dbname in devdb_list - # testing_newschmea = False - # ut.is_developer() and ibs.get_dbname() in ['PZ_MTEST', 'testdb1'] - if testing_newschmea: - # Set to true until the schema module is good then continue tests - # with this set to false - testing_force_fresh = True or ut.get_argflag('--force-fresh') - # Work on a fresh schema copy when developing - dev_sqldb_fname = ut.augpath(sqldb_fname, '_develop_schema') - sqldb_fpath = join(sqldb_dpath, sqldb_fname) - dev_sqldb_fpath = join(sqldb_dpath, dev_sqldb_fname) - ut.copy(sqldb_fpath, dev_sqldb_fpath, overwrite=testing_force_fresh) - # Set testing schema version - # ibs.db_version_expected = '1.3.6' - logger.info('[sql] TESTING NEW SQLDB VERSION: %r' % (version_next,)) - # logger.info('[sql] ... pass --force-fresh to reload any changes') - return version_next, dev_sqldb_fname - else: - logger.info('[ibs] NOT TESTING') - return version_current, sqldb_fname - - @six.add_metaclass(ut.ReloadingMetaclass) class SQLDatabaseController(object): """ @@ -413,29 +215,54 @@ def __init__(self, ctrlr): @property def version(self): - stmt = f'SELECT metadata_value FROM {METADATA_TABLE_NAME} WHERE metadata_key = ?' + stmt = text( + f'SELECT metadata_value FROM {METADATA_TABLE_NAME} WHERE metadata_key = :key' + ) try: - return self.ctrlr.executeone(stmt, ('database_version',))[0] - except IndexError: # No result + return self.ctrlr.executeone( + stmt, {'key': 'database_version'}, use_fetchone_behavior=True + )[0] + except TypeError: # NoneType return None @version.setter def version(self, value): if not value: raise ValueError(value) - self.ctrlr.executeone( - f'INSERT OR REPLACE INTO {METADATA_TABLE_NAME} (metadata_key, metadata_value) VALUES (?, ?)', - ('database_version', value), - ) + dialect = self.ctrlr._engine.dialect.name + if dialect == 'sqlite': + stmt = text( + f'INSERT OR REPLACE INTO {METADATA_TABLE_NAME} (metadata_key, metadata_value)' + 'VALUES (:key, :value)' + ) + elif dialect == 'postgresql': + stmt = text( + f"""\ + INSERT INTO {METADATA_TABLE_NAME} + (metadata_key, metadata_value) + VALUES (:key, :value) + ON CONFLICT (metadata_key) DO UPDATE + SET metadata_value = EXCLUDED.metadata_value""" + ) + else: + raise RuntimeError(f'Unknown dialect {dialect}') + params = {'key': 'database_version', 'value': value} + self.ctrlr.executeone(stmt, params) @property def init_uuid(self): - stmt = f'SELECT metadata_value FROM {METADATA_TABLE_NAME} WHERE metadata_key = ?' + stmt = text( + f'SELECT metadata_value FROM {METADATA_TABLE_NAME} WHERE metadata_key = :key' + ) try: - value = self.ctrlr.executeone(stmt, ('database_init_uuid',))[0] - except IndexError: # No result + value = self.ctrlr.executeone( + stmt, {'key': 'database_init_uuid'}, use_fetchone_behavior=True + )[0] + except TypeError: # NoneType return None - return uuid.UUID(value) + if value is not None: + value = uuid.UUID(value) + return value @init_uuid.setter def init_uuid(self, value): @@ -443,10 +270,25 @@ def init_uuid(self, value): raise ValueError(value) elif isinstance(value, uuid.UUID): value = str(value) - self.ctrlr.executeone( - f'INSERT OR REPLACE INTO {METADATA_TABLE_NAME} (metadata_key, metadata_value) VALUES (?, ?)', - ('database_init_uuid', value), - ) + dialect = self.ctrlr._engine.dialect.name + if dialect == 'sqlite': + stmt = text( + f'INSERT OR REPLACE INTO {METADATA_TABLE_NAME} (metadata_key, metadata_value) ' + 'VALUES (:key, :value)' + ) + elif dialect == 'postgresql': + stmt = text( + f"""\ + INSERT INTO {METADATA_TABLE_NAME} + (metadata_key, metadata_value) + VALUES (:key, :value) + ON CONFLICT (metadata_key) DO UPDATE + SET metadata_value = EXCLUDED.metadata_value""" + ) + else: + raise RuntimeError(f'Unknown dialect {dialect}') + params = {'key': 'database_init_uuid', 'value': value} + self.ctrlr.executeone(stmt, params) # collections.abc.MutableMapping abstract methods @@ -493,18 +335,24 @@ def update(self, **kwargs): def __getattr__(self, name): # Query the database for the value represented as name key = '_'.join([self.table_name, name]) - statement = ( + statement = text( 'SELECT metadata_value ' f'FROM {METADATA_TABLE_NAME} ' - 'WHERE metadata_key = ?' + 'WHERE metadata_key = :key' ) try: - value = self.ctrlr.executeone(statement, (key,))[0] - except IndexError: - # No value for the requested metadata_key + value = self.ctrlr.executeone( + statement, {'key': key}, use_fetchone_behavior=True + )[0] + except TypeError: # NoneType return None if METADATA_TABLE_COLUMNS[name]['is_coded_data']: value = eval(value) + if name == 'superkeys' and isinstance(value, list): + # superkeys looks like [('image_rowid, encounter_rowid',)] + # instead of [('image_rowid',), ('encounter_rowid',)] + if len(value) == 1 and len(value[0]) == 1: + value = [tuple(value[0][0].split(', '))] return value def __getattribute__(self, name): @@ -527,15 +375,27 @@ def __setattr__(self, name, value): key = self._get_key_name(name) # Insert or update the record - # FIXME postgresql (4-Aug-12020) 'insert or replace' is not valid for postgresql - statement = ( - f'INSERT OR REPLACE INTO {METADATA_TABLE_NAME} ' - f'(metadata_key, metadata_value) VALUES (?, ?)' - ) - params = ( - key, - value, - ) + dialect = self.ctrlr._engine.dialect.name + if dialect == 'sqlite': + statement = text( + f'INSERT OR REPLACE INTO {METADATA_TABLE_NAME} ' + f'(metadata_key, metadata_value) VALUES (:key, :value)' + ) + elif dialect == 'postgresql': + statement = text( + f"""\ + INSERT INTO {METADATA_TABLE_NAME} + (metadata_key, metadata_value) + VALUES (:key, :value) + ON CONFLICT (metadata_key) DO UPDATE + SET metadata_value = EXCLUDED.metadata_value""" + ) + else: + raise RuntimeError(f'Unknown dialect {dialect}') + params = { + 'key': key, + 'value': value, + } self.ctrlr.executeone(statement, params) def __delattr__(self, name): @@ -544,8 +404,10 @@ def __delattr__(self, name): raise AttributeError # Insert or update the record - statement = f'DELETE FROM {METADATA_TABLE_NAME} where metadata_key = ?' - params = (self._get_key_name(name),) + statement = text( + f'DELETE FROM {METADATA_TABLE_NAME} where metadata_key = :key' + ) + params = {'key': self._get_key_name(name)} self.ctrlr.executeone(statement, params) def __dir__(self): @@ -623,46 +485,41 @@ def __iter__(self): def __len__(self): return len(self.ctrlr.get_table_names()) + 1 # for 'database' - @classmethod - def from_uri(cls, uri, readonly=READ_ONLY, timeout=TIMEOUT): + def __init_engine(self): + """Create the SQLAlchemy Engine""" + self._engine = create_engine(self.uri) + + def __init__(self, uri, name, readonly=READ_ONLY, timeout=TIMEOUT): """Creates a controller instance from a connection URI + The name is primarily used with Postgres. In Postgres the the name + acts as the database schema name, because all the "databases" are + stored within one Postgres database that is namespaced + with the given ``name``. (Special names like ``_ibeis_database`` + are translated to the correct schema name during + the connection process.) + Args: uri (str): connection string or uri - timeout (int): connection timeout in seconds + name (str): name of the database (e.g. chips, _ibeis_database, staging) - Example: - >>> # ENABLE_DOCTEST - >>> from wbia.dtool.sql_control import * # NOQA - >>> sqldb_dpath = ut.ensure_app_resource_dir('dtool') - >>> sqldb_fname = u'test_database.sqlite3' - >>> path = os.path.join(sqldb_dpath, sqldb_fname) - >>> db_uri = 'file://{}'.format(os.path.realpath(path)) - >>> db = SQLDatabaseController.from_uri(db_uri) - >>> db.print_schema() - >>> print(db) - >>> db2 = SQLDatabaseController.from_uri(db_uri, readonly=True) - >>> db.add_table('temptable', ( - >>> ('rowid', 'INTEGER PRIMARY KEY'), - >>> ('key', 'TEXT'), - >>> ('val', 'TEXT'), - >>> ), - >>> superkeys=[('key',)]) - >>> db2.print_schema() """ - self = cls.__new__(cls) self.uri = uri + self.name = name self.timeout = timeout self.metadata = self.Metadata(self) self.readonly = readonly + self.__init_engine() + # Create a _private_ SQLAlchemy metadata instance + # TODO (27-Sept-12020) Develop API to expose elements of SQLAlchemy. + # The MetaData is unbound to ensure we don't accidentally misuse it. + self._sa_metadata = sqlalchemy.MetaData(schema=self.schema_name) + + # Reflect all known tables + self._sa_metadata.reflect(bind=self._engine) + self._tablenames = None - # FIXME (31-Jul-12020) rename to private attribute - self.thread_connections = {} - self._connection = None - # FIXME (31-Jul-12020) rename to private attribute, no direct access to the connection - # Initialize a cursor - self.cur = self.connection.cursor() if not self.readonly: # Ensure the metadata table is initialized. @@ -673,69 +530,37 @@ def from_uri(cls, uri, readonly=READ_ONLY, timeout=TIMEOUT): # Optimize the database self.optimize() - return self + @property + def is_using_sqlite(self): + return self._engine.dialect.name == 'sqlite' - def connect(self): - """Create a connection for the instance or use the existing connection""" - self._connection = lite.connect( - self.uri, detect_types=lite.PARSE_DECLTYPES, timeout=self.timeout, uri=True - ) - return self._connection + @property + def is_using_postgres(self): + return self._engine.dialect.name == 'postgresql' @property - def connection(self): - """Create a connection or reuse the existing connection""" - # TODO (31-Jul-12020) Grab the correct connection for the thread. - if self._connection is not None: - conn = self._connection + def schema_name(self): + """The name of the namespace schema (using with Postgres).""" + if self.is_using_postgres: + if self.name == '_ibeis_database': + schema = 'main' + elif self.name == '_ibeis_staging': + schema = 'staging' + else: + schema = self.name else: - conn = self.connect() - return conn - - def _create_connection(self): - path = self.uri.replace('file://', '') - if not exists(path): - logger.info('[sql] Initializing new database: %r' % (self.uri,)) - if self.readonly: - raise AssertionError('Cannot open a new database in readonly mode') - # Open the SQL database connection with support for custom types - # lite.enable_callback_tracebacks(True) - # self.fpath = ':memory:' - - # References: - # http://stackoverflow.com/questions/10205744/opening-sqlite3-database-from-python-in-read-only-mode - uri = self.uri - if self.readonly: - uri += '?mode=ro' - connection = lite.connect( - uri, uri=True, detect_types=lite.PARSE_DECLTYPES, timeout=self.timeout - ) - - # Keep track of what thead this was started in - threadid = threading.current_thread() - self.thread_connections[threadid] = connection - - return connection, uri + schema = None + return schema - def close(self): - self.cur = None - self.connection.close() - self.thread_connections = {} - - # def reconnect(db): - # # Call this if we move into a new thread - # assert db.fname != ':memory:', 'cant reconnect to mem' - # connection, uri = db._create_connection() - # db.connection = connection - # db.cur = db.connection.cursor() - - def thread_connection(self): - threadid = threading.current_thread() - if threadid in self.thread_connections: - connection = self.thread_connections[threadid] - else: - connection, uri = self._create_connection() - return connection + @contextmanager + def connect(self): + """Create a connection instance to wrap a SQL execution block as a context manager""" + with self._engine.connect() as conn: + if self.is_using_postgres: + conn.execute(f'CREATE SCHEMA IF NOT EXISTS {self.schema_name}') + conn.execute(text('SET SCHEMA :schema'), schema=self.schema_name) + initialize_postgresql_types(conn, self.schema_name) + yield conn @profile def _ensure_metadata_table(self): @@ -747,8 +572,15 @@ def _ensure_metadata_table(self): """ try: orig_table_kw = self.get_table_autogen_dict(METADATA_TABLE_NAME) - except (lite.OperationalError, NameError): + except ( + sqlalchemy.exc.OperationalError, # sqlite error + sqlalchemy.exc.ProgrammingError, # postgres error + NameError, + ): orig_table_kw = None + # Reset connection because schema was rolled back due to + # the error + self._connection = None meta_table_kw = ut.odict( [ @@ -813,7 +645,7 @@ def get_db_init_uuid(self, ensure=True): >>> import os >>> from wbia.dtool.sql_control import * # NOQA >>> # Check random database gets new UUID on init - >>> db = SQLDatabaseController.from_uri(':memory:') + >>> db = SQLDatabaseController('sqlite:///', 'testing') >>> uuid_ = db.get_db_init_uuid() >>> print('New Database: %r is valid' % (uuid_, )) >>> assert isinstance(uuid_, uuid.UUID) @@ -821,10 +653,10 @@ def get_db_init_uuid(self, ensure=True): >>> sqldb_dpath = ut.ensure_app_resource_dir('dtool') >>> sqldb_fname = u'test_database.sqlite3' >>> path = os.path.join(sqldb_dpath, sqldb_fname) - >>> db_uri = 'file://{}'.format(os.path.realpath(path)) - >>> db1 = SQLDatabaseController.from_uri(db_uri) + >>> db_uri = 'sqlite:///{}'.format(os.path.realpath(path)) + >>> db1 = SQLDatabaseController(db_uri, 'db1') >>> uuid_1 = db1.get_db_init_uuid() - >>> db2 = SQLDatabaseController.from_uri(db_uri) + >>> db2 = SQLDatabaseController(db_uri, 'db2') >>> uuid_2 = db2.get_db_init_uuid() >>> print('Existing Database: %r == %r' % (uuid_1, uuid_2, )) >>> assert uuid_1 == uuid_2 @@ -837,78 +669,90 @@ def get_db_init_uuid(self, ensure=True): def reboot(self): logger.info('[sql] reboot') - self.cur.close() - del self.cur - self.connection.close() - del self.connection - self.connection = lite.connect( - self.uri, detect_types=lite.PARSE_DECLTYPES, timeout=self.timeout, uri=True - ) - self.cur = self.connection.cursor() + self._engine.dispose() + # Re-initialize the engine + self.__init_engine() def backup(self, backup_filepath): """ backup_filepath = dst_fpath """ - # Create a brand new conenction to lock out current thread and any others - connection, uri = self._create_connection() + if self.is_using_postgres: + # TODO postgresql backup + return + else: + # Assert the database file exists, and copy to backup path + path = self.uri.replace('sqlite://', '') + if not exists(path): + raise IOError( + 'Could not backup the database as the URI does not exist: %r' + % (self.uri,) + ) # Start Exclusive transaction, lock out all other writers from making database changes - connection.isolation_level = 'EXCLUSIVE' - connection.execute('BEGIN EXCLUSIVE') - # Assert the database file exists, and copy to backup path - path = self.uri.replace('file://', '') - if exists(path): + with self.connect() as conn: + conn.execute('BEGIN EXCLUSIVE') ut.copy(path, backup_filepath) - else: - raise IOError( - 'Could not backup the database as the URI does not exist: %r' % (uri,) - ) - # Commit the transaction, releasing the lock - connection.commit() - # Close the connection - connection.close() def optimize(self): + if self._engine.dialect.name != 'sqlite': + return # http://web.utk.edu/~jplyon/sqlite/SQLite_optimization_FAQ.html#pragma-cache_size # http://web.utk.edu/~jplyon/sqlite/SQLite_optimization_FAQ.html - if VERBOSE_SQL: - logger.info('[sql] running sql pragma optimizions') - # self.cur.execute('PRAGMA cache_size = 0;') - # self.cur.execute('PRAGMA cache_size = 1024;') - # self.cur.execute('PRAGMA page_size = 1024;') - # logger.info('[sql] running sql pragma optimizions') - self.cur.execute('PRAGMA cache_size = 10000;') # Default: 2000 - self.cur.execute('PRAGMA temp_store = MEMORY;') - self.cur.execute('PRAGMA synchronous = OFF;') - # self.cur.execute('PRAGMA synchronous = NORMAL;') - # self.cur.execute('PRAGMA synchronous = FULL;') # Default - # self.cur.execute('PRAGMA parser_trace = OFF;') - # self.cur.execute('PRAGMA busy_timeout = 1;') - # self.cur.execute('PRAGMA default_cache_size = 0;') + logger.info('[sql] running sql pragma optimizions') + + with self.connect() as conn: + # conn.execute('PRAGMA cache_size = 0;') + # conn.execute('PRAGMA cache_size = 1024;') + # conn.execute('PRAGMA page_size = 1024;') + # logger.info('[sql] running sql pragma optimizions') + conn.execute('PRAGMA cache_size = 10000;') # Default: 2000 + conn.execute('PRAGMA temp_store = MEMORY;') + conn.execute('PRAGMA synchronous = OFF;') + # conn.execute('PRAGMA synchronous = NORMAL;') + # conn.execute('PRAGMA synchronous = FULL;') # Default + # conn.execute('PRAGMA parser_trace = OFF;') + # conn.execute('PRAGMA busy_timeout = 1;') + # conn.execute('PRAGMA default_cache_size = 0;') def shrink_memory(self): + if not self.is_using_sqlite: + return logger.info('[sql] shrink_memory') - self.connection.commit() - self.cur.execute('PRAGMA shrink_memory;') - self.connection.commit() + with self.connect() as conn: + conn.execute('PRAGMA shrink_memory;') def vacuum(self): + if not self.is_using_sqlite: + return logger.info('[sql] vaccum') - self.connection.commit() - self.cur.execute('VACUUM;') - self.connection.commit() + with self.connect() as conn: + conn.execute('VACUUM;') def integrity(self): + if not self.is_using_sqlite: + return logger.info('[sql] vaccum') - self.connection.commit() - self.cur.execute('PRAGMA integrity_check;') - self.connection.commit() + with self.connect() as conn: + conn.execute('PRAGMA integrity_check;') def squeeze(self): + if not self.is_using_sqlite: + return logger.info('[sql] squeeze') self.shrink_memory() self.vacuum() + def _reflect_table(self, table_name): + """Produces a SQLAlchemy Table object from the given ``table_name``""" + # Note, this on introspects once. Repeated calls will pull the Table object + # from the MetaData object. + kw = {} + if self.is_using_postgres: + kw = {'schema': self.schema_name} + return Table( + table_name, self._sa_metadata, autoload=True, autoload_with=self._engine, **kw + ) + # ============== # API INTERFACE # ============== @@ -923,11 +767,8 @@ def get_row_count(self, tblname): def get_all_rowids(self, tblname, **kwargs): """ returns a list of all rowids from a table in ascending order """ - fmtdict = { - 'tblname': tblname, - } - operation_fmt = 'SELECT rowid FROM {tblname} ORDER BY rowid ASC' - return self._executeone_operation_fmt(operation_fmt, fmtdict, **kwargs) + operation = text(f'SELECT rowid FROM {tblname} ORDER BY rowid ASC') + return self.executeone(operation, **kwargs) def get_all_col_rows(self, tblname, colname): """ returns a list of all rowids from a table in ascending order """ @@ -956,27 +797,54 @@ def get_all_rowids_where(self, tblname, where_clause, params, **kwargs): return self._executeone_operation_fmt(operation_fmt, fmtdict, params, **kwargs) def check_rowid_exists(self, tablename, rowid_iter, eager=True, **kwargs): - rowid_list1 = self.get(tablename, ('rowid',), rowid_iter) + """Check for the existence of rows (``rowid_iter``) in a table (``tablename``). + Returns as sequence of rowids that exist in the given sequence. + + The 'rowid' term is an alias for the primary key. When calling this method, + you should know that the primary key may be more than one column. + + """ + # BBB (10-Oct-12020) 'rowid' only exists in SQLite and auto-magically gets mapped + # to an integer primary key. However, SQLAlchemy doesn't abide by this magic. + # The aliased column is not part of a reflected table. + # So we find and use the primary key instead. + table = self._reflect_table(tablename) + columns = tuple(c.name for c in table.primary_key.columns) + rowid_list1 = self.get(tablename, columns, rowid_iter) exists_list = [rowid is not None for rowid in rowid_list1] return exists_list - def _add(self, tblname, colnames, params_iter, **kwargs): + def _add(self, tblname, colnames, params_iter, unpack_scalars=True, **kwargs): """ ADDER NOTE: use add_cleanly """ - fmtdict = { - 'tblname': tblname, - 'erotemes': ', '.join(['?'] * len(colnames)), - 'params': ',\n'.join(colnames), - } - operation_fmt = """ - INSERT INTO {tblname}( - rowid, - {params} - ) VALUES (NULL, {erotemes}) - """ - rowid_list = self._executemany_operation_fmt( - operation_fmt, fmtdict, params_iter=params_iter, **kwargs - ) - return rowid_list + parameterized_values = [ + {col: val for col, val in zip(colnames, params)} for params in params_iter + ] + if self.is_using_postgres: + # postgresql column names are lowercase + parameterized_values = [ + {col.lower(): val for col, val in params.items()} + for params in parameterized_values + ] + table = self._reflect_table(tblname) + + # It would be possible to do one insert, + # but SQLite is not capable of returning the primary key value after a multi-value insert. + # Thus, we are stuck doing several inserts... ineffecient. + insert_stmt = sqlalchemy.insert(table) + + primary_keys = [] + with self.connect() as conn: + with conn.begin(): # new nested database transaction + for vals in parameterized_values: + result = conn.execute(insert_stmt.values(vals)) + + pk = result.inserted_primary_key + if unpack_scalars: + # Assumption at the time of writing this is that the primary key is the SQLite rowid. + # Therefore, we can assume the primary key is a single column value. + pk = pk[0] + primary_keys.append(pk) + return primary_keys def add_cleanly( self, @@ -1015,7 +883,7 @@ def add_cleanly( Example: >>> # ENABLE_DOCTEST >>> from wbia.dtool.sql_control import * # NOQA - >>> db = SQLDatabaseController.from_uri(':memory:') + >>> db = SQLDatabaseController('sqlite:///', 'testing') >>> db.add_table('dummy_table', ( >>> ('rowid', 'INTEGER PRIMARY KEY'), >>> ('key', 'TEXT'), @@ -1106,7 +974,7 @@ def rows_exist(self, tblname, rowids): """ operation = 'SELECT count(1) FROM {tblname} WHERE rowid=?'.format(tblname=tblname) for rowid in rowids: - yield bool(self.cur.execute(operation, (rowid,)).fetchone()[0]) + yield bool(self.connection.execute(operation, (rowid,)).fetchone()[0]) def get_where_eq( self, @@ -1115,30 +983,133 @@ def get_where_eq( params_iter, where_colnames, unpack_scalars=True, - eager=True, op='AND', + batch_size=BATCH_SIZE, **kwargs, ): - """hacked in function for nicer templates + """Executes a SQL select where the given parameters match/equal + the specified where columns. - unpack_scalars = True - kwargs = {} + Args: + tblname (str): table name + colnames (tuple[str]): sequence of column names + params_iter (list[list]): a sequence of a sequence with parameters, + where each item in the sequence is used in a SQL execution + where_colnames (list[str]): column names to match for equality against the same index + of the param_iter values + op (str): SQL boolean operator (e.g. AND, OR) + unpack_scalars (bool): [deprecated] use to unpack a single result from each query + only use with operations that return a single result for each query + (default: True) - Kwargs: - verbose: """ - andwhere_clauses = [colname + '=?' for colname in where_colnames] - logicop_ = ' %s ' % (op,) - where_clause = logicop_.join(andwhere_clauses) - return self.get_where( - tblname, - colnames, - params_iter, - where_clause, - unpack_scalars=unpack_scalars, - eager=eager, - **kwargs, + if len(where_colnames) == 1: + return self.get( + tblname, + colnames, + id_iter=(p[0] for p in params_iter), + id_colname=where_colnames[0], + unpack_scalars=unpack_scalars, + batch_size=batch_size, + **kwargs, + ) + params_iter = list(params_iter) + table = self._reflect_table(tblname) + if op.lower() != 'and' or not params_iter: + # Build the equality conditions using column type information. + # This allows us to bind the parameter with the correct type. + equal_conditions = [ + (table.c[c] == bindparam(c, type_=table.c[c].type)) + for c in where_colnames + ] + gate_func = {'and': sqlalchemy.and_, 'or': sqlalchemy.or_}[op.lower()] + where_clause = gate_func(*equal_conditions) + params = [dict(zip(where_colnames, p)) for p in params_iter] + return self.get_where( + tblname, + colnames, + params, + where_clause, + unpack_scalars=unpack_scalars, + **kwargs, + ) + + params_per_batch = int(batch_size / len(params_iter[0])) + result_map = {} + stmt = sqlalchemy.select( + [table.c[c] for c in tuple(where_colnames) + tuple(colnames)] ) + stmt = stmt.where( + sqlalchemy.tuple_(*[table.c[c] for c in where_colnames]).in_( + sqlalchemy.sql.bindparam('params', expanding=True) + ) + ) + batch_list = list(range(int(len(params_iter) / params_per_batch) + 1)) + for batch in tqdm.tqdm( + batch_list, disable=len(batch_list) <= 1, desc='[db.get(%s)]' % (tblname,) + ): + val_list = self.executeone( + stmt, + { + 'params': params_iter[ + batch * params_per_batch : (batch + 1) * params_per_batch + ] + }, + ) + for val in val_list: + key = val[: len(params_iter[0])] + values = val[len(params_iter[0]) :] + if not kwargs.get('keepwrap', False) and len(values) == 1: + values = values[0] + existing = result_map.setdefault(key, set()) + if isinstance(existing, set): + try: + existing.add(values) + except TypeError: + # unhashable type + result_map[key] = list(result_map[key]) + if values not in result_map[key]: + result_map[key].append(values) + elif values not in existing: + existing.append(values) + + results = [] + processors = [] + for c in tuple(where_colnames): + + def process(column, a): + processor = column.type.bind_processor(self._engine.dialect) + if processor: + a = processor(a) + result_processor = column.type.result_processor( + self._engine.dialect, str(column.type) + ) + if result_processor: + return result_processor(a) + return a + + processors.append(functools.partial(process, table.c[c])) + + if params_iter: + first_params = params_iter[0] + if any( + not isinstance(a, bool) + and TYPE_TO_SQLTYPE.get(type(a)) != str(table.c[c].type) + for a, c in zip(first_params, where_colnames) + ): + params_iter = ( + (processor(raw_id) for raw_id, processor in zip(id_, processors)) + for id_ in params_iter + ) + + for id_ in params_iter: + result = sorted(list(result_map.get(tuple(id_), set()))) + if unpack_scalars and isinstance(result, list): + results.append(_unpacker(result)) + else: + results.append(result) + + return results def get_where_eq_set( self, @@ -1184,7 +1155,6 @@ def get_where_eq_set( } return self._executeone_operation_fmt(operation_fmt, fmtdict, **kwargs) - @profile def get_where( self, tblname, @@ -1195,39 +1165,64 @@ def get_where( eager=True, **kwargs, ): - """""" - assert isinstance(colnames, tuple), 'colnames must be a tuple' + """ + Interface to do a SQL select with a where clause + + Args: + tblname (str): table name + colnames (tuple[str]): sequence of column names + params_iter (list[dict]): a sequence of dicts with parameters, + where each item in the sequence is used in a SQL execution + where_clause (str|Operation): conditional statement used in the where clause + unpack_scalars (bool): [deprecated] use to unpack a single result from each query + only use with operations that return a single result for each query + (default: True) + + """ + if not isinstance(colnames, (tuple, list)): + raise TypeError('colnames must be a sequence type of strings') + elif where_clause is not None: + if '?' in str(where_clause): # cast in case it's an SQLAlchemy object + raise ValueError( + "Statements cannot use '?' parameterization, " + "use ':name' parameters instead." + ) + elif isinstance(where_clause, str): + where_clause = text(where_clause) + + table = self._reflect_table(tblname) + stmt = sqlalchemy.select([table.c[c] for c in colnames]) if where_clause is None: - operation_fmt = """ - SELECT {colnames} - FROM {tblname} - """ - fmtdict = { - 'tblname': tblname, - 'colnames': ', '.join(colnames), - } - val_list = self._executeone_operation_fmt(operation_fmt, fmtdict, **kwargs) + val_list = self.executeone(stmt, **kwargs) else: - operation_fmt = """ - SELECT {colnames} - FROM {tblname} - WHERE {where_clauses} - """ - fmtdict = { - 'tblname': tblname, - 'colnames': ', '.join(colnames), - 'where_clauses': where_clause, - } - val_list = self._executemany_operation_fmt( - operation_fmt, - fmtdict, - params_iter=params_iter, + stmt = stmt.where(where_clause) + val_list = self.executemany( + stmt, + params_iter, unpack_scalars=unpack_scalars, eager=eager, **kwargs, ) - return val_list + + # This code is specifically for handling duplication in colnames + # because sqlalchemy removes them. + # e.g. select field1, field1, field2 from table; + # becomes + # select field1, field2 from table; + # so the items in val_list only have 2 values + # but the caller isn't expecting it so it causes problems + returned_columns = tuple([c.name for c in stmt.columns]) + if colnames == returned_columns: + return val_list + + result = [] + for val in val_list: + if isinstance(val, LegacyRow): + result.append(tuple(val[returned_columns.index(c)] for c in colnames)) + else: + result.append(val) + return result def exists_where_eq( self, @@ -1269,8 +1264,12 @@ def get_rowid_from_superkey( self, tblname, params_iter=None, superkey_colnames=None, **kwargs ): """ getter which uses the constrained superkeys instead of rowids """ - where_clause = ' AND '.join([colname + '=?' for colname in superkey_colnames]) - return self.get_where(tblname, ('rowid',), params_iter, where_clause, **kwargs) + # ??? Why can this be called with params_iter=None & superkey_colnames=None? + table = self._reflect_table(tblname) + columns = tuple(c.name for c in table.primary_key.columns) + return self.get_where_eq( + tblname, columns, params_iter, superkey_colnames, op='AND', **kwargs + ) def get( self, @@ -1280,9 +1279,10 @@ def get( id_colname='rowid', eager=True, assume_unique=False, + batch_size=BATCH_SIZE, **kwargs, ): - """getter + """Get rows of data by ID Args: tblname (str): table name to get from @@ -1290,26 +1290,8 @@ def get( id_iter (iterable): iterable of search keys id_colname (str): column to be used as the search key (default: rowid) eager (bool): use eager evaluation + assume_unique (bool): default False. Experimental feature that could result in a 10x speedup unpack_scalars (bool): default True - id_colname (bool): default False. Experimental feature that could result in a 10x speedup - - CommandLine: - python -m dtool.sql_control get - - Ignore: - tblname = 'annotations' - colnames = ('name_rowid',) - id_iter = aid_list - #id_iter = id_iter[0:20] - id_colname = 'rowid' - eager = True - db = ibs.db - - x1 = db.get(tblname, colnames, id_iter, assume_unique=True) - x2 = db.get(tblname, colnames, id_iter, assume_unique=False) - x1 == x2 - %timeit db.get(tblname, colnames, id_iter, assume_unique=True) - %timeit db.get(tblname, colnames, id_iter, assume_unique=False) Example: >>> # ENABLE_DOCTEST @@ -1325,16 +1307,18 @@ def get( >>> got_data = db.get('notch', colnames, id_iter=rowids) >>> assert got_data == [1, 2, 3] """ - if VERBOSE_SQL: - logger.info( - '[sql]' - + ut.get_caller_name(list(range(1, 4))) - + ' db.get(%r, %r, ...)' % (tblname, colnames) - ) - assert isinstance(colnames, tuple), 'must specify column names TUPLE to get from' - # if isinstance(colnames, six.string_types): - # colnames = (colnames,) + logger.debug( + '[sql]' + + ut.get_caller_name(list(range(1, 4))) + + ' db.get(%r, %r, ...)' % (tblname, colnames) + ) + if not isinstance(colnames, (tuple, list)): + raise TypeError('colnames must be a sequence type of strings') + # ??? Getting a single column of unique values that is matched on rowid? + # And sorts the results after the query? + # ??? This seems oddly specific for a generic method. + # Perhaps the logic should be in its own method? if ( assume_unique and id_iter is not None @@ -1342,21 +1326,14 @@ def get( and len(colnames) == 1 ): id_iter = list(id_iter) - operation_fmt = """ - SELECT {colnames} - FROM {tblname} - WHERE rowid in ({id_repr}) - ORDER BY rowid ASC - """ - fmtdict = { - 'tblname': tblname, - 'colnames': ', '.join(colnames), - 'id_repr': ','.join(map(str, id_iter)), - } - operation = operation_fmt.format(**fmtdict) - results = self.cur.execute(operation).fetchall() + columns = ', '.join(colnames) + ids_listing = ', '.join(map(str, id_iter)) + operation = f'SELECT {columns} FROM {tblname} WHERE rowid in ({ids_listing}) ORDER BY rowid ASC' + with self.connect() as conn: + results = conn.execute(operation).fetchall() import numpy as np + # ??? Why order the results if they are going to be sorted here? sortx = np.argsort(np.argsort(id_iter)) results = ut.take(results, sortx) if kwargs.get('unpack_scalars', True): @@ -1366,13 +1343,75 @@ def get( if id_iter is None: where_clause = None params_iter = [] + + return self.get_where( + tblname, colnames, params_iter, where_clause, eager=eager, **kwargs + ) + + id_iter = list(id_iter) # id_iter could be a set + table = self._reflect_table(tblname) + result_map = {} + if id_colname == 'rowid': # rowid isn't an actual column in sqlite + id_column = sqlalchemy.sql.column('rowid', Integer) else: - where_clause = id_colname + '=?' - params_iter = [(_rowid,) for _rowid in id_iter] + id_column = table.c[id_colname] + stmt = sqlalchemy.select([id_column] + [table.c[c] for c in colnames]) + stmt = stmt.where(id_column.in_(bindparam('value', expanding=True))) - return self.get_where( - tblname, colnames, params_iter, where_clause, eager=eager, **kwargs - ) + batch_list = list(range(int(len(id_iter) / batch_size) + 1)) + for batch in tqdm.tqdm( + batch_list, disable=len(batch_list) <= 1, desc='[db.get(%s)]' % (tblname,) + ): + val_list = self.executeone( + stmt, + {'value': id_iter[batch * batch_size : (batch + 1) * batch_size]}, + ) + + for val in val_list: + if not kwargs.get('keepwrap', False) and len(val[1:]) == 1: + values = val[1] + else: + values = val[1:] + existing = result_map.setdefault(val[0], set()) + if isinstance(existing, set): + try: + existing.add(values) + except TypeError: + # unhashable type + result_map[val[0]] = list(result_map[val[0]]) + if values not in result_map[val[0]]: + result_map[val[0]].append(values) + elif values not in existing: + existing.append(values) + + results = [] + + def process(a): + processor = id_column.type.bind_processor(self._engine.dialect) + if processor: + a = processor(a) + result_processor = id_column.type.result_processor( + self._engine.dialect, str(id_column.type) + ) + if result_processor: + return result_processor(a) + return a + + if id_iter: + first_id = id_iter[0] + if isinstance(first_id, bool) or TYPE_TO_SQLTYPE.get( + type(first_id) + ) != str(id_column.type): + id_iter = (process(id_) for id_ in id_iter) + + for id_ in id_iter: + result = sorted(list(result_map.get(id_, set()))) + if kwargs.get('unpack_scalars', True) and isinstance(result, list): + results.append(_unpacker(result)) + else: + results.append(result) + + return results def set( self, @@ -1402,27 +1441,26 @@ def set( >>> table.print_csv() >>> # Break things to test set >>> colnames = ('dummy_annot_rowid',) - >>> val_iter = [9003, 9001, 9002] + >>> val_iter = [(9003,), (9001,), (9002,)] >>> orig_data = db.get('notch', colnames, id_iter=rowids) >>> db.set('notch', colnames, val_iter, id_iter=rowids) >>> new_data = db.get('notch', colnames, id_iter=rowids) - >>> assert new_data == val_iter + >>> assert new_data == [x[0] for x in val_iter] >>> assert new_data != orig_data >>> table.print_csv() >>> depc.clear_all() """ - assert isinstance(colnames, tuple) - # if isinstance(colnames, six.string_types): - # colnames = (colnames,) + if not isinstance(colnames, (tuple, list)): + raise TypeError('colnames must be a sequence type of strings') + val_list = list(val_iter) # eager evaluation id_list = list(id_iter) # eager evaluation - if VERBOSE_SQL or (NOT_QUIET and VERYVERBOSE): - logger.info('[sql] SETTER: ' + ut.get_caller_name()) - logger.info('[sql] * tblname=%r' % (tblname,)) - logger.info('[sql] * val_list=%r' % (val_list,)) - logger.info('[sql] * id_list=%r' % (id_list,)) - logger.info('[sql] * id_colname=%r' % (id_colname,)) + logger.debug('[sql] SETTER: ' + ut.get_caller_name()) + logger.debug('[sql] * tblname=%r' % (tblname,)) + logger.debug('[sql] * val_list=%r' % (val_list,)) + logger.debug('[sql] * id_list=%r' % (id_list,)) + logger.debug('[sql] * id_colname=%r' % (id_colname,)) if duplicate_behavior == 'error': try: @@ -1449,7 +1487,7 @@ def set( for index in sorted(pop_list, reverse=True): del id_list[index] del val_list[index] - logger.info( + logger.debug( '[!set] Auto Resolution: Removed %d duplicate (id, value) pairs from the database operation' % (len(pop_list),) ) @@ -1480,6 +1518,7 @@ def set( % (duplicate_behavior,) ) + # Check for incongruity between values and identifiers try: num_val = len(val_list) num_id = len(id_list) @@ -1487,58 +1526,70 @@ def set( except AssertionError as ex: ut.printex(ex, key_list=['num_val', 'num_id']) raise - fmtdict = { - 'tblname_str': tblname, - 'assign_str': ',\n'.join(['%s=?' % name for name in colnames]), - 'where_clause': (id_colname + '=?'), - } - operation_fmt = """ - UPDATE {tblname_str} - SET {assign_str} - WHERE {where_clause} - """ - - # TODO: The flattenize can be removed if we pass in val_lists instead - params_iter = flattenize(list(zip(val_list, id_list))) - # params_iter = list(zip(val_list, id_list)) - return self._executemany_operation_fmt( - operation_fmt, fmtdict, params_iter=params_iter, **kwargs + # BBB (28-Sept-12020) This method's usage throughout the codebase allows + # for items in `val_iter` to be a non-sequence value. + has_unsequenced_values = val_list and not isinstance(val_list[0], (tuple, list)) + if has_unsequenced_values: + val_list = [(v,) for v in val_list] + # BBB (28-Sept-12020) This method's usage throughout the codebase allows + # for items in `id_iter` to be a tuple of one value. + has_sequenced_ids = id_list and isinstance(id_list[0], (tuple, list)) + if has_sequenced_ids: + id_list = [x[0] for x in id_list] + + # Execute the SQL updates for each set of values + id_param_name = '_identifier' + table = self._reflect_table(tblname) + stmt = table.update().values( + **{col: bindparam(f'e{i}') for i, col in enumerate(colnames)} ) + where_clause = text(id_colname + f' = :{id_param_name}') + if id_colname == 'rowid': + # Cast all item values to in, in case values are numpy.integer* + # Strangely allow for None values + id_list = [id_ if id_ is None else int(id_) for id_ in id_list] + else: # b/c rowid doesn't really exist as a column + id_column = table.c[id_colname] + where_clause = where_clause.bindparams( + bindparam(id_param_name, type_=id_column.type) + ) + stmt = stmt.where(where_clause) + with self.connect() as conn: + with conn.begin(): + for i, id in enumerate(id_list): + params = {id_param_name: id} + params.update({f'e{e}': p for e, p in enumerate(val_list[i])}) + conn.execute(stmt, **params) def delete(self, tblname, id_list, id_colname='rowid', **kwargs): + """Deletes rows from a SQL table (``tblname``) by ID, + given a sequence of IDs (``id_list``). + Optionally a different ID column can be specified via ``id_colname``. + """ - deleter. USE delete_rowids instead - """ - fmtdict = { - 'tblname': tblname, - 'rowid_str': (id_colname + '=?'), - } - operation_fmt = """ - DELETE - FROM {tblname} - WHERE {rowid_str} - """ - params_iter = ((_rowid,) for _rowid in id_list) - return self._executemany_operation_fmt( - operation_fmt, fmtdict, params_iter=params_iter, **kwargs - ) + id_param_name = '_identifier' + table = self._reflect_table(tblname) + stmt = table.delete() + where_clause = text(id_colname + f' = :{id_param_name}') + if id_colname == 'rowid': + # Cast all item values to in, in case values are numpy.integer* + # Strangely allow for None values + id_list = [id_ if id_ is None else int(id_) for id_ in id_list] + else: # b/c rowid doesn't really exist as a column + id_column = table.c[id_colname] + where_clause = where_clause.bindparams( + bindparam(id_param_name, type_=id_column.type) + ) + stmt = stmt.where(where_clause) + with self.connect() as conn: + with conn.begin(): + for id in id_list: + conn.execute(stmt, {id_param_name: id}) def delete_rowids(self, tblname, rowid_list, **kwargs): """ deletes the the rows in rowid_list """ - fmtdict = { - 'tblname': tblname, - 'rowid_str': ('rowid=?'), - } - operation_fmt = """ - DELETE - FROM {tblname} - WHERE {rowid_str} - """ - params_iter = ((_rowid,) for _rowid in rowid_list) - return self._executemany_operation_fmt( - operation_fmt, fmtdict, params_iter=params_iter, **kwargs - ) + self.delete(tblname, rowid_list, id_colname='rowid', **kwargs) # ============== # CORE WRAPPERS @@ -1550,7 +1601,7 @@ def _executeone_operation_fmt( if params is None: params = [] operation = operation_fmt.format(**fmtdict) - return self.executeone(operation, params, eager=eager, **kwargs) + return self.executeone(text(operation), params, eager=eager, **kwargs) @profile def _executemany_operation_fmt( @@ -1576,102 +1627,111 @@ def _executemany_operation_fmt( # SQLDB CORE # ========= - def executeone(db, operation, params=(), eager=True, verbose=VERBOSE_SQL): - contextkw = dict(nInput=1, verbose=verbose) - with SQLExecutionContext(db, operation, **contextkw) as context: - try: - result_iter = context.execute_and_generate_results(params) - result_list = list(result_iter) - except Exception as ex: - ut.printex(ex, key_list=[(str, 'operation'), 'params']) - # ut.sys.exit(1) - raise - return result_list - - @profile - def executemany( + def executeone( self, operation, - params_iter, - verbose=VERBOSE_SQL, - unpack_scalars=True, - nInput=None, + params=(), eager=True, + verbose=VERBOSE_SQL, + use_fetchone_behavior=False, keepwrap=False, - showprog=False, ): + """Executes the given ``operation`` once with the given set of ``params`` + + Args: + operation (str|TextClause): SQL statement + params (sequence|dict): parameters to pass in with SQL execution + eager: [deprecated] no-op + verbose: [deprecated] no-op + use_fetchone_behavior (bool): Use DBAPI ``fetchone`` behavior when outputing no rows (i.e. None) + """ - if unpack_scalars is True only a single result must be returned for each query. - """ - # --- ARGS PREPROC --- - # Aggresively compute iterator if the nInput is not given - if nInput is None: - if isinstance(params_iter, (list, tuple)): - nInput = len(params_iter) - else: - if VERBOSE_SQL: - logger.info( - '[sql!] WARNING: aggressive eval of params_iter because nInput=None' - ) - params_iter = list(params_iter) - nInput = len(params_iter) - else: - if VERBOSE_SQL: - logger.info('[sql] Taking params_iter as iterator') + if not isinstance(operation, ClauseElement): + raise TypeError( + "'operation' needs to be a sqlalchemy textual sql instance " + "see docs on 'sqlalchemy.sql:text' factory function; " + f"'operation' is a '{type(operation)}'" + ) + # FIXME (12-Sept-12020) Allows passing through '?' (question mark) parameters. + with self.connect() as conn: + results = conn.execute(operation, params) - # Do not compute executemany without params - if nInput == 0: - if VERBOSE_SQL: - logger.info( - '[sql!] WARNING: dont use executemany' - 'with no params use executeone instead.' - ) - return [] - # --- SQL EXECUTION --- - contextkw = { - 'nInput': nInput, - 'start_transaction': True, - 'verbose': verbose, - 'keepwrap': keepwrap, - } - with SQLExecutionContext(self, operation, **contextkw) as context: - if eager: - if showprog: - if isinstance(showprog, six.string_types): - lbl = showprog - else: - lbl = 'sqlread' - prog = ut.ProgPartial( - adjust=True, length=nInput, freq=1, lbl=lbl, bs=True - ) - params_iter = prog(params_iter) - results_iter = [ - list(context.execute_and_generate_results(params)) - for params in params_iter - ] - if unpack_scalars: - # list of iterators - _unpacker_ = partial(_unpacker) - results_iter = list(map(_unpacker_, results_iter)) - # Eager evaluation - results_list = list(results_iter) + # BBB (12-Sept-12020) Retaining insertion rowid result + # FIXME postgresql (12-Sept-12020) This won't work in postgres. + # Maybe see if ResultProxy.inserted_primary_key will work + if ( + 'insert' in str(operation).lower() + ): # cast in case it's an SQLAlchemy object + # BBB (12-Sept-12020) Retaining behavior to unwrap single value rows. + return [results.lastrowid] + elif not results.returns_rows: + return None else: + if isinstance(operation, sqlalchemy.sql.selectable.Select): + # This code is specifically for handling duplication in colnames + # because sqlalchemy removes them. + # e.g. select field1, field1, field2 from table; + # becomes + # select field1, field2 from table; + # so the items in val_list only have 2 values + # but the caller isn't expecting it so it causes problems + returned_columns = tuple([c.name for c in operation.columns]) + raw_columns = tuple([c.name for c in operation._raw_columns]) + if raw_columns != returned_columns: + results_ = [] + for r in results: + results_.append( + tuple(r[returned_columns.index(c)] for c in raw_columns) + ) + results = results_ + values = list( + [ + # BBB (12-Sept-12020) Retaining behavior to unwrap single value rows. + row[0] if not keepwrap and len(row) == 1 else row + for row in results + ] + ) + # FIXME (28-Sept-12020) No rows results in an empty list. This behavior does not + # match the resulting expectations of `fetchone`'s DBAPI spec. + # If executeone is the shortcut of `execute` and `fetchone`, + # the expectation should be to return according to DBAPI spec. + if use_fetchone_behavior and not values: # empty list + values = None + return values + + def executemany( + self, operation, params_iter, unpack_scalars=True, keepwrap=False, **kwargs + ): + """Executes the given ``operation`` once for each item in ``params_iter`` - def _tmpgen(context): - # Temporary hack to turn off eager_evaluation - for params in params_iter: - # Eval results per query yeild per iter - results = list(context.execute_and_generate_results(params)) - if unpack_scalars: - yield _unpacker(results) - else: - yield results + Args: + operation (str): SQL operation + params_iter (sequence): a sequence of sequences + containing parameters in the sql operation + unpack_scalars (bool): [deprecated] use to unpack a single result from each query + only use with operations that return a single result for each query + (default: True) - results_list = _tmpgen(context) - return results_list + """ + if not isinstance(operation, ClauseElement): + raise TypeError( + "'operation' needs to be a sqlalchemy textual sql instance " + "see docs on 'sqlalchemy.sql:text' factory function; " + f"'operation' is a '{type(operation)}'" + ) - # def commit(db): - # db.connection.commit() + results = [] + with self.connect() as conn: + with conn.begin(): + for params in params_iter: + value = self.executeone(operation, params, keepwrap=keepwrap) + # Should only be used when the user wants back on value. + # Let the error bubble up if used wrong. + # Deprecated... Do not depend on the unpacking behavior. + if unpack_scalars: + value = _unpacker(value) + results.append(value) + return results def print_dbg_schema(self): logger.info( @@ -1713,9 +1773,22 @@ def set_metadata_val(self, key, val): 'tablename': METADATA_TABLE_NAME, 'columns': 'metadata_key, metadata_value', } - op_fmtstr = 'INSERT OR REPLACE INTO {tablename} ({columns}) VALUES (?, ?)' - operation = op_fmtstr.format(**fmtkw) - params = [key, val] + dialect = self._engine.dialect.name + if dialect == 'sqlite': + op_fmtstr = ( + 'INSERT OR REPLACE INTO {tablename} ({columns}) VALUES (:key, :val)' + ) + elif dialect == 'postgresql': + op_fmtstr = f"""\ + INSERT INTO {METADATA_TABLE_NAME} + (metadata_key, metadata_value) + VALUES (:key, :val) + ON CONFLICT (metadata_key) DO UPDATE + SET metadata_value = EXCLUDED.metadata_value""" + else: + raise RuntimeError(f'Unknown dialect {dialect}') + operation = text(op_fmtstr.format(**fmtkw)) + params = {'key': key, 'val': val} self.executeone(operation, params, verbose=False) @deprecated('Use metadata property instead') @@ -1723,10 +1796,11 @@ def get_metadata_val(self, key, eval_=False, default=None): """ val is the repr string unless eval_ is true """ - where_clause = 'metadata_key=?' colnames = ('metadata_value',) params_iter = [(key,)] - vals = self.get_where(METADATA_TABLE_NAME, colnames, params_iter, where_clause) + vals = self.get_where_eq( + METADATA_TABLE_NAME, colnames, params_iter, ('metadata_key',) + ) assert len(vals) == 1, 'duplicate keys in metadata table' val = vals[0] if val is None: @@ -1767,86 +1841,81 @@ def add_column(self, tablename, colname, coltype): operation = op_fmtstr.format(**fmtkw) self.executeone(operation, [], verbose=False) - def __make_superkey_constraints(self, superkeys: list) -> list: - """Creates SQL for the 'superkey' constraint. - A 'superkey' is one or more columns that make up a unique constraint on the table. - - """ - has_superkeys = superkeys is not None and len(superkeys) > 0 - constraints = [] - if has_superkeys: - # Create a superkey statement for each superkey item - # superkeys = [(col), (col1, col2, ...), ...], - for columns in superkeys: - columns = ','.join(columns) - constraints.append(f'CONSTRAINT superkey UNIQUE ({columns})') - return constraints + def __make_unique_constraint(self, table_name, column_or_columns): + """Creates a SQL ``CONSTRAINT`` clause for ``UNIQUE`` column data""" + if not isinstance(column_or_columns, (list, tuple)): + columns = [column_or_columns] + else: + # Cast as list incase it's a tuple, b/c tuple + list = error + columns = list(column_or_columns) + constraint_name = '_'.join(['unique', table_name] + columns) + columns_listing = ', '.join(columns) + return f'CONSTRAINT {constraint_name} UNIQUE ({columns_listing})' def __make_column_definition(self, name: str, definition: str) -> str: """Creates SQL for the given column `name` and type, default & constraint (i.e. `definition`).""" if not name: raise ValueError(f'name cannot be an empty string paired with {definition}') - if not definition: + elif not definition: raise ValueError(f'definition cannot be an empty string paired with {name}') + if self.is_using_postgres: + if ( + name.endswith('rowid') + and 'INTEGER' in definition + and 'PRIMARY KEY' in definition + ): + definition = definition.replace('INTEGER', 'BIGSERIAL') + definition = definition.replace('REAL', 'DOUBLE PRECISION').replace( + 'INTEGER', 'BIGINT' + ) return f'{name} {definition}' def _make_add_table_sqlstr( self, tablename: str, coldef_list: list, sep=' ', **metadata_keyval ): - r"""Creates the SQL for a CREATE TABLE statement + """Creates the SQL for a CREATE TABLE statement Args: tablename (str): table name coldef_list (list): list of tuples (name, type definition) + sep (str): clause separation character(s) (default: space) + kwargs: metadata specifications Returns: str: operation - CommandLine: - python -m dtool.sql_control _make_add_table_sqlstr - - Example: - >>> # ENABLE_DOCTEST - >>> from wbia.dtool.sql_control import * # NOQA - >>> from wbia.dtool.example_depcache import testdata_depc - >>> depc = testdata_depc() - >>> tablename = 'keypoint' - >>> db = depc[tablename].db - >>> autogen_dict = db.get_table_autogen_dict(tablename) - >>> coldef_list = autogen_dict['coldef_list'] - >>> operation = db._make_add_table_sqlstr(tablename, coldef_list) - >>> print(operation) - """ if not coldef_list: raise ValueError(f'empty coldef_list specified for {tablename}') + if self.is_using_postgres and 'rowid' not in [name for name, _ in coldef_list]: + coldef_list = [('rowid', 'BIGSERIAL UNIQUE')] + list(coldef_list) + # Check for invalid keyword arguments bad_kwargs = set(metadata_keyval.keys()) - set(METADATA_TABLE_COLUMN_NAMES) if len(bad_kwargs) > 0: raise TypeError(f'got unexpected keyword arguments: {bad_kwargs}') - if ut.DEBUG2: - logger.info('[sql] schema ensuring tablename=%r' % tablename) - if ut.VERBOSE: - logger.info('') - _args = [tablename, coldef_list] - logger.info(ut.func_str(self.add_table, _args, metadata_keyval)) - logger.info('') + logger.debug('[sql] schema ensuring tablename=%r' % tablename) + logger.debug( + ut.func_str(self.add_table, [tablename, coldef_list], metadata_keyval) + ) # Create the main body of the CREATE TABLE statement with column definitions # coldef_list = [(, ,), ...] body_list = [self.__make_column_definition(c, d) for c, d in coldef_list] # Make a list of constraints to place on the table - constraint_list = self.__make_superkey_constraints( - metadata_keyval.get('superkeys', []) - ) + # superkeys = [(, ...), ...] + constraint_list = [ + self.__make_unique_constraint(tablename, x) + for x in metadata_keyval.get('superkeys') or [] + ] constraint_list = ut.unique_ordered(constraint_list) comma = ',' + sep table_body = comma.join(body_list + constraint_list) - return f'CREATE TABLE IF NOT EXISTS {tablename} ({sep}{table_body}{sep})' + return text(f'CREATE TABLE IF NOT EXISTS {tablename} ({sep}{table_body}{sep})') def add_table(self, tablename=None, coldef_list=None, **metadata_keyval): """ @@ -1948,6 +2017,14 @@ def modify_table( colname_list = ut.take_column(coldef_list, 0) coltype_list = ut.take_column(coldef_list, 1) + # Find all dependent sequences so we can change the owners of the + # sequences to the new table (for postgresql) + dependent_sequences = [ + (colname, re.search(r"nextval\('([^']*)'", coldef).group(1)) + for colname, coldef in self.get_coldef_list(tablename) + if 'nextval' in coldef + ] + colname_original_list = colname_list[:] colname_dict = {colname: colname for colname in colname_list} colmap_dict = {} @@ -1974,6 +2051,9 @@ def modify_table( '[sql] WARNING: multiple index inserted add ' 'columns, may cause alignment issues' ) + if self.is_using_postgres: + # adjust for the additional "rowid" field + src += 1 colname_list.insert(src, dst) coltype_list.insert(src, type_) insert = True @@ -2036,6 +2116,17 @@ def modify_table( self.add_table(tablename_temp, coldef_list, **metadata_keyval2) + # Change owners of sequences from old table to new table + if self.is_using_postgres: + new_colnames = [name for name, _ in coldef_list] + for colname, sequence in dependent_sequences: + if colname in new_colnames: + self.executeone( + text( + f'ALTER SEQUENCE {sequence} OWNED BY {tablename_temp}.{colname}' + ) + ) + # Copy data src_list = [] dst_list = [] @@ -2065,31 +2156,26 @@ def get_rowid_from_superkey(x): return [None] * len(x) self.add_cleanly(tablename_temp, dst_list, data_list, get_rowid_from_superkey) - if tablename_new is None: + if tablename_new is None: # i.e. not renaming the table # Drop original table - self.drop_table(tablename) + self.drop_table(tablename, invalidate_cache=False) # Rename temp table to original table name - self.rename_table(tablename_temp, tablename) + self.rename_table(tablename_temp, tablename, invalidate_cache=False) else: # Rename new table to new name - self.rename_table(tablename_temp, tablename_new) + self.rename_table(tablename_temp, tablename_new, invalidate_cache=False) + # Any modifications are going to invalidate the cached tables. + self.invalidate_tables_cache() - def rename_table(self, tablename_old, tablename_new): - if ut.VERBOSE: - logger.info( - '[sql] schema renaming tablename=%r -> %r' - % (tablename_old, tablename_new) - ) + def rename_table(self, tablename_old, tablename_new, invalidate_cache=True): + logger.info( + '[sql] schema renaming tablename=%r -> %r' % (tablename_old, tablename_new) + ) # Technically insecure call, but all entries are statically inputted by # the database's owner, who could delete or alter the entire database # anyway. - fmtkw = { - 'tablename_old': tablename_old, - 'tablename_new': tablename_new, - } - op_fmtstr = 'ALTER TABLE {tablename_old} RENAME TO {tablename_new}' - operation = op_fmtstr.format(**fmtkw) - self.executeone(operation, [], verbose=False) + operation = text(f'ALTER TABLE {tablename_old} RENAME TO {tablename_new}') + self.executeone(operation, []) # Rename table's metadata key_old_list = [ @@ -2098,30 +2184,30 @@ def rename_table(self, tablename_old, tablename_new): key_new_list = [ tablename_new + '_' + suffix for suffix in METADATA_TABLE_COLUMN_NAMES ] - id_iter = [(key,) for key in key_old_list] + id_iter = [key for key in key_old_list] val_iter = [(key,) for key in key_new_list] colnames = ('metadata_key',) - # logger.info('Setting metadata_key from %s to %s' % (ut.repr2(id_iter), ut.repr2(val_iter))) self.set( METADATA_TABLE_NAME, colnames, val_iter, id_iter, id_colname='metadata_key' ) + if invalidate_cache: + self.invalidate_tables_cache() - def drop_table(self, tablename): - if VERBOSE_SQL: - logger.info('[sql] schema dropping tablename=%r' % tablename) + def drop_table(self, tablename, invalidate_cache=True): + logger.info('[sql] schema dropping tablename=%r' % tablename) # Technically insecure call, but all entries are statically inputted by # the database's owner, who could delete or alter the entire database # anyway. - fmtkw = { - 'tablename': tablename, - } - op_fmtstr = 'DROP TABLE IF EXISTS {tablename}' - operation = op_fmtstr.format(**fmtkw) - self.executeone(operation, [], verbose=False) + operation = f'DROP TABLE IF EXISTS {tablename}' + if self.uri.startswith('postgresql'): + operation = f'{operation} CASCADE' + self.executeone(text(operation), []) # Delete table's metadata key_list = [tablename + '_' + suffix for suffix in METADATA_TABLE_COLUMN_NAMES] self.delete(METADATA_TABLE_NAME, key_list, id_colname='metadata_key') + if invalidate_cache: + self.invalidate_tables_cache() def drop_all_tables(self): """ @@ -2130,8 +2216,8 @@ def drop_all_tables(self): self._tablenames = None for tablename in self.get_table_names(): if tablename != 'metadata': - self.drop_table(tablename) - self._tablenames = None + self.drop_table(tablename, invalidate_cache=False) + self.invalidate_tables_cache() # ============== # CONVINENCE @@ -2222,7 +2308,7 @@ def get_coldef_list(self, tablename): col_type += ' PRIMARY KEY' elif column[3] == 1: col_type += ' NOT NULL' - elif column[4] is not None: + if column[4] is not None: default_value = six.text_type(column[4]) # HACK: add parens if the value contains parens in the future # all default values should contain parens @@ -2249,7 +2335,7 @@ def get_table_autogen_dict(self, tablename): Example: >>> # ENABLE_DOCTEST >>> from wbia.dtool.sql_control import * # NOQA - >>> db = SQLDatabaseController.from_uri(':memory:') + >>> db = SQLDatabaseController('sqlite:///', 'testing') >>> tablename = 'dummy_table' >>> db.add_table(tablename, ( >>> ('rowid', 'INTEGER PRIMARY KEY'), @@ -2284,7 +2370,7 @@ def get_table_autogen_str(self, tablename): Example: >>> # ENABLE_DOCTEST >>> from wbia.dtool.sql_control import * # NOQA - >>> db = SQLDatabaseController.from_uri(':memory:') + >>> db = SQLDatabaseController('sqlite:///', 'testing') >>> tablename = 'dummy_table' >>> db.add_table(tablename, ( >>> ('rowid', 'INTEGER PRIMARY KEY'), @@ -2378,11 +2464,35 @@ def dump_schema(self): file_.write('\t%s%s%s%s%s\n' % col) ut.view_directory(app_resource_dir) + def invalidate_tables_cache(self): + """Invalidates the controller's cache of table names and objects + Resets the caches and/or repopulates them. + + """ + self._tablenames = None + self._sa_metadata = sqlalchemy.MetaData() + self.get_table_names() + def get_table_names(self, lazy=False): """ Conveinience: """ if not lazy or self._tablenames is None: - self.cur.execute("SELECT name FROM sqlite_master WHERE type='table'") - tablename_list = self.cur.fetchall() + dialect = self._engine.dialect.name + if dialect == 'sqlite': + stmt = "SELECT name FROM sqlite_master WHERE type='table'" + params = {} + elif dialect == 'postgresql': + stmt = text( + """\ + SELECT table_name FROM information_schema.tables + WHERE table_type='BASE TABLE' + AND table_schema = :schema""" + ) + params = {'schema': self.schema_name} + else: + raise RuntimeError(f'Unknown dialect {dialect}') + with self.connect() as conn: + result = conn.execute(stmt, **params) + tablename_list = result.fetchall() self._tablenames = {str(tablename[0]) for tablename in tablename_list} return self._tablenames @@ -2504,9 +2614,38 @@ def get_columns(self, tablename): ] """ # check if the table exists first. Throws an error if it does not exist. - self.cur.execute('SELECT 1 FROM ' + tablename + ' LIMIT 1') - self.cur.execute("PRAGMA TABLE_INFO('" + tablename + "')") - colinfo_list = self.cur.fetchall() + with self.connect() as conn: + conn.execute('SELECT 1 FROM ' + tablename + ' LIMIT 1') + dialect = self._engine.dialect.name + if dialect == 'sqlite': + stmt = f"PRAGMA TABLE_INFO('{tablename}')" + params = {} + elif dialect == 'postgresql': + stmt = text( + """SELECT + row_number() over () - 1, + column_name, + coalesce(domain_name, data_type), + CASE WHEN is_nullable = 'YES' THEN 0 ELSE 1 END, + column_default, + column_name = ( + SELECT column_name + FROM information_schema.table_constraints + NATURAL JOIN information_schema.constraint_column_usage + WHERE table_name = :table_name + AND constraint_type = 'PRIMARY KEY' + AND table_schema = :table_schema + LIMIT 1 + ) AS pk + FROM information_schema.columns + WHERE table_name = :table_name + AND table_schema = :table_schema""" + ) + params = {'table_name': tablename, 'table_schema': self.schema_name} + + with self.connect() as conn: + result = conn.execute(stmt, **params) + colinfo_list = result.fetchall() colrichinfo_list = [SQLColumnRichInfo(*colinfo) for colinfo in colinfo_list] return colrichinfo_list @@ -2517,17 +2656,12 @@ def get_column_names(self, tablename): return column_names def get_column(self, tablename, name): - """ Conveinience: """ - _table, (_column,) = sanitize_sql(self, tablename, (name,)) - column_vals = self.executeone( - operation=""" - SELECT %s - FROM %s - ORDER BY rowid ASC - """ - % (_column, _table) + """Get all the values for the specified column (``name``) of the table (``tablename``)""" + table = self._reflect_table(tablename) + stmt = sqlalchemy.select([table.c[name]]).order_by( + *[c.asc() for c in table.primary_key.columns] ) - return column_vals + return self.executeone(stmt) def get_table_as_pandas( self, tablename, rowids=None, columns=None, exclude_columns=[] @@ -2557,15 +2691,14 @@ def get_table_as_pandas( df = pd.DataFrame(ut.dzip(column_names, column_list), index=index) return df + # TODO (25-Sept-12020) Deprecate once ResultProxy can be exposed, + # because it will allow result access by index or column name. def get_table_column_data( self, tablename, columns=None, exclude_columns=[], rowids=None ): """ Grabs a table of information - CommandLine: - python -m dtool.sql_control --test-get_table_column_data - Example: >>> # ENABLE_DOCTEST >>> from wbia.dtool.sql_control import * # NOQA @@ -2574,6 +2707,10 @@ def get_table_column_data( >>> tablename = 'keypoint' >>> db = depc[tablename].db >>> column_list, column_names = db.get_table_column_data(tablename) + >>> column_list + [[], [], [], [], []] + >>> column_names + ['keypoint_rowid', 'chip_rowid', 'config_rowid', 'kpts', 'num'] """ if columns is None: all_column_names = self.get_column_names(tablename) @@ -2587,6 +2724,9 @@ def get_table_column_data( ] else: column_list = [self.get_column(tablename, name) for name in column_names] + # BBB (28-Sept-12020) The previous implementation of `executeone` returned [] + # rather than None for empty rows. + column_list = [x and x or [] for x in column_list] return column_list, column_names def make_json_table_definition(self, tablename): @@ -2728,134 +2868,120 @@ def get_table_new_transferdata(self, tablename, exclude_columns=[]): >>> print('dependsmap = %s' % (ut.repr2(dependsmap, nl=True),)) >>> print('L___') """ - import utool - - with utool.embed_on_exception_context: - all_column_names = self.get_column_names(tablename) - isvalid_list = [name not in exclude_columns for name in all_column_names] - column_names = ut.compress(all_column_names, isvalid_list) - column_list = [ - self.get_column(tablename, name) - for name in column_names - if name not in exclude_columns - ] + table = self._reflect_table(tablename) + column_names = [c.name for c in table.columns if c.name not in exclude_columns] + column_list = [self.get_column(tablename, name) for name in column_names] + + extern_colx_list = [] + extern_tablename_list = [] + extern_superkey_colname_list = [] + extern_superkey_colval_list = [] + extern_primarycolnames_list = [] + dependsmap = self.metadata[tablename].dependsmap + if dependsmap is not None: + for colname, dependtup in six.iteritems(dependsmap): + assert len(dependtup) == 3, 'must be 3 for now' + ( + extern_tablename, + extern_primary_colnames, + extern_superkey_colnames, + ) = dependtup + if extern_primary_colnames is None: + # INFER PRIMARY COLNAMES + extern_primary_colnames = self.get_table_primarykey_colnames( + extern_tablename + ) + if extern_superkey_colnames is None: - extern_colx_list = [] - extern_tablename_list = [] - extern_superkey_colname_list = [] - extern_superkey_colval_list = [] - extern_primarycolnames_list = [] - dependsmap = self.metadata[tablename].dependsmap - if dependsmap is not None: - for colname, dependtup in six.iteritems(dependsmap): - assert len(dependtup) == 3, 'must be 3 for now' - ( - extern_tablename, - extern_primary_colnames, - extern_superkey_colnames, - ) = dependtup - if extern_primary_colnames is None: - # INFER PRIMARY COLNAMES - extern_primary_colnames = self.get_table_primarykey_colnames( - extern_tablename - ) - if extern_superkey_colnames is None: - - def get_standard_superkey_colnames(tablename_): - try: - # FIXME: Rectify duplicate code - superkeys = self.get_table_superkey_colnames(tablename_) - if len(superkeys) > 1: - primary_superkey = self.metadata[ - tablename_ - ].primary_superkey - self.get_table_superkey_colnames('contributors') - if primary_superkey is None: - raise AssertionError( - ( - 'tablename_=%r has multiple superkeys=%r, ' - 'but no primary superkey.' - ' A primary superkey is required' - ) - % (tablename_, superkeys) + def get_standard_superkey_colnames(tablename_): + try: + # FIXME: Rectify duplicate code + superkeys = self.get_table_superkey_colnames(tablename_) + if len(superkeys) > 1: + primary_superkey = self.metadata[ + tablename_ + ].primary_superkey + self.get_table_superkey_colnames('contributors') + if primary_superkey is None: + raise AssertionError( + ( + 'tablename_=%r has multiple superkeys=%r, ' + 'but no primary superkey.' + ' A primary superkey is required' ) - else: - index = superkeys.index(primary_superkey) - superkey_colnames = superkeys[index] - elif len(superkeys) == 1: - superkey_colnames = superkeys[0] - else: - logger.info(self.get_table_csv_header(tablename_)) - self.print_table_csv( - 'metadata', exclude_columns=['metadata_value'] + % (tablename_, superkeys) ) - # Execute hack to fix contributor tables - if tablename_ == 'contributors': - # hack to fix contributors table - constraint_str = self.metadata[ - tablename_ - ].constraint - parse_result = parse.parse( - 'CONSTRAINT superkey UNIQUE ({superkey})', - constraint_str, - ) - superkey = parse_result['superkey'] - assert ( - superkey == 'contributor_tag' - ), 'hack failed1' - assert ( - self.metadata['contributors'].superkey is None - ), 'hack failed2' - self.metadata['contributors'].superkey = [ - (superkey,) - ] - return (superkey,) - else: - raise NotImplementedError( - 'Cannot Handle: len(superkeys) == 0. ' - 'Probably a degenerate case' - ) - except Exception as ex: - ut.printex( - ex, - 'Error Getting superkey colnames', - keys=['tablename_', 'superkeys'], + else: + index = superkeys.index(primary_superkey) + superkey_colnames = superkeys[index] + elif len(superkeys) == 1: + superkey_colnames = superkeys[0] + else: + logger.info(self.get_table_csv_header(tablename_)) + self.print_table_csv( + 'metadata', exclude_columns=['metadata_value'] ) - raise - return superkey_colnames - - try: - extern_superkey_colnames = get_standard_superkey_colnames( - extern_tablename - ) + # Execute hack to fix contributor tables + if tablename_ == 'contributors': + # hack to fix contributors table + constraint_str = self.metadata[tablename_].constraint + parse_result = parse.parse( + 'CONSTRAINT superkey UNIQUE ({superkey})', + constraint_str, + ) + superkey = parse_result['superkey'] + assert superkey == 'contributor_tag', 'hack failed1' + assert ( + self.metadata['contributors'].superkey is None + ), 'hack failed2' + self.metadata['contributors'].superkey = [(superkey,)] + return (superkey,) + else: + raise NotImplementedError( + 'Cannot Handle: len(superkeys) == 0. ' + 'Probably a degenerate case' + ) except Exception as ex: ut.printex( ex, - 'Error Building Transferdata', - keys=['tablename_', 'dependtup'], + 'Error Getting superkey colnames', + keys=['tablename_', 'superkeys'], ) raise - # INFER SUPERKEY COLNAMES - colx = ut.listfind(column_names, colname) - extern_rowids = column_list[colx] - superkey_column = self.get( - extern_tablename, extern_superkey_colnames, extern_rowids - ) - extern_colx_list.append(colx) - extern_superkey_colname_list.append(extern_superkey_colnames) - extern_superkey_colval_list.append(superkey_column) - extern_tablename_list.append(extern_tablename) - extern_primarycolnames_list.append(extern_primary_colnames) + return superkey_colnames - new_transferdata = ( - column_list, - column_names, - extern_colx_list, - extern_superkey_colname_list, - extern_superkey_colval_list, - extern_tablename_list, - extern_primarycolnames_list, - ) + try: + extern_superkey_colnames = get_standard_superkey_colnames( + extern_tablename + ) + except Exception as ex: + ut.printex( + ex, + 'Error Building Transferdata', + keys=['tablename_', 'dependtup'], + ) + raise + # INFER SUPERKEY COLNAMES + colx = ut.listfind(column_names, colname) + extern_rowids = column_list[colx] + superkey_column = self.get( + extern_tablename, extern_superkey_colnames, extern_rowids + ) + extern_colx_list.append(colx) + extern_superkey_colname_list.append(extern_superkey_colnames) + extern_superkey_colval_list.append(superkey_column) + extern_tablename_list.append(extern_tablename) + extern_primarycolnames_list.append(extern_primary_colnames) + + new_transferdata = ( + column_list, + column_names, + extern_colx_list, + extern_superkey_colname_list, + extern_superkey_colval_list, + extern_tablename_list, + extern_primarycolnames_list, + ) return new_transferdata # def import_table_new_transferdata(tablename, new_transferdata): @@ -3007,6 +3133,12 @@ def find_depth(tablename, dependency_digraph): extern_tablename_list, extern_primarycolnames_list, ) = new_transferdata + if column_names[0] == 'rowid': + # This is a postgresql database, ignore the rowid column + # which is built-in to sqlite + column_names = column_names[1:] + column_list = column_list[1:] + extern_colx_list = [i - 1 for i in extern_colx_list] # FIXME: extract the primary rowid column a little bit nicer assert column_names[0].endswith('_rowid') old_rowid_list = column_list[0] @@ -3256,30 +3388,14 @@ def view_db_in_external_reader(self): # ut.cmd(sqlite3_reader, sqlite3_db_fpath) pass + @deprecated("Use 'self.metadata.database.version = version' instead") def set_db_version(self, version): - # Do things properly, get the metadata_rowid (best because we want to assert anyway) - metadata_key_list = ['database_version'] - params_iter = ((metadata_key,) for metadata_key in metadata_key_list) - where_clause = 'metadata_key=?' - # list of relationships for each image - metadata_rowid_list = self.get_where( - METADATA_TABLE_NAME, - ('metadata_rowid',), - params_iter, - where_clause, - unpack_scalars=True, - ) - assert ( - len(metadata_rowid_list) == 1 - ), 'duplicate database_version keys in database' - id_iter = ((metadata_rowid,) for metadata_rowid in metadata_rowid_list) - val_list = ((_,) for _ in [version]) - self.set(METADATA_TABLE_NAME, ('metadata_value',), val_list, id_iter) + self.metadata.database.version = version def get_sql_version(self): """ Conveinience """ - self.cur.execute('SELECT sqlite_version()') - sql_version = self.cur.fetchone() + self.connection.execute('SELECT sqlite_version()') + sql_version = self.connection.fetchone() logger.info('[sql] SELECT sqlite_version = %r' % (sql_version,)) # The version number sqlite3 module. NOT the version of SQLite library. logger.info('[sql] sqlite3.version = %r' % (lite.version,)) diff --git a/wbia/dtool/types.py b/wbia/dtool/types.py index ffbe568c09..4734a16763 100644 --- a/wbia/dtool/types.py +++ b/wbia/dtool/types.py @@ -1,10 +1,12 @@ # -*- coding: utf-8 -*- """Mapping of Python types to SQL types""" import io -import json import uuid import numpy as np +from utool.util_cache import from_json, to_json +import sqlalchemy +from sqlalchemy.sql import text from sqlalchemy.types import Integer as SAInteger from sqlalchemy.types import TypeDecorator, UserDefinedType @@ -40,6 +42,7 @@ class JSONCodeableType(UserDefinedType): # Abstract properties base_py_type = None col_spec = None + postgresql_base_type = 'json' def get_col_spec(self, **kw): return self.col_spec @@ -47,12 +50,9 @@ def get_col_spec(self, **kw): def bind_processor(self, dialect): def process(value): if value is None: - return value + return None else: - if isinstance(value, self.base_py_type): - return json.dumps(value) - else: - return value + return to_json(value) return process @@ -60,11 +60,11 @@ def result_processor(self, dialect, coltype): def process(value): if value is None: return value + elif dialect.name == 'postgresql': + # postgresql doesn't need the value to be json decoded + return value else: - if not isinstance(value, self.base_py_type): - return json.loads(value) - else: - return value + return from_json(value) return process @@ -74,6 +74,7 @@ class NumPyPicklableType(UserDefinedType): # Abstract properties base_py_types = None col_spec = None + postgresql_base_type = 'bytea' def get_col_spec(self, **kw): return self.col_spec @@ -119,7 +120,8 @@ class Integer(TypeDecorator): impl = SAInteger def process_bind_param(self, value, dialect): - return int(value) + if value is not None: + return int(value) class List(JSONCodeableType): @@ -159,12 +161,20 @@ def bind_processor(self, dialect): def process(value): if value is None: return value + if not isinstance(value, uuid.UUID): + value = uuid.UUID(value) + + if dialect.name == 'sqlite': + return value.bytes_le + elif dialect.name == 'postgresql': + return value else: if not isinstance(value, uuid.UUID): - return '%.32x' % uuid.UUID(value).int + return uuid.UUID(value).bytes_le else: # hexstring - return '%.32x' % value.int + return value.bytes_le + raise RuntimeError(f'Unknown dialect {dialect.name}') return process @@ -172,11 +182,15 @@ def result_processor(self, dialect, coltype): def process(value): if value is None: return value - else: + if dialect.name == 'sqlite': if not isinstance(value, uuid.UUID): - return uuid.UUID(value) + return uuid.UUID(bytes_le=value) else: return value + elif dialect.name == 'postgresql': + return value + else: + raise RuntimeError(f'Unknown dialect {dialect.name}') return process @@ -184,3 +198,25 @@ def process(value): _USER_DEFINED_TYPES = (Dict, List, NDArray, Number, UUID) # SQL type (e.g. 'DICT') to SQLAlchemy type: SQL_TYPE_TO_SA_TYPE = {cls().get_col_spec(): cls for cls in _USER_DEFINED_TYPES} +# Map postgresql types to SQLAlchemy types (postgresql type names are lowercase) +SQL_TYPE_TO_SA_TYPE.update( + {cls().get_col_spec().lower(): cls for cls in _USER_DEFINED_TYPES} +) +SQL_TYPE_TO_SA_TYPE['INTEGER'] = Integer +SQL_TYPE_TO_SA_TYPE['integer'] = Integer +SQL_TYPE_TO_SA_TYPE['bigint'] = Integer + + +def initialize_postgresql_types(conn, schema): + domain_names = conn.execute( + """\ + SELECT domain_name FROM information_schema.domains + WHERE domain_schema = (select current_schema)""" + ).fetchall() + for type_name, cls in SQL_TYPE_TO_SA_TYPE.items(): + if type_name not in domain_names and hasattr(cls, 'postgresql_base_type'): + base_type = cls.postgresql_base_type + try: + conn.execute(f'CREATE DOMAIN {type_name} AS {base_type}') + except sqlalchemy.exc.ProgrammingError: + conn.execute(text('SET SCHEMA :schema'), schema=schema) diff --git a/wbia/entry_points.py b/wbia/entry_points.py index 25e5c92577..3f5f52ef13 100644 --- a/wbia/entry_points.py +++ b/wbia/entry_points.py @@ -73,6 +73,11 @@ def _init_wbia(dbdir=None, verbose=None, use_cache=True, web=None, **kwargs): params.parse_args() from wbia.control import IBEISControl + # Set up logging + # TODO (30-Nov-12020) This is intended to be a temporary fix to logging. + # logger.setLevel(logging.DEBUG) + # logger.addHandler(logging.StreamHandler()) + if verbose is None: verbose = ut.VERBOSE if verbose and NOT_QUIET: @@ -91,6 +96,10 @@ def _init_wbia(dbdir=None, verbose=None, use_cache=True, web=None, **kwargs): request_dbversion=request_dbversion, force_serial=force_serial, ) + # BBB (12-Jan-12021) daily database backup for the sqlite database + if not ibs.is_using_postgres_db: + ibs.daily_backup_database() + if web is None: web = ut.get_argflag( ('--webapp', '--webapi', '--web', '--browser'), @@ -326,6 +335,62 @@ def opendb_in_background(*args, **kwargs): return proc +@contextmanager +def opendb_with_web(*args, with_job_engine=False, **kwargs): + """Opens the database and starts the web server. + + Returns: + ibs, client - IBEISController and Werkzeug Client + + Example: + >>> from wbia.entry_points import opendb_with_web + >>> expected_response_data = {'status': {'success': True, 'code': 200, 'message': '', 'cache': -1}, 'response': True} + >>> with opendb_with_web('testdb1') as (ibs, client): + ... response = client.get('/api/test/heartbeat/') + ... assert expected_response_data == response.json + + """ + from wbia.control.controller_inject import get_flask_app + + # Create the controller instance + ibs = opendb(*args, **kwargs) + if with_job_engine: + # TODO start jobs engine + pass + + # Create the web application + app = get_flask_app() + # ??? Gotta attach the controller to the application? + setattr(app, 'ibs', ibs) + + # Return the controller and client instances to the caller + with app.test_client() as client: + yield ibs, client + + +def opendb_fg_web(*args, **kwargs): + """ + Ignore: + >>> from wbia.entry_points import * # NOQA + >>> kwargs = {'db': 'testdb1'} + >>> args = tuple() + + >>> import wbia + >>> ibs = wbia.opendb_fg_web() + + """ + # Gives you context inside the web app for testing + kwargs['start_web_loop'] = False + kwargs['web'] = True + kwargs['browser'] = False + ibs = opendb(*args, **kwargs) + from wbia.control import controller_inject + + app = controller_inject.get_flask_app() + ibs.app = app + return ibs + + def opendb_bg_web(*args, managed=False, **kwargs): """ Wrapper around opendb_in_background, returns a nice web_ibs @@ -489,29 +554,6 @@ def managed_server(): return web_ibs -def opendb_fg_web(*args, **kwargs): - """ - Ignore: - >>> from wbia.entry_points import * # NOQA - >>> kwargs = {'db': 'testdb1'} - >>> args = tuple() - - >>> import wbia - >>> ibs = wbia.opendb_fg_web() - - """ - # Gives you context inside the web app for testing - kwargs['start_web_loop'] = False - kwargs['web'] = True - kwargs['browser'] = False - ibs = opendb(*args, **kwargs) - from wbia.control import controller_inject - - app = controller_inject.get_flask_app() - ibs.app = app - return ibs - - def opendb( db=None, dbdir=None, diff --git a/wbia/gui/guiback.py b/wbia/gui/guiback.py index 339d978483..b628d6e25b 100644 --- a/wbia/gui/guiback.py +++ b/wbia/gui/guiback.py @@ -1686,9 +1686,7 @@ def merge_imagesets(back, imgsetid_list, destination_imgsetid): destination_imgsetid = imgsetid_list[destination_index] deprecated_imgsetids = list(imgsetid_list) deprecated_imgsetids.pop(destination_index) - gid_list = ut.flatten( - [ibs.get_valid_gids(imgsetid=imgsetid) for imgsetid in imgsetid_list] - ) + gid_list = ut.flatten(ibs.get_valid_gids(imgsetid_list=imgsetid_list)) imgsetid_list = [destination_imgsetid] * len(gid_list) ibs.set_image_imgsetids(gid_list, imgsetid_list) ibs.delete_imagesets(deprecated_imgsetids) diff --git a/wbia/init/filter_annots.py b/wbia/init/filter_annots.py index 87458a27a5..448a806b5a 100644 --- a/wbia/init/filter_annots.py +++ b/wbia/init/filter_annots.py @@ -1322,6 +1322,9 @@ def filter_annots_independent( logger.info('No annot filter returning') return avail_aids + if not avail_aids: # no need to filter if empty + return avail_aids + VerbosityContext = verb_context('FILTER_INDEPENDENT', aidcfg, verbose) VerbosityContext.startfilter(withpre=withpre) @@ -1599,6 +1602,9 @@ def filter_annots_intragroup( logger.info('No annot filter returning') return avail_aids + if not avail_aids: # no need to filter if empty + return avail_aids + VerbosityContext = verb_context('FILTER_INTRAGROUP', aidcfg, verbose) VerbosityContext.startfilter(withpre=withpre) diff --git a/wbia/init/sysres.py b/wbia/init/sysres.py index f6743a02a3..c485836ff1 100644 --- a/wbia/init/sysres.py +++ b/wbia/init/sysres.py @@ -6,11 +6,15 @@ """ import logging import os +from functools import lru_cache from os.path import exists, join, realpath +from pathlib import Path + import utool as ut import ubelt as ub from six.moves import input, zip, map from wbia import constants as const +from wbia.dtool.copy_sqlite_to_postgres import copy_sqlite_to_postgres (print, rrr, profile) = ut.inject2(__name__) @@ -45,6 +49,27 @@ def _wbia_cache_read(key, **kwargs): return ut.global_cache_read(key, appname=__APPNAME__, **kwargs) +def get_wbia_db_uri(db_dir: str = None): + """Central location to acquire the database URI value. + + Args: + db_dir (str): colloquial "dbdir" (default: None) + + The ``db_dir`` argument is only to be used in testing. + This function is monkeypatched by the testing environment + (see ``wbia.conftest`` for that code). + The monkeypatching is done because two or more instances of a controller + (i.e. ``IBEISController``) could be running in the same test. + In that scenario more than one URI may need to be defined, + which is not the case in production + and why the body of this function is kept fairly simple. + We ask the caller to supply the ``db_dir`` value + in order to match up the corresponding URI. + + """ + return ut.get_argval('--db-uri', default=None) + + # Specific cache getters / setters @@ -415,16 +440,11 @@ def ensure_pz_mtest(): >>> ensure_pz_mtest() """ logger.info('ensure_pz_mtest') - from wbia import sysres - - workdir = sysres.get_workdir() - mtest_zipped_url = const.ZIPPED_URLS.PZ_MTEST - mtest_dir = ut.grab_zipped_url(mtest_zipped_url, ensure=True, download_dir=workdir) - logger.info('have mtest_dir=%r' % (mtest_dir,)) + dbdir = ensure_db_from_url(const.ZIPPED_URLS.PZ_MTEST) # update the the newest database version import wbia - ibs = wbia.opendb('PZ_MTEST') + ibs = wbia.opendb(dbdir=dbdir) logger.info('cleaning up old database and ensureing everything is properly computed') ibs.db.vacuum() valid_aids = ibs.get_valid_aids() @@ -861,6 +881,11 @@ def ensure_testdb_orientation(): return ensure_db_from_url(const.ZIPPED_URLS.ORIENTATION) +@lru_cache(maxsize=None) +def ensure_testdb_assigner(): + return ensure_db_from_url(const.ZIPPED_URLS.ASSIGNER) + + def ensure_testdb_identification_example(): return ensure_db_from_url(const.ZIPPED_URLS.ID_EXAMPLE) @@ -877,6 +902,16 @@ def ensure_db_from_url(zipped_db_url): dbdir = ut.grab_zipped_url( zipped_url=zipped_db_url, ensure=True, download_dir=workdir ) + + # Determine if the implementation is using a URI for database connection. + # This is confusing, sorry. If the URI is set we are using a non-sqlite + # database connection. As such, we most translate the sqlite db. + uri = get_wbia_db_uri(dbdir) + if uri: + logger.info(f"Copying '{dbdir}' databases to the database at: {uri}") + for _, future, _, _ in copy_sqlite_to_postgres(Path(dbdir), uri): + future.result() # will raise if there is a problem + logger.info('have %s=%r' % (zipped_db_url, dbdir)) return dbdir diff --git a/wbia/other/dbinfo.py b/wbia/other/dbinfo.py index f500ddddb7..9291889ef9 100644 --- a/wbia/other/dbinfo.py +++ b/wbia/other/dbinfo.py @@ -11,6 +11,8 @@ import six import numpy as np import utool as ut +import matplotlib.pyplot as plt + print, rrr, profile = ut.inject2(__name__) logger = logging.getLogger('wbia') @@ -51,15 +53,19 @@ def print_qd_info(ibs, qaid_list, daid_list, verbose=False): def get_dbinfo( ibs, verbose=True, - with_imgsize=False, - with_bytes=False, - with_contrib=False, - with_agesex=False, + with_imgsize=True, + with_bytes=True, + with_contrib=True, + with_agesex=True, with_header=True, + with_reviews=True, + with_ggr=False, + with_map=False, short=False, tag='dbinfo', aid_list=None, aids=None, + gmt_offset=3.0, ): """ @@ -167,10 +173,9 @@ def get_dbinfo( # Basic variables request_annot_subset = False _input_aid_list = aid_list # NOQA + if aid_list is None: valid_aids = ibs.get_valid_aids() - valid_nids = ibs.get_valid_nids() - valid_gids = ibs.get_valid_gids() else: if isinstance(aid_list, str): # Hack to get experiment stats on aids @@ -182,16 +187,44 @@ def get_dbinfo( ibs, acfg_name_list ) aid_list = sorted(list(set(ut.flatten(ut.flatten(expanded_aids_list))))) - # aid_list = if verbose: logger.info('Specified %d custom aids' % (len(aid_list))) request_annot_subset = True valid_aids = aid_list - valid_nids = list( - set(ibs.get_annot_nids(aid_list, distinguish_unknowns=False)) - - {const.UNKNOWN_NAME_ROWID} - ) - valid_gids = list(set(ibs.get_annot_gids(aid_list))) + + def get_dates(ibs, gid_list): + unixtime_list = ibs.get_image_unixtime2(gid_list) + unixtime_list_ = [unixtime + (gmt_offset * 60 * 60) for unixtime in unixtime_list] + datetime_list = [ + ut.unixtime_to_datetimestr(unixtime) if unixtime is not None else 'UNKNOWN' + for unixtime in unixtime_list_ + ] + date_str_list = [value[:10] for value in datetime_list] + return date_str_list + + if with_ggr: + valid_gids = list(set(ibs.get_annot_gids(valid_aids))) + date_str_list = get_dates(ibs, valid_gids) + flag_list = [ + value in ['2016/01/30', '2016/01/31', '2018/01/27', '2018/01/28'] + for value in date_str_list + ] + valid_gids = ut.compress(valid_gids, flag_list) + ggr_aids = set(ut.flatten(ibs.get_image_aids(valid_gids))) + valid_aids = sorted(list(set(valid_aids) & ggr_aids)) + + valid_nids = list( + set(ibs.get_annot_nids(valid_aids, distinguish_unknowns=False)) + - {const.UNKNOWN_NAME_ROWID} + ) + valid_gids = list(set(ibs.get_annot_gids(valid_aids))) + # valid_rids = ibs._get_all_review_rowids() + valid_rids = [] + valid_rids += ibs.get_review_rowids_from_aid1(valid_aids) + valid_rids += ibs.get_review_rowids_from_aid2(valid_aids) + valid_rids = ut.flatten(valid_rids) + valid_rids = list(set(valid_rids)) + # associated_nids = ibs.get_valid_nids(filter_empty=True) # nids with at least one annotation valid_images = ibs.images(valid_gids) valid_annots = ibs.annots(valid_aids) @@ -224,16 +257,16 @@ def get_dbinfo( ibs.check_name_mapping_consistency(nx2_aids) - if False: + if True: # Occurrence Info def compute_annot_occurrence_ids(ibs, aid_list): from wbia.algo.preproc import preproc_occurrence + import utool as ut gid_list = ibs.get_annot_gids(aid_list) gid2_aids = ut.group_items(aid_list, gid_list) - config = {'seconds_thresh': 4 * 60 * 60} flat_imgsetids, flat_gids = preproc_occurrence.wbia_compute_occurrences( - ibs, gid_list, config=config, verbose=False + ibs, gid_list, verbose=False ) occurid2_gids = ut.group_items(flat_gids, flat_imgsetids) occurid2_aids = { @@ -283,19 +316,19 @@ def break_annots_into_encounters(aids): # ave_enc_time = [np.mean(times) for lbl, times in ut.group_items(posixtimes, labels).items()] # ut.square_pdist(ave_enc_time) - try: - am_rowids = ibs.get_annotmatch_rowids_between_groups([valid_aids], [valid_aids])[ - 0 - ] - aid_pairs = ibs.filter_aidpairs_by_tags(min_num=0, am_rowids=am_rowids) - undirected_tags = ibs.get_aidpair_tags( - aid_pairs.T[0], aid_pairs.T[1], directed=False - ) - tagged_pairs = list(zip(aid_pairs.tolist(), undirected_tags)) - tag_dict = ut.groupby_tags(tagged_pairs, undirected_tags) - pair_tag_info = ut.map_dict_vals(len, tag_dict) - except Exception: - pair_tag_info = {} + # try: + # am_rowids = ibs.get_annotmatch_rowids_between_groups([valid_aids], [valid_aids])[ + # 0 + # ] + # aid_pairs = ibs.filter_aidpairs_by_tags(min_num=0, am_rowids=am_rowids) + # undirected_tags = ibs.get_aidpair_tags( + # aid_pairs.T[0], aid_pairs.T[1], directed=False + # ) + # tagged_pairs = list(zip(aid_pairs.tolist(), undirected_tags)) + # tag_dict = ut.groupby_tags(tagged_pairs, undirected_tags) + # pair_tag_info = ut.map_dict_vals(len, tag_dict) + # except Exception: + # pair_tag_info = {} # logger.info(ut.repr2(pair_tag_info)) @@ -405,11 +438,120 @@ def arr2str(var): ut.show_if_requested() unixtime_statstr = ut.repr3(ut.get_timestats_dict(unixtime_list, full=True), si=True) + date_str_list = get_dates(ibs, valid_gids) + ggr_dates_stats = ut.dict_hist(date_str_list) + # GPS stats gps_list_ = ibs.get_image_gps(valid_gids) gpsvalid_list = [gps != (-1, -1) for gps in gps_list_] gps_list = ut.compress(gps_list_, gpsvalid_list) + if with_map: + + def plot_kenya(ibs, ax, gps_list=[], focus=False, focus2=False, margin=0.1): + import utool as ut + import pandas as pd + import geopandas + import shapely + + if focus2: + focus = True + + world = geopandas.read_file( + geopandas.datasets.get_path('naturalearth_lowres') + ) + africa = world[world.continent == 'Africa'] + kenya = africa[africa.name == 'Kenya'] + + cities = geopandas.read_file( + geopandas.datasets.get_path('naturalearth_cities') + ) + nairobi = cities[cities.name == 'Nairobi'] + + kenya.plot(ax=ax, color='white', edgecolor='black') + + path_dict = ibs.compute_ggr_path_dict() + meru = path_dict['County Meru'] + + for key in path_dict: + path = path_dict[key] + + polygon = shapely.geometry.Polygon(path.vertices[:, ::-1]) + gdf = geopandas.GeoDataFrame([1], geometry=[polygon], crs=world.crs) + + if key.startswith('County'): + if 'Meru' in key: + gdf.plot(ax=ax, color=(1, 0, 0, 0.2), edgecolor='red') + else: + gdf.plot(ax=ax, color='grey', edgecolor='black') + if focus: + if key.startswith('Land Tenure'): + gdf.plot(ax=ax, color=(1, 0, 0, 0.0), edgecolor='blue') + + if focus2: + flag_list = [] + for gps in gps_list: + flag = meru.contains_point(gps) + flag_list.append(flag) + gps_list = ut.compress(gps_list, flag_list) + + df = pd.DataFrame( + { + 'Latitude': ut.take_column(gps_list, 0), + 'Longitude': ut.take_column(gps_list, 1), + } + ) + gdf = geopandas.GeoDataFrame( + df, geometry=geopandas.points_from_xy(df.Longitude, df.Latitude) + ) + gdf.plot(ax=ax, color='red') + + min_lat, min_lon = gdf.min() + max_lat, max_lon = gdf.max() + dom_lat = max_lat - min_lat + dom_lon = max_lon - min_lon + margin_lat = dom_lat * margin + margin_lon = dom_lon * margin + min_lat -= margin_lat + min_lon -= margin_lon + max_lat += margin_lat + max_lon += margin_lon + + polygon = shapely.geometry.Polygon( + [ + [min_lon, min_lat], + [min_lon, max_lat], + [max_lon, max_lat], + [max_lon, min_lat], + ] + ) + gdf = geopandas.GeoDataFrame([1], geometry=[polygon], crs=world.crs) + gdf.plot(ax=ax, color=(1, 0, 0, 0.0), edgecolor='blue') + + nairobi.plot(ax=ax, marker='*', color='black', markersize=500) + + ax.grid(False, which='major') + ax.grid(False, which='minor') + ax.get_xaxis().set_ticks([]) + ax.get_yaxis().set_ticks([]) + + if focus: + ax.set_autoscalex_on(False) + ax.set_autoscaley_on(False) + ax.set_xlim([min_lon, max_lon]) + ax.set_ylim([min_lat, max_lat]) + + fig = plt.figure(figsize=(30, 30), dpi=400) + + ax = plt.subplot(131) + plot_kenya(ibs, ax, gps_list) + ax = plt.subplot(132) + plot_kenya(ibs, ax, gps_list, focus=True) + ax = plt.subplot(133) + plot_kenya(ibs, ax, gps_list, focus2=True) + + plt.savefig('map.png', bbox_inches='tight') + def get_annot_age_stats(aid_list): annot_age_months_est_min = ibs.get_annot_age_months_est_min(aid_list) annot_age_months_est_max = ibs.get_annot_age_months_est_max(aid_list) @@ -558,14 +700,150 @@ def fix_tag_list(tag_list): contributor_rowids = ibs.get_valid_contributor_rowids() num_contributors = len(contributor_rowids) - # print - num_tabs = 5 + if verbose: + logger.info('Checking Review Info') + + # Get reviewer statistics + def get_review_decision_stats(ibs, rid_list): + review_decision_list = ibs.get_review_decision_str(rid_list) + review_decision_to_rids = ut.group_items(rid_list, review_decision_list) + review_decision_stats = { + key: len(val) for key, val in review_decision_to_rids.items() + } + return review_decision_stats + + def get_review_identity(rid_list): + review_identity_list = ibs.get_review_identity(rid_list) + review_identity_list = [ + value.replace('user:web', 'human:web') + .replace('web:None', 'web') + .replace('auto_clf', 'vamp') + .replace(':', '[') + + ']' + for value in review_identity_list + ] + return review_identity_list + + def get_review_identity_stats(ibs, rid_list): + review_identity_list = get_review_identity(rid_list) + review_identity_to_rids = ut.group_items(rid_list, review_identity_list) + review_identity_stats = { + key: len(val) for key, val in review_identity_to_rids.items() + } + return review_identity_to_rids, review_identity_stats + + def get_review_participation(review_aids_list, value_list): + review_participation_dict = {} + for review_aids, value in zip(review_aids_list, value_list): + for value_ in [value, 'Any']: + if value_ not in review_participation_dict: + review_participation_dict[value_] = {} + for aid in review_aids: + if aid not in review_participation_dict[value_]: + review_participation_dict[value_][aid] = 0 + review_participation_dict[value_][aid] += 1 + + for value in review_participation_dict: + values = list(review_participation_dict[value].values()) + mean = np.mean(values) + std = np.std(values) + thresh = int(np.around(mean + 2 * std)) + values = [ + '%02d+' % (thresh,) if value >= thresh else '%02d' % (value,) + for value in values + ] + review_participation_dict[value] = ut.dict_hist(values) + review_participation_dict[value]['AVG'] = '%0.1f +/- %0.1f' % ( + mean, + std, + ) + + return review_participation_dict + + review_decision_stats = get_review_decision_stats(ibs, valid_rids) + review_identity_to_rids, review_identity_stats = get_review_identity_stats( + ibs, valid_rids + ) + + review_identity_to_decision_stats = { + key: get_review_decision_stats(ibs, aids) + for key, aids in six.iteritems(review_identity_to_rids) + } + + review_aids_list = ibs.get_review_aid_tuple(valid_rids) + review_decision_list = ibs.get_review_decision_str(valid_rids) + review_identity_list = get_review_identity(valid_rids) + review_decision_participation_dict = get_review_participation( + review_aids_list, review_decision_list + ) + review_identity_participation_dict = get_review_participation( + review_aids_list, review_identity_list + ) + + review_tags_list = ibs.get_review_tags(valid_rids) + review_tag_list = [ + review_tag if review_tag is None else '+'.join(sorted(review_tag)) + for review_tag in review_tags_list + ] + + review_tag_to_rids = ut.group_items(valid_rids, review_tag_list) + review_tag_stats = {key: len(val) for key, val in review_tag_to_rids.items()} + + species_list = ibs.get_annot_species_texts(valid_aids) + viewpoint_list = ibs.get_annot_viewpoints(valid_aids) + quality_list = ibs.get_annot_qualities(valid_aids) + interest_list = ibs.get_annot_interest(valid_aids) + canonical_list = ibs.get_annot_canonical(valid_aids) + + ggr_num_relevant = 0 + ggr_num_species = 0 + ggr_num_viewpoints = 0 + ggr_num_qualities = 0 + ggr_num_aois = 0 + ggr_num_cas = 0 + ggr_num_overlap = 0 + + zipped = list( + zip( + valid_aids, + species_list, + viewpoint_list, + quality_list, + interest_list, + canonical_list, + ) + ) + for aid, species_, viewpoint_, quality_, interest_, canonical_ in zipped: + assert None not in [species_, viewpoint_, quality_] + species_ = species_.lower() + viewpoint_ = viewpoint_.lower() + quality_ = int(quality_) + if species_ in ['zebra_grevys', 'zebra_plains']: + ggr_num_relevant += 1 + if species_ in ['zebra_grevys']: + ggr_num_species += 1 + if 'right' in viewpoint_: + ggr_num_viewpoints += 1 + if quality_ >= 3: + ggr_num_qualities += 1 + if interest_: + ggr_num_aois += 1 + if canonical_: + ggr_num_overlap += 1 + + if canonical_: + ggr_num_cas += 1 + + ######### + + num_tabs = 30 def align2(str_): return ut.align(str_, ':', ' :') def align_dict2(dict_): - str_ = ut.repr2(dict_, si=True) + # str_ = ut.repr2(dict_, si=True) + str_ = ut.repr3(dict_, si=True) return align2(str_) header_block_lines = [('+============================')] + ( @@ -626,17 +904,6 @@ def align_dict2(dict_): else [] ) - occurrence_block_lines = ( - [ - ('--' * num_tabs), - # ('# Occurrence Per Name (Resights) = %s' % (align_dict2(resight_name_stats),)), - # ('# Annots per Encounter (Singlesights) = %s' % (align_dict2(singlesight_annot_stats),)), - ('# Pair Tag Info (annots) = %s' % (align_dict2(pair_tag_info),)), - ] - if not short - else [] - ) - annot_per_qualview_block_lines = [ None if short else '# Annots per Viewpoint = %s' % align_dict2(viewcode2_nAnnots), None if short else '# Annots per Quality = %s' % align_dict2(qualtext2_nAnnots), @@ -644,23 +911,66 @@ def align_dict2(dict_): annot_per_agesex_block_lines = ( [ - '# Annots per Age = %s' % align_dict2(agetext2_nAnnots), - '# Annots per Sex = %s' % align_dict2(sextext2_nAnnots), + ('# Annots per Age = %s' % align_dict2(agetext2_nAnnots)), + ('# Annots per Sex = %s' % align_dict2(sextext2_nAnnots)), ] if not short and with_agesex else [] ) - contributor_block_lines = ( + annot_ggr_census = ( [ - '# Images per contributor = ' + align_dict2(contributor_tag_to_nImages), - '# Annots per contributor = ' + align_dict2(contributor_tag_to_nAnnots), - '# Quality per contributor = ' - + ut.repr2(contributor_tag_to_qualstats, sorted_=True), - '# Viewpoint per contributor = ' - + ut.repr2(contributor_tag_to_viewstats, sorted_=True), + ('GGR Annots: '), + (' +-Relevant: %s' % (ggr_num_relevant,)), + (" +- Grevy's Species: %s" % (ggr_num_species,)), + (' | +-AoIs: %s' % (ggr_num_aois,)), + (' | +- Right Side: %s' % (ggr_num_viewpoints,)), + (' | | +-Good Quality: %s' % (ggr_num_qualities,)), + (' +-CAs: %s' % (ggr_num_cas,)), + (' +-Filter + CA Overlap: %s' % (ggr_num_overlap,)), ] - if with_contrib + if with_ggr + else [] + ) + + occurrence_block_lines = ( + [ + ('--' * num_tabs), + ( + '# Occurrence Per Name (Resights) = %s' + % (align_dict2(resight_name_stats),) + ), + ( + '# Annots per Encounter (Singlesights) = %s' + % (align_dict2(singlesight_annot_stats),) + ), + # ('# Pair Tag Info (annots) = %s' % (align_dict2(pair_tag_info),)), + ] + if not short + else [] + ) + + reviews_block_lines = ( + [ + ('--' * num_tabs), + ('# Reviews = %d' % len(valid_rids)), + ('# Reviews per Decision = %s' % align_dict2(review_decision_stats)), + ('# Reviews per Reviewer = %s' % align_dict2(review_identity_stats)), + ( + '# Review Breakdown = %s' + % align_dict2(review_identity_to_decision_stats) + ), + ('# Reviews with Tag = %s' % align_dict2(review_tag_stats)), + ( + '# Review Participation #1 = %s' + % align_dict2(review_decision_participation_dict) + ), + ( + '# Review Participation #2 = %s' + % align_dict2(review_identity_participation_dict) + ), + ] + if with_reviews else [] ) @@ -675,8 +985,35 @@ def align_dict2(dict_): None if short else ('Img Time Stats = %s' % (align2(unixtime_statstr),)), + None + if with_ggr + else ('GGR Days = %s' % (align_dict2(ggr_dates_stats),)), ] + contributor_block_lines = ( + [ + ('--' * num_tabs), + ( + '# Images per contributor = ' + + align_dict2(contributor_tag_to_nImages) + ), + ( + '# Annots per contributor = ' + + align_dict2(contributor_tag_to_nAnnots) + ), + ( + '# Quality per contributor = ' + + align_dict2(contributor_tag_to_qualstats) + ), + ( + '# Viewpoint per contributor = ' + + align_dict2(contributor_tag_to_viewstats) + ), + ] + if with_contrib + else [] + ) + info_str_lines = ( header_block_lines + bytes_block_lines @@ -684,16 +1021,17 @@ def align_dict2(dict_): + name_block_lines + annot_block_lines + annot_per_basic_block_lines - + occurrence_block_lines + annot_per_qualview_block_lines + annot_per_agesex_block_lines + + occurrence_block_lines + + reviews_block_lines + img_block_lines - + contributor_block_lines + imgsize_stat_lines + + contributor_block_lines + [('L============================')] ) info_str = '\n'.join(ut.filter_Nones(info_str_lines)) - info_str2 = ut.indent(info_str, '[{tag}]'.format(tag=tag)) + info_str2 = ut.indent(info_str, '[{tag}] '.format(tag=tag)) if verbose: logger.info(info_str2) locals_ = locals() diff --git a/wbia/other/ibsfuncs.py b/wbia/other/ibsfuncs.py index e411acb0d3..4e4a432d05 100644 --- a/wbia/other/ibsfuncs.py +++ b/wbia/other/ibsfuncs.py @@ -17,6 +17,7 @@ import types import functools import re +from collections import OrderedDict from six.moves import zip, range, map, reduce from os.path import split, join, exists import numpy as np @@ -47,6 +48,9 @@ (print, rrr, profile) = ut.inject2(__name__, '[ibsfuncs]') logger = logging.getLogger('wbia') +# logging.getLogger().setLevel(logging.DEBUG) +# logger.setLevel(logging.DEBUG) + # Must import class before injection CLASS_INJECT_KEY, register_ibs_method = controller_inject.make_ibs_register_decorator( @@ -149,6 +153,8 @@ def filter_junk_annotations(ibs, aid_list): >>> result = str(filtered_aid_list) >>> print(result) """ + if not aid_list: # no need to filter if empty + return aid_list isjunk_list = ibs.get_annot_isjunk(aid_list) filtered_aid_list = ut.filterfalse_items(aid_list, isjunk_list) return filtered_aid_list @@ -939,9 +945,7 @@ def check_name_mapping_consistency(ibs, nx2_aids): """ checks that all the aids grouped in a name ahave the same name """ # DEBUGGING CODE try: - from wbia import ibsfuncs - - _nids_list = ibsfuncs.unflat_map(ibs.get_annot_name_rowids, nx2_aids) + _nids_list = unflat_map(ibs.get_annot_name_rowids, nx2_aids) assert all(map(ut.allsame, _nids_list)) except Exception as ex: # THESE SHOULD BE CONSISTENT BUT THEY ARE NOT!!? @@ -1304,6 +1308,7 @@ def check_cache_purge(ibs, ttl_days=90, dryrun=True, squeeze=True): './_ibsdb/_ibeis_cache/match_thumbs', './_ibsdb/_ibeis_cache/qres_new', './_ibsdb/_ibeis_cache/curvrank', + './_ibsdb/_ibeis_cache/curvrank_v2', './_ibsdb/_ibeis_cache/pie_neighbors', ] @@ -3611,15 +3616,12 @@ def get_primary_database_species(ibs, aid_list=None, speedhack=True): return 'zebra_grevys' if aid_list is None: aid_list = ibs.get_valid_aids(is_staged=None) - species_list = ibs.get_annot_species_texts(aid_list) - species_hist = ut.dict_hist(species_list) - if len(species_hist) == 0: + + species_count = ibs.get_database_species_count(aid_list) + if not species_count: primary_species = const.UNKNOWN else: - frequent_species = sorted( - species_hist.items(), key=lambda item: item[1], reverse=True - ) - primary_species = frequent_species[0][0] + primary_species = species_count.popitem(last=False)[0] # FIFO return primary_species @@ -3650,7 +3652,7 @@ def get_dominant_species(ibs, aid_list): @register_ibs_method -def get_database_species_count(ibs, aid_list=None): +def get_database_species_count(ibs, aid_list=None, BATCH_SIZE=25000): """ CommandLine: @@ -3662,16 +3664,49 @@ def get_database_species_count(ibs, aid_list=None): >>> import wbia # NOQA >>> #print(ut.repr2(wbia.opendb('PZ_Master0').get_database_species_count())) >>> ibs = wbia.opendb('testdb1') - >>> result = ut.repr2(ibs.get_database_species_count(), nl=False) + >>> result = ut.repr2(ibs.get_database_species_count(BATCH_SIZE=2), nl=False) >>> print(result) - {'____': 3, 'bear_polar': 2, 'zebra_grevys': 2, 'zebra_plains': 6} + {'zebra_plains': 6, '____': 3, 'zebra_grevys': 2, 'bear_polar': 2} """ if aid_list is None: aid_list = ibs.get_valid_aids() - species_list = ibs.get_annot_species_texts(aid_list) - species_count_dict = ut.item_hist(species_list) - return species_count_dict + + annotations = ibs.db._reflect_table('annotations') + species = ibs.db._reflect_table('species') + + from sqlalchemy.sql import select, func, desc, bindparam + + species_count = OrderedDict() + stmt = ( + select( + [ + species.c.species_text, + func.count(annotations.c.annot_rowid).label('num_annots'), + ] + ) + .select_from( + annotations.outerjoin( + species, annotations.c.species_rowid == species.c.species_rowid + ) + ) + .where(annotations.c.annot_rowid.in_(bindparam('aids', expanding=True))) + .group_by('species_text') + .order_by(desc('num_annots')) + ) + for batch in range(int(len(aid_list) / BATCH_SIZE) + 1): + aids = aid_list[batch * BATCH_SIZE : (batch + 1) * BATCH_SIZE] + with ibs.db.connect() as conn: + results = conn.execute(stmt, {'aids': aids}) + + for row in results: + species_text = row.species_text + if species_text is None: + species_text = const.UNKNOWN + species_count[species_text] = ( + species_count.get(species_text, 0) + row.num_annots + ) + return species_count @register_ibs_method @@ -4867,15 +4902,18 @@ def filter_aids_to_quality(ibs, aid_list, minqual, unknown_ok=True, speedhack=Tr >>> x1 = filter_aids_to_quality(ibs, aid_list, 'good', True, speedhack=True) >>> x2 = filter_aids_to_quality(ibs, aid_list, 'good', True, speedhack=False) """ + if not aid_list: # no need to filter if empty + return aid_list if speedhack: list_repr = ','.join(map(str, aid_list)) minqual_int = const.QUALITY_TEXT_TO_INT[minqual] if unknown_ok: - operation = 'SELECT rowid from annotations WHERE (annot_quality ISNULL OR annot_quality==-1 OR annot_quality>={minqual_int}) AND rowid IN ({aids})' + operation = 'SELECT rowid from annotations WHERE (annot_quality ISNULL OR annot_quality=-1 OR annot_quality>={minqual_int}) AND rowid IN ({aids})' else: operation = 'SELECT rowid from annotations WHERE annot_quality NOTNULL AND annot_quality>={minqual_int} AND rowid IN ({aids})' operation = operation.format(aids=list_repr, minqual_int=minqual_int) - aid_list_ = ut.take_column(ibs.db.cur.execute(operation).fetchall(), 0) + with ibs.db.connect() as conn: + aid_list_ = ut.take_column(conn.execute(operation).fetchall(), 0) else: qual_flags = list( ibs.get_quality_filterflags(aid_list, minqual, unknown_ok=unknown_ok) @@ -4893,6 +4931,8 @@ def filter_aids_to_viewpoint(ibs, aid_list, valid_yaws, unknown_ok=True): valid_yaws = ['primary', 'primary1', 'primary-1'] """ + if not aid_list: # no need to filter if empty + return aid_list def rectify_view_category(view): @ut.memoize @@ -4953,6 +4993,8 @@ def filter_aids_without_name(ibs, aid_list, invert=False, speedhack=True): >>> assert np.all(np.array(annots2_.nids) < 0) >>> assert len(annots2_) == 4 """ + if not aid_list: # no need to filter if empty + return aid_list if speedhack: list_repr = ','.join(map(str, aid_list)) if invert: @@ -4965,7 +5007,8 @@ def filter_aids_without_name(ibs, aid_list, invert=False, speedhack=True): 'SELECT rowid from annotations WHERE name_rowid>0 AND rowid IN (%s)' % (list_repr,) ) - aid_list_ = ut.take_column(ibs.db.cur.execute(operation).fetchall(), 0) + with ibs.db.connect() as conn: + aid_list_ = ut.take_column(conn.execute(operation).fetchall(), 0) else: flag_list = ibs.is_aid_unknown(aid_list) if not invert: @@ -5007,6 +5050,8 @@ def filter_annots_using_minimum_timedelta(ibs, aid_list, min_timedelta): >>> wbia.other.dbinfo.hackshow_names(ibs, filtered_aids) >>> ut.show_if_requested() """ + if not aid_list: # no need to filter if empty + return aid_list import vtool as vt # min_timedelta = 60 * 60 * 24 @@ -5062,6 +5107,8 @@ def filter_aids_without_timestamps(ibs, aid_list, invert=False): Removes aids without timestamps aid_list = ibs.get_valid_aids() """ + if not aid_list: # no need to filter if empty + return aid_list unixtime_list = ibs.get_annot_image_unixtimes(aid_list) flag_list = [unixtime != -1 for unixtime in unixtime_list] if invert: @@ -5096,12 +5143,15 @@ def filter_aids_to_species(ibs, aid_list, species, speedhack=True): >>> print(result) aid_list_ = [9, 10] """ + if not aid_list: # no need to filter if empty + return aid_list species_rowid = ibs.get_species_rowids_from_text(species) if speedhack: list_repr = ','.join(map(str, aid_list)) - operation = 'SELECT rowid from annotations WHERE (species_rowid == {species_rowid}) AND rowid IN ({aids})' + operation = 'SELECT rowid from annotations WHERE (species_rowid = {species_rowid}) AND rowid IN ({aids})' operation = operation.format(aids=list_repr, species_rowid=species_rowid) - aid_list_ = ut.take_column(ibs.db.cur.execute(operation).fetchall(), 0) + with ibs.db.connect() as conn: + aid_list_ = ut.take_column(conn.execute(operation).fetchall(), 0) else: species_rowid_list = ibs.get_annot_species_rowids(aid_list) is_valid_species = [sid == species_rowid for sid in species_rowid_list] diff --git a/wbia/research/__init__.py b/wbia/research/__init__.py new file mode 100644 index 0000000000..6b831212de --- /dev/null +++ b/wbia/research/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from wbia.research import metrics # NOQA + +import utool as ut + +ut.noinject(__name__, '[wbia.research.__init__]', DEBUG=False) diff --git a/wbia/research/metrics.py b/wbia/research/metrics.py new file mode 100644 index 0000000000..44b99d5e87 --- /dev/null +++ b/wbia/research/metrics.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +""" +developer convenience functions for ibs + +TODO: need to split up into sub modules: + consistency_checks + feasibility_fixes + move the export stuff to dbio + + python -m utool.util_inspect check_module_usage --pat="ibsfuncs.py" + + then there are also convineience functions that need to be ordered at least + within this file +""" +import logging +import utool as ut +from wbia.control import controller_inject +from wbia import annotmatch_funcs # NOQA +import pytz + + +PST = pytz.timezone('US/Pacific') + + +# Inject utool function +(print, rrr, profile) = ut.inject2(__name__, '[research]') +logger = logging.getLogger('wbia') + + +# Must import class before injection +CLASS_INJECT_KEY, register_ibs_method = controller_inject.make_ibs_register_decorator( + __name__ +) + + +register_api = controller_inject.get_wbia_flask_api(__name__) + + +@register_ibs_method +def research_print_metrics(ibs, tag='metrics'): + imageset_rowid_list = ibs.get_valid_imgsetids(is_special=False) + imageset_text_list = ibs.get_imageset_text(imageset_rowid_list) + + global_gid_list = [] + global_cid_list = [] + for imageset_rowid, imageset_text in zip(imageset_rowid_list, imageset_text_list): + imageset_text_ = imageset_text.strip().split(',') + if len(imageset_text_) == 3: + ggr, car, person = imageset_text_ + if ggr in ['GGR', 'GGR2']: + gid_list = ibs.get_imageset_gids(imageset_rowid) + global_gid_list += gid_list + cid = ibs.add_contributors([imageset_text])[0] + global_cid_list += [cid] * len(gid_list) + + assert len(global_gid_list) == len(set(global_gid_list)) + + ibs.set_image_contributor_rowid(global_gid_list, global_cid_list) + + ###### + + aid_list = ibs.get_valid_aids() + + species_list = ibs.get_annot_species_texts(aid_list) + viewpoint_list = ibs.get_annot_viewpoints(aid_list) + quality_list = ibs.get_annot_qualities(aid_list) + + aids = [] + zipped = list(zip(aid_list, species_list, viewpoint_list, quality_list)) + for aid, species_, viewpoint_, quality_ in zipped: + assert None not in [species_, viewpoint_, quality_] + species_ = species_.lower() + viewpoint_ = viewpoint_.lower() + quality_ = int(quality_) + if species_ != 'zebra_grevys': + continue + if 'right' not in viewpoint_: + continue + aids.append(aid) + + config = { + 'classifier_algo': 'densenet', + 'classifier_weight_filepath': 'canonical_zebra_grevys_v4', + } + prediction_list = ibs.depc_annot.get_property( + 'classifier', aids, 'class', config=config + ) + confidence_list = ibs.depc_annot.get_property( + 'classifier', aids, 'score', config=config + ) + confidence_list = [ + confidence if prediction == 'positive' else 1.0 - confidence + for prediction, confidence in zip(prediction_list, confidence_list) + ] + flags = [confidence >= 0.5 for confidence in confidence_list] + ibs.set_annot_canonical(aids, flags) + + ibs.print_dbinfo(with_ggr=True, with_map=True) diff --git a/wbia/scripts/specialdraw.py b/wbia/scripts/specialdraw.py index ee6feccb24..0262a46536 100644 --- a/wbia/scripts/specialdraw.py +++ b/wbia/scripts/specialdraw.py @@ -289,6 +289,9 @@ def double_depcache_graph(): 'CurvRankDorsal': 'CurvRank (Dorsal) Distance', 'CurvRankFinfindrHybridDorsal': 'CurvRank + FinFindR Hybrid (Dorsal) Distance', 'CurvRankFluke': 'CurvRank (Fluke) Distance', + 'CurvRankTwo': 'CurvRank V2 Distance', + 'CurvRankTwoDorsal': 'CurvRank V2 (Dorsal) Distance', + 'CurvRankTwoFluke': 'CurvRank V2 (Fluke) Distance', 'Deepsense': 'Deepsense Distance', 'Pie': 'Pie Distance', 'Finfindr': 'Finfindr Distance', @@ -316,6 +319,9 @@ def double_depcache_graph(): 'CurvRankDorsal': 'curvrank_distance_dorsal', 'CurvRankFinfindrHybridDorsal': 'curvrank_finfindr_hybrid_distance_dorsal', 'CurvRankFluke': 'curvrank_distance_fluke', + 'CurvRankTwo': 'curvrank_two_distance', + 'CurvRankTwoDorsal': 'curvrank_two_distance_dorsal', + 'CurvRankTwoFluke': 'curvrank_two_distance_fluke', 'Deepsense': 'deepsense_distance', 'Pie': 'pie_distance', 'Finfindr': 'finfindr_distance', diff --git a/wbia/tests/dtool/test__integrate_sqlite3.py b/wbia/tests/dtool/test__integrate_sqlite3.py deleted file mode 100644 index 1aca23abd0..0000000000 --- a/wbia/tests/dtool/test__integrate_sqlite3.py +++ /dev/null @@ -1,132 +0,0 @@ -# -*- coding: utf-8 -*- -import sqlite3 -import uuid - -import numpy as np -import pytest - -# We do not explicitly call code in this module because -# importing the following module is execution of the code. -import wbia.dtool._integrate_sqlite3 # noqa - - -@pytest.fixture -def db(): - with sqlite3.connect(':memory:', detect_types=sqlite3.PARSE_DECLTYPES) as con: - yield con - - -np_number_types = ( - np.int8, - np.int16, - np.int32, - np.int64, - np.uint8, - np.uint16, - np.uint32, - np.uint64, -) - - -@pytest.mark.parametrize('num_type', np_number_types) -def test_register_numpy_dtypes_ints(db, num_type): - # The magic takes place in the register_numpy_dtypes function, - # which is implicitly called on module import. - - # Create a table that uses the type - db.execute('create table test(x integer)') - - # Insert a uuid value into the table - insert_value = num_type(8) - db.execute('insert into test(x) values (?)', (insert_value,)) - # Query for the value - cur = db.execute('select x from test') - selected_value = cur.fetchone()[0] - assert selected_value == insert_value - - -@pytest.mark.parametrize('num_type', (np.float32, np.float64)) -def test_register_numpy_dtypes_floats(db, num_type): - # The magic takes place in the register_numpy_dtypes function, - # which is implicitly called on module import. - - # Create a table that uses the type - db.execute('create table test(x real)') - - # Insert a uuid value into the table - insert_value = num_type(8.0000008) - db.execute('insert into test(x) values (?)', (insert_value,)) - # Query for the value - cur = db.execute('select x from test') - selected_value = cur.fetchone()[0] - assert selected_value == insert_value - - -@pytest.mark.parametrize('type_name', ('numpy', 'ndarray')) -def test_register_numpy(db, type_name): - # The magic takes place in the register_numpy function, - # which is implicitly called on module import. - - # Create a table that uses the type - db.execute(f'create table test(x {type_name})') - - # Insert a numpy array value into the table - insert_value = np.array([[1, 2, 3], [4, 5, 6]], np.int32) - db.execute('insert into test(x) values (?)', (insert_value,)) - # Query for the value - cur = db.execute('select x from test') - selected_value = cur.fetchone()[0] - assert (selected_value == insert_value).all() - - -def test_register_uuid(db): - # The magic takes place in the register_uuid function, - # which is implicitly called on module import. - - # Create a table that uses the type - db.execute('create table test(x uuid)') - - # Insert a uuid value into the table - insert_value = uuid.uuid4() - db.execute('insert into test(x) values (?)', (insert_value,)) - # Query for the value - cur = db.execute('select x from test') - selected_value = cur.fetchone()[0] - assert selected_value == insert_value - - -def test_register_dict(db): - # The magic takes place in the register_dict function, - # which is implicitly called on module import. - - # Create a table that uses the type - db.execute('create table test(x dict)') - - # Insert a dict value into the table - insert_value = { - 'a': 1, - 'b': 2.2, - 'c': [[1, 2, 3], [4, 5, 6]], - } - db.execute('insert into test(x) values (?)', (insert_value,)) - # Query for the value - cur = db.execute('select x from test') - selected_value = cur.fetchone()[0] - for k, v in selected_value.items(): - assert v == insert_value[k] - - -def test_register_list(db): - # The magic takes place in the register_list function, - # which is implicitly called on module import. - - # Create a table that uses the type - db.execute('create table test(x list)') - - # Insert a list of list value into the table - insert_value = [[1, 2, 3], [4, 5, 6]] - db.execute('insert into test(x) values (?)', (insert_value,)) - # Query for the value - cur = db.execute('select x from test') - selected_value = cur.fetchone()[0] - assert selected_value == insert_value diff --git a/wbia/tests/dtool/test_sql_control.py b/wbia/tests/dtool/test_sql_control.py index 5cffef0dd4..7b7575bea3 100644 --- a/wbia/tests/dtool/test_sql_control.py +++ b/wbia/tests/dtool/test_sql_control.py @@ -1,29 +1,259 @@ # -*- coding: utf-8 -*- -import sqlite3 import uuid +from functools import partial +import numpy as np import pytest +import sqlalchemy.exc +from sqlalchemy import MetaData, Table +from sqlalchemy.sql import select, text from wbia.dtool.sql_control import ( METADATA_TABLE_COLUMNS, - TIMEOUT, SQLDatabaseController, ) @pytest.fixture def ctrlr(): - return SQLDatabaseController.from_uri(':memory:') + return SQLDatabaseController('sqlite:///:memory:', 'testing') + + +def make_table_definition(name, depends_on=[]): + """Creates a table definition for use with the controller's add_table method""" + definition = { + 'tablename': name, + 'coldef_list': [ + (f'{name}_id', 'INTEGER PRIMARY KEY'), + ('meta_labeler_id', 'INTEGER NOT NULL'), + ('indexer_id', 'INTEGER NOT NULL'), + ('config_id', 'INTEGER DEFAULT 0'), + ('data', 'TEXT'), + ], + 'docstr': f'docstr for {name}', + 'superkeys': [ + ('meta_labeler_id', 'indexer_id', 'config_id'), + ], + 'dependson': depends_on, + } + return definition + + +def test_instantiation_with_table_reflection(tmp_path): + db_file = (tmp_path / 'testing.db').resolve() + creating_ctrlr = SQLDatabaseController(f'sqlite:///{db_file}', 'testing') + # Assumes `add_table` is functional. If you run into failing problems + # check for failures around this method first. + created_tables = [] + table_names = map( + 'table_{}'.format, + ( + 'a', + 'b', + 'c', + ), + ) + for t in table_names: + creating_ctrlr.add_table(**make_table_definition(t, depends_on=created_tables)) + # Build up of denpendence + created_tables.append(t) + + # Delete the controller + del creating_ctrlr + + # Create the controller again for reflection (testing target) + ctrlr = SQLDatabaseController(f'sqlite:///{db_file}', 'testing') + # Verify the tables are loaded on instantiation + assert list(ctrlr._sa_metadata.tables.keys()) == ['metadata'] + created_tables + # Note, we don't have to check for the contents of the tables, + # because that's machinery within SQLAlchemy, + # which will have been tested by SQLAlchemy. + + +class TestSchemaModifiers: + """Testing the API that creates, modifies or deletes schema elements""" + + @pytest.fixture(autouse=True) + def fixture(self, ctrlr): + self.ctrlr = ctrlr + + make_table_definition = staticmethod(make_table_definition) + + @property + def _table_factory(self): + return partial(Table, autoload=True, autoload_with=self.ctrlr._engine) + + def reflect_table(self, name, metadata=None): + """Using SQLAlchemy to reflect the table at the given ``name`` + to return a SQLAlchemy Table object + + """ + if metadata is None: + metadata = MetaData() + return self._table_factory(name, metadata) + + def test_make_add_table_sqlstr(self): + table_definition = self.make_table_definition( + 'foobars', depends_on=['meta_labelers', 'indexers'] + ) + + # Call the target + sql = self.ctrlr._make_add_table_sqlstr(**table_definition) + + expected = ( + 'CREATE TABLE IF NOT EXISTS foobars ( ' + 'foobars_id INTEGER PRIMARY KEY, ' + 'meta_labeler_id INTEGER NOT NULL, ' + 'indexer_id INTEGER NOT NULL, ' + 'config_id INTEGER DEFAULT 0, ' + 'data TEXT, ' + 'CONSTRAINT unique_foobars_meta_labeler_id_indexer_id_config_id ' + 'UNIQUE (meta_labeler_id, indexer_id, config_id) )' + ) + assert sql.text == expected + + def test_add_table(self): + # Two tables... + # .. used in the creation of bars table + foos_definition = self.make_table_definition('foos') + # .. bars table depends on foos table + bars_definition = self.make_table_definition('bars', depends_on=['foos']) + # We test against bars table and basically neglect foos table + + # Call the target + self.ctrlr.add_table(**foos_definition) + self.ctrlr.add_table(**bars_definition) + + # Check the table has been added and verify details + # Use sqlalchemy's reflection + md = MetaData() + bars = self.reflect_table('bars', md) + metadata = self.reflect_table('metadata', md) + + # Check the table's column definitions + expected_bars_columns = [ + ('bars_id', 'wbia.dtool.types.Integer'), + ('config_id', 'wbia.dtool.types.Integer'), + ('data', 'sqlalchemy.sql.sqltypes.TEXT'), + ('indexer_id', 'wbia.dtool.types.Integer'), + ('meta_labeler_id', 'wbia.dtool.types.Integer'), + ] + found_bars_columns = [ + (c.name, '.'.join([c.type.__class__.__module__, c.type.__class__.__name__])) + for c in bars.columns + ] + assert sorted(found_bars_columns) == expected_bars_columns + # Check the table's constraints + expected_constraint_info = [ + ('PrimaryKeyConstraint', None, ['bars_id']), + ( + 'UniqueConstraint', + 'unique_bars_meta_labeler_id_indexer_id_config_id', + ['meta_labeler_id', 'indexer_id', 'config_id'], + ), + ] + found_constraint_info = [ + (x.__class__.__name__, x.name, [c.name for c in x.columns]) + for x in bars.constraints + ] + assert sorted(found_constraint_info) == expected_constraint_info + + # Check for metadata entries + results = self.ctrlr._engine.execute( + select([metadata.c.metadata_key, metadata.c.metadata_value]).where( + metadata.c.metadata_key.like('bars_%') + ) + ) + expected_metadata_rows = [ + ('bars_docstr', 'docstr for bars'), + ('bars_superkeys', "[('meta_labeler_id', 'indexer_id', 'config_id')]"), + ('bars_dependson', "['foos']"), + ] + assert results.fetchall() == expected_metadata_rows + + def test_rename_table(self): + # Assumes `add_table` passes to reduce this test's complexity. + table_name = 'cookies' + self.ctrlr.add_table(**self.make_table_definition(table_name)) + + # Call the target + new_table_name = 'deserts' + self.ctrlr.rename_table(table_name, new_table_name) + + # Check the table has been renamed use sqlalchemy's reflection + md = MetaData() + metadata = self.reflect_table('metadata', md) + + # Reflecting the table is enough to check that it's been renamed. + self.reflect_table(new_table_name, md) + + # Check for metadata entries have been renamed. + results = self.ctrlr._engine.execute( + select([metadata.c.metadata_key, metadata.c.metadata_value]).where( + metadata.c.metadata_key.like(f'{new_table_name}_%') + ) + ) + expected_metadata_rows = [ + (f'{new_table_name}_docstr', f'docstr for {table_name}'), + ( + f'{new_table_name}_superkeys', + "[('meta_labeler_id', 'indexer_id', 'config_id')]", + ), + (f'{new_table_name}_dependson', '[]'), + ] + assert results.fetchall() == expected_metadata_rows + def test_drop_table(self): + # Assumes `add_table` passes to reduce this test's complexity. + table_name = 'cookies' + self.ctrlr.add_table(**self.make_table_definition(table_name)) -def test_instantiation(ctrlr): - # Check for basic connection information - assert ctrlr.uri == ':memory:' - assert ctrlr.timeout == TIMEOUT + # Call the target + self.ctrlr.drop_table(table_name) - # Check for a connection, that would have been made during instantiation - assert isinstance(ctrlr.connection, sqlite3.Connection) - assert isinstance(ctrlr.cur, sqlite3.Cursor) + # Check the table using sqlalchemy's reflection + md = MetaData() + metadata = self.reflect_table('metadata', md) + + # This error in the attempt to reflect indicates the table has been removed. + with pytest.raises(sqlalchemy.exc.NoSuchTableError): + self.reflect_table(table_name, md) + + # Check for metadata entries have been renamed. + results = self.ctrlr._engine.execute( + select([metadata.c.metadata_key, metadata.c.metadata_value]).where( + metadata.c.metadata_key.like(f'{table_name}_%') + ) + ) + assert results.fetchall() == [] + + def test_drop_all_table(self): + # Assumes `add_table` passes to reduce this test's complexity. + table_names = ['cookies', 'pies', 'cakes'] + for name in table_names: + self.ctrlr.add_table(**self.make_table_definition(name)) + + # Call the target + self.ctrlr.drop_all_tables() + + # Check the table using sqlalchemy's reflection + md = MetaData() + metadata = self.reflect_table('metadata', md) + + # This error in the attempt to reflect indicates the table has been removed. + for name in table_names: + with pytest.raises(sqlalchemy.exc.NoSuchTableError): + self.reflect_table(name, md) + + # Check for the absents of metadata for the removed tables. + results = self.ctrlr._engine.execute(select([metadata.c.metadata_key])) + expected_metadata_rows = [ + ('database_init_uuid',), + ('database_version',), + ('metadata_docstr',), + ('metadata_superkeys',), + ] + assert results.fetchall() == expected_metadata_rows def test_safely_get_db_version(ctrlr): @@ -53,14 +283,14 @@ def fixture(self, ctrlr, monkeypatch): self.ctrlr._ensure_metadata_table() # Create metadata in the table + insert_stmt = text( + 'INSERT INTO metadata (metadata_key, metadata_value) VALUES (:key, :value)' + ) for key, value in self.data.items(): unprefixed_name = key.split('_')[-1] if METADATA_TABLE_COLUMNS[unprefixed_name]['is_coded_data']: value = repr(value) - self.ctrlr.executeone( - 'INSERT INTO metadata (metadata_key, metadata_value) VALUES (?, ?)', - (key, value), - ) + self.ctrlr._engine.execute(insert_stmt, key=key, value=value) def monkey_get_table_names(self, *args, **kwargs): return ['foo', 'metadata'] @@ -118,9 +348,9 @@ def test_setting_to_none(self): assert new_value == value # Also check the table does not have the record - assert not self.ctrlr.executeone( + assert not self.ctrlr._engine.execute( f"SELECT * FROM metadata WHERE metadata_key = 'foo_{key}'" - ) + ).fetchone() def test_setting_unknown_key(self): # Check setting of an unknown metadata key @@ -139,9 +369,9 @@ def test_deleter(self): assert self.ctrlr.metadata.foo.docstr is None # Also check the table does not have the record - assert not self.ctrlr.executeone( + assert not self.ctrlr._engine.execute( f"SELECT * FROM metadata WHERE metadata_key = 'foo_{key}'" - ) + ).fetchone() def test_database_attributes(self): # Check the database version @@ -242,9 +472,9 @@ def test_delitem_for_table(self): assert self.ctrlr.metadata.foo.docstr is None # Also check the table does not have the record - assert not self.ctrlr.executeone( + assert not self.ctrlr._engine.execute( f"SELECT * FROM metadata WHERE metadata_key = 'foo_{key}'" - ) + ).fetchone() def test_delitem_for_database(self): # You cannot delete database version metadata @@ -261,3 +491,449 @@ def test_delitem_for_database(self): self.ctrlr.metadata.database['init_uuid'] = None # Check the value is still a uuid.UUID assert isinstance(self.ctrlr.metadata.database.init_uuid, uuid.UUID) + + +class BaseAPITestCase: + """Testing the primary *usage* API""" + + @pytest.fixture(autouse=True) + def fixture(self, ctrlr): + self.ctrlr = ctrlr + + def make_table(self, name): + self.ctrlr._engine.execute( + f'CREATE TABLE IF NOT EXISTS {name} ' + '(id INTEGER PRIMARY KEY, x TEXT, y INTEGER, z REAL)' + ) + + def populate_table(self, name): + """To be used in conjunction with ``make_table`` to populate the table + with records from 0 to 9. + + """ + insert_stmt = text(f'INSERT INTO {name} (x, y, z) VALUES (:x, :y, :z)') + for i in range(0, 10): + x, y, z = ( + (i % 2) and 'odd' or 'even', + i, + i * 2.01, + ) + self.ctrlr._engine.execute(insert_stmt, x=x, y=y, z=z) + + +class TestExecutionAPI(BaseAPITestCase): + def test_executeone(self): + table_name = 'test_executeone' + self.make_table(table_name) + + # Create some dummy records + self.populate_table(table_name) + + # Call the testing target + result = self.ctrlr.executeone(text(f'SELECT id, y FROM {table_name}')) + + assert result == [(i + 1, i) for i in range(0, 10)] + + def test_executeone_using_fetchone_behavior(self): + table_name = 'test_executeone' + self.make_table(table_name) + + # Call the testing target with `fetchone` method's returning behavior. + result = self.ctrlr.executeone( + text(f'SELECT id, y FROM {table_name}'), use_fetchone_behavior=True + ) + + # IMO returning None is correct, + # because that's the expectation from `fetchone`'s DBAPI spec. + assert result is None + + def test_executeone_without_results(self): + table_name = 'test_executeone' + self.make_table(table_name) + + # Call the testing target + result = self.ctrlr.executeone(text(f'SELECT id, y FROM {table_name}')) + + # IMO returning None is correct, + # because that's the expectation from `fetchone`'s DBAPI spec. + assert result == [] + + def test_executeone_on_insert(self): + # Should return id after an insert + table_name = 'test_executeone' + self.make_table(table_name) + + # Create some dummy records + self.populate_table(table_name) + + # Call the testing target + result = self.ctrlr.executeone( + text(f'INSERT INTO {table_name} (y) VALUES (:y)'), {'y': 10} + ) + + # Cursory check that the result is a single int value + assert result == [11] # the result list with one unwrapped value + + # Check for the actual value associated with the resulting id + inserted_value = self.ctrlr._engine.execute( + text(f'SELECT id, y FROM {table_name} WHERE rowid = :rowid'), + rowid=result[0], + ).fetchone() + assert inserted_value == ( + 11, + 10, + ) + + def test_executemany(self): + table_name = 'test_executemany' + self.make_table(table_name) + + # Create some dummy records + self.populate_table(table_name) + + # Call the testing target + results = self.ctrlr.executemany( + text(f'SELECT id, y FROM {table_name} where x = :x'), + ({'x': 'even'}, {'x': 'odd'}), + unpack_scalars=False, + ) + + # Check for results + evens = [(i + 1, i) for i in range(0, 10) if not i % 2] + odds = [(i + 1, i) for i in range(0, 10) if i % 2] + assert results == [evens, odds] + + def test_executemany_transaction(self): + table_name = 'test_executemany' + self.make_table(table_name) + + # Test a failure to execute in the transaction to test the transaction boundary. + insert = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, y:, :z)') + params = [ + dict(x='even', y=0, z=0.0), + dict(x='odd', y=1, z=1.01), + dict(x='oops', z=2.02), # error + dict(x='odd', y=3, z=3.03), + ] + with pytest.raises(sqlalchemy.exc.OperationalError): + # Call the testing target + results = self.ctrlr.executemany(insert, params) + + # Check for results + results = self.ctrlr._engine.execute(f'select count(*) from {table_name}') + assert results.fetchone()[0] == 0 + + def test_executeone_for_single_column(self): + # Should unwrap the resulting query value (no tuple wrapping) + table_name = 'test_executeone' + self.make_table(table_name) + + # Create some dummy records + self.populate_table(table_name) + + # Call the testing target + result = self.ctrlr.executeone(text(f'SELECT y FROM {table_name}')) + + # Note the unwrapped values, rather than [(i,) ...] + assert result == [i for i in range(0, 10)] + + +class TestAdditionAPI(BaseAPITestCase): + def test_add(self): + table_name = 'test_add' + self.make_table(table_name) + + parameter_values = [] + for i in range(0, 10): + x, y, z = ( + (i % 2) and 'odd' or 'even', + i, + i * 2.01, + ) + parameter_values.append((x, y, z)) + + # Call the testing target + ids = self.ctrlr._add(table_name, ['x', 'y', 'z'], parameter_values) + + # Verify the resulting ids + assert ids == [i + 1 for i in range(0, len(parameter_values))] + # Verify addition of records + results = self.ctrlr._engine.execute(f'SELECT id, x, y, z FROM {table_name}') + expected = [(i + 1, x, y, z) for i, (x, y, z) in enumerate(parameter_values)] + assert results.fetchall() == expected + + +class TestGettingAPI(BaseAPITestCase): + def test_get_where_without_where_condition(self): + table_name = 'test_get_where' + self.make_table(table_name) + + # Create some dummy records + self.populate_table(table_name) + + # Call the testing target + results = self.ctrlr.get_where( + table_name, + ['id', 'y'], + tuple(), + None, + ) + + # Verify query + assert results == [(i + 1, i) for i in range(0, 10)] + + def test_scalar_get_where(self): + table_name = 'test_get_where' + self.make_table(table_name) + + # Create some dummy records + self.populate_table(table_name) + + # Call the testing target + results = self.ctrlr.get_where( + table_name, + ['id', 'y'], + ({'id': 1},), + 'id = :id', + ) + evens = results[0] + + # Verify query + assert evens == (1, 0) + + def test_multi_row_get_where(self): + table_name = 'test_get_where' + self.make_table(table_name) + + # Create some dummy records + self.populate_table(table_name) + + # Call the testing target + results = self.ctrlr.get_where( + table_name, + ['id', 'y'], + ({'x': 'even'}, {'x': 'odd'}), + 'x = :x', + unpack_scalars=False, # this makes it more than one row of results + ) + evens = results[0] + odds = results[1] + + # Verify query + assert evens == [(i + 1, i) for i in range(0, 10) if not i % 2] + assert odds == [(i + 1, i) for i in range(0, 10) if i % 2] + + def test_get_where_eq(self): + table_name = 'test_get_where_eq' + self.make_table(table_name) + + # Create some dummy records + self.populate_table(table_name) + + # Call the testing target + results = self.ctrlr.get_where_eq( + table_name, + ['id', 'y'], + (['even', 8], ['odd', 7]), # params_iter + ('x', 'y'), # where_colnames + op='AND', + unpack_scalars=True, + ) + + # Verify query + assert results == [(9, 8), (8, 7)] + + def test_get_all(self): + # Make a table for records + table_name = 'test_getting' + self.make_table(table_name) + + # Create some dummy records + insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') + with self.ctrlr.connect() as conn: + for i in range(0, 10): + x, y, z = (str(i), i, i * 2.01) + conn.execute(insert_stmt, x=x, y=y, z=z) + + # Build the expect results of the testing target + results = conn.execute(f'SELECT id, x, z FROM {table_name}') + rows = results.fetchall() + row_mapping = {row[0]: row[1:] for row in rows if row[1]} + + # Call the testing target + data = self.ctrlr.get(table_name, ['x', 'z']) + + # Verify getting + assert data == [r for r in row_mapping.values()] + + def test_get_by_id(self): + # Make a table for records + table_name = 'test_getting' + self.make_table(table_name) + + # Create some dummy records + insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') + with self.ctrlr.connect() as conn: + for i in range(0, 10): + x, y, z = (str(i), i, i * 2.01) + conn.execute(insert_stmt, x=x, y=y, z=z) + + # Call the testing target + requested_ids = [2, 4, 6] + data = self.ctrlr.get(table_name, ['x', 'z'], requested_ids) + + # Build the expect results of the testing target + sql_array = ', '.join([str(id) for id in requested_ids]) + with self.ctrlr.connect() as conn: + results = conn.execute( + f'SELECT x, z FROM {table_name} WHERE id in ({sql_array})' + ) + expected = results.fetchall() + # Verify getting + assert data == expected + + def test_get_by_numpy_array_of_ids(self): + # Make a table for records + table_name = 'test_getting' + self.make_table(table_name) + + # Create some dummy records + insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') + with self.ctrlr.connect() as conn: + for i in range(0, 10): + x, y, z = (str(i), i, i * 2.01) + conn.execute(insert_stmt, x=x, y=y, z=z) + + # Call the testing target + requested_ids = np.array([2, 4, 6]) + data = self.ctrlr.get(table_name, ['x', 'z'], requested_ids) + + # Build the expect results of the testing target + sql_array = ', '.join([str(id) for id in requested_ids]) + with self.ctrlr.connect() as conn: + results = conn.execute( + f'SELECT x, z FROM {table_name} WHERE id in ({sql_array})' + ) + expected = results.fetchall() + # Verify getting + assert data == expected + + def test_get_as_unique(self): + # This test could be inaccurate, because this logical path appears + # to be bolted on the side. Usage of this path's feature is unknown. + + # Make a table for records + table_name = 'test_getting' + self.make_table(table_name) + + # Create some dummy records + insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') + with self.ctrlr.connect() as conn: + for i in range(0, 10): + x, y, z = (str(i), i, i * 2.01) + conn.execute(insert_stmt, x=x, y=y, z=z) + + # Call the testing target + # The table has a INTEGER PRIMARY KEY, which essentially maps to the rowid + # in SQLite. So, we need not change the default `id_colname` param. + requested_ids = [2, 4, 6] + data = self.ctrlr.get(table_name, ['x'], requested_ids, assume_unique=True) + + # Build the expect results of the testing target + sql_array = ', '.join([str(id) for id in requested_ids]) + with self.ctrlr.connect() as conn: + results = conn.execute( + f'SELECT x FROM {table_name} WHERE id in ({sql_array})' + ) + # ... recall that the controller unpacks single values + expected = [row[0] for row in results] + # Verify getting + assert data == expected + + +class TestSettingAPI(BaseAPITestCase): + def test_setting(self): + # Note, this is not a comprehensive test. It only attempts to test the SQL logic. + # Make a table for records + table_name = 'test_setting' + self.make_table(table_name) + + # Create some dummy records + insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') + with self.ctrlr.connect() as conn: + for i in range(0, 10): + x, y, z = (str(i), i, i * 2.01) + conn.execute(insert_stmt, x=x, y=y, z=z) + + results = conn.execute(f'SELECT id, CAST((y%2) AS BOOL) FROM {table_name}') + rows = results.fetchall() + ids = [row[0] for row in rows if row[1]] + + # Call the testing target + self.ctrlr.set( + table_name, ['x', 'z'], [('even', 0.0)] * len(ids), ids, id_colname='id' + ) + + # Verify setting + sql_array = ', '.join([str(id) for id in ids]) + with self.ctrlr.connect() as conn: + results = conn.execute( + f'SELECT id, x, z FROM {table_name} ' f'WHERE id in ({sql_array})' + ) + expected = sorted(map(lambda a: tuple([a] + ['even', 0.0]), ids)) + set_rows = sorted(results) + assert set_rows == expected + + +class TestDeletionAPI(BaseAPITestCase): + def test_delete(self): + # Make a table for records + table_name = 'test_delete' + self.make_table(table_name) + + # Create some dummy records + insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') + with self.ctrlr.connect() as conn: + for i in range(0, 10): + x, y, z = (str(i), i, i * 2.01) + conn.execute(insert_stmt, x=x, y=y, z=z) + + results = conn.execute(f'SELECT id, CAST((y % 2) AS BOOL) FROM {table_name}') + rows = results.fetchall() + del_ids = [row[0] for row in rows if row[1]] + remaining_ids = sorted([row[0] for row in rows if not row[1]]) + + # Call the testing target + self.ctrlr.delete(table_name, del_ids, 'id') + + # Verify the deletion + with self.ctrlr.connect() as conn: + results = conn.execute(f'SELECT id FROM {table_name}') + assert sorted([r[0] for r in results]) == remaining_ids + + def test_delete_rowid(self): + # Make a table for records + table_name = 'test_delete_rowid' + self.make_table(table_name) + + # Create some dummy records + insert_stmt = text(f'INSERT INTO {table_name} (x, y, z) VALUES (:x, :y, :z)') + with self.ctrlr.connect() as conn: + for i in range(0, 10): + x, y, z = (str(i), i, i * 2.01) + conn.execute(insert_stmt, x=x, y=y, z=z) + + results = conn.execute( + f'SELECT rowid, CAST((y % 2) AS BOOL) FROM {table_name}' + ) + rows = results.fetchall() + del_ids = [row[0] for row in rows if row[1]] + remaining_ids = sorted([row[0] for row in rows if not row[1]]) + + # Call the testing target + self.ctrlr.delete_rowids(table_name, del_ids) + + # Verify the deletion + with self.ctrlr.connect() as conn: + results = conn.execute(f'SELECT rowid FROM {table_name}') + assert sorted([r[0] for r in results]) == remaining_ids diff --git a/wbia/tests/dtool/test_types.py b/wbia/tests/dtool/test_types.py index d342a3ca03..8b7c26c09b 100644 --- a/wbia/tests/dtool/test_types.py +++ b/wbia/tests/dtool/test_types.py @@ -193,3 +193,20 @@ def test_uuid(db): results = db.execute(stmt) selected_value = results.fetchone()[0] assert selected_value == insert_value + + +def test_le_uuid(db): + db.execute(text('CREATE TABLE test(x UUID)')) + + # Insert a uuid value but explicitly stored as little endian + # (the way uuids were stored before sqlalchemy) + insert_value = uuid.uuid4() + stmt = text('INSERT INTO test(x) VALUES (:x)') + db.execute(stmt, x=insert_value.bytes_le) + + # Query for the value + stmt = text('SELECT x FROM test') + stmt = stmt.columns(x=UUID) + results = db.execute(stmt) + selected_value = results.fetchone()[0] + assert selected_value == insert_value diff --git a/wbia/tests/web/__init__.py b/wbia/tests/web/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wbia/tests/web/test_routes.py b/wbia/tests/web/test_routes.py new file mode 100644 index 0000000000..967fdc9298 --- /dev/null +++ b/wbia/tests/web/test_routes.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +import wbia + + +def test_turk_identification_no_more_to_review(): + with wbia.opendb_with_web('testdb2') as (ibs, client): + resp = client.get('/turk/identification/lnbnn/') + assert resp.status_code == 200 + assert b'Traceback' not in resp.data + assert b'

No more to review!

' in resp.data diff --git a/wbia/viz/viz_matches.py b/wbia/viz/viz_matches.py index 01710ad9b4..78b7e52837 100644 --- a/wbia/viz/viz_matches.py +++ b/wbia/viz/viz_matches.py @@ -41,6 +41,8 @@ def get_query_annot_pair_info( 'curvrankdorsal', 'curvrankfinfindrhybriddorsal', 'curvrankfluke', + 'curvranktwodorsal', + 'curvranktwofluke', 'deepsense', 'finfindr', 'kaggle7', @@ -86,6 +88,8 @@ def get_data_annot_pair_info( 'curvrankdorsal', 'curvrankfinfindrhybriddorsal', 'curvrankfluke', + 'curvranktwodorsal', + 'curvranktwofluke', 'deepsense', 'finfindr', 'kaggle7', diff --git a/wbia/web/apis.py b/wbia/web/apis.py index 03ee516329..798a3e12b6 100644 --- a/wbia/web/apis.py +++ b/wbia/web/apis.py @@ -54,6 +54,17 @@ def web_embed(*args, **kwargs): ut.embed() +@register_route( + '/api/image/src/.jpg', + methods=['GET'], + __route_prefix_check__=False, + __route_postfix_check__=False, + __route_authenticate__=False, +) +def image_src_api_ext(*args, **kwargs): + return image_src_api(*args, **kwargs) + + # Special function that is a route only to ignore the JSON response, but is # actually (and should be) an API call @register_route( @@ -67,12 +78,12 @@ def image_src_api(rowid=None, thumbnail=False, fresh=False, **kwargs): Returns the image file of image Example: - >>> # xdoctest: +REQUIRES(--web-tests) >>> from wbia.web.app import * # NOQA >>> import wbia - >>> with wbia.opendb_bg_web('testdb1', start_job_queue=False, managed=True) as web_ibs: - ... resp = web_ibs.send_wbia_request('/api/image/src/1/', type_='get', json=False) - >>> print(resp) + >>> with wbia.opendb_with_web('testdb1') as (ibs, client): + ... resp = client.get('/api/image/src/1/') + >>> print(resp.data) + b'\xff\xd8\xff\xe0\x00\x10JFIF... RESTful: Method: GET @@ -125,12 +136,13 @@ def annot_src_api(rowid=None, fresh=False, **kwargs): Example: >>> # xdoctest: +REQUIRES(--slow) - >>> # WEB_DOCTEST + >>> # xdoctest: +REQUIRES(--web-tests) >>> from wbia.web.app import * # NOQA >>> import wbia - >>> with wbia.opendb_bg_web('testdb1', start_job_queue=False, managed=True) as web_ibs: - ... resp = web_ibs.send_wbia_request('/api/annot/src/1/', type_='get', json=False) - >>> print(resp) + >>> with wbia.opendb_with_web('testdb1') as (ibs, client): + ... resp = client.get('/api/annot/src/1/') + >>> print(resp.data) + b'\xff\xd8\xff\xe0\x00\x10JFIF... RESTful: Method: GET @@ -173,12 +185,13 @@ def background_src_api(rowid=None, fresh=False, **kwargs): Example: >>> # xdoctest: +REQUIRES(--slow) - >>> # WEB_DOCTEST + >>> # xdoctest: +REQUIRES(--web-tests) >>> from wbia.web.app import * # NOQA >>> import wbia - >>> with wbia.opendb_bg_web('testdb1', start_job_queue=False, managed=True) as web_ibs: - ... resp = web_ibs.send_wbia_request('/api/background/src/1/', type_='get', json=False) - >>> print(resp) + >>> with wbia.opendb_with_web('testdb1') as (ibs, client): + ... resp = client.get('/api/background/src/1/') + >>> print(resp.data) + b'\xff\xd8\xff\xe0\x00\x10JFIF... RESTful: Method: GET @@ -223,9 +236,10 @@ def image_src_api_json(uuid=None, **kwargs): >>> # xdoctest: +REQUIRES(--web-tests) >>> from wbia.web.app import * # NOQA >>> import wbia - >>> with wbia.opendb_bg_web('testdb1', start_job_queue=False, managed=True) as web_ibs: - ... resp = web_ibs.send_wbia_request('/api/image/src/json/0a9bc03d-a75e-8d14-0153-e2949502aba7/', type_='get', json=False) - >>> print(resp) + >>> with wbia.opendb_with_web('testdb1') as (ibs, client): + ... resp = client.get('/api/image/src/json/0a9bc03d-a75e-8d14-0153-e2949502aba7/') + >>> print(resp.data) + b'\xff\xd8\xff\xe0\x00\x10JFIF... RESTful: Method: GET @@ -425,37 +439,21 @@ def image_upload_zip(**kwargs): @register_api('/api/test/helloworld/', methods=['GET', 'POST', 'DELETE', 'PUT']) def hello_world(*args, **kwargs): """ - CommandLine: - python -m wbia.web.apis --exec-hello_world:0 - python -m wbia.web.apis --exec-hello_world:1 Example: - >>> # xdoctest: +REQUIRES(--web-tests) - >>> from wbia.web.app import * # NOQA - >>> import wbia - >>> web_ibs = wbia.opendb_bg_web(browser=True, start_job_queue=False, url_suffix='/api/test/helloworld/?test0=0') # start_job_queue=False) - >>> print('web_ibs = %r' % (web_ibs,)) - >>> print('Server will run until control c') - >>> web_ibs.terminate2() - - Example1: >>> # xdoctest: +REQUIRES(--web-tests) >>> from wbia.web.app import * # NOQA >>> import wbia >>> import requests >>> import wbia - >>> with wbia.opendb_bg_web('testdb1', start_job_queue=False, managed=True) as web_ibs: - ... web_port = ibs.get_web_port_via_scan() - ... if web_port is None: - ... raise ValueError('IA web server is not running on any expected port') - ... domain = 'http://127.0.0.1:%s' % (web_port, ) - ... url = domain + '/api/test/helloworld/?test0=0' + >>> with wbia.opendb_with_web('testdb1') as (ibs, client): + ... resp = client.get('/api/test/helloworld/?test0=0') ... payload = { ... 'test1' : 'test1', ... 'test2' : None, # NOTICE test2 DOES NOT SHOW UP ... } - ... resp = requests.post(url, data=payload) - ... print(resp) + ... resp = client.post('/api/test/helloworld/', data=payload) + """ logger.info('+------------ HELLO WORLD ------------') logger.info('Args: %r' % (args,)) diff --git a/wbia/web/apis_detect.py b/wbia/web/apis_detect.py index 79acf69282..d31abd5297 100644 --- a/wbia/web/apis_detect.py +++ b/wbia/web/apis_detect.py @@ -11,7 +11,6 @@ from wbia.web import appfuncs as appf import numpy as np -(print, rrr, profile) = ut.inject2(__name__) logger = logging.getLogger('wbia') CLASS_INJECT_KEY, register_ibs_method = controller_inject.make_ibs_register_decorator( @@ -440,6 +439,7 @@ def process_detection_html(ibs, **kwargs): return result_dict +# this is where the aids_list response from commit_localization_results is packaged & returned in json @register_ibs_method @accessor_decors.getter_1to1 def detect_cnn_json(ibs, gid_list, detect_func, config={}, **kwargs): @@ -465,40 +465,54 @@ def detect_cnn_json(ibs, gid_list, detect_func, config={}, **kwargs): >>> print(results_dict) """ # TODO: Return confidence here as well + def _json_result(ibs, aid): + result = { + 'id': aid, + 'uuid': ibs.get_annot_uuids(aid), + 'xtl': ibs.get_annot_bboxes(aid)[0], + 'ytl': ibs.get_annot_bboxes(aid)[1], + 'left': ibs.get_annot_bboxes(aid)[0], + 'top': ibs.get_annot_bboxes(aid)[1], + 'width': ibs.get_annot_bboxes(aid)[2], + 'height': ibs.get_annot_bboxes(aid)[3], + 'theta': round(ibs.get_annot_thetas(aid), 4), + 'confidence': round(ibs.get_annot_detect_confidence(aid), 4), + 'class': ibs.get_annot_species_texts(aid), + 'species': ibs.get_annot_species_texts(aid), + 'viewpoint': ibs.get_annot_viewpoints(aid), + 'quality': ibs.get_annot_qualities(aid), + 'multiple': ibs.get_annot_multiple(aid), + 'interest': ibs.get_annot_interest(aid), + } + return result + image_uuid_list = ibs.get_image_uuids(gid_list) ibs.assert_valid_gids(gid_list) - # Get detections from depc + # Get detections from depc --- this output will be affected by assigner aids_list = detect_func(gid_list, **config) - results_list = [ - [ - { - 'id': aid, - 'uuid': ibs.get_annot_uuids(aid), - 'xtl': ibs.get_annot_bboxes(aid)[0], - 'ytl': ibs.get_annot_bboxes(aid)[1], - 'left': ibs.get_annot_bboxes(aid)[0], - 'top': ibs.get_annot_bboxes(aid)[1], - 'width': ibs.get_annot_bboxes(aid)[2], - 'height': ibs.get_annot_bboxes(aid)[3], - 'theta': round(ibs.get_annot_thetas(aid), 4), - 'confidence': round(ibs.get_annot_detect_confidence(aid), 4), - 'class': ibs.get_annot_species_texts(aid), - 'species': ibs.get_annot_species_texts(aid), - 'viewpoint': ibs.get_annot_viewpoints(aid), - 'quality': ibs.get_annot_qualities(aid), - 'multiple': ibs.get_annot_multiple(aid), - 'interest': ibs.get_annot_interest(aid), - } - for aid in aid_list - ] - for aid_list in aids_list - ] + results_list = [] + has_assignments = False + for aid_list in aids_list: + result_list = [] + for aid in aid_list: + if not isinstance(aid, tuple): # we have an assignment + result = _json_result(ibs, aid) + else: + assert len(aid) > 0 + has_assignments = True + result = [] + for val in aid: + result.append(_json_result(ibs, val)) + result_list.append(result) + results_list.append(result_list) + score_list = [0.0] * len(gid_list) # Wrap up results with other information results_dict = { 'image_uuid_list': image_uuid_list, 'results_list': results_list, 'score_list': score_list, + 'has_assignments': has_assignments, } return results_dict @@ -895,6 +909,8 @@ def commit_localization_results( use_labeler_species=False, orienter_algo=None, orienter_model_tag=None, + assigner_algo=None, + assigner_model_tag=None, update_json_log=True, apply_nms_post_use_labeler_species=True, **kwargs, @@ -981,10 +997,21 @@ def commit_localization_results( if len(bbox_list) > 0: ibs.set_annot_bboxes(aid_list, bbox_list, theta_list=theta_list) + if assigner_algo is not None: + # aids_list is a list of lists of aids, now we want a list of lists of + all_assignments = [] + for aids in aids_list: + assigned, unassigned = ibs.assign_parts_one_image(aids) + # unassigned aids should also be tuples to indicate they went through the assigner + unassigned = [(aid,) for aid in unassigned] + all_assignments.append(assigned + unassigned) + aids_list = all_assignments + ibs._clean_species() if update_json_log: ibs.log_detections(aid_list) + # list of list of ints return aids_list diff --git a/wbia/web/apis_engine.py b/wbia/web/apis_engine.py index 022293f0de..aa7700d929 100644 --- a/wbia/web/apis_engine.py +++ b/wbia/web/apis_engine.py @@ -262,6 +262,7 @@ def start_identify_annots( Example: >>> # xdoctest: +REQUIRES(--web-tests) + >>> # xdoctest: +REQUIRES(--job-engine-tests) >>> from wbia.web.apis_engine import * # NOQA >>> import wbia >>> with wbia.opendb_bg_web('testdb1', managed=True) as web_ibs: # , domain='http://52.33.105.88') @@ -452,6 +453,7 @@ def start_identify_annots_query( Example: >>> # DISABLE_DOCTEST + >>> # xdoctest: +REQUIRES(--job-engine-tests) >>> from wbia.web.apis_engine import * # NOQA >>> import wbia >>> #domain = 'localhost' @@ -533,6 +535,8 @@ def sanitize(state): 'curvrankdorsal', 'curvrankfinfindrhybriddorsal', 'curvrankfluke', + 'curvranktwodorsal', + 'curvranktwofluke', ): curvrank_daily_tag = query_config_dict.get('curvrank_daily_tag', '') if len(curvrank_daily_tag) > 144: diff --git a/wbia/web/apis_microsoft.py b/wbia/web/apis_microsoft.py index 5e11535d6a..904f348751 100644 --- a/wbia/web/apis_microsoft.py +++ b/wbia/web/apis_microsoft.py @@ -1060,7 +1060,7 @@ def microsoft_identify( $ref: "#/definitions/Annotation" - name: algorithm in: formData - description: The algorithm you with to run ID with. Must be one of "HotSpotter", "CurvRank", "Finfindr", or "Deepsense" + description: The algorithm you with to run ID with. Must be one of "HotSpotter", "CurvRank", "CurvRankTwo", "Finfindr", or "Deepsense" required: true type: string - name: callback_url @@ -1104,11 +1104,13 @@ def microsoft_identify( assert algorithm in [ 'hotspotter', 'curvrank', + 'curvrank_v2', + 'curvrankv2', 'deepsense', 'finfindr', 'kaggle7', 'kaggleseven', - ], 'Must specify the algorithm for ID as HotSpotter, CurvRank, Deepsense, Finfindr, Kaggle7' + ], 'Must specify the algorithm for ID as HotSpotter, CurvRank, CurvRankTwo, Deepsense, Finfindr, Kaggle7' parameter = 'callback_url' assert callback_url is None or isinstance( @@ -1140,6 +1142,10 @@ def microsoft_identify( query_config_dict = { 'pipeline_root': 'CurvRankFluke', } + elif algorithm in ['curvrank_v2', 'curvrankv2']: + query_config_dict = { + 'pipeline_root': 'CurvRankTwoFluke', + } elif algorithm in ['deepsense']: query_config_dict = { 'pipeline_root': 'Deepsense', diff --git a/wbia/web/apis_query.py b/wbia/web/apis_query.py index d448d3dd9d..14274e107a 100644 --- a/wbia/web/apis_query.py +++ b/wbia/web/apis_query.py @@ -325,6 +325,9 @@ def review_graph_match_html( Example: >>> # xdoctest: +REQUIRES(--web-tests) + >>> # xdoctest: +REQUIRES(--job-engine-tests) + >>> # DISABLE_DOCTEST + >>> # Disabled because this test uses opendb_bg_web, which hangs the test runner and leaves zombie processes >>> from wbia.web.apis_query import * # NOQA >>> import wbia >>> web_ibs = wbia.opendb_bg_web('testdb1') # , domain='http://52.33.105.88') @@ -377,6 +380,7 @@ def review_graph_match_html( Example2: >>> # DISABLE_DOCTEST + >>> # xdoctest: +REQUIRES(--job-engine-tests) >>> # This starts off using web to get information, but finishes the rest in python >>> from wbia.web.apis_query import * # NOQA >>> import wbia @@ -433,6 +437,8 @@ def review_graph_match_html( 'curvrankdorsal', 'curvrankfinfindrhybriddorsal', 'curvrankfluke', + 'curvranktwodorsal', + 'curvranktwofluke', 'deepsense', 'finfindr', 'kaggle7', @@ -544,16 +550,6 @@ def review_graph_match_html( @register_route('/test/review/query/chip/', methods=['GET']) def review_query_chips_test(**kwargs): - """ - CommandLine: - python -m wbia.web.apis_query review_query_chips_test --show - - Example: - >>> # SCRIPT - >>> import wbia - >>> web_ibs = wbia.opendb_bg_web( - >>> browser=True, url_suffix='/test/review/query/chip/?__format__=true') - """ ibs = current_app.ibs # the old block curvature dtw @@ -568,6 +564,10 @@ def review_query_chips_test(**kwargs): query_config_dict = {'pipeline_root': 'CurvRankFinfindrHybridDorsal'} elif 'use_curvrank_fluke' in request.args: query_config_dict = {'pipeline_root': 'CurvRankFluke'} + elif 'use_curvrank_v2_dorsal' in request.args: + query_config_dict = {'pipeline_root': 'CurvRankTwoDorsal'} + elif 'use_curvrank_v2_fluke' in request.args: + query_config_dict = {'pipeline_root': 'CurvRankTwoFluke'} elif 'use_deepsense' in request.args: query_config_dict = {'pipeline_root': 'Deepsense'} elif 'use_finfindr' in request.args: @@ -762,6 +762,24 @@ def query_chips_graph_complete(ibs, aid_list, query_config_dict={}, k=5, **kwarg return result_dict +@register_ibs_method +def log_render_status(ibs, *args): + import os + + json_log_path = ibs.get_logdir_local() + json_log_filename = 'render.log' + json_log_filepath = os.path.join(json_log_path, json_log_filename) + logger.info('Logging renders added to: %r' % (json_log_filepath,)) + + try: + with open(json_log_filepath, 'a') as json_log_file: + line = ','.join(['%s' % (arg,) for arg in args]) + line = '%s\n' % (line,) + json_log_file.write(line) + except Exception: + logger.info('WRITE RENDER.LOG FAILED') + + @register_ibs_method @register_api('/api/query/graph/', methods=['GET', 'POST']) def query_chips_graph( @@ -772,7 +790,7 @@ def query_chips_graph( query_config_dict={}, echo_query_params=True, cache_images=True, - n=16, + n=30, view_orientation='horizontal', return_summary=True, **kwargs, @@ -963,6 +981,20 @@ def convert_to_uuid(nid): except Exception: filepath_matches = None extern_flag = 'error' + log_render_status( + ibs, + cm.qaid, + daid, + quuid, + duuid, + cm, + qreq_, + view_orientation, + True, + False, + filepath_matches, + extern_flag, + ) try: _, filepath_heatmask = ensure_review_image( ibs, @@ -976,6 +1008,20 @@ def convert_to_uuid(nid): except Exception: filepath_heatmask = None extern_flag = 'error' + log_render_status( + ibs, + cm.qaid, + daid, + quuid, + duuid, + cm, + qreq_, + view_orientation, + False, + True, + filepath_heatmask, + extern_flag, + ) try: _, filepath_clean = ensure_review_image( ibs, @@ -989,6 +1035,20 @@ def convert_to_uuid(nid): except Exception: filepath_clean = None extern_flag = 'error' + log_render_status( + ibs, + cm.qaid, + daid, + quuid, + duuid, + cm, + qreq_, + view_orientation, + False, + False, + filepath_clean, + extern_flag, + ) if filepath_matches is not None: args = ( @@ -1408,13 +1468,12 @@ def query_chips_graph_v2( >>> # Open local instance >>> ibs = wbia.opendb('PZ_MTEST') >>> uuid_list = ibs.annots().uuids[0:10] - >>> # Start up the web instance - >>> web_ibs = wbia.opendb_bg_web(db='PZ_MTEST', web=True, browser=False) >>> data = dict(annot_uuid_list=uuid_list) - >>> resp = web_ibs.send_wbia_request('/api/query/graph/v2/', **data) - >>> print('resp = %r' % (resp,)) - >>> #cmdict_list = json_dict['response'] - >>> #assert 'score_list' in cmdict_list[0] + >>> # Start up the web instance + >>> with wbia.opendb_with_web(db='PZ_MTEST') as (ibs, client): + ... resp = client.post('/api/query/graph/v2/', data=data) + >>> resp.json + {'status': {'success': False, 'code': 608, 'message': 'Invalid image and/or annotation UUIDs (0, 1)', 'cache': -1}, 'response': {'invalid_image_uuid_list': [], 'invalid_annot_uuid_list': [[0, '...']]}} Example: >>> # DEBUG_SCRIPT diff --git a/wbia/web/apis_sync.py b/wbia/web/apis_sync.py index 594289db53..04f8139653 100644 --- a/wbia/web/apis_sync.py +++ b/wbia/web/apis_sync.py @@ -47,7 +47,7 @@ REMOTE_UUID = None -REMOTE_URL = 'http://%s:%s' % (REMOTE_DOMAIN, REMOTE_PORT) +REMOTE_URL = 'https://%s:%s' % (REMOTE_DOMAIN, REMOTE_PORT) REMOTE_UUID = None if REMOTE_UUID is None else uuid.UUID(REMOTE_UUID) @@ -366,7 +366,9 @@ def sync_get_training_data(ibs, species_name, force_update=False, **kwargs): name_texts = ibs._sync_get_annot_endpoint('/api/annot/name/text/', aid_list) name_uuids = ibs._sync_get_annot_endpoint('/api/annot/name/uuid/', aid_list) images = ibs._sync_get_annot_endpoint('/api/annot/image/rowid/', aid_list) - gpaths = [ibs._construct_route_url_ibs('/api/image/src/%s/' % gid) for gid in images] + gpaths = [ + ibs._construct_route_url_ibs('/api/image/src/%s.jpg' % gid) for gid in images + ] specieses = [species_name] * len(aid_list) gid_list = ibs.add_images(gpaths) diff --git a/wbia/web/app.py b/wbia/web/app.py index 9c397cd926..1368673d1f 100644 --- a/wbia/web/app.py +++ b/wbia/web/app.py @@ -37,14 +37,13 @@ def tst_html_error(): r""" This test will show what our current errors look like - CommandLine: - python -m wbia.web.app --exec-tst_html_error - Example: - >>> # DISABLE_DOCTEST - >>> from wbia.web.app import * # NOQA >>> import wbia - >>> web_ibs = wbia.opendb_bg_web(browser=True, start_job_queue=False, url_suffix='/api/image/imagesettext/?__format__=True') + >>> with wbia.opendb_with_web('testdb1') as (ibs, client): + ... resp = client.get('/api/image/imagesettext/?__format__=True') + >>> print(resp) + + """ pass diff --git a/wbia/web/job_engine.py b/wbia/web/job_engine.py index f61ff4ed66..d125e17f8b 100644 --- a/wbia/web/job_engine.py +++ b/wbia/web/job_engine.py @@ -162,6 +162,7 @@ def initialize_job_manager(ibs): Example: >>> # DISABLE_DOCTEST + >>> # xdoctest: +REQUIRES(--job-engine-tests) >>> from wbia.web.job_engine import * # NOQA >>> import wbia >>> ibs = wbia.opendb(defaultdb='testdb1') @@ -174,24 +175,6 @@ def initialize_job_manager(ibs): >>> ibs.close_job_manager() >>> print('Closing success.') - Example: - >>> # xdoctest: +REQUIRES(--web-tests) - >>> from wbia.web.job_engine import * # NOQA - >>> import wbia - >>> import requests - >>> with wbia.opendb_bg_web(db='testdb1', managed=True) as web_instance: - ... web_port = ibs.get_web_port_via_scan() - ... if web_port is None: - ... raise ValueError('IA web server is not running on any expected port') - ... baseurl = 'http://127.0.1.1:%s' % (web_port, ) - ... _payload = {'image_attrs_list': [], 'annot_attrs_list': []} - ... payload = ut.map_dict_vals(ut.to_json, _payload) - ... resp1 = requests.post(baseurl + '/api/test/helloworld/?f=b', data=payload) - ... #resp2 = requests.post(baseurl + '/api/image/json/', data=payload) - ... #print(resp2) - ... #json_dict = resp2.json() - ... #text = json_dict['response'] - ... #print(text) """ ibs.job_manager = ut.DynStruct() @@ -265,6 +248,7 @@ def get_job_id_list(ibs): Example: >>> # xdoctest: +REQUIRES(--web-tests) + >>> # xdoctest: +REQUIRES(--job-engine-tests) >>> from wbia.web.job_engine import * # NOQA >>> import wbia >>> with wbia.opendb_bg_web('testdb1', managed=True) as web_ibs: # , domain='http://52.33.105.88') @@ -317,6 +301,7 @@ def get_job_status(ibs, jobid=None): Example: >>> # xdoctest: +REQUIRES(--web-tests) + >>> # xdoctest: +REQUIRES(--job-engine-tests) >>> from wbia.web.job_engine import * # NOQA >>> import wbia >>> with wbia.opendb_bg_web('testdb1', managed=True) as web_ibs: # , domain='http://52.33.105.88') @@ -362,8 +347,10 @@ def get_job_metadata(ibs, jobid): python -m wbia.web.job_engine --exec-get_job_metadata:0 --fg Example: + >>> # xdoctest: +REQUIRES(--web-tests) >>> # xdoctest: +REQUIRES(--slow) - >>> # WEB_DOCTEST + >>> # xdoctest: +REQUIRES(--job-engine-tests) + >>> # xdoctest: +REQUIRES(--web-tests) >>> from wbia.web.job_engine import * # NOQA >>> import wbia >>> with wbia.opendb_bg_web('testdb1', managed=True) as web_ibs: # , domain='http://52.33.105.88') diff --git a/wbia/web/routes.py b/wbia/web/routes.py index 84070ad708..71d3d3dea7 100644 --- a/wbia/web/routes.py +++ b/wbia/web/routes.py @@ -279,6 +279,8 @@ def _date_list(gid_list): return date_list def filter_annots_imageset(aid_list): + if not aid_list: # no need to filter if empty + return aid_list try: imgsetid = request.args.get('imgsetid', '') imgsetid = int(imgsetid) @@ -296,6 +298,8 @@ def filter_annots_imageset(aid_list): return aid_list def filter_images_imageset(gid_list): + if not gid_list: # no need to filter if empty + return gid_list try: imgsetid = request.args.get('imgsetid', '') imgsetid = int(imgsetid) @@ -313,6 +317,8 @@ def filter_images_imageset(gid_list): return gid_list def filter_names_imageset(nid_list): + if not nid_list: # no need to filter if empty + return nid_list try: imgsetid = request.args.get('imgsetid', '') imgsetid = int(imgsetid) @@ -333,6 +339,8 @@ def filter_names_imageset(nid_list): return nid_list def filter_annots_general(ibs, aid_list): + if not aid_list: # no need to filter if empty + return aid_list if ibs.dbname == 'GGR-IBEIS': # Grevy's filter_kw = { @@ -1013,6 +1021,8 @@ def _date_list(gid_list): return date_list def filter_annots_imageset(aid_list): + if not aid_list: # no need to filter if empty + return aid_list try: imgsetid = request.args.get('imgsetid', '') imgsetid = int(imgsetid) @@ -1030,6 +1040,8 @@ def filter_annots_imageset(aid_list): return aid_list def filter_annots_general(ibs, aid_list): + if not aid_list: # no need to filter if empty + return aid_list if ibs.dbname == 'GGR-IBEIS': # Grevy's filter_kw = { @@ -1233,6 +1245,8 @@ def _date_list(gid_list): return date_list def filter_species_of_interest(gid_list): + if not gid_list: # no need to filter if empty + return gid_list wanted_set = set(['zebra_plains', 'zebra_grevys', 'giraffe_masai']) aids_list = ibs.get_image_aids(gid_list) speciess_list = ut.unflat_map(ibs.get_annot_species_texts, aids_list) @@ -1245,6 +1259,8 @@ def filter_species_of_interest(gid_list): return gid_list_filtered def filter_viewpoints_of_interest(gid_list, allowed_viewpoint_list): + if not gid_list: # no need to filter if empty + return gid_list aids_list = ibs.get_image_aids(gid_list) wanted_set = set(allowed_viewpoint_list) viewpoints_list = ut.unflat_map(ibs.get_annot_viewpoints, aids_list) @@ -1257,6 +1273,8 @@ def filter_viewpoints_of_interest(gid_list, allowed_viewpoint_list): return gid_list_filtered def filter_bad_metadata(gid_list): + if not gid_list: # no need to filter if empty + return gid_list wanted_set = set(['2015/03/01', '2015/03/02', '2016/01/30', '2016/01/31']) date_list = _date_list(gid_list) gps_list = ibs.get_image_gps(gid_list) @@ -1267,6 +1285,8 @@ def filter_bad_metadata(gid_list): return gid_list_filtered def filter_bad_quality(gid_list, allowed_quality_list): + if not gid_list: # no need to filter if empty + return gid_list aids_list = ibs.get_image_aids(gid_list) wanted_set = set(allowed_quality_list) qualities_list = ut.unflat_map(ibs.get_annot_quality_texts, aids_list) @@ -1484,7 +1504,7 @@ def view_imagesets(**kwargs): all_gid_list = ibs.get_valid_gids() all_aid_list = ibs.get_valid_aids() - gids_list = [ibs.get_valid_gids(imgsetid=imgsetid_) for imgsetid_ in imgsetid_list] + gids_list = ibs.get_valid_gids(imgsetid_list=imgsetid_list) num_gids = list(map(len, gids_list)) ###################################################################################### @@ -1777,9 +1797,7 @@ def view_images(**kwargs): None if imgsetid_ == 'None' or imgsetid_ == '' else int(imgsetid_) for imgsetid_ in imgsetid_list ] - gid_list = ut.flatten( - [ibs.get_valid_gids(imgsetid=imgsetid) for imgsetid_ in imgsetid_list] - ) + gid_list = ut.flatten(ibs.get_valid_gids(imgsetid_list=imgsetid_list)) else: gid_list = ibs.get_valid_gids() filtered = False @@ -1868,9 +1886,7 @@ def view_annotations(**kwargs): None if imgsetid_ == 'None' or imgsetid_ == '' else int(imgsetid_) for imgsetid_ in imgsetid_list ] - gid_list = ut.flatten( - [ibs.get_valid_gids(imgsetid=imgsetid_) for imgsetid_ in imgsetid_list] - ) + gid_list = ut.flatten(ibs.get_valid_gids(imgsetid_list=imgsetid_list)) aid_list = ut.flatten(ibs.get_image_aids(gid_list)) else: aid_list = ibs.get_valid_aids() @@ -2028,9 +2044,7 @@ def view_names(**kwargs): None if imgsetid_ == 'None' or imgsetid_ == '' else int(imgsetid_) for imgsetid_ in imgsetid_list ] - gid_list = ut.flatten( - [ibs.get_valid_gids(imgsetid=imgsetid_) for imgsetid_ in imgsetid_list] - ) + gid_list = ut.flatten(ibs.get_valid_gids(imgsetid_list=imgsetid_list)) aid_list = ut.flatten(ibs.get_image_aids(gid_list)) nid_list = ibs.get_annot_name_rowids(aid_list) else: @@ -3960,11 +3974,6 @@ def check_engine_identification_query_object( if current_app.QUERY_OBJECT_JOBID is None: current_app.QUERY_OBJECT = None current_app.QUERY_OBJECT_JOBID = ibs.start_web_query_all() - # import wbia - # web_ibs = wbia.opendb_bg_web(dbdir=ibs.dbdir, port=6000) - # query_object_jobid = web_ibs.send_wbia_request('/api/engine/query/graph/') - # logger.info('query_object_jobid = %r' % (query_object_jobid, )) - # current_app.QUERY_OBJECT_JOBID = query_object_jobid query_object_status_dict = ibs.get_job_status(current_app.QUERY_OBJECT_JOBID) args = ( @@ -4004,11 +4013,11 @@ def turk_identification( >>> # SCRIPT >>> from wbia.other.ibsfuncs import * # NOQA >>> import wbia - >>> with wbia.opendb_bg_web('testdb1', managed=True) as web_ibs: - ... resp = web_ibs.get('/turk/identification/lnbnn/') + >>> with wbia.opendb_with_web('testdb1') as (ibs, client): + ... resp = client.get('/turk/identification/lnbnn/') >>> ut.quit_if_noshow() >>> import wbia.plottool as pt - >>> ut.render_html(resp.content) + >>> ut.render_html(resp.data.decode('utf8')) >>> ut.show_if_requested() """ from wbia.web import apis_query @@ -4040,7 +4049,10 @@ def turk_identification( review_cfg[ 'max_num' ] = global_feedback_limit # Controls the top X to be randomly sampled and displayed to all concurrent users - values = query_object.pop() + try: + values = query_object.pop() + except StopIteration as e: + return appf.template(None, 'simple', title=str(e).capitalize()) (review_aid1_list, review_aid2_list), review_confidence = values review_aid1_list = [review_aid1_list] review_aid2_list = [review_aid2_list] @@ -4653,6 +4665,7 @@ def turk_identification_graph_refer( annot_uuid_list=annot_uuid_list, hogwild_species=species, creation_imageset_rowid_list=[imgsetid], + census=True, ) elif option in ['rosemary']: imgsetid_ = ibs.get_imageset_imgsetids_from_text('RosemaryLoopsData') @@ -4745,8 +4758,8 @@ def turk_identification_hardcase(*args, **kwargs): >>> # SCRIPT >>> from wbia.other.ibsfuncs import * # NOQA >>> import wbia - >>> with wbia.opendb_bg_web('PZ_Master1', managed=True) as web_ibs: - ... resp = web_ibs.get('/turk/identification/hardcase/') + >>> with wbia.opendb_with_web('PZ_Master1') as (ibs, client): + ... resp = client.get('/turk/identification/hardcase/') Ignore: import wbia @@ -4791,6 +4804,7 @@ def turk_identification_graph( hogwild_species=None, creation_imageset_rowid_list=None, kaia=False, + census=False, **kwargs, ): """ @@ -4805,11 +4819,11 @@ def turk_identification_graph( >>> # SCRIPT >>> from wbia.other.ibsfuncs import * # NOQA >>> import wbia - >>> with wbia.opendb_bg_web('testdb1', managed=True) as web_ibs: - ... resp = web_ibs.get('/turk/identification/graph/') + >>> with wbia.opendb_with_web('testdb1') as (ibs, client): + ... resp = client.get('/turk/identification/graph/') >>> ut.quit_if_noshow() >>> import wbia.plottool as pt - >>> ut.render_html(resp.content) + >>> ut.render_html(resp.data.decode('utf8')) >>> ut.show_if_requested() """ ibs = current_app.ibs @@ -4953,6 +4967,22 @@ def turk_identification_graph( 'redun.neg': 2, 'redun.pos': 2, } + elif census: + logger.info('[routes] Graph is in CA-mode') + query_config_dict = { + 'autoreview.enabled': True, + 'autoreview.prioritize_nonpos': True, + 'inference.enabled': True, + 'ranking.enabled': True, + 'ranking.ntop': 20, + 'redun.enabled': True, + 'redun.enforce_neg': True, + 'redun.enforce_pos': True, + 'redun.neg.only_auto': False, + 'redun.neg': 3, + 'redun.pos': 3, + 'algo.hardcase': False, + } else: logger.info('[routes] Graph is not in hardcase-mode') query_config_dict = {} diff --git a/wbia/web/templates/index.html b/wbia/web/templates/index.html index 1ea2493c45..a0d762bd40 100644 --- a/wbia/web/templates/index.html +++ b/wbia/web/templates/index.html @@ -20,7 +20,7 @@

Welcome to IBEIS


For more information: http://wbia.org/
- To view the code repository: https://github.com/WildbookOrg/wbia + To view the code repository: https://github.com/WildMeOrg/wildbook-ia
To view the API settings: {{ request.url }}{{ url_for('api_root').lstrip('/') }}