diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..f6b4abd --- /dev/null +++ b/.dockerignore @@ -0,0 +1,43 @@ +# Ignore Python cache files +__pycache__/ +*.pyc +*.pyo +*.pyd + +# Ignore virtual environment directories +venv/ +env/ +.venv/ +.env/ + +# Jupyter Notebook checkpoints +.ipynb_checkpoints/ + +# Local configuration files +*.env +*.local + +# Ignore test and coverage files +tests/ +*.cover +.coverage +nosetests.xml +coverage.xml +*.log + +# Ignore IDE/editor specific files +.vscode/ +.idea/ + +# Ignore Docker files +Dockerfile +docker-compose.yml + +# Ignore documentation files +docs/ +*.md + +# Ignore other unnecessary files +*.DS_Store +*.tmp +*.temp diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..dfe0770 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# Auto detect text files and perform LF normalization +* text=auto diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..749fb3f --- /dev/null +++ b/.gitignore @@ -0,0 +1,142 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# Manually added for this project +.idea/ +**/.DS_Store diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..2dab8d0 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,10 @@ +FROM python:3.11.7-slim +RUN apt-get update -y \ + && apt-get install libpq-dev -y \ + && rm -rf /var/lib/apt/lists/* +WORKDIR /app +COPY requirements.txt . +RUN pip install --no-cache-dir --upgrade -r requirements.txt +COPY . . +EXPOSE 8000 +CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..24ae370 --- /dev/null +++ b/LICENSE @@ -0,0 +1,28 @@ +BSD 3-Clause License + +Copyright (c) 2024, United Nations Development Programme + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..898ae03 --- /dev/null +++ b/Makefile @@ -0,0 +1,8 @@ +install: + pip install --upgrade pip && pip install -r requirements_dev.txt +format: + isort . --profile black --multi-line 3 && black . +lint: + pylint main.py src/ +test: + python -m pytest tests/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..5ee5bb4 --- /dev/null +++ b/README.md @@ -0,0 +1,148 @@ +# Future Trends and Signals System (FTSS) API + +[![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg)](https://www.python.org/downloads/release/python-3110/) +[![License](https://img.shields.io/github/license/undp-data/ftss-api)](https://github.com/undp-data/ftss-api/blob/main/LICENSE) +[![Black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) +[![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) +[![Conventional Commits](https://img.shields.io/badge/Conventional%20Commits-1.0.0-%23FE5196?logo=conventionalcommits&logoColor=white)](https://conventionalcommits.org) + +This repository hosts the API that powers the [UNDP Future Trends and Signals System](https://signals.data.undp.org) (FTSS). +The API is written using [FastAPI](https://fastapi.tiangolo.com) in Python and deployed on Azure App Services. +It serves as an intermediary between the front-end application and back-end database. The codebase is an open-source +of the original project transferred from Azure DevOps. + +## Table of Contents + +- [Introduction](#introduction) +- [Getting Started](#getting-started) +- [Build and Test](#build-and-test) +- [Contribute](#contribute) +- [License](#license) +- [Contact](#contact) + +## Introduction + +The FTSS is an internal system built for the staff of the United Nations Development Programme, designed to capture +signals of change, and identify emerging trends within and outside the organisation. This repository hosts the back-end +API that powers the platform and is accompanied by the [front-end repository](https://github.com/undp-data/fe-signals-and-trends). + +The API is written and tested in Python `3.11` using [FastAPI](https://fastapi.tiangolo.com) framework. Database and +storage routines are implemented in an asynchronous manner, making the application fast and responsive. The API is +deployed on Azure App Services to development and production environments from `dev` and `main` branches +respectively. The API interacts with a PostgreSQL database deployed as an Azure Database for PostgreSQL instance. The +instance comprises `staging` and `production` databases. An Azure Blob Storage container stores images used as +illustrations for signals and trends. The simplified architecture of the whole application is shown in the image below. + +Commits to `staging` branch in the front-end repository and `dev` branch in this repository trigger CI/CD pipelines for +the staging environment. While there is a single database instance, the data in the staging environment is isolated in +the `staging` database/schema separate from `production` database/schema within the same database instance. The same +logic applies to the blob storage – images uploaded in the staging environment are managed separately from those in the +production environment. + +![Preview](images/architecture.drawio.svg) + +Authentication in the API happens via tokens (JWT) issued by Microsoft Entra upon user log-in in the front-end +application. Some endpoints to retrieve approved signals/trends are accessible with a static API key +for integration with other applications. + +## Getting Started + +For running the application locally, you can use either your local environment or a Docker container. Either way, +clone the repository and navigate to the project directory first: + +```shell +# Clone the repository +git clone https://github.com/undp-data/ftss-api + +# Navigate to the project folder +cd ftss-api +``` + +You must also ensure that the following environment variables are set up: + +```text +# Authentication +TENANT_ID="" +CLIENT_ID="" +API_KEY="" # for accessing "public" endpoints + +# Database and Storage +DB_CONNECTION="postgresql://:@:5432/" +SAS_URL=""https://.blob.core.windows.net/?" + +# Azure OpenAI, only required for `/signals/generation` +AZURE_OPENAI_ENDPOINT="https://.openai.azure.com/" +AZURE_OPENAI_API_KEY="" + +# Testing, only required to run tests, must be a valid token of a regular user +API_JWT="" +``` + +### Local Environment + +For this scenario, you will need a connection string to the staging database. + +```bash +# Create and activate a virtual environment. +python3 -m venv venv +source venv/bin/activate + +# Install core dependencies. +pip install -r requirements.txt + +# Launch the application. +uvicorn main:app --reload +``` + +Once launched, the application will be running at http://127.0.0.1:8000. + +### Docker Environment + +For this scenario, you do not need a connection string as a fresh PostgreSQL instance will be +set up for you in the container. Ensure that Docker engine is running on you machine, then execute: + +```shell +# Start the containers +docker compose up --build -d +``` + +Once launched, the application will be running at http://127.0.0.1:8000. + +# Build and Test + +The codebase provides some basic tests written in `pytest`. To run them, ensure you have specified a valid token in your +`API_JWT` environment variable. Then run: + +```shell +# run all tests + python -m pytest tests/ + + # or alternatively + make test +``` + +Note that some tests for search endpoints might fail as the tests are run against dynamically changing databases. + +# Contribute + +All contributions must follow [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/). +The codebase is formatted with `black` and `isort`. Use the provided [Makefile](Makefile) for these +routine operations. Make sure to run the linter against your code. + +1. Clone or fork the repository +2. Create a new branch (`git checkout -b feature-branch`) +3. Make your changes +4. Ensure your code is properly formatted (`make format`) +5. Run the linter and check for any issues (`make lint`) +6. Execute the tests (`make test`) +7. Commit your changes (`git commit -m 'Add some feature'`) +8. Push to the branch (`git push origin feature-branch`) +9. Open a pull request to `dev` branch +10. Once tested in the staging environment, open a pull requests to `main` branch + +## Contact + +This project has been originally developed and maintained by [Data Futures Exchange (DFx)](https://data.undp.org) at UNDP. +If you are facing any issues or would like to make some suggestions, feel free to +[open an issue](https://github.com/undp-data/ftss-api/issues/new/choose). +For enquiries about DFx, visit [Contact Us](https://data.undp.org/contact-us). diff --git a/data/locations.csv b/data/locations.csv new file mode 100644 index 0000000..5a9cd6a --- /dev/null +++ b/data/locations.csv @@ -0,0 +1,269 @@ +id,name,iso,region,bureau +1,Global,,Global, +2,Antarctica,,Antarctica, +3,Australia and New Zealand,,Australia and New Zealand, +4,Central Asia,,Central Asia, +5,Eastern Asia,,Eastern Asia, +6,Eastern Europe,,Eastern Europe, +7,Latin America and the Caribbean,,Latin America and the Caribbean, +8,Melanesia,,Melanesia, +9,Micronesia,,Micronesia, +10,Northern Africa,,Northern Africa, +11,Northern America,,Northern America, +12,Northern Europe,,Northern Europe, +13,Polynesia,,Polynesia, +14,South-eastern Asia,,South-eastern Asia, +15,Southern Asia,,Southern Asia, +16,Southern Europe,,Southern Europe, +17,Sub-Saharan Africa,,Sub-Saharan Africa, +18,Western Asia,,Western Asia, +19,Western Europe,,Western Europe, +20,Afghanistan,AFG,Southern Asia,RBAP +21,Albania,ALB,Southern Europe,RBEC +22,Antarctica,ATA,Antarctica, +23,Algeria,DZA,Northern Africa,RBAS +24,American Samoa,ASM,Polynesia, +25,Andorra,AND,Southern Europe, +26,Angola,AGO,Sub-Saharan Africa,RBA +27,Antigua and Barbuda,ATG,Latin America and the Caribbean,RBLAC +28,Azerbaijan,AZE,Western Asia,RBEC +29,Argentina,ARG,Latin America and the Caribbean,RBLAC +30,Australia,AUS,Australia and New Zealand, +31,Austria,AUT,Western Europe, +32,Bahamas,BHS,Latin America and the Caribbean,RBLAC +33,Bahrain,BHR,Western Asia,RBAS +34,Bangladesh,BGD,Southern Asia,RBAP +35,Armenia,ARM,Western Asia,RBEC +36,Barbados,BRB,Latin America and the Caribbean,RBLAC +37,Belgium,BEL,Western Europe, +38,Bermuda,BMU,Northern America, +39,Bhutan,BTN,Southern Asia,RBAP +40,Bolivia (Plurinational State of),BOL,Latin America and the Caribbean,RBLAC +41,Bosnia and Herzegovina,BIH,Southern Europe,RBEC +42,Botswana,BWA,Sub-Saharan Africa,RBA +43,Bouvet Island,BVT,Latin America and the Caribbean, +44,Brazil,BRA,Latin America and the Caribbean,RBLAC +45,Belize,BLZ,Latin America and the Caribbean,RBLAC +46,British Indian Ocean Territory,IOT,Sub-Saharan Africa, +47,Solomon Islands,SLB,Melanesia,RBAP +48,British Virgin Islands,VGB,Latin America and the Caribbean, +49,Brunei Darussalam,BRN,South-eastern Asia,RBAP +50,Bulgaria,BGR,Eastern Europe, +51,Myanmar,MMR,South-eastern Asia,RBAP +52,Burundi,BDI,Sub-Saharan Africa,RBA +53,Belarus,BLR,Eastern Europe,RBEC +54,Cambodia,KHM,South-eastern Asia,RBAP +55,Cameroon,CMR,Sub-Saharan Africa,RBA +56,Canada,CAN,Northern America, +57,Cabo Verde,CPV,Sub-Saharan Africa,RBA +58,Cayman Islands,CYM,Latin America and the Caribbean, +59,Central African Republic,CAF,Sub-Saharan Africa,RBA +60,Sri Lanka,LKA,Southern Asia,RBAP +61,Chad,TCD,Sub-Saharan Africa,RBA +62,Chile,CHL,Latin America and the Caribbean,RBLAC +63,China,CHN,Eastern Asia,RBAP +64,Christmas Island,CXR,Australia and New Zealand, +65,Cocos (Keeling) Islands,CCK,Australia and New Zealand, +66,Colombia,COL,Latin America and the Caribbean,RBLAC +67,Comoros,COM,Sub-Saharan Africa,RBA +68,Mayotte,MYT,Sub-Saharan Africa, +69,Congo,COG,Sub-Saharan Africa,RBA +70,Democratic Republic of the Congo,COD,Sub-Saharan Africa,RBA +71,Cook Islands,COK,Polynesia, +72,Costa Rica,CRI,Latin America and the Caribbean,RBLAC +73,Croatia,HRV,Southern Europe, +74,Cuba,CUB,Latin America and the Caribbean,RBLAC +75,Cyprus,CYP,Western Asia, +76,Czechia,CZE,Eastern Europe, +77,Benin,BEN,Sub-Saharan Africa,RBA +78,Denmark,DNK,Northern Europe, +79,Dominica,DMA,Latin America and the Caribbean,RBLAC +80,Dominican Republic,DOM,Latin America and the Caribbean,RBLAC +81,Ecuador,ECU,Latin America and the Caribbean,RBLAC +82,El Salvador,SLV,Latin America and the Caribbean,RBLAC +83,Equatorial Guinea,GNQ,Sub-Saharan Africa,RBA +84,Ethiopia,ETH,Sub-Saharan Africa,RBA +85,Eritrea,ERI,Sub-Saharan Africa,RBA +86,Estonia,EST,Northern Europe, +87,Faroe Islands,FRO,Northern Europe, +88,Falkland Islands (Malvinas),FLK,Latin America and the Caribbean, +89,South Georgia and the South Sandwich Islands,SGS,Latin America and the Caribbean, +90,Fiji,FJI,Melanesia,RBAP +91,Finland,FIN,Northern Europe, +92,Åland Islands,ALA,Northern Europe, +93,France,FRA,Western Europe, +94,French Guiana,GUF,Latin America and the Caribbean, +95,French Polynesia,PYF,Polynesia, +96,French Southern Territories,ATF,Sub-Saharan Africa, +97,Djibouti,DJI,Sub-Saharan Africa,RBAS +98,Gabon,GAB,Sub-Saharan Africa,RBA +99,Georgia,GEO,Western Asia,RBEC +100,Gambia,GMB,Sub-Saharan Africa,RBA +101,State of Palestine,PSE,Western Asia,RBAS +102,Germany,DEU,Western Europe, +103,Ghana,GHA,Sub-Saharan Africa,RBA +104,Gibraltar,GIB,Southern Europe, +105,Kiribati,KIR,Micronesia,RBAP +106,Greece,GRC,Southern Europe, +107,Greenland,GRL,Northern America, +108,Grenada,GRD,Latin America and the Caribbean,RBLAC +109,Guadeloupe,GLP,Latin America and the Caribbean, +110,Guam,GUM,Micronesia, +111,Guatemala,GTM,Latin America and the Caribbean,RBLAC +112,Guinea,GIN,Sub-Saharan Africa,RBA +113,Guyana,GUY,Latin America and the Caribbean,RBLAC +114,Haiti,HTI,Latin America and the Caribbean,RBLAC +115,Heard Island and McDonald Islands,HMD,Australia and New Zealand, +116,Holy See,VAT,Southern Europe, +117,Honduras,HND,Latin America and the Caribbean,RBLAC +118,"China, Hong Kong Special Administrative Region",HKG,Eastern Asia, +119,Hungary,HUN,Eastern Europe, +120,Iceland,ISL,Northern Europe, +121,India,IND,Southern Asia,RBAP +122,Indonesia,IDN,South-eastern Asia,RBAP +123,Iran (Islamic Republic of),IRN,Southern Asia,RBAP +124,Iraq,IRQ,Western Asia,RBAS +125,Ireland,IRL,Northern Europe, +126,Israel,ISR,Western Asia, +127,Italy,ITA,Southern Europe, +128,Côte d’Ivoire,CIV,Sub-Saharan Africa,RBA +129,Jamaica,JAM,Latin America and the Caribbean,RBLAC +130,Japan,JPN,Eastern Asia, +131,Kazakhstan,KAZ,Central Asia,RBEC +132,Jordan,JOR,Western Asia,RBAS +133,Kenya,KEN,Sub-Saharan Africa,RBA +134,Democratic People's Republic of Korea,PRK,Eastern Asia,RBAP +135,Republic of Korea,KOR,Eastern Asia, +136,Kuwait,KWT,Western Asia,RBAS +137,Kyrgyzstan,KGZ,Central Asia,RBEC +138,Lao People's Democratic Republic,LAO,South-eastern Asia,RBAP +139,Lebanon,LBN,Western Asia,RBAS +140,Lesotho,LSO,Sub-Saharan Africa,RBA +141,Latvia,LVA,Northern Europe, +142,Liberia,LBR,Sub-Saharan Africa,RBA +143,Libya,LBY,Northern Africa,RBAS +144,Liechtenstein,LIE,Western Europe, +145,Lithuania,LTU,Northern Europe, +146,Luxembourg,LUX,Western Europe, +147,"China, Macao Special Administrative Region",MAC,Eastern Asia, +148,Madagascar,MDG,Sub-Saharan Africa,RBA +149,Malawi,MWI,Sub-Saharan Africa,RBA +150,Malaysia,MYS,South-eastern Asia,RBAP +151,Maldives,MDV,Southern Asia,RBAP +152,Mali,MLI,Sub-Saharan Africa,RBA +153,Malta,MLT,Southern Europe, +154,Martinique,MTQ,Latin America and the Caribbean, +155,Mauritania,MRT,Sub-Saharan Africa,RBA +156,Mauritius,MUS,Sub-Saharan Africa,RBA +157,Mexico,MEX,Latin America and the Caribbean,RBLAC +158,Monaco,MCO,Western Europe, +159,Mongolia,MNG,Eastern Asia,RBAP +160,Republic of Moldova,MDA,Eastern Europe,RBEC +161,Montenegro,MNE,Southern Europe,RBEC +162,Montserrat,MSR,Latin America and the Caribbean, +163,Morocco,MAR,Northern Africa,RBAS +164,Mozambique,MOZ,Sub-Saharan Africa,RBA +165,Oman,OMN,Western Asia,RBAS +166,Namibia,NAM,Sub-Saharan Africa,RBA +167,Nauru,NRU,Micronesia,RBAP +168,Nepal,NPL,Southern Asia,RBAP +169,Netherlands (Kingdom of the),NLD,Western Europe, +170,Curaçao,CUW,Latin America and the Caribbean, +171,Aruba,ABW,Latin America and the Caribbean, +172,Sint Maarten (Dutch part),SXM,Latin America and the Caribbean, +173,"Bonaire, Sint Eustatius and Saba",BES,Latin America and the Caribbean, +174,New Caledonia,NCL,Melanesia, +175,Vanuatu,VUT,Melanesia,RBAP +176,New Zealand,NZL,Australia and New Zealand, +177,Nicaragua,NIC,Latin America and the Caribbean,RBLAC +178,Niger,NER,Sub-Saharan Africa,RBA +179,Nigeria,NGA,Sub-Saharan Africa,RBA +180,Niue,NIU,Polynesia, +181,Norfolk Island,NFK,Australia and New Zealand, +182,Norway,NOR,Northern Europe, +183,Northern Mariana Islands,MNP,Micronesia, +184,United States Minor Outlying Islands,UMI,Micronesia, +185,Micronesia (Federated States of),FSM,Micronesia,RBAP +186,Marshall Islands,MHL,Micronesia,RBAP +187,Palau,PLW,Micronesia,RBAP +188,Pakistan,PAK,Southern Asia,RBAP +189,Panama,PAN,Latin America and the Caribbean,RBLAC +190,Papua New Guinea,PNG,Melanesia,RBAP +191,Paraguay,PRY,Latin America and the Caribbean,RBLAC +192,Peru,PER,Latin America and the Caribbean,RBLAC +193,Philippines,PHL,South-eastern Asia,RBAP +194,Pitcairn,PCN,Polynesia, +195,Poland,POL,Eastern Europe, +196,Portugal,PRT,Southern Europe, +197,Guinea-Bissau,GNB,Sub-Saharan Africa,RBA +198,Timor-Leste,TLS,South-eastern Asia,RBAP +199,Puerto Rico,PRI,Latin America and the Caribbean, +200,Qatar,QAT,Western Asia,RBAS +201,Réunion,REU,Sub-Saharan Africa, +202,Romania,ROU,Eastern Europe, +203,Russian Federation,RUS,Eastern Europe, +204,Rwanda,RWA,Sub-Saharan Africa,RBA +205,Saint Barthélemy,BLM,Latin America and the Caribbean, +206,Saint Helena,SHN,Sub-Saharan Africa, +207,Saint Kitts and Nevis,KNA,Latin America and the Caribbean,RBLAC +208,Anguilla,AIA,Latin America and the Caribbean, +209,Saint Lucia,LCA,Latin America and the Caribbean,RBLAC +210,Saint Martin (French Part),MAF,Latin America and the Caribbean, +211,Saint Pierre and Miquelon,SPM,Northern America, +212,Saint Vincent and the Grenadines,VCT,Latin America and the Caribbean,RBLAC +213,San Marino,SMR,Southern Europe, +214,Sao Tome and Principe,STP,Sub-Saharan Africa,RBA +215,Sark,CRQ,Northern Europe, +216,Saudi Arabia,SAU,Western Asia,RBAS +217,Senegal,SEN,Sub-Saharan Africa,RBA +218,Serbia,SRB,Southern Europe,RBEC +219,Seychelles,SYC,Sub-Saharan Africa,RBA +220,Sierra Leone,SLE,Sub-Saharan Africa,RBA +221,Singapore,SGP,South-eastern Asia,RBAP +222,Slovakia,SVK,Eastern Europe, +223,Viet Nam,VNM,South-eastern Asia,RBAP +224,Slovenia,SVN,Southern Europe, +225,Somalia,SOM,Sub-Saharan Africa,RBAS +226,South Africa,ZAF,Sub-Saharan Africa,RBA +227,Zimbabwe,ZWE,Sub-Saharan Africa,RBA +228,Spain,ESP,Southern Europe, +229,South Sudan,SSD,Sub-Saharan Africa,RBA +230,Sudan,SDN,Northern Africa,RBAS +231,Western Sahara,ESH,Northern Africa, +232,Suriname,SUR,Latin America and the Caribbean,RBLAC +233,Svalbard and Jan Mayen Islands,SJM,Northern Europe, +234,Eswatini,SWZ,Sub-Saharan Africa,RBA +235,Sweden,SWE,Northern Europe, +236,Switzerland,CHE,Western Europe, +237,Syrian Arab Republic,SYR,Western Asia,RBAS +238,Tajikistan,TJK,Central Asia,RBEC +239,Thailand,THA,South-eastern Asia,RBAP +240,Togo,TGO,Sub-Saharan Africa,RBA +241,Tokelau,TKL,Polynesia, +242,Tonga,TON,Polynesia,RBAP +243,Trinidad and Tobago,TTO,Latin America and the Caribbean,RBLAC +244,United Arab Emirates,ARE,Western Asia,RBAS +245,Tunisia,TUN,Northern Africa,RBAS +246,Türkiye,TUR,Western Asia,RBEC +247,Turkmenistan,TKM,Central Asia,RBEC +248,Turks and Caicos Islands,TCA,Latin America and the Caribbean, +249,Tuvalu,TUV,Polynesia,RBAP +250,Uganda,UGA,Sub-Saharan Africa,RBA +251,Ukraine,UKR,Eastern Europe,RBEC +252,North Macedonia,MKD,Southern Europe,RBEC +253,Egypt,EGY,Northern Africa,RBAS +254,United Kingdom of Great Britain and Northern Ireland,GBR,Northern Europe, +255,Guernsey,GGY,Northern Europe, +256,Jersey,JEY,Northern Europe, +257,Isle of Man,IMN,Northern Europe, +258,United Republic of Tanzania,TZA,Sub-Saharan Africa,RBA +259,United States of America,USA,Northern America, +260,United States Virgin Islands,VIR,Latin America and the Caribbean, +261,Burkina Faso,BFA,Sub-Saharan Africa,RBA +262,Uruguay,URY,Latin America and the Caribbean,RBLAC +263,Uzbekistan,UZB,Central Asia,RBEC +264,Venezuela (Bolivarian Republic of),VEN,Latin America and the Caribbean,RBLAC +265,Wallis and Futuna Islands,WLF,Polynesia, +266,Samoa,WSM,Polynesia,RBAP +267,Yemen,YEM,Western Asia,RBAS +268,Zambia,ZMB,Sub-Saharan Africa,RBA diff --git a/data/units.csv b/data/units.csv new file mode 100644 index 0000000..68fe0b5 --- /dev/null +++ b/data/units.csv @@ -0,0 +1,201 @@ +id,name,region +1,Afghanistan,RBAP +2,Albania,RBEC +3,Algeria,RBAS +4,Angola,RBA +5,Anguilla,RBLAC +6,Antigua and Barbuda,RBLAC +7,Argentina,RBLAC +8,Armenia,RBEC +9,Aruba,RBLAC +10,Azerbaijan,RBEC +11,Bureau of External Relations and Advocacy (BERA),HQ +12,BMS,HQ +13,Bureau for Policy and Programme Support (BPPS),HQ +14,Bahrain,RBAS +15,Bangladesh,RBAP +16,Barbados,RBLAC +17,Belize,RBLAC +18,Benin,RBA +19,Bermuda,RBLAC +20,Bhutan,RBAP +21,Bolivia,RBLAC +22,Bosnia and Herzegovina,RBEC +23,Botswana,RBA +24,Brazil,RBLAC +25,Brunei Darussalam,RBAP +26,Burkina Faso,RBA +27,Burundi,RBA +28,Cambodia,RBAP +29,Cameroon,RBA +30,Cayman Islands,RBLAC +31,Central African Republic,RBA +32,Chad,RBA +33,Chief Digital Office,HQ +34,Chief Digital Office (CDO),HQ +35,Chile,RBLAC +36,China,RBAP +37,Colombia,RBLAC +38,Comoros,RBA +39,Congo,RBA +40,Cook Islands,RBAP +41,Costa Rica,RBLAC +42,Crisis Bureau (CB),HQ +43,Cuba,RBLAC +44,Curacao and Sint Maarten,RBLAC +45,Cyprus (project office),RBEC +46,Côte d’Ivoire,RBA +47,DPRK,RBAP +48,Democratic Republic of Congo,RBA +49,Development financing,HQ +50,Djibouti,RBAS +51,Dom. Republic,RBLAC +52,Ecuador,RBLAC +53,Egypt,RBAS +54,El Salvador,RBLAC +55,Equatorial Guinea,RBA +56,Eritrea,RBA +57,Eswatini,RBA +58,Ethiopia,RBA +59,Executive Office (ExO),HQ +60,External,Other +61,Federated States of Micronesia,RBAP +62,Fiji MCO,RBAP +63,Future Fellows,HQ +64,Gabon,RBA +65,Gambia,RBA +66,Georgia,RBEC +67,Ghana,RBA +68,"Global Centre for Technology, Innovation and Sustainable Development (Singapore)",BPPS +69,Grenada,RBLAC +70,Guatemala,RBLAC +71,Guinea,RBA +72,Guinea-Bissau,RBA +73,Guyana,RBLAC +74,Guyana-Suriname,RBLAC +75,Haiti,RBLAC +76,Honduras,RBLAC +77,Human Development Report Office (HDRO),HQ +78,Independent Evaluation Office (IEO),HQ +79,India,RBAP +80,Indonesia,RBAP +81,Iran,RBAP +82,Iraq,RBAS +83,"Istanbul International Center for Private Sector in Development (Istanbul, Turkey)",BPPS +84,Jamaica,RBLAC +85,Jordan,RBAS +86,Kazakhstan,RBEC +87,Kenya,RBA +88,Kiribati,RBAP +89,Kosovo (as per UNSCR 1244),RBEC +90,Kuwait,RBAS +91,Kyrgyzstan,RBEC +92,Lao PDR,RBAP +93,Lebanon,RBAS +94,Lesotho,RBA +95,Liberia,RBA +96,Libya,RBAS +97,Madagascar,RBA +98,Malawi,RBA +99,Malaysia MCO,RBAP +100,Maldives,RBAP +101,Mali,RBA +102,Marshall Islands,RBAP +103,Mauritania,RBA +104,Mauritius,RBA +105,Mexico,RBLAC +106,Moldova,RBEC +107,Mongolia,RBAP +108,Montserrat,RBLAC +109,Morocco,RBAS +110,Mozambique,RBA +111,Myanmar,RBAP +112,"Nairobi Global Centre on Resilient Ecosystems and Desertification (Nairobi, Kenya)",BPPS +113,Namibia,RBA +114,Nauru,RBAP +115,Nepal,RBAP +116,Nicaragua,RBLAC +117,Niger,RBA +118,Nigeria,RBA +119,Niue,RBAP +120,North Macedonia,RBEC +121,Office of Audit and Investigations (OAI),HQ +122,"Oslo Governance Centre (Oslo, Norway)",BPPS +123,PNG,RBAP +124,Pakistan,RBAP +125,Palau,RBAP +126,Panama,RBLAC +127,Paraguay,RBLAC +128,People,RBAS +129,Peru,RBLAC +130,Philippines,RBAP +131,Prog for Palestinian,RBAS +132,"Regional Bureau for Africa, New York, USA",HQ +133,"Regional Bureau for Arab States, New York, USA",HQ +134,"Regional Bureau for Asia and the Pacific, New York",HQ +135,"Regional Bureau for Europe and the CIS, New York, USA",HQ +136,"Regional Bureau for Latin America and the Caribbean, New York, USA",HQ +137,"Regional Centre: Panama City, Panama",RBLAC +138,"Regional Hub: Amman, Jordan",RBAS +139,"Regional Hub: Bangkok, Thailand",RBAP +140,"Regional Hub: Dakar, Senegal",RBA +141,"Regional Hub: Istanbul, Turkey",RBEC +142,"Regional Hub: Nairobi, Kenya",RBA +143,"Regional Hub: Pretoria, South Africa",RBA +144,"Regional Service Centre: Addis Ababa, Ethiopia",RBA +145,Republic of Belarus,RBEC +146,Republic of Cape Verde,RBA +147,Republic of Montenegro,RBEC +148,Republic of South Sudan,RBA +149,Republic of the,RBAS +150,Republic of the Sudan,RBAS +151,Rwanda,RBA +152,Saint Kitts and Nevis,RBLAC +153,Saint Lucia and Saint Vincent,RBLAC +154,Samoa MCO,RBAP +155,Sao Tome and Principe,RBA +156,Saudi Arabia,RBAS +157,Senegal,RBA +158,Serbia,RBEC +159,Seychelles,RBA +160,Sierra Leone,RBA +161,Singapore,RBAP +162,Solomon Islands,RBAP +163,Somalia,RBAS +164,South Africa,RBA +165,Sri Lanka,RBAP +166,Suriname,RBLAC +167,Syria,RBAS +168,Tajikistan,RBEC +169,Tanzania,RBA +170,Thailand,RBAP +171,The Bahamas,RBLAC +172,Timor Leste,RBAP +173,Togo,RBA +174,Tokelau,RBAP +175,Tonga,RBAP +176,Trinidad & Tobago,RBLAC +177,Tunisia,RBAS +178,Turkey,RBEC +179,Turkmenistan,RBEC +180,Turks and Caicos Islands,RBLAC +181,Tuvalu,RBAP +182,"UNDP Nordic Representation Office (Copenhagen, Denmark)",BERA +183,"UNDP Office in Geneva (Geneva, Switzerland)",BERA +184,"UNDP Representation Office in Brussels (Brussels, Belgium)",BERA +185,"UNDP Representation Office in Japan (Tokyo, Japan)",BERA +186,"UNDP Seoul Policy Centre for Knowledge Exchange through SDG Partnerships (Seoul, Republic of Korea)",BPPS +187,"UNDP Washington Representation Office (Washington, USA)",BERA +188,Uganda,RBA +189,Ukraine,RBEC +190,Uruguay,RBLAC +191,Uzbekistan,RBEC +192,Vanuatu,RBAP +193,Venezuela,RBLAC +194,Viet Nam,RBAP +195,Yemen,RBAS +196,Zambia,RBA +197,Zimbabwe,RBA +198,the British Virgin Islands,RBLAC +199,the Commonwealth of Dominica,RBLAC +200,the Grenadines,RBLAC diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000..4adec8a --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,20 @@ +services: + web: + build: + context: . + env_file: .env + environment: + - DB_CONNECTION=postgresql://postgres:password@db:5432/postgres + ports: + - "8000:8000" + depends_on: + - db + db: + image: postgres:16.4-alpine + environment: + POSTGRES_PASSWORD: password + ports: + - "5432:5432" + volumes: + - ./sql:/docker-entrypoint-initdb.d + - ./data:/docker-entrypoint-initdb.d/data diff --git a/images/architecture.drawio.svg b/images/architecture.drawio.svg new file mode 100644 index 0000000..c097e4b --- /dev/null +++ b/images/architecture.drawio.svg @@ -0,0 +1,4 @@ + + + +
Deploys from
production branch
HTTP requests
Manages
production
database
Manages
production images
ftss-api.azurewebsites.net
Azure Web App
Deploys to
UNDP-Signals-And-Trends
Static Web App
Deploys from
main branch
Deploys from
dev branch
ftss
Azure Database
for PostgreSQL
Manages
staging
database
Manages
staging images
ftss-api-dev.azurewebsites.net
Azure Web App
UNDP-Signals-And-Trends
Static Web App
Deploys from
staging branch
HTTP requests
Staging
Website
Page-1 Azure Blob Storage (Deprecated) Sheet.265 Sheet.266 Sheet.267 Sheet.268 Sheet.269 ftssBlob Storage
Deploys to
Regular Users
ftss-api
Repository
Production
Staging
Developers and Testers
\ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..2756a42 --- /dev/null +++ b/main.py @@ -0,0 +1,52 @@ +""" +This application serves as an API endpoint for the Signals and Trends project that connects +the frontend platform with the backend database. +""" + +from dotenv import load_dotenv +from fastapi import Depends, FastAPI + +from src import routers +from src.authentication import authenticate_user + +load_dotenv() + +app = FastAPI( + debug=False, + title="Future Trends and Signals API", + version="3.0.0-beta", + summary="""The Future Trends and Signals (FTSS) API powers user experiences on UNDP Future + Trends and Signals System by providing functionality to to manage signals, trends and users.""", + description="""The FTSS API serves as a interface for the + [UNDP Future Trends and Signals System](https://signals.data.undp.org), + facilitating interaction between the front-end application and the underlying relational database. + This API enables users to submit, retrieve, and update data related to signals, trends, and user + profiles within the platform. + + As a private API, it mandates authentication for all endpoints to ensure secure access. + Authentication is achieved by including the `access_token` in the request header, utilising JWT tokens + issued by [Microsoft Entra](https://learn.microsoft.com/en-us/entra/identity-platform/access-tokens). + This mechanism not only secures the API but also allows for the automatic recording of user information + derived from the API token. Approved signals and trends can be accesses using a predefined API key for + integration with other applications. + """.strip().replace( + " ", " " + ), + contact={ + "name": "UNDP Data Futures Platform", + "url": "https://data.undp.org", + "email": "data@undp.org", + }, + openapi_tags=[ + {"name": "signals", "description": "CRUD operations on signals."}, + {"name": "trends", "description": "CRUD operations on trends."}, + {"name": "users", "description": "CRUD operations on users."}, + {"name": "choices", "description": "List valid options for forms fields."}, + ], + docs_url="/", + redoc_url=None, +) + + +for router in routers.ALL: + app.include_router(router=router, dependencies=[Depends(authenticate_user)]) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..aa12ec0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,16 @@ +fastapi == 0.115.3 +uvicorn == 0.32.0 +python-dotenv ~= 1.0.1 +httpx ~= 0.27.2 +pyjwt[crypto] ~= 2.9.0 +pydantic[email] ~= 2.9.2 +psycopg == 3.2.3 +pandas ~= 2.2.3 +openpyxl ~= 3.1.5 +scikit-learn ~= 1.5.2 +azure-storage-blob ~= 12.23.0 +aiohttp ~= 3.10.10 +pillow ~= 11.0.0 +beautifulsoup4 ~= 4.12.3 +lxml ~= 5.3.0 +openai == 1.52.2 diff --git a/requirements_dev.txt b/requirements_dev.txt new file mode 100644 index 0000000..a0cf82c --- /dev/null +++ b/requirements_dev.txt @@ -0,0 +1,6 @@ +-r requirements.txt +black ~= 24.10.0 +isort ~= 5.13.2 +pylint ~= 3.3.1 +pytest ~= 8.3.3 +notebook ~= 7.2.2 diff --git a/sql/create_tables.sql b/sql/create_tables.sql new file mode 100644 index 0000000..e6d5b14 --- /dev/null +++ b/sql/create_tables.sql @@ -0,0 +1,137 @@ +/* +The initialisation script to create tables in an empty database. + +The database schema assumes a denormalised form. This allows to insert data "as is", minimising +the differences between the API layer and database layer and making CRUD operations simpler. +Given the expected size of the database, this design will have marginal impact on the efficiency +even in the long run. + +The database tables comprise: + +1. Users +2. Signals +3. Trends +4. Connections – a junction table for connected signals/trends to model a many-to-many relationship. +5. Locations – stores country and area metadata based on UN M49 that are used for signal location. +6. Units – stores metadata on UNDP units used to assign user units and filter signals. +*/ + +-- users table and indices +CREATE TABLE users ( + id SERIAL PRIMARY KEY, + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + email VARCHAR(255) UNIQUE NOT NULL, + role VARCHAR(255) NOT NULL, + name VARCHAR(255), + unit VARCHAR(255), + acclab BOOLEAN +); + +CREATE INDEX ON users (email); +CREATE INDEX ON users (role); + +-- signals table and indices +CREATE TABLE signals ( + id SERIAL PRIMARY KEY, + status VARCHAR(255) NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + created_by VARCHAR(255) NOT NULL, + created_for VARCHAR(255), + modified_at TIMESTAMP NOT NULL DEFAULT NOW(), + modified_by VARCHAR(255) NOT NULL, + headline TEXT, + description TEXT, + attachment TEXT, -- a URL to Azure Blob Storage + steep_primary TEXT, + steep_secondary TEXT[], + signature_primary TEXT, + signature_secondary TEXT[], + sdgs TEXT[], + created_unit VARCHAR(255), + url TEXT, + relevance TEXT, + keywords TEXT[], + location TEXT, + score TEXT, + text_search_field tsvector GENERATED ALWAYS AS (to_tsvector('english', headline || ' ' || description)) STORED +); + +CREATE INDEX ON signals ( + status, + created_by, + created_for, + created_unit, + steep_primary, + steep_secondary, + signature_primary, + signature_secondary, + sdgs, + location, + score +); +CREATE INDEX ON signals USING GIN (text_search_field); + +-- trends table and indices +CREATE TABLE trends ( + id SERIAL PRIMARY KEY, + status VARCHAR(255) NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + created_by VARCHAR(255) NOT NULL, + created_for TEXT, + modified_at TIMESTAMP NOT NULL DEFAULT NOW(), + modified_by VARCHAR(255) NOT NULL, + headline TEXT, + description TEXT, + attachment TEXT, + steep_primary TEXT, + steep_secondary TEXT[], + signature_primary TEXT, + signature_secondary TEXT[], + sdgs TEXT[], + assigned_to TEXT, + time_horizon TEXT, + impact_rating TEXT, + impact_description TEXT, + text_search_field tsvector GENERATED ALWAYS AS (to_tsvector('english', headline || ' ' || description)) STORED +); + +CREATE INDEX ON trends ( + status, + created_for, + assigned_to, + steep_primary, + steep_secondary, + signature_primary, + signature_secondary, + sdgs, + time_horizon, + impact_rating +); +CREATE INDEX ON trends USING GIN (text_search_field); + +-- junction table for connected signals/trends to model many-to-many relationship +CREATE TABLE connections ( + signal_id INT REFERENCES signals(id) ON DELETE CASCADE, + trend_id INT REFERENCES trends(id) ON DELETE CASCADE, + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + created_by VARCHAR(255) NOT NULL, + CONSTRAINT connection_pk PRIMARY KEY (signal_id, trend_id) +); + +-- locations table and indices +CREATE TABLE locations ( + id SERIAL PRIMARY KEY, + name VARCHAR(128) NOT NULL, + iso VARCHAR(3), + region VARCHAR(128) NOT NULL, + bureau VARCHAR(5) +); +CREATE INDEX ON locations (name, region, bureau); + +-- units table and indices +CREATE TABLE units ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + region VARCHAR(255) +); +CREATE INDEX ON units (name, region); diff --git a/sql/import_data.sql b/sql/import_data.sql new file mode 100644 index 0000000..1965fd8 --- /dev/null +++ b/sql/import_data.sql @@ -0,0 +1,11 @@ +/* +The initialisation script to import locations and units data into an empty database, +from within the docker container. This script is automatically executed by docker +compose after `create_tables.sql`. + */ + +-- import locations data +\copy locations FROM 'docker-entrypoint-initdb.d/data/locations.csv' DELIMITER ',' CSV HEADER; + +-- import units data +\copy units FROM 'docker-entrypoint-initdb.d/data/units.csv' DELIMITER ',' CSV HEADER; diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..8e7302c --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,4 @@ +""" +API components, including database functions, entity (data) models, routers, endpoint +dependencies for RBAC, Azure Blob Storage and GenAI. +""" diff --git a/src/authentication.py b/src/authentication.py new file mode 100644 index 0000000..0f748f3 --- /dev/null +++ b/src/authentication.py @@ -0,0 +1,146 @@ +""" +Dependencies for API authentication using JWT tokens from Microsoft Entra. +""" + +import os + +import httpx +import jwt +from fastapi import Depends, Security +from fastapi.security import APIKeyHeader +from psycopg import AsyncCursor + +from . import database as db +from . import exceptions +from .entities import Role, User + +api_key_header = APIKeyHeader( + name="access_token", + description="The API access token for the application.", + auto_error=True, +) + + +async def get_jwks() -> dict[str, dict]: + """ + Get JSON Web Key Set (JWKS) containing the public keys. + + Returns + ------- + keys : dict[str, dict] + A mapping of kid to JWKS. + """ + # obtain OpenID configuration + tenant_id = os.environ["TENANT_ID"] + endpoint = f"https://login.microsoftonline.com/{tenant_id}/v2.0/.well-known/openid-configuration" + async with httpx.AsyncClient() as client: + response = await client.get(endpoint) + configuration = response.json() + + # get the JSON Web Keys + response = await client.get(configuration["jwks_uri"]) + response.raise_for_status() + jwks = response.json() + + # get a mapping of key IDs to keys + keys = {key["kid"]: key for key in jwks["keys"]} + return keys + + +async def get_jwk(token: str) -> jwt.PyJWK: + """ + Obtain a JSON Web Key (JWK) for a token. + + Parameters + ---------- + token : str + A JSON Web Tokens issued by Microsoft Entra. + + Returns + ------- + jwk : jwt.PyJWK + A ready-to-use JWK object. + """ + header = jwt.get_unverified_header(token) + try: + jwks = await get_jwks() + except httpx.HTTPError: + jwks = {} + jwk = jwks.get(header["kid"]) + if jwk is None: + raise ValueError("JWK could not be obtained or found") + jwk = jwt.PyJWK.from_dict(jwk, "RS256") + return jwk + + +async def decode_token(token: str) -> dict: + """ + Decode and verify a payload of a JSON Web Token (JWT). + + Parameters + ---------- + token : str + A JSON Web Tokens issued by Microsoft Entra. + + Returns + ------- + payload : dict + The decoded payload that should include email + and name for further authentication. + """ + jwk = await get_jwk(token) + tenant_id = os.environ["TENANT_ID"] + payload = jwt.decode( + jwt=token, + key=jwk.key, + algorithms=["RS256"], + audience=os.environ["CLIENT_ID"], + issuer=f"https://sts.windows.net/{tenant_id}/", + options={ + "verify_signature": True, + "verify_aud": True, + "verify_exp": True, + "verify_iss": True, + }, + ) + return payload + + +async def authenticate_user( + token: str = Security(api_key_header), + cursor: AsyncCursor = Depends(db.yield_cursor), +) -> User: + """ + Authenticate a user with a valid JWT token from Microsoft Entra ID. + + The function is used for dependency injection in endpoints to authenticate incoming requests. + The tokens must be issued by Microsoft Entra ID. For a list of available attributes, see + https://learn.microsoft.com/en-us/entra/identity-platform/id-token-claims-reference + + Parameters + ---------- + token : str + Either a predefined api_key to access "public" endpoints or a valid signed JWT. + cursor : AsyncCursor + An async database cursor. + + Returns + ------- + user : User + Pydantic model for a User object (if authentication succeeded). + """ + if token == os.environ.get("API_KEY"): + # dummy user object for anonymous access + user = User(email="name.surname@undp.org", role=Role.VISITOR) + return user + try: + payload = await decode_token(token) + except jwt.exceptions.PyJWTError as e: + raise exceptions.not_authenticated from e + email, name = payload.get("unique_name"), payload.get("name") + if email is None or name is None: + raise exceptions.not_authenticated + if (user := await db.read_user_by_email(cursor, email)) is None: + user = User(email=email, role=Role.USER, name=name) + await db.create_user(cursor, user) + return user diff --git a/src/database/__init__.py b/src/database/__init__.py new file mode 100644 index 0000000..6a2c1e0 --- /dev/null +++ b/src/database/__init__.py @@ -0,0 +1,9 @@ +""" +Functions for connecting to and performing CRUD operations in a PostgreSQL database. +""" + +from .choices import * +from .connection import yield_cursor +from .signals import * +from .trends import * +from .users import * diff --git a/src/database/choices.py b/src/database/choices.py new file mode 100644 index 0000000..98da71f --- /dev/null +++ b/src/database/choices.py @@ -0,0 +1,63 @@ +""" +Functions for reading data related to choice lists. +""" + +from psycopg import AsyncCursor + +__all__ = ["get_unit_names", "get_unit_regions", "get_location_names"] + + +async def get_unit_names(cursor: AsyncCursor) -> list[str]: + """ + Read unit names from the database. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + + Returns + ------- + list[str] + A list of unit names. + """ + await cursor.execute("SELECT name FROM units ORDER BY name;") + return [row["name"] async for row in cursor] + + +async def get_unit_regions(cursor: AsyncCursor) -> list[str]: + """ + Read unit regions from the database. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + + Returns + ------- + list[str] + A list of unique unit regions. + """ + await cursor.execute("SELECT DISTINCT region FROM units ORDER BY region;") + return [row["region"] async for row in cursor] + + +async def get_location_names(cursor: AsyncCursor) -> list[str]: + """ + Read location names from the database. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + + Returns + ------- + list[str] + A list of location names that includes geographic regions, + countries and territories based on UNSD M49. + """ + # do not order by so that regions to appear first + await cursor.execute("SELECT name FROM locations;") + return [row["name"] async for row in cursor] diff --git a/src/database/connection.py b/src/database/connection.py new file mode 100644 index 0000000..995b936 --- /dev/null +++ b/src/database/connection.py @@ -0,0 +1,46 @@ +""" +Database connection functions based on Psycopg 3 project. +""" + +import os + +import psycopg +from psycopg.rows import dict_row + + +async def get_connection() -> psycopg.AsyncConnection: + """ + Get a connection to a PostgreSQL database. + + The connection includes a row factory to return database rows as dictionaries + and a cursor factory that ensures client-side binding. See the + [documentation](https://www.psycopg.org/psycopg3/docs/basic/from_pg2.html#server-side-binding) + for details. + + Returns + ------- + conn : psycopg.Connection + A database connection object row and cursor factory settings. + """ + conn = await psycopg.AsyncConnection.connect( + conninfo=os.environ["DB_CONNECTION"], + autocommit=False, + row_factory=dict_row, + cursor_factory=psycopg.AsyncClientCursor, + ) + return conn + + +async def yield_cursor() -> psycopg.Cursor: + """ + Yield a PostgreSQL database cursor object to be used for dependency injection. + + Yields + ------ + cursor : psycopg.AsyncCursor + A database cursor object. + """ + # handle rollbacks from the context manager and close on exit + async with await get_connection() as conn: + async with conn.cursor() as cursor: + yield cursor diff --git a/src/database/signals.py b/src/database/signals.py new file mode 100644 index 0000000..4e1c591 --- /dev/null +++ b/src/database/signals.py @@ -0,0 +1,360 @@ +""" +CRUD operations for signal entities. +""" + +from psycopg import AsyncCursor, sql + +from .. import storage +from ..entities import Signal, SignalFilters, SignalPage, Status + +__all__ = [ + "search_signals", + "create_signal", + "read_signal", + "update_signal", + "delete_signal", + "read_user_signals", +] + + +async def search_signals(cursor: AsyncCursor, filters: SignalFilters) -> SignalPage: + """ + Search signals in the database using filters and pagination. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + filters : SignalFilters + Query filters for search, including pagination. + + Returns + ------- + page : SignalPage + Paginated search results for signals. + """ + query = """ + SELECT + *, COUNT(*) OVER() AS total_count + FROM + signals AS s + LEFT OUTER JOIN ( + SELECT + signal_id, array_agg(trend_id) AS connected_trends + FROM + connections + GROUP BY + signal_id + ) AS c + ON + s.id = c.signal_id + LEFT OUTER JOIN ( + SELECT + name AS unit_name, + region AS unit_region + FROM + units + ) AS u + ON + s.created_unit = u.unit_name + LEFT OUTER JOIN ( + SELECT + name AS location, + region AS location_region, + bureau AS location_bureau + FROM + locations + ) AS l + ON + s.location = l.location + WHERE + (%(ids)s IS NULL OR id = ANY(%(ids)s)) + AND status = ANY(%(statuses)s) + AND (%(created_by)s IS NULL OR created_by = %(created_by)s) + AND (%(created_for)s IS NULL OR created_for = %(created_for)s) + AND (%(steep_primary)s IS NULL OR steep_primary = %(steep_primary)s) + AND (%(steep_secondary)s IS NULL OR steep_secondary && %(steep_secondary)s) + AND (%(signature_primary)s IS NULL OR signature_primary = %(signature_primary)s) + AND (%(signature_secondary)s IS NULL OR signature_secondary && %(signature_secondary)s) + AND (%(location)s IS NULL OR (s.location = %(location)s) OR location_region = %(location)s) + AND (%(bureau)s IS NULL OR location_bureau = %(bureau)s) + AND (%(sdgs)s IS NULL OR %(sdgs)s && sdgs) + AND (%(score)s IS NULL OR score = %(score)s) + AND (%(unit)s IS NULL OR unit_region = %(unit)s OR unit_name = %(unit)s) + AND (%(query)s IS NULL OR text_search_field @@ websearch_to_tsquery('english', %(query)s)) + ORDER BY + {} {} + OFFSET + %(offset)s + LIMIT + %(limit)s + ; + """ + query = sql.SQL(query).format( + sql.Identifier(filters.order_by), + sql.SQL(filters.direction), + ) + await cursor.execute(query, filters.model_dump()) + rows = await cursor.fetchall() + # extract total count of rows matching the WHERE clause + page = SignalPage.from_search(rows, filters) + return page + + +async def create_signal(cursor: AsyncCursor, signal: Signal) -> int: + """ + Insert a signal into the database, connect it to trends and upload an attachment + to Azure Blob Storage if applicable. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + signal : Signal + A signal object to insert. + + Returns + ------- + signal_id : int + An ID of the signal in the database. + """ + query = """ + INSERT INTO signals ( + status, + created_by, + created_for, + modified_by, + headline, + description, + steep_primary, + steep_secondary, + signature_primary, + signature_secondary, + sdgs, + created_unit, + url, + relevance, + keywords, + location, + score + ) + VALUES ( + %(status)s, + %(created_by)s, + %(created_for)s, + %(modified_by)s, + %(headline)s, + %(description)s, + %(steep_primary)s, + %(steep_secondary)s, + %(signature_primary)s, + %(signature_secondary)s, + %(sdgs)s, + %(created_unit)s, + %(url)s, + %(relevance)s, + %(keywords)s, + %(location)s, + %(score)s + ) + RETURNING + id + ; + """ + await cursor.execute(query, signal.model_dump()) + row = await cursor.fetchone() + signal_id = row["id"] + + # add connected trends if any are present + for trend_id in signal.connected_trends or []: + query = "INSERT INTO connections (signal_id, trend_id, created_by) VALUES (%s, %s, %s);" + await cursor.execute(query, (signal_id, trend_id, signal.created_by)) + + # upload an image + if signal.attachment is not None: + try: + blob_url = await storage.upload_image( + signal_id, "signals", signal.attachment + ) + except Exception as e: + print(e) + else: + query = "UPDATE signals SET attachment = %s WHERE id = %s;" + await cursor.execute(query, (blob_url, signal_id)) + return signal_id + + +async def read_signal(cursor: AsyncCursor, uid: int) -> Signal | None: + """ + Read a signal from the database using an ID. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + uid : int + An ID of the signal to retrieve data for. + + Returns + ------- + Signal | None + A signal if it exits, otherwise None. + """ + query = """ + SELECT + * + FROM + signals AS s + LEFT OUTER JOIN ( + SELECT + signal_id, array_agg(trend_id) AS connected_trends + FROM + connections + GROUP BY + signal_id + ) AS c + ON + s.id = c.signal_id + WHERE + id = %s + ; + """ + await cursor.execute(query, (uid,)) + if (row := await cursor.fetchone()) is None: + return None + return Signal(**row) + + +async def update_signal(cursor: AsyncCursor, signal: Signal) -> int | None: + """ + Update a signal in the database, update its connected trends and update an attachment + in the Azure Blob Storage if applicable. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + signal : Signal + A signal object to update. + + Returns + ------- + int | None + A signal ID if the update has been performed, otherwise None. + """ + query = """ + UPDATE + signals + SET + status = COALESCE(%(status)s, status), + created_for = COALESCE(%(created_for)s, created_for), + modified_at = NOW(), + modified_by = %(modified_by)s, + headline = COALESCE(%(headline)s, headline), + description = COALESCE(%(description)s, description), + steep_primary = COALESCE(%(steep_primary)s, steep_primary), + steep_secondary = COALESCE(%(steep_secondary)s, steep_secondary), + signature_primary = COALESCE(%(signature_primary)s, signature_primary), + signature_secondary = COALESCE(%(signature_secondary)s, signature_secondary), + sdgs = COALESCE(%(sdgs)s, sdgs), + created_unit = COALESCE(%(created_unit)s, created_unit), + url = COALESCE(%(url)s, url), + relevance = COALESCE(%(relevance)s, relevance), + keywords = COALESCE(%(keywords)s, keywords), + location = COALESCE(%(location)s, location), + score = COALESCE(%(score)s, score) + WHERE + id = %(id)s + RETURNING + id + ; + """ + await cursor.execute(query, signal.model_dump()) + if (row := await cursor.fetchone()) is None: + return None + signal_id = row["id"] + + # update connected trends if any are present + await cursor.execute("DELETE FROM connections WHERE signal_id = %s;", (signal_id,)) + for trend_id in signal.connected_trends or []: + query = "INSERT INTO connections (signal_id, trend_id, created_by) VALUES (%s, %s, %s);" + await cursor.execute(query, (signal_id, trend_id, signal.created_by)) + + # upload an image if it is not a URL to an existing image + blob_url = await storage.update_image(signal_id, "signals", signal.attachment) + query = "UPDATE signals SET attachment = %s WHERE id = %s;" + await cursor.execute(query, (blob_url, signal_id)) + + return signal_id + + +async def delete_signal(cursor: AsyncCursor, uid: int) -> Signal | None: + """ + Delete a signal from the database and, if applicable, an image from + Azure Blob Storage, using an ID. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + uid : int + An ID of the signal to delete. + + Returns + ------- + Signal | None + A deleted signal object if the operation has been successful, otherwise None. + """ + query = "DELETE FROM signals WHERE id = %s RETURNING *;" + await cursor.execute(query, (uid,)) + if (row := await cursor.fetchone()) is None: + return None + signal = Signal(**row) + if signal.attachment is not None: + await storage.delete_image(entity_id=signal.id, folder_name="signals") + return signal + + +async def read_user_signals( + cursor: AsyncCursor, + user_email: str, + status: Status, +) -> list[Signal]: + """ + Read signals from the database using a user email and status filter. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + user_email : str + An email of the user whose signals to read. + status : Status + A status of signals to filter by. + + Returns + ------- + list[Signal] + A list of matching signals. + """ + query = """ + SELECT + * + FROM + signals AS s + LEFT OUTER JOIN ( + SELECT + signal_id, array_agg(trend_id) AS connected_trends + FROM + connections + GROUP BY + signal_id + ) AS c + ON + s.id = c.signal_id + WHERE + created_by = %s AND status = %s + ; + """ + await cursor.execute(query, (user_email, status)) + return [Signal(**row) async for row in cursor] diff --git a/src/database/trends.py b/src/database/trends.py new file mode 100644 index 0000000..1b914da --- /dev/null +++ b/src/database/trends.py @@ -0,0 +1,285 @@ +""" +CRUD operations for trend entities. +""" + +from psycopg import AsyncCursor, sql + +from .. import storage +from ..entities import Trend, TrendFilters, TrendPage + +__all__ = [ + "search_trends", + "create_trend", + "read_trend", + "update_trend", + "delete_trend", +] + + +async def search_trends(cursor: AsyncCursor, filters: TrendFilters) -> TrendPage: + """ + Search signals in the database using filters and pagination. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + filters : TrendFilters + Query filters for search, including pagination. + + Returns + ------- + page : TrendPage + Paginated search results for trends. + """ + query = """ + SELECT + *, COUNT(*) OVER() AS total_count + FROM + trends AS t + LEFT OUTER JOIN ( + SELECT + trend_id, array_agg(signal_id) AS connected_signals + FROM + connections + GROUP BY + trend_id + ) AS c + ON + t.id = c.trend_id + WHERE + (%(ids)s IS NULL OR id = ANY(%(ids)s)) + AND status = ANY(%(statuses)s) + AND (%(created_by)s IS NULL OR created_by = %(created_by)s) + AND (%(created_for)s IS NULL OR created_for = %(created_for)s) + AND (%(steep_primary)s IS NULL OR steep_primary = %(steep_primary)s) + AND (%(steep_secondary)s IS NULL OR steep_secondary && %(steep_secondary)s) + AND (%(signature_primary)s IS NULL OR signature_primary = %(signature_primary)s) + AND (%(signature_secondary)s IS NULL OR signature_secondary && %(signature_secondary)s) + AND (%(sdgs)s IS NULL OR sdgs && %(sdgs)s) + AND (%(assigned_to)s IS NULL OR assigned_to = %(assigned_to)s) + AND (%(time_horizon)s IS NULL OR time_horizon = %(time_horizon)s) + AND (%(impact_rating)s IS NULL OR impact_rating = %(impact_rating)s) + AND (%(query)s IS NULL OR text_search_field @@ websearch_to_tsquery('english', %(query)s)) + ORDER BY + {} {} + OFFSET + %(offset)s + LIMIT + %(limit)s + ; + """ + query = sql.SQL(query).format( + sql.Identifier(filters.order_by), + sql.SQL(filters.direction), + ) + await cursor.execute(query, filters.model_dump()) + rows = await cursor.fetchall() + page = TrendPage.from_search(rows, filters) + return page + + +async def create_trend(cursor: AsyncCursor, trend: Trend) -> int: + """ + Insert a trend into the database, connect it to signals and upload an attachment + to Azure Blob Storage if applicable. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + trend : Trend + A trend object to insert. + + Returns + ------- + trend_id : int + An ID of the trend in the database. + """ + query = """ + INSERT INTO trends ( + status, + created_by, + created_for, + modified_by, + headline, + description, + steep_primary, + steep_secondary, + signature_primary, + signature_secondary, + sdgs, + assigned_to, + time_horizon, + impact_rating, + impact_description + ) + VALUES ( + %(status)s, + %(created_by)s, + %(created_for)s, + %(modified_by)s, + %(headline)s, + %(description)s, + %(steep_primary)s, + %(steep_secondary)s, + %(signature_primary)s, + %(signature_secondary)s, + %(sdgs)s, + %(assigned_to)s, + %(time_horizon)s, + %(impact_rating)s, + %(impact_description)s + ) + RETURNING + id + ; + """ + await cursor.execute(query, trend.model_dump()) + row = await cursor.fetchone() + trend_id = row["id"] + + # add connected signals if any are present + for signal_id in trend.connected_signals or []: + query = "INSERT INTO connections (signal_id, trend_id, created_by) VALUES (%s, %s, %s);" + await cursor.execute(query, (signal_id, trend_id, trend.created_by)) + + # upload an image + if trend.attachment is not None: + try: + blob_url = await storage.upload_image(trend_id, "trends", trend.attachment) + except Exception as e: + print(e) + else: + query = "UPDATE trends SET attachment = %s WHERE id = %s;" + await cursor.execute(query, (blob_url, trend_id)) + return trend_id + + +async def read_trend(cursor: AsyncCursor, uid: int) -> Trend | None: + """ + Read a trend from the database using an ID. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + uid : int + An ID of the trend to retrieve data for. + + Returns + ------- + Trend | None + A trend if it exits, otherwise None. + """ + query = """ + SELECT + * + FROM + trends AS t + LEFT OUTER JOIN ( + SELECT + trend_id, array_agg(signal_id) AS connected_signals + FROM + connections + GROUP BY + trend_id + ) AS c + ON + t.id = c.trend_id + WHERE + id = %s + ; + """ + await cursor.execute(query, (uid,)) + if (row := await cursor.fetchone()) is None: + return None + return Trend(**row) + + +async def update_trend(cursor: AsyncCursor, trend: Trend) -> int | None: + """ + Update a trend in the database, update its connected signals and update an attachment + in the Azure Blob Storage if applicable. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + trend : Trend + A trend object to update. + + Returns + ------- + int | None + A trend ID if the update has been performed, otherwise None. + """ + query = """ + UPDATE + trends + SET + status = COALESCE(%(status)s, status), + created_for = COALESCE(%(created_for)s, created_for), + modified_at = NOW(), + modified_by = %(modified_by)s, + headline = COALESCE(%(headline)s, headline), + description = COALESCE(%(description)s, description), + steep_primary = COALESCE(%(steep_primary)s, steep_primary), + steep_secondary = COALESCE(%(steep_secondary)s, steep_secondary), + signature_primary = COALESCE(%(signature_primary)s, signature_primary), + signature_secondary = COALESCE(%(signature_secondary)s, signature_secondary), + sdgs = COALESCE(%(sdgs)s, sdgs), + assigned_to = COALESCE(%(assigned_to)s, assigned_to), + time_horizon = COALESCE(%(time_horizon)s, time_horizon), + impact_rating = COALESCE(%(impact_rating)s, impact_rating), + impact_description = COALESCE(%(impact_description)s, impact_description) + WHERE + id = %(id)s + RETURNING + id + ; + """ + await cursor.execute(query, trend.model_dump()) + if (row := await cursor.fetchone()) is None: + return None + trend_id = row["id"] + + # update connected signals if any are present + await cursor.execute("DELETE FROM connections WHERE trend_id = %s;", (trend_id,)) + for signal_id in trend.connected_signals or []: + query = "INSERT INTO connections (signal_id, trend_id, created_by) VALUES (%s, %s, %s);" + await cursor.execute(query, (signal_id, trend_id, trend.created_by)) + + # upload an image if it is not a URL to an existing image + blob_url = await storage.update_image(trend_id, "trends", trend.attachment) + query = "UPDATE trends SET attachment = %s WHERE id = %s;" + await cursor.execute(query, (blob_url, trend_id)) + + return trend_id + + +async def delete_trend(cursor: AsyncCursor, uid: int) -> Trend | None: + """ + Delete a trend from the database and, if applicable, an image from + Azure Blob Storage, using an ID. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + uid : int + An ID of the signal to delete. + + Returns + ------- + Trend | None + A deleted trend object if the operation has been successful, otherwise None. + """ + query = "DELETE FROM trends WHERE id = %s RETURNING *;" + await cursor.execute(query, (uid,)) + if (row := await cursor.fetchone()) is None: + return None + trend = Trend(**row) + if trend.attachment is not None: + await storage.delete_image(entity_id=trend.id, folder_name="trends") + return trend diff --git a/src/database/users.py b/src/database/users.py new file mode 100644 index 0000000..06ef088 --- /dev/null +++ b/src/database/users.py @@ -0,0 +1,198 @@ +""" +CRUD operations for user entities. +""" + +from psycopg import AsyncCursor + +from ..entities import User, UserFilters, UserPage + +__all__ = [ + "search_users", + "create_user", + "read_user_by_email", + "read_user", + "update_user", + "get_acclab_users", +] + + +async def search_users(cursor: AsyncCursor, filters: UserFilters) -> UserPage: + """ + Search users in the database using filters and pagination. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + filters : UserFilters + Query filters for search, including pagination. + + Returns + ------- + page : UserPage + Paginated search results for users. + """ + query = """ + SELECT + *, COUNT(*) OVER() AS total_count + FROM + users + WHERE + role = ANY(%(roles)s) + AND (%(query)s IS NULL OR name ~* %(query)s) + ORDER BY + name + OFFSET + %(offset)s + LIMIT + %(limit)s + ; + """ + await cursor.execute(query, filters.model_dump()) + rows = await cursor.fetchall() + page = UserPage.from_search(rows, filters) + return page + + +async def create_user(cursor: AsyncCursor, user: User) -> int: + """ + Insert a user into the database. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + user : User + A user object to insert. + + Returns + ------- + int + An ID of the user in the database. + """ + query = """ + INSERT INTO users ( + created_at, + email, + role, + name, + unit, + acclab + ) + VALUES ( + %(created_at)s, + %(email)s, + %(role)s, + %(name)s, + %(unit)s, + %(acclab)s + ) + RETURNING + id + ; + """ + await cursor.execute(query, user.model_dump()) + row = await cursor.fetchone() + return row["id"] + + +async def read_user_by_email(cursor: AsyncCursor, email: str) -> User | None: + """ + Read a user from the database using an email address. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + email : str + An email address. + + Returns + ------- + user : User + A user object if found, otherwise None. + """ + query = "SELECT * FROM users WHERE email = %s;" + await cursor.execute(query, (email,)) + if (row := await cursor.fetchone()) is None: + return None + user = User(**row) + return user + + +async def read_user(cursor: AsyncCursor, uid: int) -> User | None: + """ + Read a user from the database using an ID. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + uid : int + An ID of the user to retrieve data for. + + Returns + ------- + User | None + A user if it exits, otherwise None. + """ + query = "SELECT * FROM users WHERE id = %s;" + await cursor.execute(query, (uid,)) + if (row := await cursor.fetchone()) is None: + return None + return User(**row) + + +async def update_user(cursor: AsyncCursor, user: User) -> int | None: + """ + Update a user in the database. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + user : User + A user object to update. + + Returns + ------- + int | None + A user ID if the update has been performed, otherwise None. + """ + query = """ + UPDATE + users + SET + role = %(role)s, + name = %(name)s, + unit = %(unit)s, + acclab = %(acclab)s + WHERE + email = %(email)s + RETURNING + id + ; + """ + await cursor.execute(query, user.model_dump()) + if (row := await cursor.fetchone()) is None: + return None + return row["id"] + + +async def get_acclab_users(cursor: AsyncCursor) -> list[str]: + """ + Get emails of users who are part of the Accelerator Labs. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + + Returns + ------- + list[str] + A list of emails for users who are part of the Accelerator Labs. + """ + query = "SELECT email FROM users WHERE acclab = TRUE;" + await cursor.execute(query) + return [row["email"] async for row in cursor] diff --git a/src/dependencies.py b/src/dependencies.py new file mode 100644 index 0000000..df9f7d5 --- /dev/null +++ b/src/dependencies.py @@ -0,0 +1,62 @@ +""" +Functions used for dependency injection for role-based access control. +""" + +from typing import Annotated + +from fastapi import Depends, Path +from psycopg import AsyncCursor + +from . import database as db +from . import exceptions +from .authentication import authenticate_user +from .entities import User + +__all__ = [ + "require_admin", + "require_curator", + "require_user", + "require_creator", +] + + +async def require_admin(user: User = Depends(authenticate_user)) -> User: + """Require that the user is assigned an admin role.""" + if not user.is_admin: + raise exceptions.permission_denied + return user + + +async def require_curator(user: User = Depends(authenticate_user)) -> User: + """Require that the user is assigned at least a curator role.""" + if not user.is_staff: + raise exceptions.permission_denied + return user + + +async def require_user(user: User = Depends(authenticate_user)) -> User: + """Require that the user is assigned at least a user role and is not a visitor.""" + if not user.is_regular: + raise exceptions.permission_denied + return user + + +async def require_creator( + uid: Annotated[int, Path(description="The ID of the signal to be updated")], + user: User = Depends(authenticate_user), + cursor: AsyncCursor = Depends(db.yield_cursor), +) -> User: + """Require that the user is at least a curator or is the creator of the signal.""" + # admins and curators can modify all signals + if user.is_staff: + return user + # check if the user created the original signal + signal = await db.read_signal(cursor, uid) + if signal is None: + raise exceptions.not_found + if signal.created_by != user.email: + raise exceptions.permission_denied + # regular users can modify their signals but cannot change their statuses + if signal.status != signal.status: + raise exceptions.permission_denied + return user diff --git a/src/entities/__init__.py b/src/entities/__init__.py new file mode 100644 index 0000000..7c421d8 --- /dev/null +++ b/src/entities/__init__.py @@ -0,0 +1,117 @@ +""" +Entities (models) for receiving, managing and sending data. +""" + +from math import ceil +from typing import Self + +from pydantic import BaseModel, Field + +from .parameters import * +from .signal import * +from .trend import * +from .user import * +from .utils import * + + +class Page(BaseModel): + """ + A paginated results model, holding pagination metadata and search results. + """ + + per_page: int = Field(description="Number of entities per page to retrieve.") + current_page: int = Field( + description="Current page for which entities should be retrieved." + ) + total_pages: int = Field(description="Total number of pages for pagination.") + total_count: int = Field(description="Total number of entities in the database.") + data: list[Signal] | list[Trend] | list[User] + + @classmethod + def from_search(cls, rows: list[dict], pagination: Pagination) -> Self: + """ + Create paginated results model from search results. + + Parameters + ---------- + rows : list[dict] + A list of results returned from the database. + pagination : Pagination + The pagination object used to retrieve the results. + + Returns + ------- + page : Page + The paginated results model. + """ + total = rows[0]["total_count"] if rows else 0 + page = cls( + per_page=pagination.limit, + current_page=pagination.page, + total_pages=ceil(total / pagination.limit), + total_count=total, + data=rows, + ) + return page + + def sanitise(self, user: User) -> Self: + """ + Remove items from the list the user has no permissions to access. + + Parameters + ---------- + user : User + A user making the request. + + Returns + ------- + self : Self + The paginated results model with sanitised data. + """ + if user.role == Role.ADMIN: + # all signals/trends are shown + pass + elif user.role == Role.CURATOR: + # all signals/trends are shown except other users' drafts + self.data = [ + entity + for entity in self.data + if not ( + entity.status == Status.DRAFT and entity.created_by != user.email + ) + ] + return self + elif user.role == Role.USER: + # only approved signals/trends are shown or signals/trends authored by the user + self.data = [ + entity + for entity in self.data + if entity.status == Status.APPROVED or entity.created_by == user.email + ] + else: + # only approved signals/trends after anonymisation are shown + self.data = [ + entity.anonymise() + for entity in self.data + if entity.status == Status.APPROVED + ] + return self + return self + + +class UserPage(Page): + """A specialised paginated results model for users.""" + + data: list[User] + + +class SignalPage(Page): + """A specialised paginated results model for signals.""" + + data: list[Signal] + + +class TrendPage(Page): + """A specialised paginated results model for trends.""" + + data: list[Trend] diff --git a/src/entities/base.py b/src/entities/base.py new file mode 100644 index 0000000..4ddc157 --- /dev/null +++ b/src/entities/base.py @@ -0,0 +1,120 @@ +""" +Entity (model) definitions for base objects that others inherit from. +""" + +from datetime import UTC, datetime + +from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator + +from . import utils + +__all__ = ["BaseMetadata", "BaseEntity", "timestamp"] + + +def timestamp() -> str: + """ + Get the current timestamp in the ISO format. + + Returns + ------- + str + A timestamp in the ISO format. + """ + return datetime.now(UTC).isoformat() + + +class BaseMetadata(BaseModel): + """Base metadata for database objects.""" + + id: int = Field(default=1) + created_at: str = Field(default_factory=timestamp) + + @field_validator("created_at", mode="before") + @classmethod + def format_created_at(cls, value): + """(De)serialisation function for `created_at` timestamp.""" + if isinstance(value, str): + return value + return value.isoformat() + + +class BaseEntity(BaseMetadata): + """Base entity for signals and trends.""" + + status: utils.Status = Field( + default=utils.Status.NEW, + description="Current signal review status.", + ) + created_by: EmailStr | None = Field(default=None) + created_for: str | None = Field(default=None) + modified_at: str = Field(default_factory=timestamp) + modified_by: EmailStr | None = Field(default=None) + headline: str | None = Field( + default=None, + description="A clear and concise title headline.", + ) + description: str | None = Field( + default=None, + description="A clear and concise description.", + ) + attachment: str | None = Field( + default=None, + description="An optional base64-encoded image/URL for illustration.", + ) + steep_primary: utils.Steep | None = Field( + default=None, + description="Primary category in terms of STEEP+V analysis methodology.", + ) + steep_secondary: list[utils.Steep] | None = Field( + default=None, + description="Secondary categories in terms of STEEP+V analysis methodology.", + ) + signature_primary: utils.Signature | None = Field( + default=None, + description="Primary category in terms of UNDP Signature Solutions methodology.", + ) + signature_secondary: list[utils.Signature] | None = Field( + default=None, + description="Secondary categories in terms of UNDP Signature Solutions methodology.", + ) + sdgs: list[utils.Goal] | None = Field( + default=None, + description="Relevant Sustainable Development Goals.", + ) + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "headline": "The cost of corruption", + "description": "Corruption is one of the scourges of modern life. Its costs are staggering.", + "steep_primary": utils.Steep.ECONOMIC, + "steep_secondary": [utils.Steep.SOCIAL], + "signature_primary": utils.Signature.GOVERNANCE, + "signature_secondary": [ + utils.Signature.POVERTY, + utils.Signature.RESILIENCE, + ], + "sdgs": [utils.Goal.G16, utils.Goal.G17], + } + } + ) + + @field_validator("modified_at", mode="before") + @classmethod + def format_modified_at(cls, value): + """(De)serialisation function for `modified_at` timestamp.""" + if isinstance(value, str): + return value + return value.isoformat() + + def anonymise(self) -> "BaseEntity": + """ + Anonymise an entity by removing personal information, such as user emails. + + This function is to be used to retrieve information for visitors, i.e., + not authenticated users, to preserve privacy of other users. + """ + email_mask = "email.hidden@undp.org" + self.created_by = email_mask + self.modified_by = email_mask + return self diff --git a/src/entities/parameters.py b/src/entities/parameters.py new file mode 100644 index 0000000..d019efb --- /dev/null +++ b/src/entities/parameters.py @@ -0,0 +1,86 @@ +""" +Dataclasses used to define query parameters in API endpoints. +""" + +from typing import Literal + +from pydantic import BaseModel, Field, computed_field + +from .utils import Goal, Horizon, Rating, Role, Score, Signature, Status, Steep + +__all__ = [ + "Pagination", + "SignalFilters", + "TrendFilters", + "UserFilters", +] + + +class Pagination(BaseModel): + """A container class for pagination parameters.""" + + page: int = Field(default=1, gt=0) + per_page: int = Field(default=10, gt=0, le=10_000) + order_by: str = Field(default="created_at") + direction: Literal["desc", "asc"] = Field( + default="desc", + description="Ascending or descending order.", + ) + + @computed_field + def limit(self) -> int: + """ + An alias for `per_page` can be dropped in favour of + limit: int = Field(default=10, gt=0, le=10_000, alias="per_page") + once https://github.com/fastapi/fastapi/discussions/12401 is resolved. + + Returns + ------- + int + A value of `per_page`. + """ + return self.per_page + + @computed_field + def offset(self) -> int: + """An offset value that can be used in a database query.""" + return self.limit * (self.page - 1) + + +class BaseFilters(Pagination): + """Base filtering parameters shared by signal and trend filters.""" + + ids: list[int] | None = Field(default=None) + statuses: list[Status] = Field(default=(Status.APPROVED,)) + created_by: str | None = Field(default=None) + created_for: str | None = Field(default=None) + steep_primary: Steep | None = Field(default=None) + steep_secondary: list[Steep] | None = Field(default=None) + signature_primary: Signature | None = Field(default=None) + signature_secondary: list[Signature] | None = Field(default=None) + sdgs: list[Goal] | None = Field(default=None) + query: str | None = Field(default=None) + + +class SignalFilters(BaseFilters): + """Filter parameters for searching signals.""" + + location: str | None = Field(default=None) + bureau: str | None = Field(default=None) + score: Score | None = Field(default=None) + unit: str | None = Field(default=None) + + +class TrendFilters(BaseFilters): + """Filter parameters for searching trends.""" + + assigned_to: str | None = Field(default=None) + time_horizon: Horizon | None = Field(default=None) + impact_rating: Rating | None = Field(default=None) + + +class UserFilters(Pagination): + """Filter parameters for searching users.""" + + roles: list[Role] = Field(default=(Role.VISITOR, Role.CURATOR, Role.ADMIN)) + query: str | None = Field(default=None) diff --git a/src/entities/signal.py b/src/entities/signal.py new file mode 100644 index 0000000..5112e46 --- /dev/null +++ b/src/entities/signal.py @@ -0,0 +1,43 @@ +""" +Entity (model) definitions for signal objects. +""" + +from pydantic import ConfigDict, Field + +from . import utils +from .base import BaseEntity + +__all__ = ["Signal"] + + +class Signal(BaseEntity): + """The signal entity model used in the database and API endpoints.""" + + created_unit: str | None = Field(default=None) + url: str | None = Field(default=None) + relevance: str | None = Field(default=None) + keywords: list[str] | None = Field( + default=None, + description="Use up to 3 clear, simple keywords for ease of searchability.", + ) + location: str | None = Field( + default=None, + description="Region and/or country for which this signal has greatest relevance.", + ) + score: utils.Score | None = Field(default=None) + connected_trends: list[int] | None = Field( + default=None, + description="IDs of trends connected to this signal.", + ) + + model_config = ConfigDict( + json_schema_extra={ + "example": BaseEntity.model_config["json_schema_extra"]["example"] + | { + "url": "https://undp.medium.com/the-cost-of-corruption-a827306696fb", + "relevance": "Of the approximately US$13 trillion that governments spend on public spending, up to 25 percent is lost to corruption.", + "keywords": ["economy", "governance"], + "location": "Global", + } + } + ) diff --git a/src/entities/trend.py b/src/entities/trend.py new file mode 100644 index 0000000..88b6ae8 --- /dev/null +++ b/src/entities/trend.py @@ -0,0 +1,44 @@ +""" +Entity (model) definitions for trend objects. +""" + +from pydantic import ConfigDict, Field, field_validator + +from . import utils +from .base import BaseEntity + +__all__ = ["Trend"] + + +class Trend(BaseEntity): + """The trend entity model used in the database and API endpoints.""" + + assigned_to: str | None = Field(default=None) + time_horizon: utils.Horizon | None = Field(default=None) + impact_rating: utils.Rating | None = Field(default=None) + impact_description: str | None = Field(default=None) + connected_signals: list[int] | None = Field( + default=None, + description="IDs of signals connected to this trend.", + ) + + model_config = ConfigDict( + json_schema_extra={ + "example": BaseEntity.model_config["json_schema_extra"]["example"] + | { + "time_horizon": utils.Horizon.MEDIUM, + "impact_rating": utils.Rating.HIGH, + "impact_description": "Anywhere between 1.4 and 35 percent of climate action funds may be lost due to corruption.", + } + } + ) + + @field_validator("impact_rating", mode="before") + @classmethod + def from_string(cls, value): + """ + Coerce an integer string for `impact_rating` to a string + """ + if value is None or isinstance(value, str): + return value + return str(value) diff --git a/src/entities/user.py b/src/entities/user.py new file mode 100644 index 0000000..8cde39a --- /dev/null +++ b/src/entities/user.py @@ -0,0 +1,55 @@ +""" +Entity (model) definitions for user objects. +""" + +from pydantic import ConfigDict, EmailStr, Field + +from .base import BaseMetadata +from .utils import Role + +__all__ = ["User"] + + +class User(BaseMetadata): + """The user entity model used in the database and API endpoints.""" + + email: EmailStr = Field(description="Work email ending with @undp.org.") + role: Role = Field(default=Role.VISITOR) + name: str | None = Field(default=None, min_length=5, description="Full name.") + unit: str | None = Field( + default=None, + min_length=2, + description="UNDP unit name from a predefined list.", + ) + acclab: bool | None = Field( + default=None, + description="Whether or not a user is part of the Accelerator Labs.", + ) + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "id": 1, + "email": "john.doe@undp.org", + "role": "Curator", + "name": "John Doe", + "unit": "BPPS", + "acclab": True, + } + } + ) + + @property + def is_admin(self): + """Check if the user is an admin.""" + return self.role == Role.ADMIN + + @property + def is_staff(self): + """Check if the user is a curator or admin.""" + return self.role in {Role.ADMIN, Role.CURATOR} + + @property + def is_regular(self): + """Check if the user is a regular user, not a visitor using API key.""" + return self.role in {Role.ADMIN, Role.CURATOR, Role.USER} diff --git a/src/entities/utils.py b/src/entities/utils.py new file mode 100644 index 0000000..b475897 --- /dev/null +++ b/src/entities/utils.py @@ -0,0 +1,123 @@ +""" +Utility classes that define valid string options. +""" + +from enum import StrEnum + +__all__ = [ + "Role", + "Status", + "Steep", + "Signature", + "Goal", + "Score", + "Horizon", + "Rating", + "Bureau", +] + + +class Role(StrEnum): + """ + User roles for RBAC. Admins, curators and users are actual users logged in + to the platform who authenticate via JWT. Visitor role is assigned to + a dummy user authenticated with an API key. + """ + + ADMIN = "Admin" # curator + can change the roles of other users + CURATOR = "Curator" # user + can edit and approve signals and trends + USER = "User" # visitor + can submit signals + VISITOR = "Visitor" # can only view signals and trends + + +class Status(StrEnum): + """Signal/trend review statuses.""" + + DRAFT = "Draft" + NEW = "New" + APPROVED = "Approved" + ARCHIVED = "Archived" + + +class Steep(StrEnum): + """Categories in terms of Steep-V methodology.""" + + SOCIAL = "Social – Issues related to human culture, demography, communication, movement and migration, work and education" + TECHNOLOGICAL = "Technological – Made culture, tools, devices, systems, infrastructure and networks" + ECONOMIC = "Economic – Issues of value, money, financial tools and systems, business and business models, exchanges and transactions" + ENVIRONMENTAL = "Environmental – The natural world, living environment, sustainability, resources, climate and health" + POLITICAL = "Political – Legal issues, policy, governance, rules and regulations and organizational systems" + VALUES = "Values – Ethics, spirituality, ideology or other forms of values" + + +class Signature(StrEnum): + """The six Signature Solutions of the United Nations Development Programme.""" + + POVERTY = "Poverty and Inequality" + GOVERNANCE = "Governance" + RESILIENCE = "Resilience" + ENVIRONMENT = "Environment" + ENERGY = "Energy" + GENDER = "Gender Equality" + # 3 enables + INNOVATION = "Strategic Innovation" + DIGITALISATION = "Digitalisation" + FINANCING = "Development Financing" + + +class Goal(StrEnum): + """The 17 United Nations Sustainable Development Goals.""" + + G1 = "GOAL 1: No Poverty" + G2 = "GOAL 2: Zero Hunger" + G3 = "GOAL 3: Good Health and Well-being" + G4 = "GOAL 4: Quality Education" + G5 = "GOAL 5: Gender Equality" + G6 = "GOAL 6: Clean Water and Sanitation" + G7 = "GOAL 7: Affordable and Clean Energy" + G8 = "GOAL 8: Decent Work and Economic Growth" + G9 = "GOAL 9: Industry, Innovation and Infrastructure" + G10 = "GOAL 10: Reduced Inequality" + G11 = "GOAL 11: Sustainable Cities and Communities" + G12 = "GOAL 12: Responsible Consumption and Production" + G13 = "GOAL 13: Climate Action" + G14 = "GOAL 14: Life Below Water" + G15 = "GOAL 15: Life on Land" + G16 = "GOAL 16: Peace and Justice Strong Institutions" + G17 = "GOAL 17: Partnerships to achieve the Goal" + + +class Score(StrEnum): + """Signal novelty scores.""" + + ONE = "1 — Non-novel (known, but potentially notable in particular context)" + TWO = "2" + THREE = "3 — Potentially novel or uncertain, but not clear in its potential impact" + FOUR = "4" + FIVE = "5 — Something that introduces or points to a potentially interesting or consequential change in direction of trends" + + +class Horizon(StrEnum): + """Trend impact horizons.""" + + SHORT = "Horizon 1 (0-3 years)" + MEDIUM = "Horizon 2 (3-7 years)" + LONG = "Horizon 3 (7-10 years)" + + +class Rating(StrEnum): + """Trend impact rating.""" + + LOW = "1 – Low" + MODERATE = "2 – Moderate" + HIGH = "3 – Significant" + + +class Bureau(StrEnum): + """Bureaus of the United Nations Development Programme.""" + + RBA = "RBA" + RBAP = "RBAP" + RBAS = "RBAS" + RBEC = "RBEC" + RBLAC = "RBLAC" diff --git a/src/exceptions.py b/src/exceptions.py new file mode 100644 index 0000000..52b11bc --- /dev/null +++ b/src/exceptions.py @@ -0,0 +1,44 @@ +""" +Exceptions raised by API endpoints. +""" + +from fastapi import HTTPException, status + +__all__ = [ + "id_mismatch", + "not_authenticated", + "permission_denied", + "not_found", + "content_error", + "generation_error", +] + +id_mismatch = HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Resource ID in body does not match path ID.", +) + +not_authenticated = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated.", +) + +permission_denied = HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have permissions to perform this action.", +) + +not_found = HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="The requested resource could not be found.", +) + +content_error = HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="The content from the URL could not be fetched.", +) + +generation_error = HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="A signal could not be generated from the content.", +) diff --git a/src/genai.py b/src/genai.py new file mode 100644 index 0000000..6834c57 --- /dev/null +++ b/src/genai.py @@ -0,0 +1,106 @@ +""" +Functions for generating signals from web content using Azure OpenAI. +""" + +import json +import os + +from openai import AsyncAzureOpenAI + +from .entities import Signal + +__all__ = ["get_system_message", "get_client", "generate_signal"] + + +def get_system_message() -> str: + """ + Get a system message for generating a signal from web content. + + Returns + ------- + system_message : str + A system message that can be used to generate a signal. + """ + schema = Signal.model_json_schema() + schema.pop("example", None) + # generate content only for the following fields + fields = { + "headline", + "description", + "steep_primary", + "steep_secondary", + "signature_primary", + "signature_secondary", + "sdgs", + "keywords", + } + schema["properties"] = { + k: v for k, v in schema["properties"].items() if k in fields + } + system_message = f""" + You are a Signal Scanner within the Strategy & Futures Team at the United Nations Development Programme. + Your task is to generate a Signal from web content provided by the user. A Signal is defined as a single + piece of evidence or indicator that points to, relates to, or otherwise supports a trend. + It can also stand alone as a potential indicator of future change in one or more trends. + + ### Rules + 1. Your output must be a valid JSON string object without any markdown that can be directly passed to `json.loads`. + 2. The JSON string must conform to the schema below. + 3. The response must be in English, so translate content if necessary. + 4. For `headline` and `description`, do not just copy-paste text, instead summarize the information + in a concise yet insightful manner. + + ### Signal Schema + + ```json + {json.dumps(schema, indent=2)} + ``` + """.strip() + return system_message + + +def get_client() -> AsyncAzureOpenAI: + """ + Get an asynchronous Azure OpenAI client. + + Returns + ------- + client : AsyncAzureOpenAI + An asynchronous client for Azure OpenAI. + """ + client = AsyncAzureOpenAI( + azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + api_version="2024-02-15-preview", + api_key=os.getenv("AZURE_OPENAI_API_KEY"), + timeout=10, + ) + return client + + +async def generate_signal(text: str) -> Signal: + """ + Generate a signal from a text. + + Parameters + ---------- + text : str + A text, i.e., web content, to be analysed. + + Returns + ------- + signal : Signal + A signal entity generated from the text. + """ + client = get_client() + system_message = get_system_message() + response = await client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + {"role": "system", "content": system_message}, + {"role": "user", "content": text}, + ], + temperature=0.3, # vary the output to alleviate occasional errors + ) + content = response.choices[0].message.content + signal = Signal(**json.loads(content)) + return signal diff --git a/src/routers/__init__.py b/src/routers/__init__.py new file mode 100644 index 0000000..cf23153 --- /dev/null +++ b/src/routers/__init__.py @@ -0,0 +1,15 @@ +""" +Router definitions for API endpoints. +""" + +from .choices import router as choice_router +from .signals import router as signal_router +from .trends import router as trend_router +from .users import router as user_router + +ALL = [ + choice_router, + signal_router, + trend_router, + user_router, +] diff --git a/src/routers/choices.py b/src/routers/choices.py new file mode 100644 index 0000000..4d5f395 --- /dev/null +++ b/src/routers/choices.py @@ -0,0 +1,59 @@ +""" +A router for obtaining valid choice options. +""" + +from fastapi import APIRouter, Depends +from psycopg import AsyncCursor + +from .. import database as db +from .. import exceptions +from ..authentication import authenticate_user +from ..entities import utils + +router = APIRouter(prefix="/choices", tags=["choices"]) + + +CREATED_FOR = [ + "General scanning", + "Global Signals Spotlight 2024", + "Global Signals Spotlight 2023", + "HDR 2023", + "Sustainable Finance Hub 2023", +] + + +@router.get("", response_model=dict, dependencies=[Depends(authenticate_user)]) +async def read_choices(cursor: AsyncCursor = Depends(db.yield_cursor)): + """ + List valid options for all fields. + """ + choices = { + name.lower(): [member.value for member in getattr(utils, name)] + for name in utils.__all__ + } + choices["created_for"] = CREATED_FOR + choices["unit_name"] = await db.get_unit_names(cursor) + choices["unit_region"] = await db.get_unit_regions(cursor) + choices["location"] = await db.get_location_names(cursor) + return choices + + +@router.get( + "/{name}", response_model=list[str], dependencies=[Depends(authenticate_user)] +) +async def read_field_choices(name: str, cursor: AsyncCursor = Depends(db.yield_cursor)): + """ + List valid options for a given field. + """ + match name: + case "unit_name": + choices = await db.get_unit_names(cursor) + case "unit_region": + choices = await db.get_unit_regions(cursor) + case "location": + choices = await db.get_location_names(cursor) + case name if name.capitalize() in utils.__all__: + choices = [member.value for member in getattr(utils, name.capitalize())] + case _: + raise exceptions.not_found + return choices diff --git a/src/routers/signals.py b/src/routers/signals.py new file mode 100644 index 0000000..bfb5057 --- /dev/null +++ b/src/routers/signals.py @@ -0,0 +1,153 @@ +""" +A router for retrieving, submitting and updating signals. +""" + +from typing import Annotated + +import pandas as pd +from fastapi import APIRouter, Depends, Path, Query +from psycopg import AsyncCursor + +from .. import database as db +from .. import exceptions, genai, utils +from ..authentication import authenticate_user +from ..dependencies import require_creator, require_curator, require_user +from ..entities import Role, Signal, SignalFilters, SignalPage, Status, User + +router = APIRouter(prefix="/signals", tags=["signals"]) + + +@router.get("/search", response_model=SignalPage) +async def search_signals( + filters: Annotated[SignalFilters, Query()], + user: User = Depends(authenticate_user), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """Search signals in the database using pagination and filters.""" + page = await db.search_signals(cursor, filters) + return page.sanitise(user) + + +@router.get("/export", response_model=None, dependencies=[Depends(require_curator)]) +async def export_signals( + filters: Annotated[SignalFilters, Query()], + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Export signals that match the filters from the database. You can export up to + 10k rows at once. + """ + page = await db.search_signals(cursor, filters) + + # prettify the data + df = pd.DataFrame([signal.model_dump() for signal in page.data]) + df = utils.binarise_columns(df, ["steep_secondary", "signature_secondary", "sdgs"]) + df["keywords"] = df["keywords"].str.join(" ;") + df["connected_trends"] = df["connected_trends"].str.join("; ") + + # add acclab indicator variable + emails = await db.users.get_acclab_users(cursor) + df["acclab"] = df["created_by"].isin(emails) + + response = utils.write_to_response(df, "signals") + return response + + +@router.get("/generation", response_model=Signal) +async def generate_signal( + url: str = Query( + description="A public webpage URL whose content will be used to generate a signal." + ), + user: User = Depends(require_user), +): + """Generate a signal from web content using OpenAI.""" + try: + content = await utils.scrape_content(url) + except Exception as e: + print(e) + raise exceptions.content_error + try: + signal = await genai.generate_signal(content) + except Exception as e: + print(e) + raise exceptions.generation_error + signal.created_by = user.email + signal.created_unit = user.unit + signal.url = url + return signal + + +@router.post("", response_model=Signal, status_code=201) +async def create_signal( + signal: Signal, + user: User = Depends(require_user), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Submit a signal to the database. If the signal has a base64 encoded image + attachment, it will be uploaded to Azure Blob Storage. + """ + signal.created_by = user.email + signal.modified_by = user.email + signal.created_unit = user.unit + signal_id = await db.create_signal(cursor, signal) + return await db.read_signal(cursor, signal_id) + + +@router.get("/me", response_model=list[Signal]) +async def read_my_signals( + status: Status = Query(), + user: User = Depends(authenticate_user), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Retrieve signal with a given status submitted by the current user. + """ + return await db.read_user_signals(cursor, user.email, status) + + +@router.get("/{uid}", response_model=Signal) +async def read_signal( + uid: Annotated[int, Path(description="The ID of the signal to retrieve")], + user: User = Depends(authenticate_user), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Retrieve a signal form the database using an ID. Trends connected to the signal + can be retrieved using IDs from the `signal.connected_trends` field. + """ + if (signal := await db.read_signal(cursor, uid)) is None: + raise exceptions.not_found + if user.role == Role.VISITOR and signal.status != Status.APPROVED: + raise exceptions.permission_denied + return signal + + +@router.put("/{uid}", response_model=Signal) +async def update_signal( + uid: Annotated[int, Path(description="The ID of the signal to be updated")], + signal: Signal, + user: User = Depends(require_creator), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """Update a signal in the database.""" + if uid != signal.id: + raise exceptions.id_mismatch + signal.modified_by = user.email + if (signal_id := await db.update_signal(cursor, signal)) is None: + raise exceptions.not_found + return await db.read_signal(cursor, signal_id) + + +@router.delete("/{uid}", response_model=Signal, dependencies=[Depends(require_creator)]) +async def delete_signal( + uid: Annotated[int, Path(description="The ID of the signal to be deleted")], + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Delete a signal from the database using IDs. This also deletes an image attachment from + Azure Blob Storage if there is one. + """ + if (signal := await db.delete_signal(cursor, uid)) is None: + raise exceptions.not_found + return signal diff --git a/src/routers/trends.py b/src/routers/trends.py new file mode 100644 index 0000000..d85bda2 --- /dev/null +++ b/src/routers/trends.py @@ -0,0 +1,112 @@ +""" +A router for retrieving, submitting and updating trends. +""" + +from typing import Annotated + +import pandas as pd +from fastapi import APIRouter, Depends, Path, Query, status +from psycopg import AsyncCursor + +from .. import database as db +from .. import exceptions, utils +from ..authentication import authenticate_user +from ..dependencies import require_curator +from ..entities import Role, Status, Trend, TrendFilters, TrendPage, User + +router = APIRouter(prefix="/trends", tags=["trends"]) + + +@router.get("/search", response_model=TrendPage) +async def search_trends( + filters: Annotated[TrendFilters, Query()], + user: User = Depends(authenticate_user), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """Search trends in the database using pagination and filters.""" + page = await db.search_trends(cursor, filters) + return page.sanitise(user) + + +@router.get("/export", response_model=None, dependencies=[Depends(require_curator)]) +async def export_trends( + filters: Annotated[TrendFilters, Query()], + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Export trends that match the filters from the database. You can export up to + 10k rows at once. + """ + page = await db.search_trends(cursor, filters) + + # prettify the data + df = pd.DataFrame([trend.model_dump() for trend in page.data]) + df = utils.binarise_columns(df, ["steep_secondary", "signature_secondary", "sdgs"]) + df["connected_signals_count"] = df["connected_signals"].str.len() + df.drop("connected_signals", axis=1, inplace=True) + + response = utils.write_to_response(df, "trends") + return response + + +@router.post("", response_model=Trend, status_code=status.HTTP_201_CREATED) +async def create_trend( + trend: Trend, + user: User = Depends(require_curator), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Submit a trend to the database. If the trend has a base64 encoded image + attachment, it will be uploaded to Azure Blob Storage. + """ + trend.created_by = user.email + trend.modified_by = user.email + trend_id = await db.create_trend(cursor, trend) + return await db.read_trend(cursor, trend_id) + + +@router.get("/{uid}", response_model=Trend) +async def read_trend( + uid: Annotated[int, Path(description="The ID of the trend to retrieve")], + user: User = Depends(authenticate_user), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Retrieve a trend form the database using an ID. Signals connected to the trend + can be retrieved using IDs from the `trend.connected_signals` field. + """ + if (trend := await db.read_trend(cursor, uid)) is None: + raise exceptions.not_found + if user.role == Role.VISITOR and trend.status != Status.APPROVED: + raise exceptions.permission_denied + return trend + + +@router.put("/{uid}", response_model=Trend) +async def update_trend( + uid: Annotated[int, Path(description="The ID of the trend to be updated")], + trend: Trend, + user: User = Depends(require_curator), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """Update a trend in the database.""" + if uid != trend.id: + raise exceptions.id_mismatch + trend.modified_by = user.email + if (trend_id := await db.update_trend(cursor, trend=trend)) is None: + raise exceptions.not_found + return await db.read_trend(cursor, trend_id) + + +@router.delete("/{uid}", response_model=Trend, dependencies=[Depends(require_curator)]) +async def delete_trend( + uid: Annotated[int, Path(description="The ID of the trend to be deleted")], + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Delete a trend from the database using IDs. This also deletes an image attachment from + Azure Blob Storage if there is one. + """ + if (trend := await db.delete_trend(cursor, uid)) is None: + raise exceptions.not_found + return trend diff --git a/src/routers/users.py b/src/routers/users.py new file mode 100644 index 0000000..e244458 --- /dev/null +++ b/src/routers/users.py @@ -0,0 +1,70 @@ +""" +A router for creating, reading and updating trends. +""" + +from typing import Annotated + +from fastapi import APIRouter, Depends, Path, Query +from psycopg import AsyncCursor + +from .. import database as db +from .. import exceptions +from ..authentication import authenticate_user +from ..dependencies import require_admin, require_user +from ..entities import Role, User, UserFilters, UserPage + +router = APIRouter(prefix="/users", tags=["users"]) + + +@router.get("/search", response_model=UserPage, dependencies=[Depends(require_admin)]) +async def search_users( + filters: Annotated[UserFilters, Query()], + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """Search users in the database using pagination and filters.""" + page = await db.search_users(cursor, filters) + return page + + +@router.get("/me", response_model=User) +async def read_current_user(user: User = Depends(authenticate_user)): + """Read the current user information from a JTW token.""" + if user is None: + raise exceptions.not_found + return user + + +@router.get("/{uid}", response_model=User, dependencies=[Depends(require_admin)]) +async def read_user( + uid: Annotated[int, Path(description="The ID of the user to retrieve")], + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """Read users form the database using IDs.""" + if (users := await db.read_user(cursor, uid)) is None: + raise exceptions.not_found + return users + + +@router.put("/{uid}", response_model=User) +async def update_user( + uid: Annotated[int, Path(description="The ID of the user to be updated")], + user_new: User, + user: User = Depends(require_user), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Update a user in the database. Non-admin users can only update + their own name, unit and accelerator lab flag. Only admin users can + update other users' roles. + """ + if uid != user_new.id: + raise exceptions.id_mismatch + if user.role == Role.ADMIN: + pass + elif user.email != user_new.email or user.id != user_new.id: + raise exceptions.permission_denied + elif user.role != user_new.role: + raise exceptions.permission_denied + if (user_id := await db.update_user(cursor, user_new)) is None: + raise exceptions.not_found + return await db.read_user(cursor, user_id) diff --git a/src/storage.py b/src/storage.py new file mode 100644 index 0000000..b27ca5c --- /dev/null +++ b/src/storage.py @@ -0,0 +1,162 @@ +""" +Utilities for interacting with Azure Blob Storage for uploading and deleting image attachments. +""" + +import os +from typing import Literal +from urllib.parse import urlparse + +from azure.core.exceptions import ResourceNotFoundError +from azure.storage.blob import ContentSettings +from azure.storage.blob.aio import ContainerClient + +from .utils import convert_to_thumbnail + +__all__ = [ + "upload_image", + "delete_image", + "update_image", +] + + +def get_folder_path(folder_name: Literal["signals", "trends"]) -> str: + """ + Get a path to an image folder derived from a database connection string. + + This allows to manage images from staging and production environments differently. + + Returns + ------- + str + A folder path to save signal/trend images to. + """ + database_name = urlparse(os.getenv("DB_CONNECTION")).path.strip("/") + return f"{database_name}/{folder_name}" + + +def get_container_client() -> ContainerClient: + """ + Get a asynchronous container client for Azure Blob Storage. + + Returns + ------- + client : ContainerClient + An asynchronous container client. + """ + client = ContainerClient.from_container_url(container_url=os.environ["SAS_URL"]) + return client + + +async def upload_image( + entity_id: int, + folder_name: Literal["signals", "trends"], + image_string: str, +) -> str: + """ + Upload a thumbnail JPEG version of an image to Azure Blob Storage and return a (public) URL. + + The function converts the image to a JPEG format and rescales it to 720p. + + Parameters + ---------- + entity_id : str + Signal or trend ID for which an image is uploaded. + folder_name : Literal['signals', 'trends'] + Folder name to save the image to. + image_string : str + Base64-encoded image data. + + Returns + ------- + blob_url : str + A (public) URL pointing to the image file on Blob Storage + that can be used to embed the image in HTML. + """ + # decode the image string + image_string = image_string.split(sep=",", maxsplit=1)[-1] + image_data = convert_to_thumbnail(image_string) + + # connect and upload to the storage + async with get_container_client() as client: + folder_path = get_folder_path(folder_name) + blob_client = await client.upload_blob( + name=f"{folder_path}/{entity_id}.jpeg", + data=image_data, + blob_type="BlockBlob", + overwrite=True, + content_settings=ContentSettings(content_type="image/jpeg"), + ) + + # remove URL parameters that expose a SAS token + blob_url = blob_client.url.split("?")[0] + return blob_url + + +async def delete_image( + entity_id: int, + folder_name: Literal["signals", "trends"], +) -> bool: + """ + Remove an image from Azure Blob Storage. + + Parameters + ---------- + entity_id : str + Signal or trend ID whose image is to be deleted. + folder_name : Literal['signals', 'trends'] + Folder name to delete the image from. + + Returns + ------- + True if the blob has been deleted and False otherwise. + """ + folder_path = get_folder_path(folder_name) + async with get_container_client() as client: + try: + await client.delete_blob(f"{folder_path}/{entity_id}.jpeg") + except ResourceNotFoundError: + return False + return True + + +async def update_image( + entity_id: int, + folder_name: Literal["signals", "trends"], + attachment: str | None, +) -> str | None: + """ + Update an image attachment on Azure Blob Storage. + + If the attachment is None, the attachment will be deleted, if it is a base64-encoded + image, the attachment will be updated, it is a URL to an image, no action will be taken. + + Parameters + ---------- + entity_id : str + Signal or trend ID whose image is to be updated. + folder_name : Literal['signals', 'trends'] + Folder name to delete the image from. + attachment : str | None + A base64-encoded image data, existing attachment URL or None. + + Returns + ------- + str | None + A string to the current/updated image or None if it has been deleted or update failed. + """ + if attachment is None: + await delete_image(entity_id, folder_name) + return None + if attachment.startswith("https"): + return attachment + + try: + blob_url = await upload_image( + entity_id=entity_id, + folder_name=folder_name, + image_string=attachment, + ) + except Exception as e: + print(e) + return None + return blob_url diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..0cf9e4a --- /dev/null +++ b/src/utils.py @@ -0,0 +1,143 @@ +""" +Miscellaneous utilities for data wrangling. +""" + +import base64 +from datetime import UTC, datetime +from io import BytesIO +from typing import Literal + +import httpx +import pandas as pd +from bs4 import BeautifulSoup +from fastapi.responses import StreamingResponse +from PIL import Image +from sklearn.preprocessing import MultiLabelBinarizer + + +def convert_to_thumbnail(image_string: str) -> bytes: + """ + Convert an image to a JPEG thumbnail no larger than HD quality. + + Parameters + ---------- + image_string : str + Base64 encoded image string. + + Returns + ------- + image_data : bytes + The image thumbnail encoded as JPEG. + """ + # read the original image + buffer = BytesIO(base64.b64decode(image_string)) + image = Image.open(buffer) + + # resize, save in-memory and return bytes + buffer = BytesIO() + image.thumbnail((1280, 720)) + image.convert("RGB").save(buffer, format="jpeg") + image_data = buffer.getvalue() + return image_data + + +async def scrape_content(url: str) -> str: + """ + Scrape content of a web page to be fed to an OpenAI model. + + Parameters + ---------- + url : str + A publicly accessible URL. + + Returns + ------- + str + Web content of a page "as is". + """ + headers = {"User-Agent": "Mozilla/5.0"} + async with httpx.AsyncClient(headers=headers, timeout=10) as client: + response = await client.get(url) + soup = BeautifulSoup(response.content, features="lxml") + return soup.text + + +def format_column_name(prefix: str, value: str) -> str: + """ + Format column names for dummy columns created by MultiLabelBinarizer. + + This function is used for prettifying exports. + + Parameters + ---------- + prefix : str + A prefix to assign to a column. + value : str + The original column name value. + + Returns + ------- + str + Formated column name. + """ + value = value.split("–")[0] + value = value.strip().lower() + value = value.replace(" ", "_") + return f"{prefix}_{value}" + + +def binarise_columns(df: pd.DataFrame, columns: list[str]): + """ + Binarise columns containing array values. + + Parameters + ---------- + df : pd.DataFrame + Arbitrary dataframe. + columns : list[str] + Columns to binarise. + + Returns + ------- + df : pd.DataFrame + Mutated data frame. + """ + for column in columns: + mlb = MultiLabelBinarizer() + # fill in missing values with an empty list + values = mlb.fit_transform(df[column].apply(lambda x: x or [])) + df_dummies = pd.DataFrame(values, columns=mlb.classes_) + df_dummies.rename(lambda x: format_column_name(column, x), axis=1, inplace=True) + df = df.join(df_dummies) + df.drop(columns, axis=1, inplace=True) + return df + + +def write_to_response( + df: pd.DataFrame, + kind: Literal["signals", "trends"], +) -> StreamingResponse: + """ + Write a data frame to an Excel file in a Streaming response that can be returned by the API. + + Parameters + ---------- + df : pd.DataFrame + A data frame of exported signals/trends. + kind : Literal["signals", "trends"] + A kind of the data being exported to include in the file name. + + Returns + ------- + response : StreamingResponse + A response object containing the exported data that can be returned by the API. + """ + buffer = BytesIO() + df.to_excel(buffer, index=False) + file_name = f"ftss-{kind}-{datetime.now(UTC):%y-%m-%d}.xlsx" + response = StreamingResponse( + BytesIO(buffer.getvalue()), + media_type="application/vnd.ms-excel", + headers={"Content-Disposition": f"attachment; filename={file_name}"}, + ) + return response diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..47d7588 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,22 @@ +""" +Fixtures for setting up the tests shared across the test suite. +""" + +import os + +import pytest +from dotenv import load_dotenv + +load_dotenv() + + +@pytest.fixture(scope="session", params=[os.environ["API_KEY"], os.environ["API_JWT"]]) +def headers(request) -> dict[str, str]: + """Header for authentication with an API key or a JWT for a regular user (not curator or admin).""" + return {"access_token": request.param} + + +@pytest.fixture(scope="session") +def headers_with_jwt(request) -> dict[str, str]: + """Header for authentication with a JWT for a regular user (not curator or admin).""" + return {"access_token": os.environ["API_JWT"]} diff --git a/tests/test_authentication.py b/tests/test_authentication.py new file mode 100644 index 0000000..985b1d0 --- /dev/null +++ b/tests/test_authentication.py @@ -0,0 +1,78 @@ +""" +Basic tests to ensure the authentication is required and API key works for specific endpoints. +Currently, the tests do not cover JWT-based authentication. +""" + +import re +from typing import Literal + +from fastapi.testclient import TestClient +from pytest import mark + +from main import app + +client = TestClient(app) + + +def get_endpoints(pattern: str, method: Literal["GET", "POST", "PUT"] | None = None): + """ + Convenience method to get all endpoints for a specific method. + + Parameters + ---------- + pattern : str + A regex pattern to match against endpoint path. + method : Literal["GET", "POST", "PUT"] | None + An optional type of HTTP method to filter endpoints. + + Returns + ------- + endpoints : list[tuple[str, str]] + A list of tuples containing endpoint path and method name. + """ + endpoints = [] + for route in app.routes: + if re.search(pattern, route.path): + if method is None or {method} & route.methods: + endpoints.append((route.path, list(route.methods)[0])) + return endpoints + + +def test_read_docs(): + """Ensure the documentation page is accessible without authentication.""" + response = client.get("/") + assert response.status_code == 200, "Documentation page is inaccessible" + + +@mark.parametrize( + "endpoint,method", + get_endpoints(r"signals|trends|users|choices"), +) +def test_authentication_required(endpoint: str, method: str): + """Check if endpoints, except for documentation, require authentication.""" + response = client.request(method, endpoint) + assert response.status_code == 403 + + +@mark.parametrize( + "endpoint,status_code", + [ + ("/signals/search", 200), + ("/signals/export", 403), + ("/signals/10", 200), + ("/trends/search", 200), + ("/trends/export", 403), + ("/trends/10", 200), + ("/users/search", 403), + ("/users/me", 200), + ("/users/1", 403), + ("/choices", 200), + ("/choices/status", 200), + ], +) +def test_authentication_get(endpoint: str, status_code: int, headers: dict): + """ + Check if authentication for GET endpoints works as expected. + """ + response = client.get(endpoint, headers=headers) + assert response.status_code == status_code diff --git a/tests/test_choices.py b/tests/test_choices.py new file mode 100644 index 0000000..e643850 --- /dev/null +++ b/tests/test_choices.py @@ -0,0 +1,41 @@ +""" +Basic tests for choices endpoints. +""" + +from enum import StrEnum + +from fastapi.testclient import TestClient +from pytest import mark + +from main import app +from src.entities import Goal, Role, Status + +client = TestClient(app) + +enums = [("status", Status), ("role", Role), ("goal", Goal)] + + +def test_choices(headers: dict): + endpoint = "/choices" + response = client.get(endpoint, headers=headers) + assert response.status_code == 200 + + data = response.json() + assert isinstance(data, dict) + for k, v in data.items(): + assert isinstance(k, str) + assert isinstance(v, list) + for name, enum in enums: + assert name in data + assert data[name] == [x.value for x in enum] + + +@mark.parametrize("name,enum", enums) +def test_choice_name(name: str, enum: StrEnum, headers: dict): + endpoint = f"/choices/{name}" + response = client.get(endpoint, headers=headers) + assert response.status_code == 200 + + data = response.json() + assert isinstance(data, list) + assert data == [x.value for x in enum] diff --git a/tests/test_signals_and_trends.py b/tests/test_signals_and_trends.py new file mode 100644 index 0000000..cdf6054 --- /dev/null +++ b/tests/test_signals_and_trends.py @@ -0,0 +1,120 @@ +""" +Basic tests for search and CRUD operations on signals and trends. +""" + +from typing import Literal + +from fastapi.testclient import TestClient +from pytest import mark + +from main import app +from src.entities import Goal, Pagination, Signal, Trend + +client = TestClient(app) + + +@mark.parametrize("path", ["signals", "trends"]) +@mark.parametrize("page", [None, 1, 2]) +@mark.parametrize("per_page", [None, 10, 20]) +@mark.parametrize("goal", [None, Goal.G13]) +@mark.parametrize("query", [None, "climate"]) +@mark.parametrize("ids", [None, list(range(100))]) +def test_search( + path: Literal["signals", "trends"], + page: int | None, + per_page: int | None, + goal: Goal | None, + query: str | None, + ids: list[int] | None, + headers_with_jwt: dict, +): + endpoint = f"/{path}/search" + params = { + "page": page, + "per_page": per_page, + "goal": goal, + "query": query, + "ids": ids, + } + params = {k: v for k, v in params.items() if v is not None} + + # ensure the pagination values are set to defaults if not used in the request + page = page or Pagination.model_fields["page"].default + per_page = per_page or Pagination.model_fields["per_page"].default + + response = client.get(endpoint, params=params, headers=headers_with_jwt) + assert response.status_code == 200 + results = response.json() + assert results.get("current_page") == page + assert results.get("per_page") == per_page + assert isinstance(results.get("data"), list) + assert 0 < len(results["data"]) <= per_page + match path: + case "signals": + entity = Signal(**results["data"][0]) + case "trends": + entity = Trend(**results["data"][0]) + case _: + raise ValueError(f"Unknown path: {path}") + assert entity.id == results["data"][0]["id"] + + +@mark.parametrize("path", ["signals", "trends"]) +@mark.parametrize("uid", list(range(10, 20))) +def test_read_by_id( + path: Literal["signals", "trends"], + uid: int, + headers_with_jwt: dict, +): + endpoint = f"/{path}/{uid}" + response = client.get(endpoint, headers=headers_with_jwt) + assert response.status_code in {200, 404} + if response.status_code == 200: + data = response.json() + signal = Signal(**data) + assert signal.id == data["id"] + + +def test_crud(headers_with_jwt: dict): + """Currently, testing for signals only as a staff role is required to manage trends.""" + # instantiate a test object + entity = Signal(**Signal.model_config["json_schema_extra"]["example"]) + + # create + endpoint = "/signals" + response = client.post(endpoint, json=entity.model_dump(), headers=headers_with_jwt) + assert response.status_code == 201 + data = response.json() + assert entity.headline == data["headline"] + assert entity.description == data["description"] + assert entity.sdgs == data["sdgs"] + + # read + endpoint = "/signals/{}".format(data["id"]) + response = client.get(endpoint, headers=headers_with_jwt) + assert response.status_code == 200 + data = response.json() + assert entity.headline == data["headline"] + assert entity.description == data["description"] + assert entity.sdgs == data["sdgs"] + + # update + endpoint = "/signals/{}".format(data["id"]) + data |= { + "headline": "New Headline", + "description": "Lorem opsum " * 10, + "sdgs": [Goal.G1, Goal.G17], + } + response = client.put(endpoint, json=data, headers=headers_with_jwt) + assert response.status_code == 200 + data = response.json() + assert entity.headline != data["headline"] + assert entity.description != data["description"] + assert entity.sdgs != data["sdgs"] + + # delete + endpoint = "/signals/{}".format(data["id"]) + response = client.delete(endpoint, headers=headers_with_jwt) + assert response.status_code == 200 + response = client.get(endpoint, headers=headers_with_jwt) + assert response.status_code == 404 diff --git a/tests/test_users.py b/tests/test_users.py new file mode 100644 index 0000000..2cbdca9 --- /dev/null +++ b/tests/test_users.py @@ -0,0 +1,57 @@ +""" +Basic tests for user endpoints. +""" + +from fastapi.testclient import TestClient +from pytest import mark + +from main import app +from src.entities import Role, User + +client = TestClient(app) + + +def test_me(headers: dict): + endpoint = "/users/me" + response = client.get(endpoint, headers=headers) + assert response.status_code == 200 + + user = User(**response.json()) + assert user.role in {Role.VISITOR, Role.USER} + assert not user.is_admin + assert not user.is_staff + + +@mark.parametrize( + "unit", + [ + "Bureau for Policy and Programme Support (BPPS)", + "Chief Digital Office (CDO)", + "Executive Office (ExO)", + ], +) +@mark.parametrize("acclab", [True, False]) +def test_update(unit: str, acclab: bool, headers_with_jwt: dict): + # get the user data from the database + endpoint = "/users/me" + response = client.get(endpoint, headers=headers_with_jwt) + assert response.status_code == 200 + user = User(**response.json()) + assert user.role == Role.USER, "The JWT must belong to a regular user" + + # regular users should be able to update their profile + endpoint = f"/users/{user.id}" + user.acclab = acclab + user.unit = unit + response = client.put(endpoint, json=user.model_dump(), headers=headers_with_jwt) + assert response.status_code == 200 + data = response.json() + assert user.id == data["id"] + assert user.is_regular + assert user.acclab == acclab + assert user.unit == unit + + # regular users should not be able to change their role + user.role = Role.ADMIN + response = client.put(endpoint, json=user.model_dump(), headers=headers_with_jwt) + assert response.status_code == 403