From 52b207d866cc5e2498c976fe8a0530bbab786d53 Mon Sep 17 00:00:00 2001 From: Mykola Skrynnyk <45297092+mykolaskrynnyk@users.noreply.github.com> Date: Thu, 31 Oct 2024 18:45:03 +0100 Subject: [PATCH] Initial commit --- .dockerignore | 43 ++++ .gitattributes | 2 + .gitignore | 142 ++++++++++++ Dockerfile | 10 + LICENSE | 28 +++ Makefile | 8 + README.md | 148 +++++++++++++ data/locations.csv | 269 +++++++++++++++++++++++ data/units.csv | 201 +++++++++++++++++ docker-compose.yaml | 20 ++ images/architecture.drawio.svg | 4 + main.py | 52 +++++ requirements.txt | 16 ++ requirements_dev.txt | 6 + sql/create_tables.sql | 137 ++++++++++++ sql/import_data.sql | 11 + src/__init__.py | 4 + src/authentication.py | 146 +++++++++++++ src/database/__init__.py | 9 + src/database/choices.py | 63 ++++++ src/database/connection.py | 46 ++++ src/database/signals.py | 360 +++++++++++++++++++++++++++++++ src/database/trends.py | 285 ++++++++++++++++++++++++ src/database/users.py | 198 +++++++++++++++++ src/dependencies.py | 62 ++++++ src/entities/__init__.py | 117 ++++++++++ src/entities/base.py | 120 +++++++++++ src/entities/parameters.py | 86 ++++++++ src/entities/signal.py | 43 ++++ src/entities/trend.py | 44 ++++ src/entities/user.py | 55 +++++ src/entities/utils.py | 123 +++++++++++ src/exceptions.py | 44 ++++ src/genai.py | 106 +++++++++ src/routers/__init__.py | 15 ++ src/routers/choices.py | 59 +++++ src/routers/signals.py | 153 +++++++++++++ src/routers/trends.py | 112 ++++++++++ src/routers/users.py | 70 ++++++ src/storage.py | 162 ++++++++++++++ src/utils.py | 143 ++++++++++++ tests/conftest.py | 22 ++ tests/test_authentication.py | 78 +++++++ tests/test_choices.py | 41 ++++ tests/test_signals_and_trends.py | 120 +++++++++++ tests/test_users.py | 57 +++++ 46 files changed, 4040 insertions(+) create mode 100644 .dockerignore create mode 100644 .gitattributes create mode 100644 .gitignore create mode 100644 Dockerfile create mode 100644 LICENSE create mode 100644 Makefile create mode 100644 README.md create mode 100644 data/locations.csv create mode 100644 data/units.csv create mode 100644 docker-compose.yaml create mode 100644 images/architecture.drawio.svg create mode 100644 main.py create mode 100644 requirements.txt create mode 100644 requirements_dev.txt create mode 100644 sql/create_tables.sql create mode 100644 sql/import_data.sql create mode 100644 src/__init__.py create mode 100644 src/authentication.py create mode 100644 src/database/__init__.py create mode 100644 src/database/choices.py create mode 100644 src/database/connection.py create mode 100644 src/database/signals.py create mode 100644 src/database/trends.py create mode 100644 src/database/users.py create mode 100644 src/dependencies.py create mode 100644 src/entities/__init__.py create mode 100644 src/entities/base.py create mode 100644 src/entities/parameters.py create mode 100644 src/entities/signal.py create mode 100644 src/entities/trend.py create mode 100644 src/entities/user.py create mode 100644 src/entities/utils.py create mode 100644 src/exceptions.py create mode 100644 src/genai.py create mode 100644 src/routers/__init__.py create mode 100644 src/routers/choices.py create mode 100644 src/routers/signals.py create mode 100644 src/routers/trends.py create mode 100644 src/routers/users.py create mode 100644 src/storage.py create mode 100644 src/utils.py create mode 100644 tests/conftest.py create mode 100644 tests/test_authentication.py create mode 100644 tests/test_choices.py create mode 100644 tests/test_signals_and_trends.py create mode 100644 tests/test_users.py diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..f6b4abd --- /dev/null +++ b/.dockerignore @@ -0,0 +1,43 @@ +# Ignore Python cache files +__pycache__/ +*.pyc +*.pyo +*.pyd + +# Ignore virtual environment directories +venv/ +env/ +.venv/ +.env/ + +# Jupyter Notebook checkpoints +.ipynb_checkpoints/ + +# Local configuration files +*.env +*.local + +# Ignore test and coverage files +tests/ +*.cover +.coverage +nosetests.xml +coverage.xml +*.log + +# Ignore IDE/editor specific files +.vscode/ +.idea/ + +# Ignore Docker files +Dockerfile +docker-compose.yml + +# Ignore documentation files +docs/ +*.md + +# Ignore other unnecessary files +*.DS_Store +*.tmp +*.temp diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..dfe0770 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# Auto detect text files and perform LF normalization +* text=auto diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..749fb3f --- /dev/null +++ b/.gitignore @@ -0,0 +1,142 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# Manually added for this project +.idea/ +**/.DS_Store diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..2dab8d0 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,10 @@ +FROM python:3.11.7-slim +RUN apt-get update -y \ + && apt-get install libpq-dev -y \ + && rm -rf /var/lib/apt/lists/* +WORKDIR /app +COPY requirements.txt . +RUN pip install --no-cache-dir --upgrade -r requirements.txt +COPY . . +EXPOSE 8000 +CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..24ae370 --- /dev/null +++ b/LICENSE @@ -0,0 +1,28 @@ +BSD 3-Clause License + +Copyright (c) 2024, United Nations Development Programme + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..898ae03 --- /dev/null +++ b/Makefile @@ -0,0 +1,8 @@ +install: + pip install --upgrade pip && pip install -r requirements_dev.txt +format: + isort . --profile black --multi-line 3 && black . +lint: + pylint main.py src/ +test: + python -m pytest tests/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..5ee5bb4 --- /dev/null +++ b/README.md @@ -0,0 +1,148 @@ +# Future Trends and Signals System (FTSS) API + +[![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg)](https://www.python.org/downloads/release/python-3110/) +[![License](https://img.shields.io/github/license/undp-data/ftss-api)](https://github.com/undp-data/ftss-api/blob/main/LICENSE) +[![Black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) +[![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) +[![Conventional Commits](https://img.shields.io/badge/Conventional%20Commits-1.0.0-%23FE5196?logo=conventionalcommits&logoColor=white)](https://conventionalcommits.org) + +This repository hosts the API that powers the [UNDP Future Trends and Signals System](https://signals.data.undp.org) (FTSS). +The API is written using [FastAPI](https://fastapi.tiangolo.com) in Python and deployed on Azure App Services. +It serves as an intermediary between the front-end application and back-end database. The codebase is an open-source +of the original project transferred from Azure DevOps. + +## Table of Contents + +- [Introduction](#introduction) +- [Getting Started](#getting-started) +- [Build and Test](#build-and-test) +- [Contribute](#contribute) +- [License](#license) +- [Contact](#contact) + +## Introduction + +The FTSS is an internal system built for the staff of the United Nations Development Programme, designed to capture +signals of change, and identify emerging trends within and outside the organisation. This repository hosts the back-end +API that powers the platform and is accompanied by the [front-end repository](https://github.com/undp-data/fe-signals-and-trends). + +The API is written and tested in Python `3.11` using [FastAPI](https://fastapi.tiangolo.com) framework. Database and +storage routines are implemented in an asynchronous manner, making the application fast and responsive. The API is +deployed on Azure App Services to development and production environments from `dev` and `main` branches +respectively. The API interacts with a PostgreSQL database deployed as an Azure Database for PostgreSQL instance. The +instance comprises `staging` and `production` databases. An Azure Blob Storage container stores images used as +illustrations for signals and trends. The simplified architecture of the whole application is shown in the image below. + +Commits to `staging` branch in the front-end repository and `dev` branch in this repository trigger CI/CD pipelines for +the staging environment. While there is a single database instance, the data in the staging environment is isolated in +the `staging` database/schema separate from `production` database/schema within the same database instance. The same +logic applies to the blob storage – images uploaded in the staging environment are managed separately from those in the +production environment. + +![Preview](images/architecture.drawio.svg) + +Authentication in the API happens via tokens (JWT) issued by Microsoft Entra upon user log-in in the front-end +application. Some endpoints to retrieve approved signals/trends are accessible with a static API key +for integration with other applications. + +## Getting Started + +For running the application locally, you can use either your local environment or a Docker container. Either way, +clone the repository and navigate to the project directory first: + +```shell +# Clone the repository +git clone https://github.com/undp-data/ftss-api + +# Navigate to the project folder +cd ftss-api +``` + +You must also ensure that the following environment variables are set up: + +```text +# Authentication +TENANT_ID="" +CLIENT_ID="" +API_KEY="" # for accessing "public" endpoints + +# Database and Storage +DB_CONNECTION="postgresql://:@:5432/" +SAS_URL=""https://.blob.core.windows.net/?" + +# Azure OpenAI, only required for `/signals/generation` +AZURE_OPENAI_ENDPOINT="https://.openai.azure.com/" +AZURE_OPENAI_API_KEY="" + +# Testing, only required to run tests, must be a valid token of a regular user +API_JWT="" +``` + +### Local Environment + +For this scenario, you will need a connection string to the staging database. + +```bash +# Create and activate a virtual environment. +python3 -m venv venv +source venv/bin/activate + +# Install core dependencies. +pip install -r requirements.txt + +# Launch the application. +uvicorn main:app --reload +``` + +Once launched, the application will be running at http://127.0.0.1:8000. + +### Docker Environment + +For this scenario, you do not need a connection string as a fresh PostgreSQL instance will be +set up for you in the container. Ensure that Docker engine is running on you machine, then execute: + +```shell +# Start the containers +docker compose up --build -d +``` + +Once launched, the application will be running at http://127.0.0.1:8000. + +# Build and Test + +The codebase provides some basic tests written in `pytest`. To run them, ensure you have specified a valid token in your +`API_JWT` environment variable. Then run: + +```shell +# run all tests + python -m pytest tests/ + + # or alternatively + make test +``` + +Note that some tests for search endpoints might fail as the tests are run against dynamically changing databases. + +# Contribute + +All contributions must follow [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/). +The codebase is formatted with `black` and `isort`. Use the provided [Makefile](Makefile) for these +routine operations. Make sure to run the linter against your code. + +1. Clone or fork the repository +2. Create a new branch (`git checkout -b feature-branch`) +3. Make your changes +4. Ensure your code is properly formatted (`make format`) +5. Run the linter and check for any issues (`make lint`) +6. Execute the tests (`make test`) +7. Commit your changes (`git commit -m 'Add some feature'`) +8. Push to the branch (`git push origin feature-branch`) +9. Open a pull request to `dev` branch +10. Once tested in the staging environment, open a pull requests to `main` branch + +## Contact + +This project has been originally developed and maintained by [Data Futures Exchange (DFx)](https://data.undp.org) at UNDP. +If you are facing any issues or would like to make some suggestions, feel free to +[open an issue](https://github.com/undp-data/ftss-api/issues/new/choose). +For enquiries about DFx, visit [Contact Us](https://data.undp.org/contact-us). diff --git a/data/locations.csv b/data/locations.csv new file mode 100644 index 0000000..5a9cd6a --- /dev/null +++ b/data/locations.csv @@ -0,0 +1,269 @@ +id,name,iso,region,bureau +1,Global,,Global, +2,Antarctica,,Antarctica, +3,Australia and New Zealand,,Australia and New Zealand, +4,Central Asia,,Central Asia, +5,Eastern Asia,,Eastern Asia, +6,Eastern Europe,,Eastern Europe, +7,Latin America and the Caribbean,,Latin America and the Caribbean, +8,Melanesia,,Melanesia, +9,Micronesia,,Micronesia, +10,Northern Africa,,Northern Africa, +11,Northern America,,Northern America, +12,Northern Europe,,Northern Europe, +13,Polynesia,,Polynesia, +14,South-eastern Asia,,South-eastern Asia, +15,Southern Asia,,Southern Asia, +16,Southern Europe,,Southern Europe, +17,Sub-Saharan Africa,,Sub-Saharan Africa, +18,Western Asia,,Western Asia, +19,Western Europe,,Western Europe, +20,Afghanistan,AFG,Southern Asia,RBAP +21,Albania,ALB,Southern Europe,RBEC +22,Antarctica,ATA,Antarctica, +23,Algeria,DZA,Northern Africa,RBAS +24,American Samoa,ASM,Polynesia, +25,Andorra,AND,Southern Europe, +26,Angola,AGO,Sub-Saharan Africa,RBA +27,Antigua and Barbuda,ATG,Latin America and the Caribbean,RBLAC +28,Azerbaijan,AZE,Western Asia,RBEC +29,Argentina,ARG,Latin America and the Caribbean,RBLAC +30,Australia,AUS,Australia and New Zealand, +31,Austria,AUT,Western Europe, +32,Bahamas,BHS,Latin America and the Caribbean,RBLAC +33,Bahrain,BHR,Western Asia,RBAS +34,Bangladesh,BGD,Southern Asia,RBAP +35,Armenia,ARM,Western Asia,RBEC +36,Barbados,BRB,Latin America and the Caribbean,RBLAC +37,Belgium,BEL,Western Europe, +38,Bermuda,BMU,Northern America, +39,Bhutan,BTN,Southern Asia,RBAP +40,Bolivia (Plurinational State of),BOL,Latin America and the Caribbean,RBLAC +41,Bosnia and Herzegovina,BIH,Southern Europe,RBEC +42,Botswana,BWA,Sub-Saharan Africa,RBA +43,Bouvet Island,BVT,Latin America and the Caribbean, +44,Brazil,BRA,Latin America and the Caribbean,RBLAC +45,Belize,BLZ,Latin America and the Caribbean,RBLAC +46,British Indian Ocean Territory,IOT,Sub-Saharan Africa, +47,Solomon Islands,SLB,Melanesia,RBAP +48,British Virgin Islands,VGB,Latin America and the Caribbean, +49,Brunei Darussalam,BRN,South-eastern Asia,RBAP +50,Bulgaria,BGR,Eastern Europe, +51,Myanmar,MMR,South-eastern Asia,RBAP +52,Burundi,BDI,Sub-Saharan Africa,RBA +53,Belarus,BLR,Eastern Europe,RBEC +54,Cambodia,KHM,South-eastern Asia,RBAP +55,Cameroon,CMR,Sub-Saharan Africa,RBA +56,Canada,CAN,Northern America, +57,Cabo Verde,CPV,Sub-Saharan Africa,RBA +58,Cayman Islands,CYM,Latin America and the Caribbean, +59,Central African Republic,CAF,Sub-Saharan Africa,RBA +60,Sri Lanka,LKA,Southern Asia,RBAP +61,Chad,TCD,Sub-Saharan Africa,RBA +62,Chile,CHL,Latin America and the Caribbean,RBLAC +63,China,CHN,Eastern Asia,RBAP +64,Christmas Island,CXR,Australia and New Zealand, +65,Cocos (Keeling) Islands,CCK,Australia and New Zealand, +66,Colombia,COL,Latin America and the Caribbean,RBLAC +67,Comoros,COM,Sub-Saharan Africa,RBA +68,Mayotte,MYT,Sub-Saharan Africa, +69,Congo,COG,Sub-Saharan Africa,RBA +70,Democratic Republic of the Congo,COD,Sub-Saharan Africa,RBA +71,Cook Islands,COK,Polynesia, +72,Costa Rica,CRI,Latin America and the Caribbean,RBLAC +73,Croatia,HRV,Southern Europe, +74,Cuba,CUB,Latin America and the Caribbean,RBLAC +75,Cyprus,CYP,Western Asia, +76,Czechia,CZE,Eastern Europe, +77,Benin,BEN,Sub-Saharan Africa,RBA +78,Denmark,DNK,Northern Europe, +79,Dominica,DMA,Latin America and the Caribbean,RBLAC +80,Dominican Republic,DOM,Latin America and the Caribbean,RBLAC +81,Ecuador,ECU,Latin America and the Caribbean,RBLAC +82,El Salvador,SLV,Latin America and the Caribbean,RBLAC +83,Equatorial Guinea,GNQ,Sub-Saharan Africa,RBA +84,Ethiopia,ETH,Sub-Saharan Africa,RBA +85,Eritrea,ERI,Sub-Saharan Africa,RBA +86,Estonia,EST,Northern Europe, +87,Faroe Islands,FRO,Northern Europe, +88,Falkland Islands (Malvinas),FLK,Latin America and the Caribbean, +89,South Georgia and the South Sandwich Islands,SGS,Latin America and the Caribbean, +90,Fiji,FJI,Melanesia,RBAP +91,Finland,FIN,Northern Europe, +92,Åland Islands,ALA,Northern Europe, +93,France,FRA,Western Europe, +94,French Guiana,GUF,Latin America and the Caribbean, +95,French Polynesia,PYF,Polynesia, +96,French Southern Territories,ATF,Sub-Saharan Africa, +97,Djibouti,DJI,Sub-Saharan Africa,RBAS +98,Gabon,GAB,Sub-Saharan Africa,RBA +99,Georgia,GEO,Western Asia,RBEC +100,Gambia,GMB,Sub-Saharan Africa,RBA +101,State of Palestine,PSE,Western Asia,RBAS +102,Germany,DEU,Western Europe, +103,Ghana,GHA,Sub-Saharan Africa,RBA +104,Gibraltar,GIB,Southern Europe, +105,Kiribati,KIR,Micronesia,RBAP +106,Greece,GRC,Southern Europe, +107,Greenland,GRL,Northern America, +108,Grenada,GRD,Latin America and the Caribbean,RBLAC +109,Guadeloupe,GLP,Latin America and the Caribbean, +110,Guam,GUM,Micronesia, +111,Guatemala,GTM,Latin America and the Caribbean,RBLAC +112,Guinea,GIN,Sub-Saharan Africa,RBA +113,Guyana,GUY,Latin America and the Caribbean,RBLAC +114,Haiti,HTI,Latin America and the Caribbean,RBLAC +115,Heard Island and McDonald Islands,HMD,Australia and New Zealand, +116,Holy See,VAT,Southern Europe, +117,Honduras,HND,Latin America and the Caribbean,RBLAC +118,"China, Hong Kong Special Administrative Region",HKG,Eastern Asia, +119,Hungary,HUN,Eastern Europe, +120,Iceland,ISL,Northern Europe, +121,India,IND,Southern Asia,RBAP +122,Indonesia,IDN,South-eastern Asia,RBAP +123,Iran (Islamic Republic of),IRN,Southern Asia,RBAP +124,Iraq,IRQ,Western Asia,RBAS +125,Ireland,IRL,Northern Europe, +126,Israel,ISR,Western Asia, +127,Italy,ITA,Southern Europe, +128,Côte d’Ivoire,CIV,Sub-Saharan Africa,RBA +129,Jamaica,JAM,Latin America and the Caribbean,RBLAC +130,Japan,JPN,Eastern Asia, +131,Kazakhstan,KAZ,Central Asia,RBEC +132,Jordan,JOR,Western Asia,RBAS +133,Kenya,KEN,Sub-Saharan Africa,RBA +134,Democratic People's Republic of Korea,PRK,Eastern Asia,RBAP +135,Republic of Korea,KOR,Eastern Asia, +136,Kuwait,KWT,Western Asia,RBAS +137,Kyrgyzstan,KGZ,Central Asia,RBEC +138,Lao People's Democratic Republic,LAO,South-eastern Asia,RBAP +139,Lebanon,LBN,Western Asia,RBAS +140,Lesotho,LSO,Sub-Saharan Africa,RBA +141,Latvia,LVA,Northern Europe, +142,Liberia,LBR,Sub-Saharan Africa,RBA +143,Libya,LBY,Northern Africa,RBAS +144,Liechtenstein,LIE,Western Europe, +145,Lithuania,LTU,Northern Europe, +146,Luxembourg,LUX,Western Europe, +147,"China, Macao Special Administrative Region",MAC,Eastern Asia, +148,Madagascar,MDG,Sub-Saharan Africa,RBA +149,Malawi,MWI,Sub-Saharan Africa,RBA +150,Malaysia,MYS,South-eastern Asia,RBAP +151,Maldives,MDV,Southern Asia,RBAP +152,Mali,MLI,Sub-Saharan Africa,RBA +153,Malta,MLT,Southern Europe, +154,Martinique,MTQ,Latin America and the Caribbean, +155,Mauritania,MRT,Sub-Saharan Africa,RBA +156,Mauritius,MUS,Sub-Saharan Africa,RBA +157,Mexico,MEX,Latin America and the Caribbean,RBLAC +158,Monaco,MCO,Western Europe, +159,Mongolia,MNG,Eastern Asia,RBAP +160,Republic of Moldova,MDA,Eastern Europe,RBEC +161,Montenegro,MNE,Southern Europe,RBEC +162,Montserrat,MSR,Latin America and the Caribbean, +163,Morocco,MAR,Northern Africa,RBAS +164,Mozambique,MOZ,Sub-Saharan Africa,RBA +165,Oman,OMN,Western Asia,RBAS +166,Namibia,NAM,Sub-Saharan Africa,RBA +167,Nauru,NRU,Micronesia,RBAP +168,Nepal,NPL,Southern Asia,RBAP +169,Netherlands (Kingdom of the),NLD,Western Europe, +170,Curaçao,CUW,Latin America and the Caribbean, +171,Aruba,ABW,Latin America and the Caribbean, +172,Sint Maarten (Dutch part),SXM,Latin America and the Caribbean, +173,"Bonaire, Sint Eustatius and Saba",BES,Latin America and the Caribbean, +174,New Caledonia,NCL,Melanesia, +175,Vanuatu,VUT,Melanesia,RBAP +176,New Zealand,NZL,Australia and New Zealand, +177,Nicaragua,NIC,Latin America and the Caribbean,RBLAC +178,Niger,NER,Sub-Saharan Africa,RBA +179,Nigeria,NGA,Sub-Saharan Africa,RBA +180,Niue,NIU,Polynesia, +181,Norfolk Island,NFK,Australia and New Zealand, +182,Norway,NOR,Northern Europe, +183,Northern Mariana Islands,MNP,Micronesia, +184,United States Minor Outlying Islands,UMI,Micronesia, +185,Micronesia (Federated States of),FSM,Micronesia,RBAP +186,Marshall Islands,MHL,Micronesia,RBAP +187,Palau,PLW,Micronesia,RBAP +188,Pakistan,PAK,Southern Asia,RBAP +189,Panama,PAN,Latin America and the Caribbean,RBLAC +190,Papua New Guinea,PNG,Melanesia,RBAP +191,Paraguay,PRY,Latin America and the Caribbean,RBLAC +192,Peru,PER,Latin America and the Caribbean,RBLAC +193,Philippines,PHL,South-eastern Asia,RBAP +194,Pitcairn,PCN,Polynesia, +195,Poland,POL,Eastern Europe, +196,Portugal,PRT,Southern Europe, +197,Guinea-Bissau,GNB,Sub-Saharan Africa,RBA +198,Timor-Leste,TLS,South-eastern Asia,RBAP +199,Puerto Rico,PRI,Latin America and the Caribbean, +200,Qatar,QAT,Western Asia,RBAS +201,Réunion,REU,Sub-Saharan Africa, +202,Romania,ROU,Eastern Europe, +203,Russian Federation,RUS,Eastern Europe, +204,Rwanda,RWA,Sub-Saharan Africa,RBA +205,Saint Barthélemy,BLM,Latin America and the Caribbean, +206,Saint Helena,SHN,Sub-Saharan Africa, +207,Saint Kitts and Nevis,KNA,Latin America and the Caribbean,RBLAC +208,Anguilla,AIA,Latin America and the Caribbean, +209,Saint Lucia,LCA,Latin America and the Caribbean,RBLAC +210,Saint Martin (French Part),MAF,Latin America and the Caribbean, +211,Saint Pierre and Miquelon,SPM,Northern America, +212,Saint Vincent and the Grenadines,VCT,Latin America and the Caribbean,RBLAC +213,San Marino,SMR,Southern Europe, +214,Sao Tome and Principe,STP,Sub-Saharan Africa,RBA +215,Sark,CRQ,Northern Europe, +216,Saudi Arabia,SAU,Western Asia,RBAS +217,Senegal,SEN,Sub-Saharan Africa,RBA +218,Serbia,SRB,Southern Europe,RBEC +219,Seychelles,SYC,Sub-Saharan Africa,RBA +220,Sierra Leone,SLE,Sub-Saharan Africa,RBA +221,Singapore,SGP,South-eastern Asia,RBAP +222,Slovakia,SVK,Eastern Europe, +223,Viet Nam,VNM,South-eastern Asia,RBAP +224,Slovenia,SVN,Southern Europe, +225,Somalia,SOM,Sub-Saharan Africa,RBAS +226,South Africa,ZAF,Sub-Saharan Africa,RBA +227,Zimbabwe,ZWE,Sub-Saharan Africa,RBA +228,Spain,ESP,Southern Europe, +229,South Sudan,SSD,Sub-Saharan Africa,RBA +230,Sudan,SDN,Northern Africa,RBAS +231,Western Sahara,ESH,Northern Africa, +232,Suriname,SUR,Latin America and the Caribbean,RBLAC +233,Svalbard and Jan Mayen Islands,SJM,Northern Europe, +234,Eswatini,SWZ,Sub-Saharan Africa,RBA +235,Sweden,SWE,Northern Europe, +236,Switzerland,CHE,Western Europe, +237,Syrian Arab Republic,SYR,Western Asia,RBAS +238,Tajikistan,TJK,Central Asia,RBEC +239,Thailand,THA,South-eastern Asia,RBAP +240,Togo,TGO,Sub-Saharan Africa,RBA +241,Tokelau,TKL,Polynesia, +242,Tonga,TON,Polynesia,RBAP +243,Trinidad and Tobago,TTO,Latin America and the Caribbean,RBLAC +244,United Arab Emirates,ARE,Western Asia,RBAS +245,Tunisia,TUN,Northern Africa,RBAS +246,Türkiye,TUR,Western Asia,RBEC +247,Turkmenistan,TKM,Central Asia,RBEC +248,Turks and Caicos Islands,TCA,Latin America and the Caribbean, +249,Tuvalu,TUV,Polynesia,RBAP +250,Uganda,UGA,Sub-Saharan Africa,RBA +251,Ukraine,UKR,Eastern Europe,RBEC +252,North Macedonia,MKD,Southern Europe,RBEC +253,Egypt,EGY,Northern Africa,RBAS +254,United Kingdom of Great Britain and Northern Ireland,GBR,Northern Europe, +255,Guernsey,GGY,Northern Europe, +256,Jersey,JEY,Northern Europe, +257,Isle of Man,IMN,Northern Europe, +258,United Republic of Tanzania,TZA,Sub-Saharan Africa,RBA +259,United States of America,USA,Northern America, +260,United States Virgin Islands,VIR,Latin America and the Caribbean, +261,Burkina Faso,BFA,Sub-Saharan Africa,RBA +262,Uruguay,URY,Latin America and the Caribbean,RBLAC +263,Uzbekistan,UZB,Central Asia,RBEC +264,Venezuela (Bolivarian Republic of),VEN,Latin America and the Caribbean,RBLAC +265,Wallis and Futuna Islands,WLF,Polynesia, +266,Samoa,WSM,Polynesia,RBAP +267,Yemen,YEM,Western Asia,RBAS +268,Zambia,ZMB,Sub-Saharan Africa,RBA diff --git a/data/units.csv b/data/units.csv new file mode 100644 index 0000000..68fe0b5 --- /dev/null +++ b/data/units.csv @@ -0,0 +1,201 @@ +id,name,region +1,Afghanistan,RBAP +2,Albania,RBEC +3,Algeria,RBAS +4,Angola,RBA +5,Anguilla,RBLAC +6,Antigua and Barbuda,RBLAC +7,Argentina,RBLAC +8,Armenia,RBEC +9,Aruba,RBLAC +10,Azerbaijan,RBEC +11,Bureau of External Relations and Advocacy (BERA),HQ +12,BMS,HQ +13,Bureau for Policy and Programme Support (BPPS),HQ +14,Bahrain,RBAS +15,Bangladesh,RBAP +16,Barbados,RBLAC +17,Belize,RBLAC +18,Benin,RBA +19,Bermuda,RBLAC +20,Bhutan,RBAP +21,Bolivia,RBLAC +22,Bosnia and Herzegovina,RBEC +23,Botswana,RBA +24,Brazil,RBLAC +25,Brunei Darussalam,RBAP +26,Burkina Faso,RBA +27,Burundi,RBA +28,Cambodia,RBAP +29,Cameroon,RBA +30,Cayman Islands,RBLAC +31,Central African Republic,RBA +32,Chad,RBA +33,Chief Digital Office,HQ +34,Chief Digital Office (CDO),HQ +35,Chile,RBLAC +36,China,RBAP +37,Colombia,RBLAC +38,Comoros,RBA +39,Congo,RBA +40,Cook Islands,RBAP +41,Costa Rica,RBLAC +42,Crisis Bureau (CB),HQ +43,Cuba,RBLAC +44,Curacao and Sint Maarten,RBLAC +45,Cyprus (project office),RBEC +46,Côte d’Ivoire,RBA +47,DPRK,RBAP +48,Democratic Republic of Congo,RBA +49,Development financing,HQ +50,Djibouti,RBAS +51,Dom. Republic,RBLAC +52,Ecuador,RBLAC +53,Egypt,RBAS +54,El Salvador,RBLAC +55,Equatorial Guinea,RBA +56,Eritrea,RBA +57,Eswatini,RBA +58,Ethiopia,RBA +59,Executive Office (ExO),HQ +60,External,Other +61,Federated States of Micronesia,RBAP +62,Fiji MCO,RBAP +63,Future Fellows,HQ +64,Gabon,RBA +65,Gambia,RBA +66,Georgia,RBEC +67,Ghana,RBA +68,"Global Centre for Technology, Innovation and Sustainable Development (Singapore)",BPPS +69,Grenada,RBLAC +70,Guatemala,RBLAC +71,Guinea,RBA +72,Guinea-Bissau,RBA +73,Guyana,RBLAC +74,Guyana-Suriname,RBLAC +75,Haiti,RBLAC +76,Honduras,RBLAC +77,Human Development Report Office (HDRO),HQ +78,Independent Evaluation Office (IEO),HQ +79,India,RBAP +80,Indonesia,RBAP +81,Iran,RBAP +82,Iraq,RBAS +83,"Istanbul International Center for Private Sector in Development (Istanbul, Turkey)",BPPS +84,Jamaica,RBLAC +85,Jordan,RBAS +86,Kazakhstan,RBEC +87,Kenya,RBA +88,Kiribati,RBAP +89,Kosovo (as per UNSCR 1244),RBEC +90,Kuwait,RBAS +91,Kyrgyzstan,RBEC +92,Lao PDR,RBAP +93,Lebanon,RBAS +94,Lesotho,RBA +95,Liberia,RBA +96,Libya,RBAS +97,Madagascar,RBA +98,Malawi,RBA +99,Malaysia MCO,RBAP +100,Maldives,RBAP +101,Mali,RBA +102,Marshall Islands,RBAP +103,Mauritania,RBA +104,Mauritius,RBA +105,Mexico,RBLAC +106,Moldova,RBEC +107,Mongolia,RBAP +108,Montserrat,RBLAC +109,Morocco,RBAS +110,Mozambique,RBA +111,Myanmar,RBAP +112,"Nairobi Global Centre on Resilient Ecosystems and Desertification (Nairobi, Kenya)",BPPS +113,Namibia,RBA +114,Nauru,RBAP +115,Nepal,RBAP +116,Nicaragua,RBLAC +117,Niger,RBA +118,Nigeria,RBA +119,Niue,RBAP +120,North Macedonia,RBEC +121,Office of Audit and Investigations (OAI),HQ +122,"Oslo Governance Centre (Oslo, Norway)",BPPS +123,PNG,RBAP +124,Pakistan,RBAP +125,Palau,RBAP +126,Panama,RBLAC +127,Paraguay,RBLAC +128,People,RBAS +129,Peru,RBLAC +130,Philippines,RBAP +131,Prog for Palestinian,RBAS +132,"Regional Bureau for Africa, New York, USA",HQ +133,"Regional Bureau for Arab States, New York, USA",HQ +134,"Regional Bureau for Asia and the Pacific, New York",HQ +135,"Regional Bureau for Europe and the CIS, New York, USA",HQ +136,"Regional Bureau for Latin America and the Caribbean, New York, USA",HQ +137,"Regional Centre: Panama City, Panama",RBLAC +138,"Regional Hub: Amman, Jordan",RBAS +139,"Regional Hub: Bangkok, Thailand",RBAP +140,"Regional Hub: Dakar, Senegal",RBA +141,"Regional Hub: Istanbul, Turkey",RBEC +142,"Regional Hub: Nairobi, Kenya",RBA +143,"Regional Hub: Pretoria, South Africa",RBA +144,"Regional Service Centre: Addis Ababa, Ethiopia",RBA +145,Republic of Belarus,RBEC +146,Republic of Cape Verde,RBA +147,Republic of Montenegro,RBEC +148,Republic of South Sudan,RBA +149,Republic of the,RBAS +150,Republic of the Sudan,RBAS +151,Rwanda,RBA +152,Saint Kitts and Nevis,RBLAC +153,Saint Lucia and Saint Vincent,RBLAC +154,Samoa MCO,RBAP +155,Sao Tome and Principe,RBA +156,Saudi Arabia,RBAS +157,Senegal,RBA +158,Serbia,RBEC +159,Seychelles,RBA +160,Sierra Leone,RBA +161,Singapore,RBAP +162,Solomon Islands,RBAP +163,Somalia,RBAS +164,South Africa,RBA +165,Sri Lanka,RBAP +166,Suriname,RBLAC +167,Syria,RBAS +168,Tajikistan,RBEC +169,Tanzania,RBA +170,Thailand,RBAP +171,The Bahamas,RBLAC +172,Timor Leste,RBAP +173,Togo,RBA +174,Tokelau,RBAP +175,Tonga,RBAP +176,Trinidad & Tobago,RBLAC +177,Tunisia,RBAS +178,Turkey,RBEC +179,Turkmenistan,RBEC +180,Turks and Caicos Islands,RBLAC +181,Tuvalu,RBAP +182,"UNDP Nordic Representation Office (Copenhagen, Denmark)",BERA +183,"UNDP Office in Geneva (Geneva, Switzerland)",BERA +184,"UNDP Representation Office in Brussels (Brussels, Belgium)",BERA +185,"UNDP Representation Office in Japan (Tokyo, Japan)",BERA +186,"UNDP Seoul Policy Centre for Knowledge Exchange through SDG Partnerships (Seoul, Republic of Korea)",BPPS +187,"UNDP Washington Representation Office (Washington, USA)",BERA +188,Uganda,RBA +189,Ukraine,RBEC +190,Uruguay,RBLAC +191,Uzbekistan,RBEC +192,Vanuatu,RBAP +193,Venezuela,RBLAC +194,Viet Nam,RBAP +195,Yemen,RBAS +196,Zambia,RBA +197,Zimbabwe,RBA +198,the British Virgin Islands,RBLAC +199,the Commonwealth of Dominica,RBLAC +200,the Grenadines,RBLAC diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000..4adec8a --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,20 @@ +services: + web: + build: + context: . + env_file: .env + environment: + - DB_CONNECTION=postgresql://postgres:password@db:5432/postgres + ports: + - "8000:8000" + depends_on: + - db + db: + image: postgres:16.4-alpine + environment: + POSTGRES_PASSWORD: password + ports: + - "5432:5432" + volumes: + - ./sql:/docker-entrypoint-initdb.d + - ./data:/docker-entrypoint-initdb.d/data diff --git a/images/architecture.drawio.svg b/images/architecture.drawio.svg new file mode 100644 index 0000000..c097e4b --- /dev/null +++ b/images/architecture.drawio.svg @@ -0,0 +1,4 @@ + + + +
Deploys from
production branch
HTTP requests
Manages
production
database
Manages
production images
ftss-api.azurewebsites.net
Azure Web App
Deploys to
UNDP-Signals-And-Trends
Static Web App
Deploys from
main branch
Deploys from
dev branch
ftss
Azure Database
for PostgreSQL
Manages
staging
database
Manages
staging images
ftss-api-dev.azurewebsites.net
Azure Web App
UNDP-Signals-And-Trends
Static Web App
Deploys from
staging branch
HTTP requests
Staging
Website
Page-1 Azure Blob Storage (Deprecated) Sheet.265 Sheet.266 Sheet.267 Sheet.268 Sheet.269 ftssBlob Storage
Deploys to
Regular Users
ftss-api
Repository
Production
Staging
Developers and Testers
\ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..2756a42 --- /dev/null +++ b/main.py @@ -0,0 +1,52 @@ +""" +This application serves as an API endpoint for the Signals and Trends project that connects +the frontend platform with the backend database. +""" + +from dotenv import load_dotenv +from fastapi import Depends, FastAPI + +from src import routers +from src.authentication import authenticate_user + +load_dotenv() + +app = FastAPI( + debug=False, + title="Future Trends and Signals API", + version="3.0.0-beta", + summary="""The Future Trends and Signals (FTSS) API powers user experiences on UNDP Future + Trends and Signals System by providing functionality to to manage signals, trends and users.""", + description="""The FTSS API serves as a interface for the + [UNDP Future Trends and Signals System](https://signals.data.undp.org), + facilitating interaction between the front-end application and the underlying relational database. + This API enables users to submit, retrieve, and update data related to signals, trends, and user + profiles within the platform. + + As a private API, it mandates authentication for all endpoints to ensure secure access. + Authentication is achieved by including the `access_token` in the request header, utilising JWT tokens + issued by [Microsoft Entra](https://learn.microsoft.com/en-us/entra/identity-platform/access-tokens). + This mechanism not only secures the API but also allows for the automatic recording of user information + derived from the API token. Approved signals and trends can be accesses using a predefined API key for + integration with other applications. + """.strip().replace( + " ", " " + ), + contact={ + "name": "UNDP Data Futures Platform", + "url": "https://data.undp.org", + "email": "data@undp.org", + }, + openapi_tags=[ + {"name": "signals", "description": "CRUD operations on signals."}, + {"name": "trends", "description": "CRUD operations on trends."}, + {"name": "users", "description": "CRUD operations on users."}, + {"name": "choices", "description": "List valid options for forms fields."}, + ], + docs_url="/", + redoc_url=None, +) + + +for router in routers.ALL: + app.include_router(router=router, dependencies=[Depends(authenticate_user)]) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..aa12ec0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,16 @@ +fastapi == 0.115.3 +uvicorn == 0.32.0 +python-dotenv ~= 1.0.1 +httpx ~= 0.27.2 +pyjwt[crypto] ~= 2.9.0 +pydantic[email] ~= 2.9.2 +psycopg == 3.2.3 +pandas ~= 2.2.3 +openpyxl ~= 3.1.5 +scikit-learn ~= 1.5.2 +azure-storage-blob ~= 12.23.0 +aiohttp ~= 3.10.10 +pillow ~= 11.0.0 +beautifulsoup4 ~= 4.12.3 +lxml ~= 5.3.0 +openai == 1.52.2 diff --git a/requirements_dev.txt b/requirements_dev.txt new file mode 100644 index 0000000..a0cf82c --- /dev/null +++ b/requirements_dev.txt @@ -0,0 +1,6 @@ +-r requirements.txt +black ~= 24.10.0 +isort ~= 5.13.2 +pylint ~= 3.3.1 +pytest ~= 8.3.3 +notebook ~= 7.2.2 diff --git a/sql/create_tables.sql b/sql/create_tables.sql new file mode 100644 index 0000000..e6d5b14 --- /dev/null +++ b/sql/create_tables.sql @@ -0,0 +1,137 @@ +/* +The initialisation script to create tables in an empty database. + +The database schema assumes a denormalised form. This allows to insert data "as is", minimising +the differences between the API layer and database layer and making CRUD operations simpler. +Given the expected size of the database, this design will have marginal impact on the efficiency +even in the long run. + +The database tables comprise: + +1. Users +2. Signals +3. Trends +4. Connections – a junction table for connected signals/trends to model a many-to-many relationship. +5. Locations – stores country and area metadata based on UN M49 that are used for signal location. +6. Units – stores metadata on UNDP units used to assign user units and filter signals. +*/ + +-- users table and indices +CREATE TABLE users ( + id SERIAL PRIMARY KEY, + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + email VARCHAR(255) UNIQUE NOT NULL, + role VARCHAR(255) NOT NULL, + name VARCHAR(255), + unit VARCHAR(255), + acclab BOOLEAN +); + +CREATE INDEX ON users (email); +CREATE INDEX ON users (role); + +-- signals table and indices +CREATE TABLE signals ( + id SERIAL PRIMARY KEY, + status VARCHAR(255) NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + created_by VARCHAR(255) NOT NULL, + created_for VARCHAR(255), + modified_at TIMESTAMP NOT NULL DEFAULT NOW(), + modified_by VARCHAR(255) NOT NULL, + headline TEXT, + description TEXT, + attachment TEXT, -- a URL to Azure Blob Storage + steep_primary TEXT, + steep_secondary TEXT[], + signature_primary TEXT, + signature_secondary TEXT[], + sdgs TEXT[], + created_unit VARCHAR(255), + url TEXT, + relevance TEXT, + keywords TEXT[], + location TEXT, + score TEXT, + text_search_field tsvector GENERATED ALWAYS AS (to_tsvector('english', headline || ' ' || description)) STORED +); + +CREATE INDEX ON signals ( + status, + created_by, + created_for, + created_unit, + steep_primary, + steep_secondary, + signature_primary, + signature_secondary, + sdgs, + location, + score +); +CREATE INDEX ON signals USING GIN (text_search_field); + +-- trends table and indices +CREATE TABLE trends ( + id SERIAL PRIMARY KEY, + status VARCHAR(255) NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + created_by VARCHAR(255) NOT NULL, + created_for TEXT, + modified_at TIMESTAMP NOT NULL DEFAULT NOW(), + modified_by VARCHAR(255) NOT NULL, + headline TEXT, + description TEXT, + attachment TEXT, + steep_primary TEXT, + steep_secondary TEXT[], + signature_primary TEXT, + signature_secondary TEXT[], + sdgs TEXT[], + assigned_to TEXT, + time_horizon TEXT, + impact_rating TEXT, + impact_description TEXT, + text_search_field tsvector GENERATED ALWAYS AS (to_tsvector('english', headline || ' ' || description)) STORED +); + +CREATE INDEX ON trends ( + status, + created_for, + assigned_to, + steep_primary, + steep_secondary, + signature_primary, + signature_secondary, + sdgs, + time_horizon, + impact_rating +); +CREATE INDEX ON trends USING GIN (text_search_field); + +-- junction table for connected signals/trends to model many-to-many relationship +CREATE TABLE connections ( + signal_id INT REFERENCES signals(id) ON DELETE CASCADE, + trend_id INT REFERENCES trends(id) ON DELETE CASCADE, + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + created_by VARCHAR(255) NOT NULL, + CONSTRAINT connection_pk PRIMARY KEY (signal_id, trend_id) +); + +-- locations table and indices +CREATE TABLE locations ( + id SERIAL PRIMARY KEY, + name VARCHAR(128) NOT NULL, + iso VARCHAR(3), + region VARCHAR(128) NOT NULL, + bureau VARCHAR(5) +); +CREATE INDEX ON locations (name, region, bureau); + +-- units table and indices +CREATE TABLE units ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + region VARCHAR(255) +); +CREATE INDEX ON units (name, region); diff --git a/sql/import_data.sql b/sql/import_data.sql new file mode 100644 index 0000000..1965fd8 --- /dev/null +++ b/sql/import_data.sql @@ -0,0 +1,11 @@ +/* +The initialisation script to import locations and units data into an empty database, +from within the docker container. This script is automatically executed by docker +compose after `create_tables.sql`. + */ + +-- import locations data +\copy locations FROM 'docker-entrypoint-initdb.d/data/locations.csv' DELIMITER ',' CSV HEADER; + +-- import units data +\copy units FROM 'docker-entrypoint-initdb.d/data/units.csv' DELIMITER ',' CSV HEADER; diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..8e7302c --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,4 @@ +""" +API components, including database functions, entity (data) models, routers, endpoint +dependencies for RBAC, Azure Blob Storage and GenAI. +""" diff --git a/src/authentication.py b/src/authentication.py new file mode 100644 index 0000000..0f748f3 --- /dev/null +++ b/src/authentication.py @@ -0,0 +1,146 @@ +""" +Dependencies for API authentication using JWT tokens from Microsoft Entra. +""" + +import os + +import httpx +import jwt +from fastapi import Depends, Security +from fastapi.security import APIKeyHeader +from psycopg import AsyncCursor + +from . import database as db +from . import exceptions +from .entities import Role, User + +api_key_header = APIKeyHeader( + name="access_token", + description="The API access token for the application.", + auto_error=True, +) + + +async def get_jwks() -> dict[str, dict]: + """ + Get JSON Web Key Set (JWKS) containing the public keys. + + Returns + ------- + keys : dict[str, dict] + A mapping of kid to JWKS. + """ + # obtain OpenID configuration + tenant_id = os.environ["TENANT_ID"] + endpoint = f"https://login.microsoftonline.com/{tenant_id}/v2.0/.well-known/openid-configuration" + async with httpx.AsyncClient() as client: + response = await client.get(endpoint) + configuration = response.json() + + # get the JSON Web Keys + response = await client.get(configuration["jwks_uri"]) + response.raise_for_status() + jwks = response.json() + + # get a mapping of key IDs to keys + keys = {key["kid"]: key for key in jwks["keys"]} + return keys + + +async def get_jwk(token: str) -> jwt.PyJWK: + """ + Obtain a JSON Web Key (JWK) for a token. + + Parameters + ---------- + token : str + A JSON Web Tokens issued by Microsoft Entra. + + Returns + ------- + jwk : jwt.PyJWK + A ready-to-use JWK object. + """ + header = jwt.get_unverified_header(token) + try: + jwks = await get_jwks() + except httpx.HTTPError: + jwks = {} + jwk = jwks.get(header["kid"]) + if jwk is None: + raise ValueError("JWK could not be obtained or found") + jwk = jwt.PyJWK.from_dict(jwk, "RS256") + return jwk + + +async def decode_token(token: str) -> dict: + """ + Decode and verify a payload of a JSON Web Token (JWT). + + Parameters + ---------- + token : str + A JSON Web Tokens issued by Microsoft Entra. + + Returns + ------- + payload : dict + The decoded payload that should include email + and name for further authentication. + """ + jwk = await get_jwk(token) + tenant_id = os.environ["TENANT_ID"] + payload = jwt.decode( + jwt=token, + key=jwk.key, + algorithms=["RS256"], + audience=os.environ["CLIENT_ID"], + issuer=f"https://sts.windows.net/{tenant_id}/", + options={ + "verify_signature": True, + "verify_aud": True, + "verify_exp": True, + "verify_iss": True, + }, + ) + return payload + + +async def authenticate_user( + token: str = Security(api_key_header), + cursor: AsyncCursor = Depends(db.yield_cursor), +) -> User: + """ + Authenticate a user with a valid JWT token from Microsoft Entra ID. + + The function is used for dependency injection in endpoints to authenticate incoming requests. + The tokens must be issued by Microsoft Entra ID. For a list of available attributes, see + https://learn.microsoft.com/en-us/entra/identity-platform/id-token-claims-reference + + Parameters + ---------- + token : str + Either a predefined api_key to access "public" endpoints or a valid signed JWT. + cursor : AsyncCursor + An async database cursor. + + Returns + ------- + user : User + Pydantic model for a User object (if authentication succeeded). + """ + if token == os.environ.get("API_KEY"): + # dummy user object for anonymous access + user = User(email="name.surname@undp.org", role=Role.VISITOR) + return user + try: + payload = await decode_token(token) + except jwt.exceptions.PyJWTError as e: + raise exceptions.not_authenticated from e + email, name = payload.get("unique_name"), payload.get("name") + if email is None or name is None: + raise exceptions.not_authenticated + if (user := await db.read_user_by_email(cursor, email)) is None: + user = User(email=email, role=Role.USER, name=name) + await db.create_user(cursor, user) + return user diff --git a/src/database/__init__.py b/src/database/__init__.py new file mode 100644 index 0000000..6a2c1e0 --- /dev/null +++ b/src/database/__init__.py @@ -0,0 +1,9 @@ +""" +Functions for connecting to and performing CRUD operations in a PostgreSQL database. +""" + +from .choices import * +from .connection import yield_cursor +from .signals import * +from .trends import * +from .users import * diff --git a/src/database/choices.py b/src/database/choices.py new file mode 100644 index 0000000..98da71f --- /dev/null +++ b/src/database/choices.py @@ -0,0 +1,63 @@ +""" +Functions for reading data related to choice lists. +""" + +from psycopg import AsyncCursor + +__all__ = ["get_unit_names", "get_unit_regions", "get_location_names"] + + +async def get_unit_names(cursor: AsyncCursor) -> list[str]: + """ + Read unit names from the database. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + + Returns + ------- + list[str] + A list of unit names. + """ + await cursor.execute("SELECT name FROM units ORDER BY name;") + return [row["name"] async for row in cursor] + + +async def get_unit_regions(cursor: AsyncCursor) -> list[str]: + """ + Read unit regions from the database. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + + Returns + ------- + list[str] + A list of unique unit regions. + """ + await cursor.execute("SELECT DISTINCT region FROM units ORDER BY region;") + return [row["region"] async for row in cursor] + + +async def get_location_names(cursor: AsyncCursor) -> list[str]: + """ + Read location names from the database. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + + Returns + ------- + list[str] + A list of location names that includes geographic regions, + countries and territories based on UNSD M49. + """ + # do not order by so that regions to appear first + await cursor.execute("SELECT name FROM locations;") + return [row["name"] async for row in cursor] diff --git a/src/database/connection.py b/src/database/connection.py new file mode 100644 index 0000000..995b936 --- /dev/null +++ b/src/database/connection.py @@ -0,0 +1,46 @@ +""" +Database connection functions based on Psycopg 3 project. +""" + +import os + +import psycopg +from psycopg.rows import dict_row + + +async def get_connection() -> psycopg.AsyncConnection: + """ + Get a connection to a PostgreSQL database. + + The connection includes a row factory to return database rows as dictionaries + and a cursor factory that ensures client-side binding. See the + [documentation](https://www.psycopg.org/psycopg3/docs/basic/from_pg2.html#server-side-binding) + for details. + + Returns + ------- + conn : psycopg.Connection + A database connection object row and cursor factory settings. + """ + conn = await psycopg.AsyncConnection.connect( + conninfo=os.environ["DB_CONNECTION"], + autocommit=False, + row_factory=dict_row, + cursor_factory=psycopg.AsyncClientCursor, + ) + return conn + + +async def yield_cursor() -> psycopg.Cursor: + """ + Yield a PostgreSQL database cursor object to be used for dependency injection. + + Yields + ------ + cursor : psycopg.AsyncCursor + A database cursor object. + """ + # handle rollbacks from the context manager and close on exit + async with await get_connection() as conn: + async with conn.cursor() as cursor: + yield cursor diff --git a/src/database/signals.py b/src/database/signals.py new file mode 100644 index 0000000..4e1c591 --- /dev/null +++ b/src/database/signals.py @@ -0,0 +1,360 @@ +""" +CRUD operations for signal entities. +""" + +from psycopg import AsyncCursor, sql + +from .. import storage +from ..entities import Signal, SignalFilters, SignalPage, Status + +__all__ = [ + "search_signals", + "create_signal", + "read_signal", + "update_signal", + "delete_signal", + "read_user_signals", +] + + +async def search_signals(cursor: AsyncCursor, filters: SignalFilters) -> SignalPage: + """ + Search signals in the database using filters and pagination. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + filters : SignalFilters + Query filters for search, including pagination. + + Returns + ------- + page : SignalPage + Paginated search results for signals. + """ + query = """ + SELECT + *, COUNT(*) OVER() AS total_count + FROM + signals AS s + LEFT OUTER JOIN ( + SELECT + signal_id, array_agg(trend_id) AS connected_trends + FROM + connections + GROUP BY + signal_id + ) AS c + ON + s.id = c.signal_id + LEFT OUTER JOIN ( + SELECT + name AS unit_name, + region AS unit_region + FROM + units + ) AS u + ON + s.created_unit = u.unit_name + LEFT OUTER JOIN ( + SELECT + name AS location, + region AS location_region, + bureau AS location_bureau + FROM + locations + ) AS l + ON + s.location = l.location + WHERE + (%(ids)s IS NULL OR id = ANY(%(ids)s)) + AND status = ANY(%(statuses)s) + AND (%(created_by)s IS NULL OR created_by = %(created_by)s) + AND (%(created_for)s IS NULL OR created_for = %(created_for)s) + AND (%(steep_primary)s IS NULL OR steep_primary = %(steep_primary)s) + AND (%(steep_secondary)s IS NULL OR steep_secondary && %(steep_secondary)s) + AND (%(signature_primary)s IS NULL OR signature_primary = %(signature_primary)s) + AND (%(signature_secondary)s IS NULL OR signature_secondary && %(signature_secondary)s) + AND (%(location)s IS NULL OR (s.location = %(location)s) OR location_region = %(location)s) + AND (%(bureau)s IS NULL OR location_bureau = %(bureau)s) + AND (%(sdgs)s IS NULL OR %(sdgs)s && sdgs) + AND (%(score)s IS NULL OR score = %(score)s) + AND (%(unit)s IS NULL OR unit_region = %(unit)s OR unit_name = %(unit)s) + AND (%(query)s IS NULL OR text_search_field @@ websearch_to_tsquery('english', %(query)s)) + ORDER BY + {} {} + OFFSET + %(offset)s + LIMIT + %(limit)s + ; + """ + query = sql.SQL(query).format( + sql.Identifier(filters.order_by), + sql.SQL(filters.direction), + ) + await cursor.execute(query, filters.model_dump()) + rows = await cursor.fetchall() + # extract total count of rows matching the WHERE clause + page = SignalPage.from_search(rows, filters) + return page + + +async def create_signal(cursor: AsyncCursor, signal: Signal) -> int: + """ + Insert a signal into the database, connect it to trends and upload an attachment + to Azure Blob Storage if applicable. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + signal : Signal + A signal object to insert. + + Returns + ------- + signal_id : int + An ID of the signal in the database. + """ + query = """ + INSERT INTO signals ( + status, + created_by, + created_for, + modified_by, + headline, + description, + steep_primary, + steep_secondary, + signature_primary, + signature_secondary, + sdgs, + created_unit, + url, + relevance, + keywords, + location, + score + ) + VALUES ( + %(status)s, + %(created_by)s, + %(created_for)s, + %(modified_by)s, + %(headline)s, + %(description)s, + %(steep_primary)s, + %(steep_secondary)s, + %(signature_primary)s, + %(signature_secondary)s, + %(sdgs)s, + %(created_unit)s, + %(url)s, + %(relevance)s, + %(keywords)s, + %(location)s, + %(score)s + ) + RETURNING + id + ; + """ + await cursor.execute(query, signal.model_dump()) + row = await cursor.fetchone() + signal_id = row["id"] + + # add connected trends if any are present + for trend_id in signal.connected_trends or []: + query = "INSERT INTO connections (signal_id, trend_id, created_by) VALUES (%s, %s, %s);" + await cursor.execute(query, (signal_id, trend_id, signal.created_by)) + + # upload an image + if signal.attachment is not None: + try: + blob_url = await storage.upload_image( + signal_id, "signals", signal.attachment + ) + except Exception as e: + print(e) + else: + query = "UPDATE signals SET attachment = %s WHERE id = %s;" + await cursor.execute(query, (blob_url, signal_id)) + return signal_id + + +async def read_signal(cursor: AsyncCursor, uid: int) -> Signal | None: + """ + Read a signal from the database using an ID. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + uid : int + An ID of the signal to retrieve data for. + + Returns + ------- + Signal | None + A signal if it exits, otherwise None. + """ + query = """ + SELECT + * + FROM + signals AS s + LEFT OUTER JOIN ( + SELECT + signal_id, array_agg(trend_id) AS connected_trends + FROM + connections + GROUP BY + signal_id + ) AS c + ON + s.id = c.signal_id + WHERE + id = %s + ; + """ + await cursor.execute(query, (uid,)) + if (row := await cursor.fetchone()) is None: + return None + return Signal(**row) + + +async def update_signal(cursor: AsyncCursor, signal: Signal) -> int | None: + """ + Update a signal in the database, update its connected trends and update an attachment + in the Azure Blob Storage if applicable. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + signal : Signal + A signal object to update. + + Returns + ------- + int | None + A signal ID if the update has been performed, otherwise None. + """ + query = """ + UPDATE + signals + SET + status = COALESCE(%(status)s, status), + created_for = COALESCE(%(created_for)s, created_for), + modified_at = NOW(), + modified_by = %(modified_by)s, + headline = COALESCE(%(headline)s, headline), + description = COALESCE(%(description)s, description), + steep_primary = COALESCE(%(steep_primary)s, steep_primary), + steep_secondary = COALESCE(%(steep_secondary)s, steep_secondary), + signature_primary = COALESCE(%(signature_primary)s, signature_primary), + signature_secondary = COALESCE(%(signature_secondary)s, signature_secondary), + sdgs = COALESCE(%(sdgs)s, sdgs), + created_unit = COALESCE(%(created_unit)s, created_unit), + url = COALESCE(%(url)s, url), + relevance = COALESCE(%(relevance)s, relevance), + keywords = COALESCE(%(keywords)s, keywords), + location = COALESCE(%(location)s, location), + score = COALESCE(%(score)s, score) + WHERE + id = %(id)s + RETURNING + id + ; + """ + await cursor.execute(query, signal.model_dump()) + if (row := await cursor.fetchone()) is None: + return None + signal_id = row["id"] + + # update connected trends if any are present + await cursor.execute("DELETE FROM connections WHERE signal_id = %s;", (signal_id,)) + for trend_id in signal.connected_trends or []: + query = "INSERT INTO connections (signal_id, trend_id, created_by) VALUES (%s, %s, %s);" + await cursor.execute(query, (signal_id, trend_id, signal.created_by)) + + # upload an image if it is not a URL to an existing image + blob_url = await storage.update_image(signal_id, "signals", signal.attachment) + query = "UPDATE signals SET attachment = %s WHERE id = %s;" + await cursor.execute(query, (blob_url, signal_id)) + + return signal_id + + +async def delete_signal(cursor: AsyncCursor, uid: int) -> Signal | None: + """ + Delete a signal from the database and, if applicable, an image from + Azure Blob Storage, using an ID. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + uid : int + An ID of the signal to delete. + + Returns + ------- + Signal | None + A deleted signal object if the operation has been successful, otherwise None. + """ + query = "DELETE FROM signals WHERE id = %s RETURNING *;" + await cursor.execute(query, (uid,)) + if (row := await cursor.fetchone()) is None: + return None + signal = Signal(**row) + if signal.attachment is not None: + await storage.delete_image(entity_id=signal.id, folder_name="signals") + return signal + + +async def read_user_signals( + cursor: AsyncCursor, + user_email: str, + status: Status, +) -> list[Signal]: + """ + Read signals from the database using a user email and status filter. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + user_email : str + An email of the user whose signals to read. + status : Status + A status of signals to filter by. + + Returns + ------- + list[Signal] + A list of matching signals. + """ + query = """ + SELECT + * + FROM + signals AS s + LEFT OUTER JOIN ( + SELECT + signal_id, array_agg(trend_id) AS connected_trends + FROM + connections + GROUP BY + signal_id + ) AS c + ON + s.id = c.signal_id + WHERE + created_by = %s AND status = %s + ; + """ + await cursor.execute(query, (user_email, status)) + return [Signal(**row) async for row in cursor] diff --git a/src/database/trends.py b/src/database/trends.py new file mode 100644 index 0000000..1b914da --- /dev/null +++ b/src/database/trends.py @@ -0,0 +1,285 @@ +""" +CRUD operations for trend entities. +""" + +from psycopg import AsyncCursor, sql + +from .. import storage +from ..entities import Trend, TrendFilters, TrendPage + +__all__ = [ + "search_trends", + "create_trend", + "read_trend", + "update_trend", + "delete_trend", +] + + +async def search_trends(cursor: AsyncCursor, filters: TrendFilters) -> TrendPage: + """ + Search signals in the database using filters and pagination. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + filters : TrendFilters + Query filters for search, including pagination. + + Returns + ------- + page : TrendPage + Paginated search results for trends. + """ + query = """ + SELECT + *, COUNT(*) OVER() AS total_count + FROM + trends AS t + LEFT OUTER JOIN ( + SELECT + trend_id, array_agg(signal_id) AS connected_signals + FROM + connections + GROUP BY + trend_id + ) AS c + ON + t.id = c.trend_id + WHERE + (%(ids)s IS NULL OR id = ANY(%(ids)s)) + AND status = ANY(%(statuses)s) + AND (%(created_by)s IS NULL OR created_by = %(created_by)s) + AND (%(created_for)s IS NULL OR created_for = %(created_for)s) + AND (%(steep_primary)s IS NULL OR steep_primary = %(steep_primary)s) + AND (%(steep_secondary)s IS NULL OR steep_secondary && %(steep_secondary)s) + AND (%(signature_primary)s IS NULL OR signature_primary = %(signature_primary)s) + AND (%(signature_secondary)s IS NULL OR signature_secondary && %(signature_secondary)s) + AND (%(sdgs)s IS NULL OR sdgs && %(sdgs)s) + AND (%(assigned_to)s IS NULL OR assigned_to = %(assigned_to)s) + AND (%(time_horizon)s IS NULL OR time_horizon = %(time_horizon)s) + AND (%(impact_rating)s IS NULL OR impact_rating = %(impact_rating)s) + AND (%(query)s IS NULL OR text_search_field @@ websearch_to_tsquery('english', %(query)s)) + ORDER BY + {} {} + OFFSET + %(offset)s + LIMIT + %(limit)s + ; + """ + query = sql.SQL(query).format( + sql.Identifier(filters.order_by), + sql.SQL(filters.direction), + ) + await cursor.execute(query, filters.model_dump()) + rows = await cursor.fetchall() + page = TrendPage.from_search(rows, filters) + return page + + +async def create_trend(cursor: AsyncCursor, trend: Trend) -> int: + """ + Insert a trend into the database, connect it to signals and upload an attachment + to Azure Blob Storage if applicable. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + trend : Trend + A trend object to insert. + + Returns + ------- + trend_id : int + An ID of the trend in the database. + """ + query = """ + INSERT INTO trends ( + status, + created_by, + created_for, + modified_by, + headline, + description, + steep_primary, + steep_secondary, + signature_primary, + signature_secondary, + sdgs, + assigned_to, + time_horizon, + impact_rating, + impact_description + ) + VALUES ( + %(status)s, + %(created_by)s, + %(created_for)s, + %(modified_by)s, + %(headline)s, + %(description)s, + %(steep_primary)s, + %(steep_secondary)s, + %(signature_primary)s, + %(signature_secondary)s, + %(sdgs)s, + %(assigned_to)s, + %(time_horizon)s, + %(impact_rating)s, + %(impact_description)s + ) + RETURNING + id + ; + """ + await cursor.execute(query, trend.model_dump()) + row = await cursor.fetchone() + trend_id = row["id"] + + # add connected signals if any are present + for signal_id in trend.connected_signals or []: + query = "INSERT INTO connections (signal_id, trend_id, created_by) VALUES (%s, %s, %s);" + await cursor.execute(query, (signal_id, trend_id, trend.created_by)) + + # upload an image + if trend.attachment is not None: + try: + blob_url = await storage.upload_image(trend_id, "trends", trend.attachment) + except Exception as e: + print(e) + else: + query = "UPDATE trends SET attachment = %s WHERE id = %s;" + await cursor.execute(query, (blob_url, trend_id)) + return trend_id + + +async def read_trend(cursor: AsyncCursor, uid: int) -> Trend | None: + """ + Read a trend from the database using an ID. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + uid : int + An ID of the trend to retrieve data for. + + Returns + ------- + Trend | None + A trend if it exits, otherwise None. + """ + query = """ + SELECT + * + FROM + trends AS t + LEFT OUTER JOIN ( + SELECT + trend_id, array_agg(signal_id) AS connected_signals + FROM + connections + GROUP BY + trend_id + ) AS c + ON + t.id = c.trend_id + WHERE + id = %s + ; + """ + await cursor.execute(query, (uid,)) + if (row := await cursor.fetchone()) is None: + return None + return Trend(**row) + + +async def update_trend(cursor: AsyncCursor, trend: Trend) -> int | None: + """ + Update a trend in the database, update its connected signals and update an attachment + in the Azure Blob Storage if applicable. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + trend : Trend + A trend object to update. + + Returns + ------- + int | None + A trend ID if the update has been performed, otherwise None. + """ + query = """ + UPDATE + trends + SET + status = COALESCE(%(status)s, status), + created_for = COALESCE(%(created_for)s, created_for), + modified_at = NOW(), + modified_by = %(modified_by)s, + headline = COALESCE(%(headline)s, headline), + description = COALESCE(%(description)s, description), + steep_primary = COALESCE(%(steep_primary)s, steep_primary), + steep_secondary = COALESCE(%(steep_secondary)s, steep_secondary), + signature_primary = COALESCE(%(signature_primary)s, signature_primary), + signature_secondary = COALESCE(%(signature_secondary)s, signature_secondary), + sdgs = COALESCE(%(sdgs)s, sdgs), + assigned_to = COALESCE(%(assigned_to)s, assigned_to), + time_horizon = COALESCE(%(time_horizon)s, time_horizon), + impact_rating = COALESCE(%(impact_rating)s, impact_rating), + impact_description = COALESCE(%(impact_description)s, impact_description) + WHERE + id = %(id)s + RETURNING + id + ; + """ + await cursor.execute(query, trend.model_dump()) + if (row := await cursor.fetchone()) is None: + return None + trend_id = row["id"] + + # update connected signals if any are present + await cursor.execute("DELETE FROM connections WHERE trend_id = %s;", (trend_id,)) + for signal_id in trend.connected_signals or []: + query = "INSERT INTO connections (signal_id, trend_id, created_by) VALUES (%s, %s, %s);" + await cursor.execute(query, (signal_id, trend_id, trend.created_by)) + + # upload an image if it is not a URL to an existing image + blob_url = await storage.update_image(trend_id, "trends", trend.attachment) + query = "UPDATE trends SET attachment = %s WHERE id = %s;" + await cursor.execute(query, (blob_url, trend_id)) + + return trend_id + + +async def delete_trend(cursor: AsyncCursor, uid: int) -> Trend | None: + """ + Delete a trend from the database and, if applicable, an image from + Azure Blob Storage, using an ID. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + uid : int + An ID of the signal to delete. + + Returns + ------- + Trend | None + A deleted trend object if the operation has been successful, otherwise None. + """ + query = "DELETE FROM trends WHERE id = %s RETURNING *;" + await cursor.execute(query, (uid,)) + if (row := await cursor.fetchone()) is None: + return None + trend = Trend(**row) + if trend.attachment is not None: + await storage.delete_image(entity_id=trend.id, folder_name="trends") + return trend diff --git a/src/database/users.py b/src/database/users.py new file mode 100644 index 0000000..06ef088 --- /dev/null +++ b/src/database/users.py @@ -0,0 +1,198 @@ +""" +CRUD operations for user entities. +""" + +from psycopg import AsyncCursor + +from ..entities import User, UserFilters, UserPage + +__all__ = [ + "search_users", + "create_user", + "read_user_by_email", + "read_user", + "update_user", + "get_acclab_users", +] + + +async def search_users(cursor: AsyncCursor, filters: UserFilters) -> UserPage: + """ + Search users in the database using filters and pagination. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + filters : UserFilters + Query filters for search, including pagination. + + Returns + ------- + page : UserPage + Paginated search results for users. + """ + query = """ + SELECT + *, COUNT(*) OVER() AS total_count + FROM + users + WHERE + role = ANY(%(roles)s) + AND (%(query)s IS NULL OR name ~* %(query)s) + ORDER BY + name + OFFSET + %(offset)s + LIMIT + %(limit)s + ; + """ + await cursor.execute(query, filters.model_dump()) + rows = await cursor.fetchall() + page = UserPage.from_search(rows, filters) + return page + + +async def create_user(cursor: AsyncCursor, user: User) -> int: + """ + Insert a user into the database. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + user : User + A user object to insert. + + Returns + ------- + int + An ID of the user in the database. + """ + query = """ + INSERT INTO users ( + created_at, + email, + role, + name, + unit, + acclab + ) + VALUES ( + %(created_at)s, + %(email)s, + %(role)s, + %(name)s, + %(unit)s, + %(acclab)s + ) + RETURNING + id + ; + """ + await cursor.execute(query, user.model_dump()) + row = await cursor.fetchone() + return row["id"] + + +async def read_user_by_email(cursor: AsyncCursor, email: str) -> User | None: + """ + Read a user from the database using an email address. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + email : str + An email address. + + Returns + ------- + user : User + A user object if found, otherwise None. + """ + query = "SELECT * FROM users WHERE email = %s;" + await cursor.execute(query, (email,)) + if (row := await cursor.fetchone()) is None: + return None + user = User(**row) + return user + + +async def read_user(cursor: AsyncCursor, uid: int) -> User | None: + """ + Read a user from the database using an ID. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + uid : int + An ID of the user to retrieve data for. + + Returns + ------- + User | None + A user if it exits, otherwise None. + """ + query = "SELECT * FROM users WHERE id = %s;" + await cursor.execute(query, (uid,)) + if (row := await cursor.fetchone()) is None: + return None + return User(**row) + + +async def update_user(cursor: AsyncCursor, user: User) -> int | None: + """ + Update a user in the database. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + user : User + A user object to update. + + Returns + ------- + int | None + A user ID if the update has been performed, otherwise None. + """ + query = """ + UPDATE + users + SET + role = %(role)s, + name = %(name)s, + unit = %(unit)s, + acclab = %(acclab)s + WHERE + email = %(email)s + RETURNING + id + ; + """ + await cursor.execute(query, user.model_dump()) + if (row := await cursor.fetchone()) is None: + return None + return row["id"] + + +async def get_acclab_users(cursor: AsyncCursor) -> list[str]: + """ + Get emails of users who are part of the Accelerator Labs. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + + Returns + ------- + list[str] + A list of emails for users who are part of the Accelerator Labs. + """ + query = "SELECT email FROM users WHERE acclab = TRUE;" + await cursor.execute(query) + return [row["email"] async for row in cursor] diff --git a/src/dependencies.py b/src/dependencies.py new file mode 100644 index 0000000..df9f7d5 --- /dev/null +++ b/src/dependencies.py @@ -0,0 +1,62 @@ +""" +Functions used for dependency injection for role-based access control. +""" + +from typing import Annotated + +from fastapi import Depends, Path +from psycopg import AsyncCursor + +from . import database as db +from . import exceptions +from .authentication import authenticate_user +from .entities import User + +__all__ = [ + "require_admin", + "require_curator", + "require_user", + "require_creator", +] + + +async def require_admin(user: User = Depends(authenticate_user)) -> User: + """Require that the user is assigned an admin role.""" + if not user.is_admin: + raise exceptions.permission_denied + return user + + +async def require_curator(user: User = Depends(authenticate_user)) -> User: + """Require that the user is assigned at least a curator role.""" + if not user.is_staff: + raise exceptions.permission_denied + return user + + +async def require_user(user: User = Depends(authenticate_user)) -> User: + """Require that the user is assigned at least a user role and is not a visitor.""" + if not user.is_regular: + raise exceptions.permission_denied + return user + + +async def require_creator( + uid: Annotated[int, Path(description="The ID of the signal to be updated")], + user: User = Depends(authenticate_user), + cursor: AsyncCursor = Depends(db.yield_cursor), +) -> User: + """Require that the user is at least a curator or is the creator of the signal.""" + # admins and curators can modify all signals + if user.is_staff: + return user + # check if the user created the original signal + signal = await db.read_signal(cursor, uid) + if signal is None: + raise exceptions.not_found + if signal.created_by != user.email: + raise exceptions.permission_denied + # regular users can modify their signals but cannot change their statuses + if signal.status != signal.status: + raise exceptions.permission_denied + return user diff --git a/src/entities/__init__.py b/src/entities/__init__.py new file mode 100644 index 0000000..7c421d8 --- /dev/null +++ b/src/entities/__init__.py @@ -0,0 +1,117 @@ +""" +Entities (models) for receiving, managing and sending data. +""" + +from math import ceil +from typing import Self + +from pydantic import BaseModel, Field + +from .parameters import * +from .signal import * +from .trend import * +from .user import * +from .utils import * + + +class Page(BaseModel): + """ + A paginated results model, holding pagination metadata and search results. + """ + + per_page: int = Field(description="Number of entities per page to retrieve.") + current_page: int = Field( + description="Current page for which entities should be retrieved." + ) + total_pages: int = Field(description="Total number of pages for pagination.") + total_count: int = Field(description="Total number of entities in the database.") + data: list[Signal] | list[Trend] | list[User] + + @classmethod + def from_search(cls, rows: list[dict], pagination: Pagination) -> Self: + """ + Create paginated results model from search results. + + Parameters + ---------- + rows : list[dict] + A list of results returned from the database. + pagination : Pagination + The pagination object used to retrieve the results. + + Returns + ------- + page : Page + The paginated results model. + """ + total = rows[0]["total_count"] if rows else 0 + page = cls( + per_page=pagination.limit, + current_page=pagination.page, + total_pages=ceil(total / pagination.limit), + total_count=total, + data=rows, + ) + return page + + def sanitise(self, user: User) -> Self: + """ + Remove items from the list the user has no permissions to access. + + Parameters + ---------- + user : User + A user making the request. + + Returns + ------- + self : Self + The paginated results model with sanitised data. + """ + if user.role == Role.ADMIN: + # all signals/trends are shown + pass + elif user.role == Role.CURATOR: + # all signals/trends are shown except other users' drafts + self.data = [ + entity + for entity in self.data + if not ( + entity.status == Status.DRAFT and entity.created_by != user.email + ) + ] + return self + elif user.role == Role.USER: + # only approved signals/trends are shown or signals/trends authored by the user + self.data = [ + entity + for entity in self.data + if entity.status == Status.APPROVED or entity.created_by == user.email + ] + else: + # only approved signals/trends after anonymisation are shown + self.data = [ + entity.anonymise() + for entity in self.data + if entity.status == Status.APPROVED + ] + return self + return self + + +class UserPage(Page): + """A specialised paginated results model for users.""" + + data: list[User] + + +class SignalPage(Page): + """A specialised paginated results model for signals.""" + + data: list[Signal] + + +class TrendPage(Page): + """A specialised paginated results model for trends.""" + + data: list[Trend] diff --git a/src/entities/base.py b/src/entities/base.py new file mode 100644 index 0000000..4ddc157 --- /dev/null +++ b/src/entities/base.py @@ -0,0 +1,120 @@ +""" +Entity (model) definitions for base objects that others inherit from. +""" + +from datetime import UTC, datetime + +from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator + +from . import utils + +__all__ = ["BaseMetadata", "BaseEntity", "timestamp"] + + +def timestamp() -> str: + """ + Get the current timestamp in the ISO format. + + Returns + ------- + str + A timestamp in the ISO format. + """ + return datetime.now(UTC).isoformat() + + +class BaseMetadata(BaseModel): + """Base metadata for database objects.""" + + id: int = Field(default=1) + created_at: str = Field(default_factory=timestamp) + + @field_validator("created_at", mode="before") + @classmethod + def format_created_at(cls, value): + """(De)serialisation function for `created_at` timestamp.""" + if isinstance(value, str): + return value + return value.isoformat() + + +class BaseEntity(BaseMetadata): + """Base entity for signals and trends.""" + + status: utils.Status = Field( + default=utils.Status.NEW, + description="Current signal review status.", + ) + created_by: EmailStr | None = Field(default=None) + created_for: str | None = Field(default=None) + modified_at: str = Field(default_factory=timestamp) + modified_by: EmailStr | None = Field(default=None) + headline: str | None = Field( + default=None, + description="A clear and concise title headline.", + ) + description: str | None = Field( + default=None, + description="A clear and concise description.", + ) + attachment: str | None = Field( + default=None, + description="An optional base64-encoded image/URL for illustration.", + ) + steep_primary: utils.Steep | None = Field( + default=None, + description="Primary category in terms of STEEP+V analysis methodology.", + ) + steep_secondary: list[utils.Steep] | None = Field( + default=None, + description="Secondary categories in terms of STEEP+V analysis methodology.", + ) + signature_primary: utils.Signature | None = Field( + default=None, + description="Primary category in terms of UNDP Signature Solutions methodology.", + ) + signature_secondary: list[utils.Signature] | None = Field( + default=None, + description="Secondary categories in terms of UNDP Signature Solutions methodology.", + ) + sdgs: list[utils.Goal] | None = Field( + default=None, + description="Relevant Sustainable Development Goals.", + ) + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "headline": "The cost of corruption", + "description": "Corruption is one of the scourges of modern life. Its costs are staggering.", + "steep_primary": utils.Steep.ECONOMIC, + "steep_secondary": [utils.Steep.SOCIAL], + "signature_primary": utils.Signature.GOVERNANCE, + "signature_secondary": [ + utils.Signature.POVERTY, + utils.Signature.RESILIENCE, + ], + "sdgs": [utils.Goal.G16, utils.Goal.G17], + } + } + ) + + @field_validator("modified_at", mode="before") + @classmethod + def format_modified_at(cls, value): + """(De)serialisation function for `modified_at` timestamp.""" + if isinstance(value, str): + return value + return value.isoformat() + + def anonymise(self) -> "BaseEntity": + """ + Anonymise an entity by removing personal information, such as user emails. + + This function is to be used to retrieve information for visitors, i.e., + not authenticated users, to preserve privacy of other users. + """ + email_mask = "email.hidden@undp.org" + self.created_by = email_mask + self.modified_by = email_mask + return self diff --git a/src/entities/parameters.py b/src/entities/parameters.py new file mode 100644 index 0000000..d019efb --- /dev/null +++ b/src/entities/parameters.py @@ -0,0 +1,86 @@ +""" +Dataclasses used to define query parameters in API endpoints. +""" + +from typing import Literal + +from pydantic import BaseModel, Field, computed_field + +from .utils import Goal, Horizon, Rating, Role, Score, Signature, Status, Steep + +__all__ = [ + "Pagination", + "SignalFilters", + "TrendFilters", + "UserFilters", +] + + +class Pagination(BaseModel): + """A container class for pagination parameters.""" + + page: int = Field(default=1, gt=0) + per_page: int = Field(default=10, gt=0, le=10_000) + order_by: str = Field(default="created_at") + direction: Literal["desc", "asc"] = Field( + default="desc", + description="Ascending or descending order.", + ) + + @computed_field + def limit(self) -> int: + """ + An alias for `per_page` can be dropped in favour of + limit: int = Field(default=10, gt=0, le=10_000, alias="per_page") + once https://github.com/fastapi/fastapi/discussions/12401 is resolved. + + Returns + ------- + int + A value of `per_page`. + """ + return self.per_page + + @computed_field + def offset(self) -> int: + """An offset value that can be used in a database query.""" + return self.limit * (self.page - 1) + + +class BaseFilters(Pagination): + """Base filtering parameters shared by signal and trend filters.""" + + ids: list[int] | None = Field(default=None) + statuses: list[Status] = Field(default=(Status.APPROVED,)) + created_by: str | None = Field(default=None) + created_for: str | None = Field(default=None) + steep_primary: Steep | None = Field(default=None) + steep_secondary: list[Steep] | None = Field(default=None) + signature_primary: Signature | None = Field(default=None) + signature_secondary: list[Signature] | None = Field(default=None) + sdgs: list[Goal] | None = Field(default=None) + query: str | None = Field(default=None) + + +class SignalFilters(BaseFilters): + """Filter parameters for searching signals.""" + + location: str | None = Field(default=None) + bureau: str | None = Field(default=None) + score: Score | None = Field(default=None) + unit: str | None = Field(default=None) + + +class TrendFilters(BaseFilters): + """Filter parameters for searching trends.""" + + assigned_to: str | None = Field(default=None) + time_horizon: Horizon | None = Field(default=None) + impact_rating: Rating | None = Field(default=None) + + +class UserFilters(Pagination): + """Filter parameters for searching users.""" + + roles: list[Role] = Field(default=(Role.VISITOR, Role.CURATOR, Role.ADMIN)) + query: str | None = Field(default=None) diff --git a/src/entities/signal.py b/src/entities/signal.py new file mode 100644 index 0000000..5112e46 --- /dev/null +++ b/src/entities/signal.py @@ -0,0 +1,43 @@ +""" +Entity (model) definitions for signal objects. +""" + +from pydantic import ConfigDict, Field + +from . import utils +from .base import BaseEntity + +__all__ = ["Signal"] + + +class Signal(BaseEntity): + """The signal entity model used in the database and API endpoints.""" + + created_unit: str | None = Field(default=None) + url: str | None = Field(default=None) + relevance: str | None = Field(default=None) + keywords: list[str] | None = Field( + default=None, + description="Use up to 3 clear, simple keywords for ease of searchability.", + ) + location: str | None = Field( + default=None, + description="Region and/or country for which this signal has greatest relevance.", + ) + score: utils.Score | None = Field(default=None) + connected_trends: list[int] | None = Field( + default=None, + description="IDs of trends connected to this signal.", + ) + + model_config = ConfigDict( + json_schema_extra={ + "example": BaseEntity.model_config["json_schema_extra"]["example"] + | { + "url": "https://undp.medium.com/the-cost-of-corruption-a827306696fb", + "relevance": "Of the approximately US$13 trillion that governments spend on public spending, up to 25 percent is lost to corruption.", + "keywords": ["economy", "governance"], + "location": "Global", + } + } + ) diff --git a/src/entities/trend.py b/src/entities/trend.py new file mode 100644 index 0000000..88b6ae8 --- /dev/null +++ b/src/entities/trend.py @@ -0,0 +1,44 @@ +""" +Entity (model) definitions for trend objects. +""" + +from pydantic import ConfigDict, Field, field_validator + +from . import utils +from .base import BaseEntity + +__all__ = ["Trend"] + + +class Trend(BaseEntity): + """The trend entity model used in the database and API endpoints.""" + + assigned_to: str | None = Field(default=None) + time_horizon: utils.Horizon | None = Field(default=None) + impact_rating: utils.Rating | None = Field(default=None) + impact_description: str | None = Field(default=None) + connected_signals: list[int] | None = Field( + default=None, + description="IDs of signals connected to this trend.", + ) + + model_config = ConfigDict( + json_schema_extra={ + "example": BaseEntity.model_config["json_schema_extra"]["example"] + | { + "time_horizon": utils.Horizon.MEDIUM, + "impact_rating": utils.Rating.HIGH, + "impact_description": "Anywhere between 1.4 and 35 percent of climate action funds may be lost due to corruption.", + } + } + ) + + @field_validator("impact_rating", mode="before") + @classmethod + def from_string(cls, value): + """ + Coerce an integer string for `impact_rating` to a string + """ + if value is None or isinstance(value, str): + return value + return str(value) diff --git a/src/entities/user.py b/src/entities/user.py new file mode 100644 index 0000000..8cde39a --- /dev/null +++ b/src/entities/user.py @@ -0,0 +1,55 @@ +""" +Entity (model) definitions for user objects. +""" + +from pydantic import ConfigDict, EmailStr, Field + +from .base import BaseMetadata +from .utils import Role + +__all__ = ["User"] + + +class User(BaseMetadata): + """The user entity model used in the database and API endpoints.""" + + email: EmailStr = Field(description="Work email ending with @undp.org.") + role: Role = Field(default=Role.VISITOR) + name: str | None = Field(default=None, min_length=5, description="Full name.") + unit: str | None = Field( + default=None, + min_length=2, + description="UNDP unit name from a predefined list.", + ) + acclab: bool | None = Field( + default=None, + description="Whether or not a user is part of the Accelerator Labs.", + ) + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "id": 1, + "email": "john.doe@undp.org", + "role": "Curator", + "name": "John Doe", + "unit": "BPPS", + "acclab": True, + } + } + ) + + @property + def is_admin(self): + """Check if the user is an admin.""" + return self.role == Role.ADMIN + + @property + def is_staff(self): + """Check if the user is a curator or admin.""" + return self.role in {Role.ADMIN, Role.CURATOR} + + @property + def is_regular(self): + """Check if the user is a regular user, not a visitor using API key.""" + return self.role in {Role.ADMIN, Role.CURATOR, Role.USER} diff --git a/src/entities/utils.py b/src/entities/utils.py new file mode 100644 index 0000000..b475897 --- /dev/null +++ b/src/entities/utils.py @@ -0,0 +1,123 @@ +""" +Utility classes that define valid string options. +""" + +from enum import StrEnum + +__all__ = [ + "Role", + "Status", + "Steep", + "Signature", + "Goal", + "Score", + "Horizon", + "Rating", + "Bureau", +] + + +class Role(StrEnum): + """ + User roles for RBAC. Admins, curators and users are actual users logged in + to the platform who authenticate via JWT. Visitor role is assigned to + a dummy user authenticated with an API key. + """ + + ADMIN = "Admin" # curator + can change the roles of other users + CURATOR = "Curator" # user + can edit and approve signals and trends + USER = "User" # visitor + can submit signals + VISITOR = "Visitor" # can only view signals and trends + + +class Status(StrEnum): + """Signal/trend review statuses.""" + + DRAFT = "Draft" + NEW = "New" + APPROVED = "Approved" + ARCHIVED = "Archived" + + +class Steep(StrEnum): + """Categories in terms of Steep-V methodology.""" + + SOCIAL = "Social – Issues related to human culture, demography, communication, movement and migration, work and education" + TECHNOLOGICAL = "Technological – Made culture, tools, devices, systems, infrastructure and networks" + ECONOMIC = "Economic – Issues of value, money, financial tools and systems, business and business models, exchanges and transactions" + ENVIRONMENTAL = "Environmental – The natural world, living environment, sustainability, resources, climate and health" + POLITICAL = "Political – Legal issues, policy, governance, rules and regulations and organizational systems" + VALUES = "Values – Ethics, spirituality, ideology or other forms of values" + + +class Signature(StrEnum): + """The six Signature Solutions of the United Nations Development Programme.""" + + POVERTY = "Poverty and Inequality" + GOVERNANCE = "Governance" + RESILIENCE = "Resilience" + ENVIRONMENT = "Environment" + ENERGY = "Energy" + GENDER = "Gender Equality" + # 3 enables + INNOVATION = "Strategic Innovation" + DIGITALISATION = "Digitalisation" + FINANCING = "Development Financing" + + +class Goal(StrEnum): + """The 17 United Nations Sustainable Development Goals.""" + + G1 = "GOAL 1: No Poverty" + G2 = "GOAL 2: Zero Hunger" + G3 = "GOAL 3: Good Health and Well-being" + G4 = "GOAL 4: Quality Education" + G5 = "GOAL 5: Gender Equality" + G6 = "GOAL 6: Clean Water and Sanitation" + G7 = "GOAL 7: Affordable and Clean Energy" + G8 = "GOAL 8: Decent Work and Economic Growth" + G9 = "GOAL 9: Industry, Innovation and Infrastructure" + G10 = "GOAL 10: Reduced Inequality" + G11 = "GOAL 11: Sustainable Cities and Communities" + G12 = "GOAL 12: Responsible Consumption and Production" + G13 = "GOAL 13: Climate Action" + G14 = "GOAL 14: Life Below Water" + G15 = "GOAL 15: Life on Land" + G16 = "GOAL 16: Peace and Justice Strong Institutions" + G17 = "GOAL 17: Partnerships to achieve the Goal" + + +class Score(StrEnum): + """Signal novelty scores.""" + + ONE = "1 — Non-novel (known, but potentially notable in particular context)" + TWO = "2" + THREE = "3 — Potentially novel or uncertain, but not clear in its potential impact" + FOUR = "4" + FIVE = "5 — Something that introduces or points to a potentially interesting or consequential change in direction of trends" + + +class Horizon(StrEnum): + """Trend impact horizons.""" + + SHORT = "Horizon 1 (0-3 years)" + MEDIUM = "Horizon 2 (3-7 years)" + LONG = "Horizon 3 (7-10 years)" + + +class Rating(StrEnum): + """Trend impact rating.""" + + LOW = "1 – Low" + MODERATE = "2 – Moderate" + HIGH = "3 – Significant" + + +class Bureau(StrEnum): + """Bureaus of the United Nations Development Programme.""" + + RBA = "RBA" + RBAP = "RBAP" + RBAS = "RBAS" + RBEC = "RBEC" + RBLAC = "RBLAC" diff --git a/src/exceptions.py b/src/exceptions.py new file mode 100644 index 0000000..52b11bc --- /dev/null +++ b/src/exceptions.py @@ -0,0 +1,44 @@ +""" +Exceptions raised by API endpoints. +""" + +from fastapi import HTTPException, status + +__all__ = [ + "id_mismatch", + "not_authenticated", + "permission_denied", + "not_found", + "content_error", + "generation_error", +] + +id_mismatch = HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Resource ID in body does not match path ID.", +) + +not_authenticated = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated.", +) + +permission_denied = HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have permissions to perform this action.", +) + +not_found = HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="The requested resource could not be found.", +) + +content_error = HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="The content from the URL could not be fetched.", +) + +generation_error = HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="A signal could not be generated from the content.", +) diff --git a/src/genai.py b/src/genai.py new file mode 100644 index 0000000..6834c57 --- /dev/null +++ b/src/genai.py @@ -0,0 +1,106 @@ +""" +Functions for generating signals from web content using Azure OpenAI. +""" + +import json +import os + +from openai import AsyncAzureOpenAI + +from .entities import Signal + +__all__ = ["get_system_message", "get_client", "generate_signal"] + + +def get_system_message() -> str: + """ + Get a system message for generating a signal from web content. + + Returns + ------- + system_message : str + A system message that can be used to generate a signal. + """ + schema = Signal.model_json_schema() + schema.pop("example", None) + # generate content only for the following fields + fields = { + "headline", + "description", + "steep_primary", + "steep_secondary", + "signature_primary", + "signature_secondary", + "sdgs", + "keywords", + } + schema["properties"] = { + k: v for k, v in schema["properties"].items() if k in fields + } + system_message = f""" + You are a Signal Scanner within the Strategy & Futures Team at the United Nations Development Programme. + Your task is to generate a Signal from web content provided by the user. A Signal is defined as a single + piece of evidence or indicator that points to, relates to, or otherwise supports a trend. + It can also stand alone as a potential indicator of future change in one or more trends. + + ### Rules + 1. Your output must be a valid JSON string object without any markdown that can be directly passed to `json.loads`. + 2. The JSON string must conform to the schema below. + 3. The response must be in English, so translate content if necessary. + 4. For `headline` and `description`, do not just copy-paste text, instead summarize the information + in a concise yet insightful manner. + + ### Signal Schema + + ```json + {json.dumps(schema, indent=2)} + ``` + """.strip() + return system_message + + +def get_client() -> AsyncAzureOpenAI: + """ + Get an asynchronous Azure OpenAI client. + + Returns + ------- + client : AsyncAzureOpenAI + An asynchronous client for Azure OpenAI. + """ + client = AsyncAzureOpenAI( + azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + api_version="2024-02-15-preview", + api_key=os.getenv("AZURE_OPENAI_API_KEY"), + timeout=10, + ) + return client + + +async def generate_signal(text: str) -> Signal: + """ + Generate a signal from a text. + + Parameters + ---------- + text : str + A text, i.e., web content, to be analysed. + + Returns + ------- + signal : Signal + A signal entity generated from the text. + """ + client = get_client() + system_message = get_system_message() + response = await client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + {"role": "system", "content": system_message}, + {"role": "user", "content": text}, + ], + temperature=0.3, # vary the output to alleviate occasional errors + ) + content = response.choices[0].message.content + signal = Signal(**json.loads(content)) + return signal diff --git a/src/routers/__init__.py b/src/routers/__init__.py new file mode 100644 index 0000000..cf23153 --- /dev/null +++ b/src/routers/__init__.py @@ -0,0 +1,15 @@ +""" +Router definitions for API endpoints. +""" + +from .choices import router as choice_router +from .signals import router as signal_router +from .trends import router as trend_router +from .users import router as user_router + +ALL = [ + choice_router, + signal_router, + trend_router, + user_router, +] diff --git a/src/routers/choices.py b/src/routers/choices.py new file mode 100644 index 0000000..4d5f395 --- /dev/null +++ b/src/routers/choices.py @@ -0,0 +1,59 @@ +""" +A router for obtaining valid choice options. +""" + +from fastapi import APIRouter, Depends +from psycopg import AsyncCursor + +from .. import database as db +from .. import exceptions +from ..authentication import authenticate_user +from ..entities import utils + +router = APIRouter(prefix="/choices", tags=["choices"]) + + +CREATED_FOR = [ + "General scanning", + "Global Signals Spotlight 2024", + "Global Signals Spotlight 2023", + "HDR 2023", + "Sustainable Finance Hub 2023", +] + + +@router.get("", response_model=dict, dependencies=[Depends(authenticate_user)]) +async def read_choices(cursor: AsyncCursor = Depends(db.yield_cursor)): + """ + List valid options for all fields. + """ + choices = { + name.lower(): [member.value for member in getattr(utils, name)] + for name in utils.__all__ + } + choices["created_for"] = CREATED_FOR + choices["unit_name"] = await db.get_unit_names(cursor) + choices["unit_region"] = await db.get_unit_regions(cursor) + choices["location"] = await db.get_location_names(cursor) + return choices + + +@router.get( + "/{name}", response_model=list[str], dependencies=[Depends(authenticate_user)] +) +async def read_field_choices(name: str, cursor: AsyncCursor = Depends(db.yield_cursor)): + """ + List valid options for a given field. + """ + match name: + case "unit_name": + choices = await db.get_unit_names(cursor) + case "unit_region": + choices = await db.get_unit_regions(cursor) + case "location": + choices = await db.get_location_names(cursor) + case name if name.capitalize() in utils.__all__: + choices = [member.value for member in getattr(utils, name.capitalize())] + case _: + raise exceptions.not_found + return choices diff --git a/src/routers/signals.py b/src/routers/signals.py new file mode 100644 index 0000000..bfb5057 --- /dev/null +++ b/src/routers/signals.py @@ -0,0 +1,153 @@ +""" +A router for retrieving, submitting and updating signals. +""" + +from typing import Annotated + +import pandas as pd +from fastapi import APIRouter, Depends, Path, Query +from psycopg import AsyncCursor + +from .. import database as db +from .. import exceptions, genai, utils +from ..authentication import authenticate_user +from ..dependencies import require_creator, require_curator, require_user +from ..entities import Role, Signal, SignalFilters, SignalPage, Status, User + +router = APIRouter(prefix="/signals", tags=["signals"]) + + +@router.get("/search", response_model=SignalPage) +async def search_signals( + filters: Annotated[SignalFilters, Query()], + user: User = Depends(authenticate_user), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """Search signals in the database using pagination and filters.""" + page = await db.search_signals(cursor, filters) + return page.sanitise(user) + + +@router.get("/export", response_model=None, dependencies=[Depends(require_curator)]) +async def export_signals( + filters: Annotated[SignalFilters, Query()], + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Export signals that match the filters from the database. You can export up to + 10k rows at once. + """ + page = await db.search_signals(cursor, filters) + + # prettify the data + df = pd.DataFrame([signal.model_dump() for signal in page.data]) + df = utils.binarise_columns(df, ["steep_secondary", "signature_secondary", "sdgs"]) + df["keywords"] = df["keywords"].str.join(" ;") + df["connected_trends"] = df["connected_trends"].str.join("; ") + + # add acclab indicator variable + emails = await db.users.get_acclab_users(cursor) + df["acclab"] = df["created_by"].isin(emails) + + response = utils.write_to_response(df, "signals") + return response + + +@router.get("/generation", response_model=Signal) +async def generate_signal( + url: str = Query( + description="A public webpage URL whose content will be used to generate a signal." + ), + user: User = Depends(require_user), +): + """Generate a signal from web content using OpenAI.""" + try: + content = await utils.scrape_content(url) + except Exception as e: + print(e) + raise exceptions.content_error + try: + signal = await genai.generate_signal(content) + except Exception as e: + print(e) + raise exceptions.generation_error + signal.created_by = user.email + signal.created_unit = user.unit + signal.url = url + return signal + + +@router.post("", response_model=Signal, status_code=201) +async def create_signal( + signal: Signal, + user: User = Depends(require_user), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Submit a signal to the database. If the signal has a base64 encoded image + attachment, it will be uploaded to Azure Blob Storage. + """ + signal.created_by = user.email + signal.modified_by = user.email + signal.created_unit = user.unit + signal_id = await db.create_signal(cursor, signal) + return await db.read_signal(cursor, signal_id) + + +@router.get("/me", response_model=list[Signal]) +async def read_my_signals( + status: Status = Query(), + user: User = Depends(authenticate_user), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Retrieve signal with a given status submitted by the current user. + """ + return await db.read_user_signals(cursor, user.email, status) + + +@router.get("/{uid}", response_model=Signal) +async def read_signal( + uid: Annotated[int, Path(description="The ID of the signal to retrieve")], + user: User = Depends(authenticate_user), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Retrieve a signal form the database using an ID. Trends connected to the signal + can be retrieved using IDs from the `signal.connected_trends` field. + """ + if (signal := await db.read_signal(cursor, uid)) is None: + raise exceptions.not_found + if user.role == Role.VISITOR and signal.status != Status.APPROVED: + raise exceptions.permission_denied + return signal + + +@router.put("/{uid}", response_model=Signal) +async def update_signal( + uid: Annotated[int, Path(description="The ID of the signal to be updated")], + signal: Signal, + user: User = Depends(require_creator), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """Update a signal in the database.""" + if uid != signal.id: + raise exceptions.id_mismatch + signal.modified_by = user.email + if (signal_id := await db.update_signal(cursor, signal)) is None: + raise exceptions.not_found + return await db.read_signal(cursor, signal_id) + + +@router.delete("/{uid}", response_model=Signal, dependencies=[Depends(require_creator)]) +async def delete_signal( + uid: Annotated[int, Path(description="The ID of the signal to be deleted")], + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Delete a signal from the database using IDs. This also deletes an image attachment from + Azure Blob Storage if there is one. + """ + if (signal := await db.delete_signal(cursor, uid)) is None: + raise exceptions.not_found + return signal diff --git a/src/routers/trends.py b/src/routers/trends.py new file mode 100644 index 0000000..d85bda2 --- /dev/null +++ b/src/routers/trends.py @@ -0,0 +1,112 @@ +""" +A router for retrieving, submitting and updating trends. +""" + +from typing import Annotated + +import pandas as pd +from fastapi import APIRouter, Depends, Path, Query, status +from psycopg import AsyncCursor + +from .. import database as db +from .. import exceptions, utils +from ..authentication import authenticate_user +from ..dependencies import require_curator +from ..entities import Role, Status, Trend, TrendFilters, TrendPage, User + +router = APIRouter(prefix="/trends", tags=["trends"]) + + +@router.get("/search", response_model=TrendPage) +async def search_trends( + filters: Annotated[TrendFilters, Query()], + user: User = Depends(authenticate_user), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """Search trends in the database using pagination and filters.""" + page = await db.search_trends(cursor, filters) + return page.sanitise(user) + + +@router.get("/export", response_model=None, dependencies=[Depends(require_curator)]) +async def export_trends( + filters: Annotated[TrendFilters, Query()], + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Export trends that match the filters from the database. You can export up to + 10k rows at once. + """ + page = await db.search_trends(cursor, filters) + + # prettify the data + df = pd.DataFrame([trend.model_dump() for trend in page.data]) + df = utils.binarise_columns(df, ["steep_secondary", "signature_secondary", "sdgs"]) + df["connected_signals_count"] = df["connected_signals"].str.len() + df.drop("connected_signals", axis=1, inplace=True) + + response = utils.write_to_response(df, "trends") + return response + + +@router.post("", response_model=Trend, status_code=status.HTTP_201_CREATED) +async def create_trend( + trend: Trend, + user: User = Depends(require_curator), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Submit a trend to the database. If the trend has a base64 encoded image + attachment, it will be uploaded to Azure Blob Storage. + """ + trend.created_by = user.email + trend.modified_by = user.email + trend_id = await db.create_trend(cursor, trend) + return await db.read_trend(cursor, trend_id) + + +@router.get("/{uid}", response_model=Trend) +async def read_trend( + uid: Annotated[int, Path(description="The ID of the trend to retrieve")], + user: User = Depends(authenticate_user), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Retrieve a trend form the database using an ID. Signals connected to the trend + can be retrieved using IDs from the `trend.connected_signals` field. + """ + if (trend := await db.read_trend(cursor, uid)) is None: + raise exceptions.not_found + if user.role == Role.VISITOR and trend.status != Status.APPROVED: + raise exceptions.permission_denied + return trend + + +@router.put("/{uid}", response_model=Trend) +async def update_trend( + uid: Annotated[int, Path(description="The ID of the trend to be updated")], + trend: Trend, + user: User = Depends(require_curator), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """Update a trend in the database.""" + if uid != trend.id: + raise exceptions.id_mismatch + trend.modified_by = user.email + if (trend_id := await db.update_trend(cursor, trend=trend)) is None: + raise exceptions.not_found + return await db.read_trend(cursor, trend_id) + + +@router.delete("/{uid}", response_model=Trend, dependencies=[Depends(require_curator)]) +async def delete_trend( + uid: Annotated[int, Path(description="The ID of the trend to be deleted")], + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Delete a trend from the database using IDs. This also deletes an image attachment from + Azure Blob Storage if there is one. + """ + if (trend := await db.delete_trend(cursor, uid)) is None: + raise exceptions.not_found + return trend diff --git a/src/routers/users.py b/src/routers/users.py new file mode 100644 index 0000000..e244458 --- /dev/null +++ b/src/routers/users.py @@ -0,0 +1,70 @@ +""" +A router for creating, reading and updating trends. +""" + +from typing import Annotated + +from fastapi import APIRouter, Depends, Path, Query +from psycopg import AsyncCursor + +from .. import database as db +from .. import exceptions +from ..authentication import authenticate_user +from ..dependencies import require_admin, require_user +from ..entities import Role, User, UserFilters, UserPage + +router = APIRouter(prefix="/users", tags=["users"]) + + +@router.get("/search", response_model=UserPage, dependencies=[Depends(require_admin)]) +async def search_users( + filters: Annotated[UserFilters, Query()], + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """Search users in the database using pagination and filters.""" + page = await db.search_users(cursor, filters) + return page + + +@router.get("/me", response_model=User) +async def read_current_user(user: User = Depends(authenticate_user)): + """Read the current user information from a JTW token.""" + if user is None: + raise exceptions.not_found + return user + + +@router.get("/{uid}", response_model=User, dependencies=[Depends(require_admin)]) +async def read_user( + uid: Annotated[int, Path(description="The ID of the user to retrieve")], + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """Read users form the database using IDs.""" + if (users := await db.read_user(cursor, uid)) is None: + raise exceptions.not_found + return users + + +@router.put("/{uid}", response_model=User) +async def update_user( + uid: Annotated[int, Path(description="The ID of the user to be updated")], + user_new: User, + user: User = Depends(require_user), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Update a user in the database. Non-admin users can only update + their own name, unit and accelerator lab flag. Only admin users can + update other users' roles. + """ + if uid != user_new.id: + raise exceptions.id_mismatch + if user.role == Role.ADMIN: + pass + elif user.email != user_new.email or user.id != user_new.id: + raise exceptions.permission_denied + elif user.role != user_new.role: + raise exceptions.permission_denied + if (user_id := await db.update_user(cursor, user_new)) is None: + raise exceptions.not_found + return await db.read_user(cursor, user_id) diff --git a/src/storage.py b/src/storage.py new file mode 100644 index 0000000..b27ca5c --- /dev/null +++ b/src/storage.py @@ -0,0 +1,162 @@ +""" +Utilities for interacting with Azure Blob Storage for uploading and deleting image attachments. +""" + +import os +from typing import Literal +from urllib.parse import urlparse + +from azure.core.exceptions import ResourceNotFoundError +from azure.storage.blob import ContentSettings +from azure.storage.blob.aio import ContainerClient + +from .utils import convert_to_thumbnail + +__all__ = [ + "upload_image", + "delete_image", + "update_image", +] + + +def get_folder_path(folder_name: Literal["signals", "trends"]) -> str: + """ + Get a path to an image folder derived from a database connection string. + + This allows to manage images from staging and production environments differently. + + Returns + ------- + str + A folder path to save signal/trend images to. + """ + database_name = urlparse(os.getenv("DB_CONNECTION")).path.strip("/") + return f"{database_name}/{folder_name}" + + +def get_container_client() -> ContainerClient: + """ + Get a asynchronous container client for Azure Blob Storage. + + Returns + ------- + client : ContainerClient + An asynchronous container client. + """ + client = ContainerClient.from_container_url(container_url=os.environ["SAS_URL"]) + return client + + +async def upload_image( + entity_id: int, + folder_name: Literal["signals", "trends"], + image_string: str, +) -> str: + """ + Upload a thumbnail JPEG version of an image to Azure Blob Storage and return a (public) URL. + + The function converts the image to a JPEG format and rescales it to 720p. + + Parameters + ---------- + entity_id : str + Signal or trend ID for which an image is uploaded. + folder_name : Literal['signals', 'trends'] + Folder name to save the image to. + image_string : str + Base64-encoded image data. + + Returns + ------- + blob_url : str + A (public) URL pointing to the image file on Blob Storage + that can be used to embed the image in HTML. + """ + # decode the image string + image_string = image_string.split(sep=",", maxsplit=1)[-1] + image_data = convert_to_thumbnail(image_string) + + # connect and upload to the storage + async with get_container_client() as client: + folder_path = get_folder_path(folder_name) + blob_client = await client.upload_blob( + name=f"{folder_path}/{entity_id}.jpeg", + data=image_data, + blob_type="BlockBlob", + overwrite=True, + content_settings=ContentSettings(content_type="image/jpeg"), + ) + + # remove URL parameters that expose a SAS token + blob_url = blob_client.url.split("?")[0] + return blob_url + + +async def delete_image( + entity_id: int, + folder_name: Literal["signals", "trends"], +) -> bool: + """ + Remove an image from Azure Blob Storage. + + Parameters + ---------- + entity_id : str + Signal or trend ID whose image is to be deleted. + folder_name : Literal['signals', 'trends'] + Folder name to delete the image from. + + Returns + ------- + True if the blob has been deleted and False otherwise. + """ + folder_path = get_folder_path(folder_name) + async with get_container_client() as client: + try: + await client.delete_blob(f"{folder_path}/{entity_id}.jpeg") + except ResourceNotFoundError: + return False + return True + + +async def update_image( + entity_id: int, + folder_name: Literal["signals", "trends"], + attachment: str | None, +) -> str | None: + """ + Update an image attachment on Azure Blob Storage. + + If the attachment is None, the attachment will be deleted, if it is a base64-encoded + image, the attachment will be updated, it is a URL to an image, no action will be taken. + + Parameters + ---------- + entity_id : str + Signal or trend ID whose image is to be updated. + folder_name : Literal['signals', 'trends'] + Folder name to delete the image from. + attachment : str | None + A base64-encoded image data, existing attachment URL or None. + + Returns + ------- + str | None + A string to the current/updated image or None if it has been deleted or update failed. + """ + if attachment is None: + await delete_image(entity_id, folder_name) + return None + if attachment.startswith("https"): + return attachment + + try: + blob_url = await upload_image( + entity_id=entity_id, + folder_name=folder_name, + image_string=attachment, + ) + except Exception as e: + print(e) + return None + return blob_url diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..0cf9e4a --- /dev/null +++ b/src/utils.py @@ -0,0 +1,143 @@ +""" +Miscellaneous utilities for data wrangling. +""" + +import base64 +from datetime import UTC, datetime +from io import BytesIO +from typing import Literal + +import httpx +import pandas as pd +from bs4 import BeautifulSoup +from fastapi.responses import StreamingResponse +from PIL import Image +from sklearn.preprocessing import MultiLabelBinarizer + + +def convert_to_thumbnail(image_string: str) -> bytes: + """ + Convert an image to a JPEG thumbnail no larger than HD quality. + + Parameters + ---------- + image_string : str + Base64 encoded image string. + + Returns + ------- + image_data : bytes + The image thumbnail encoded as JPEG. + """ + # read the original image + buffer = BytesIO(base64.b64decode(image_string)) + image = Image.open(buffer) + + # resize, save in-memory and return bytes + buffer = BytesIO() + image.thumbnail((1280, 720)) + image.convert("RGB").save(buffer, format="jpeg") + image_data = buffer.getvalue() + return image_data + + +async def scrape_content(url: str) -> str: + """ + Scrape content of a web page to be fed to an OpenAI model. + + Parameters + ---------- + url : str + A publicly accessible URL. + + Returns + ------- + str + Web content of a page "as is". + """ + headers = {"User-Agent": "Mozilla/5.0"} + async with httpx.AsyncClient(headers=headers, timeout=10) as client: + response = await client.get(url) + soup = BeautifulSoup(response.content, features="lxml") + return soup.text + + +def format_column_name(prefix: str, value: str) -> str: + """ + Format column names for dummy columns created by MultiLabelBinarizer. + + This function is used for prettifying exports. + + Parameters + ---------- + prefix : str + A prefix to assign to a column. + value : str + The original column name value. + + Returns + ------- + str + Formated column name. + """ + value = value.split("–")[0] + value = value.strip().lower() + value = value.replace(" ", "_") + return f"{prefix}_{value}" + + +def binarise_columns(df: pd.DataFrame, columns: list[str]): + """ + Binarise columns containing array values. + + Parameters + ---------- + df : pd.DataFrame + Arbitrary dataframe. + columns : list[str] + Columns to binarise. + + Returns + ------- + df : pd.DataFrame + Mutated data frame. + """ + for column in columns: + mlb = MultiLabelBinarizer() + # fill in missing values with an empty list + values = mlb.fit_transform(df[column].apply(lambda x: x or [])) + df_dummies = pd.DataFrame(values, columns=mlb.classes_) + df_dummies.rename(lambda x: format_column_name(column, x), axis=1, inplace=True) + df = df.join(df_dummies) + df.drop(columns, axis=1, inplace=True) + return df + + +def write_to_response( + df: pd.DataFrame, + kind: Literal["signals", "trends"], +) -> StreamingResponse: + """ + Write a data frame to an Excel file in a Streaming response that can be returned by the API. + + Parameters + ---------- + df : pd.DataFrame + A data frame of exported signals/trends. + kind : Literal["signals", "trends"] + A kind of the data being exported to include in the file name. + + Returns + ------- + response : StreamingResponse + A response object containing the exported data that can be returned by the API. + """ + buffer = BytesIO() + df.to_excel(buffer, index=False) + file_name = f"ftss-{kind}-{datetime.now(UTC):%y-%m-%d}.xlsx" + response = StreamingResponse( + BytesIO(buffer.getvalue()), + media_type="application/vnd.ms-excel", + headers={"Content-Disposition": f"attachment; filename={file_name}"}, + ) + return response diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..47d7588 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,22 @@ +""" +Fixtures for setting up the tests shared across the test suite. +""" + +import os + +import pytest +from dotenv import load_dotenv + +load_dotenv() + + +@pytest.fixture(scope="session", params=[os.environ["API_KEY"], os.environ["API_JWT"]]) +def headers(request) -> dict[str, str]: + """Header for authentication with an API key or a JWT for a regular user (not curator or admin).""" + return {"access_token": request.param} + + +@pytest.fixture(scope="session") +def headers_with_jwt(request) -> dict[str, str]: + """Header for authentication with a JWT for a regular user (not curator or admin).""" + return {"access_token": os.environ["API_JWT"]} diff --git a/tests/test_authentication.py b/tests/test_authentication.py new file mode 100644 index 0000000..985b1d0 --- /dev/null +++ b/tests/test_authentication.py @@ -0,0 +1,78 @@ +""" +Basic tests to ensure the authentication is required and API key works for specific endpoints. +Currently, the tests do not cover JWT-based authentication. +""" + +import re +from typing import Literal + +from fastapi.testclient import TestClient +from pytest import mark + +from main import app + +client = TestClient(app) + + +def get_endpoints(pattern: str, method: Literal["GET", "POST", "PUT"] | None = None): + """ + Convenience method to get all endpoints for a specific method. + + Parameters + ---------- + pattern : str + A regex pattern to match against endpoint path. + method : Literal["GET", "POST", "PUT"] | None + An optional type of HTTP method to filter endpoints. + + Returns + ------- + endpoints : list[tuple[str, str]] + A list of tuples containing endpoint path and method name. + """ + endpoints = [] + for route in app.routes: + if re.search(pattern, route.path): + if method is None or {method} & route.methods: + endpoints.append((route.path, list(route.methods)[0])) + return endpoints + + +def test_read_docs(): + """Ensure the documentation page is accessible without authentication.""" + response = client.get("/") + assert response.status_code == 200, "Documentation page is inaccessible" + + +@mark.parametrize( + "endpoint,method", + get_endpoints(r"signals|trends|users|choices"), +) +def test_authentication_required(endpoint: str, method: str): + """Check if endpoints, except for documentation, require authentication.""" + response = client.request(method, endpoint) + assert response.status_code == 403 + + +@mark.parametrize( + "endpoint,status_code", + [ + ("/signals/search", 200), + ("/signals/export", 403), + ("/signals/10", 200), + ("/trends/search", 200), + ("/trends/export", 403), + ("/trends/10", 200), + ("/users/search", 403), + ("/users/me", 200), + ("/users/1", 403), + ("/choices", 200), + ("/choices/status", 200), + ], +) +def test_authentication_get(endpoint: str, status_code: int, headers: dict): + """ + Check if authentication for GET endpoints works as expected. + """ + response = client.get(endpoint, headers=headers) + assert response.status_code == status_code diff --git a/tests/test_choices.py b/tests/test_choices.py new file mode 100644 index 0000000..e643850 --- /dev/null +++ b/tests/test_choices.py @@ -0,0 +1,41 @@ +""" +Basic tests for choices endpoints. +""" + +from enum import StrEnum + +from fastapi.testclient import TestClient +from pytest import mark + +from main import app +from src.entities import Goal, Role, Status + +client = TestClient(app) + +enums = [("status", Status), ("role", Role), ("goal", Goal)] + + +def test_choices(headers: dict): + endpoint = "/choices" + response = client.get(endpoint, headers=headers) + assert response.status_code == 200 + + data = response.json() + assert isinstance(data, dict) + for k, v in data.items(): + assert isinstance(k, str) + assert isinstance(v, list) + for name, enum in enums: + assert name in data + assert data[name] == [x.value for x in enum] + + +@mark.parametrize("name,enum", enums) +def test_choice_name(name: str, enum: StrEnum, headers: dict): + endpoint = f"/choices/{name}" + response = client.get(endpoint, headers=headers) + assert response.status_code == 200 + + data = response.json() + assert isinstance(data, list) + assert data == [x.value for x in enum] diff --git a/tests/test_signals_and_trends.py b/tests/test_signals_and_trends.py new file mode 100644 index 0000000..cdf6054 --- /dev/null +++ b/tests/test_signals_and_trends.py @@ -0,0 +1,120 @@ +""" +Basic tests for search and CRUD operations on signals and trends. +""" + +from typing import Literal + +from fastapi.testclient import TestClient +from pytest import mark + +from main import app +from src.entities import Goal, Pagination, Signal, Trend + +client = TestClient(app) + + +@mark.parametrize("path", ["signals", "trends"]) +@mark.parametrize("page", [None, 1, 2]) +@mark.parametrize("per_page", [None, 10, 20]) +@mark.parametrize("goal", [None, Goal.G13]) +@mark.parametrize("query", [None, "climate"]) +@mark.parametrize("ids", [None, list(range(100))]) +def test_search( + path: Literal["signals", "trends"], + page: int | None, + per_page: int | None, + goal: Goal | None, + query: str | None, + ids: list[int] | None, + headers_with_jwt: dict, +): + endpoint = f"/{path}/search" + params = { + "page": page, + "per_page": per_page, + "goal": goal, + "query": query, + "ids": ids, + } + params = {k: v for k, v in params.items() if v is not None} + + # ensure the pagination values are set to defaults if not used in the request + page = page or Pagination.model_fields["page"].default + per_page = per_page or Pagination.model_fields["per_page"].default + + response = client.get(endpoint, params=params, headers=headers_with_jwt) + assert response.status_code == 200 + results = response.json() + assert results.get("current_page") == page + assert results.get("per_page") == per_page + assert isinstance(results.get("data"), list) + assert 0 < len(results["data"]) <= per_page + match path: + case "signals": + entity = Signal(**results["data"][0]) + case "trends": + entity = Trend(**results["data"][0]) + case _: + raise ValueError(f"Unknown path: {path}") + assert entity.id == results["data"][0]["id"] + + +@mark.parametrize("path", ["signals", "trends"]) +@mark.parametrize("uid", list(range(10, 20))) +def test_read_by_id( + path: Literal["signals", "trends"], + uid: int, + headers_with_jwt: dict, +): + endpoint = f"/{path}/{uid}" + response = client.get(endpoint, headers=headers_with_jwt) + assert response.status_code in {200, 404} + if response.status_code == 200: + data = response.json() + signal = Signal(**data) + assert signal.id == data["id"] + + +def test_crud(headers_with_jwt: dict): + """Currently, testing for signals only as a staff role is required to manage trends.""" + # instantiate a test object + entity = Signal(**Signal.model_config["json_schema_extra"]["example"]) + + # create + endpoint = "/signals" + response = client.post(endpoint, json=entity.model_dump(), headers=headers_with_jwt) + assert response.status_code == 201 + data = response.json() + assert entity.headline == data["headline"] + assert entity.description == data["description"] + assert entity.sdgs == data["sdgs"] + + # read + endpoint = "/signals/{}".format(data["id"]) + response = client.get(endpoint, headers=headers_with_jwt) + assert response.status_code == 200 + data = response.json() + assert entity.headline == data["headline"] + assert entity.description == data["description"] + assert entity.sdgs == data["sdgs"] + + # update + endpoint = "/signals/{}".format(data["id"]) + data |= { + "headline": "New Headline", + "description": "Lorem opsum " * 10, + "sdgs": [Goal.G1, Goal.G17], + } + response = client.put(endpoint, json=data, headers=headers_with_jwt) + assert response.status_code == 200 + data = response.json() + assert entity.headline != data["headline"] + assert entity.description != data["description"] + assert entity.sdgs != data["sdgs"] + + # delete + endpoint = "/signals/{}".format(data["id"]) + response = client.delete(endpoint, headers=headers_with_jwt) + assert response.status_code == 200 + response = client.get(endpoint, headers=headers_with_jwt) + assert response.status_code == 404 diff --git a/tests/test_users.py b/tests/test_users.py new file mode 100644 index 0000000..2cbdca9 --- /dev/null +++ b/tests/test_users.py @@ -0,0 +1,57 @@ +""" +Basic tests for user endpoints. +""" + +from fastapi.testclient import TestClient +from pytest import mark + +from main import app +from src.entities import Role, User + +client = TestClient(app) + + +def test_me(headers: dict): + endpoint = "/users/me" + response = client.get(endpoint, headers=headers) + assert response.status_code == 200 + + user = User(**response.json()) + assert user.role in {Role.VISITOR, Role.USER} + assert not user.is_admin + assert not user.is_staff + + +@mark.parametrize( + "unit", + [ + "Bureau for Policy and Programme Support (BPPS)", + "Chief Digital Office (CDO)", + "Executive Office (ExO)", + ], +) +@mark.parametrize("acclab", [True, False]) +def test_update(unit: str, acclab: bool, headers_with_jwt: dict): + # get the user data from the database + endpoint = "/users/me" + response = client.get(endpoint, headers=headers_with_jwt) + assert response.status_code == 200 + user = User(**response.json()) + assert user.role == Role.USER, "The JWT must belong to a regular user" + + # regular users should be able to update their profile + endpoint = f"/users/{user.id}" + user.acclab = acclab + user.unit = unit + response = client.put(endpoint, json=user.model_dump(), headers=headers_with_jwt) + assert response.status_code == 200 + data = response.json() + assert user.id == data["id"] + assert user.is_regular + assert user.acclab == acclab + assert user.unit == unit + + # regular users should not be able to change their role + user.role = Role.ADMIN + response = client.put(endpoint, json=user.model_dump(), headers=headers_with_jwt) + assert response.status_code == 403