diff --git a/.github/workflows/unit_tests.yaml b/.github/workflows/unit_tests.yaml index f4019fd..2b1c655 100644 --- a/.github/workflows/unit_tests.yaml +++ b/.github/workflows/unit_tests.yaml @@ -35,7 +35,14 @@ jobs: poetry-version: 1.1.13 - name: Install dependencies run: poetry install + - name: Start MongoDB + uses: supercharge/mongodb-github-action@1.7.0 + with: + mongodb-version: 5.0 + mongodb-username: root + mongodb-password: example + mongodb-db: scim2UnitTest + mongodb-port: 27017 - name: Unit tests run: | - ISTORE_PG_SCHEMA_IGNORE_THIS=unit_test_$(date +%s) \ poetry run pytest tests/unit -p no:warnings --asyncio-mode=strict diff --git a/Makefile b/Makefile index 8bf556b..d2ae61d 100644 --- a/Makefile +++ b/Makefile @@ -31,6 +31,13 @@ integration-tests-cosmos-store: $(POETRY_BIN) run pytest tests/integration -p no:warnings --verbose --asyncio-mode=strict ; \ $(POETRY_BIN) run tests/integration/scripts/cleanup.py +.PHONY: integration-tests-mongo-store +integration-tests-mongo-store: export CONFIG_PATH=./config/integration-tests-mongo-store.yaml +integration-tests-mongo-store: export STORE_MONGO_DATABASE=scim_int_tsts_$(shell date +%s) +integration-tests-mongo-store: + $(POETRY_BIN) run pytest tests/integration -p no:warnings --verbose --asyncio-mode=strict ; \ + $(POETRY_BIN) run tests/integration/scripts/cleanup.py + .PHONY: security-tests security-tests: $(POETRY_BIN) run bandit -r ./keystone diff --git a/README.md b/README.md index d7dea81..aefd45a 100644 --- a/README.md +++ b/README.md @@ -49,20 +49,23 @@ operations with an identity manager that supports user provisioning (e.g., Azure you can use **Keystone** to persist directory changes. Keystone v0.1.0 supports two persistence layers: PostgreSQL and Azure Cosmos DB. +
+ logo +
+ + Key features: * A compliant [SCIM 2.0 REST API](https://datatracker.ietf.org/doc/html/rfc7644) implementation for Users and Groups. -* Stateless container - deploy it anywhere you want (e.g., Kubernetes). +* Stateless container - deploy it anywhere you want (e.g., Kubernetes) and bring your own storage. * Pluggable store for users and groups. Current supported storage technologies: * [Azure Cosmos DB](https://docs.microsoft.com/en-us/azure/cosmos-db/introduction) * [PostgreSQL](https://www.postgresql.org) (version 10 or higher) - + * [MongoDB](https://www.mongodb.com/docs/) (version 3.6 or higher) * Azure Key Vault bearer token retrieval. -* Extensible stores. - -Can't use Cosmos DB or PostgreSQL? Open an issue and/or consider -[becoming a contributor](./CONTRIBUTING.md). +* Extensible store: Can't use MongoDB, Cosmos DB, or PostgreSQL? Open an issue and/or consider + [becoming a contributor](./CONTRIBUTING.md) by implementing your own data store. ## Configure the API diff --git a/config/README.md b/config/README.md deleted file mode 100644 index 32f0a54..0000000 --- a/config/README.md +++ /dev/null @@ -1,53 +0,0 @@ -# Keystone Configuration - -This page outlines the possible configurations that -a Keystone container supports. - -## YAML File or Environment Variables? - -The short answer: You can use either a YAML file, environment variables, -or **a combination of both**. - -You can populate some, all, or none of the configuration keys using environment -variables. All configuration keys can be represented by an environment variable by -the capitalizing the entire key name and replacing the nesting dot (`.`) annotation with -an underscore (`_`). - -For example, `store.cosmos_account_key` can be populated with the -`STORE_COSMOS_ACCOUNT_KEY` environment variable in the container -the API is running in. - -## Configure Keystore with Environment Variables - -| **VARIABLE** | **Type** | **Description** | **Default Value** | -|-----------------------------------------------------------------------------------|----------|--------------------------------------------------------------------------------------|------------------------| -| store.
  type | string | The persistence layer type. Supported values: `CosmosDB`, `InMemory` | `CosmosDB` | -| store.
  tenant_id | string | Azure Tenant ID, if using a Cosmos DB store with Client Secret Credentials auth. | - | -| store.
  client_id | string | Azure Client ID, if using a Cosmos DB store with Client Secret Credentials auth. | - | -| store.
  secret | string | Azure Client Secret, if using a Cosmos DB store with Client Secret Credentials auth. | - | -| store.
  cosmos_account_uri | string | Cosmos Account URI, if using a Cosmos DB store | - | -| store.
  cosmos_account_key | string | Cosmos DB account key, if using a Cosmos DB store with Account Key auth. | - | -| store.
  cosmos_db_name | string | Cosmos DB database name, if using a Cosmos DB store | `scim_2_identity_pool` | -| authentication.
  secret | string | Plain secret bearer token | - | -| authentication.
  akv.
    vault_name | string | AKV name, if bearer token is stored in AKV. | - | -| authentication.
  akv.
    secret_name | string | AKV secret name, if bearer token is stored in AKV. | `scim-2-api-token` | -| authentication.
  akv.
    credentials_client | string | Credentials client type, if bearer token is stored in AKV. | `default` | -| authentication.
  akv.
    force_create | bool | Try to create an AKV secret on startup, if bearer token to be stored in AKV. | `false` | - -## Configure Keystone with a YAML File - - -| **Key** | **Type** | **Description** | **Default Value** | -|-----------------------------------------------------------------------------------|----------|--------------------------------------------------------------------------------------|------------------------| -| store.
  type | string | The persistence layer type. Supported values: `CosmosDB`, `InMemory` | `CosmosDB` | -| store.
  tenant_id | string | Azure Tenant ID, if using a Cosmos DB store with Client Secret Credentials auth. | - | -| store.
  client_id | string | Azure Client ID, if using a Cosmos DB store with Client Secret Credentials auth. | - | -| store.
  secret | string | Azure Client Secret, if using a Cosmos DB store with Client Secret Credentials auth. | - | -| store.
  cosmos_account_uri | string | Cosmos Account URI, if using a Cosmos DB store | - | -| store.
  cosmos_account_key | string | Cosmos DB account key, if using a Cosmos DB store with Account Key auth. | - | -| store.
  cosmos_db_name | string | Cosmos DB database name, if using a Cosmos DB store | `scim_2_identity_pool` | -| authentication.
  secret | string | Plain secret bearer token | - | -| authentication.
  akv.
    vault_name | string | AKV name, if bearer token is stored in AKV. | - | -| authentication.
  akv.
    secret_name | string | AKV secret name, if bearer token is stored in AKV. | `scim-2-api-token` | -| authentication.
  akv.
    credentials_client | string | Credentials client type, if bearer token is stored in AKV. | `default` | -| authentication.
  akv.
    force_create | bool | Try to create an AKV secret on startup, if bearer token to be stored in AKV. | `false` | \ No newline at end of file diff --git a/config/dev-cosmos.yaml b/config/dev-cosmos.yaml index e253143..cea66c9 100644 --- a/config/dev-cosmos.yaml +++ b/config/dev-cosmos.yaml @@ -1,6 +1,7 @@ store: type: CosmosDB - cosmos_account_uri: - cosmos_account_key: + cosmos: + account_uri: + account_key: authentication: secret: not-so-secret diff --git a/config/integration-tests-mongo-store.yaml b/config/integration-tests-mongo-store.yaml new file mode 100644 index 0000000..ab03e56 --- /dev/null +++ b/config/integration-tests-mongo-store.yaml @@ -0,0 +1,8 @@ +store: + type: MongoDb + mongo: + host: localhost + port: 27017 + tls: false +authentication: + secret: not-so-secret diff --git a/config/integration-tests-pg-store.yaml b/config/integration-tests-pg-store.yaml index 7aee0b3..a23997a 100644 --- a/config/integration-tests-pg-store.yaml +++ b/config/integration-tests-pg-store.yaml @@ -1,5 +1,6 @@ store: type: PostgreSQL - pg_schema: public + pg: + schema: public authentication: secret: not-so-secret diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000..2cefa2d --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,23 @@ +# Use root/example as user/password credentials +version: '3.1' + +services: + + mongo: + image: mongo + restart: always + ports: + - 27017:27017 + environment: + MONGO_INITDB_ROOT_USERNAME: root + MONGO_INITDB_ROOT_PASSWORD: example + + mongo-express: + image: mongo-express + restart: always + ports: + - 8081:8081 + environment: + ME_CONFIG_MONGODB_ADMINUSERNAME: root + ME_CONFIG_MONGODB_ADMINPASSWORD: example + ME_CONFIG_MONGODB_URL: mongodb://root:example@mongo:27017/ diff --git a/keystone/cmd.py b/keystone/cmd.py index f0de761..27fa5ad 100644 --- a/keystone/cmd.py +++ b/keystone/cmd.py @@ -3,6 +3,7 @@ import os from keystone import VERSION, LOGO, InterceptHandler +from keystone.store.mongodb_store import set_up from keystone.store.postgresql_store import set_up_schema from keystone.util.logger import get_log_handler @@ -46,8 +47,10 @@ async def print_logo(_logger): async def serve(host: str = "0.0.0.0", port: int = 5001): - if CONFIG.get("store.type") == "PostgreSQL": + if CONFIG.get("store.pg.host") is not None: set_up_schema() + elif CONFIG.get("store.mongo.host") or CONFIG.get("store.mongo.dsn"): + await set_up() error_handling_mw = await get_error_handling_mw() diff --git a/keystone/rest/__init__.py b/keystone/rest/__init__.py index 3824a93..c885adb 100644 --- a/keystone/rest/__init__.py +++ b/keystone/rest/__init__.py @@ -1,6 +1,7 @@ from aiohttp_catcher import Catcher, canned, catch from azure.cosmos.exceptions import CosmosResourceNotFoundError from psycopg2.errors import UniqueViolation +from pymongo.errors import DuplicateKeyError from keystone.models import DEFAULT_ERROR_SCHEMA from keystone.util.exc import ResourceNotFound, ResourceAlreadyExists, UnauthorizedRequest @@ -9,14 +10,17 @@ async def get_error_handling_mw(): catcher = Catcher(code="status", envelope="detail") err_schemas = {"schemas": [DEFAULT_ERROR_SCHEMA]} - await catcher.add_scenarios(*[sc.with_additional_fields(err_schemas) for sc in canned.AIOHTTP_SCENARIOS]) await catcher.add_scenarios( + *[sc.with_additional_fields(err_schemas) for sc in canned.AIOHTTP_SCENARIOS], + catch(ResourceNotFound).with_status_code(404).and_stringify().with_additional_fields(err_schemas), + catch(CosmosResourceNotFoundError).with_status_code(404).and_return( "Resource not found").with_additional_fields(err_schemas), - catch(UniqueViolation).with_status_code(409).and_return( + + catch(UniqueViolation, DuplicateKeyError, ResourceAlreadyExists).with_status_code(409).and_return( "Resource already exists").with_additional_fields(err_schemas), - catch(ResourceAlreadyExists).with_status_code(409).and_stringify().with_additional_fields(err_schemas), + catch(UnauthorizedRequest).with_status_code(401).and_return("Unauthorized request").with_additional_fields( err_schemas) ) diff --git a/keystone/rest/group.py b/keystone/rest/group.py index 7e2dcb7..e58f4c3 100644 --- a/keystone/rest/group.py +++ b/keystone/rest/group.py @@ -12,7 +12,7 @@ from keystone.models import ListQueryParams, ErrorResponse, DEFAULT_LIST_SCHEMA from keystone.models.group import Group, PatchGroupOp, ListGroupsResponse -from keystone.store import BaseStore, RDBMSStore +from keystone.store import BaseStore, RDBMSStore, DocumentStore, DatabaseStore from keystone.util.store_util import Stores LOGGER = logging.getLogger(__name__) @@ -97,7 +97,9 @@ async def _execute_group_operation(self, operation: Dict) -> Dict: ] } """ - is_rdbms = isinstance(group_store, RDBMSStore) + is_dbs = isinstance(group_store, DatabaseStore) + is_rdbmss = isinstance(group_store, RDBMSStore) + is_docs = isinstance(group_store, DocumentStore) group_id = self.request.match_info["group_id"] if hasattr(group_store, "resource_db"): group = group_store.resource_db[group_id] @@ -113,20 +115,23 @@ async def _execute_group_operation(self, operation: Dict) -> Dict: if op_path and op_path.startswith("members[") and not op_value: # Remove members with a path: _filter = op_path.strip("members[").strip("]") - if is_rdbms: + if is_dbs: if op_type == "remove": - selected_members = await group_store.search_members( - _filter=_filter, group_id=group_id - ) - _ = await group_store.remove_users_from_group( - user_ids=[m.get("id") for m in selected_members], group_id=group_id - ) + if is_rdbmss: + selected_members = await group_store.search_members( + _filter=_filter, group_id=group_id + ) + _ = await group_store.remove_users_from_group( + user_ids=[m.get("id") for m in selected_members], group_id=group_id + ) + else: + _filter = _filter.replace("value", "id").replace("display", "userName") + u, _ = await user_store.search(_filter=_filter) + _ = await group_store.remove_users_from_group( + user_ids=[m.get("id") for m in u], group_id=group_id + ) elif op_type == "add": selected_members, _ = await user_store.search(_filter=_filter.replace("value", "id")) - LOGGER.debug(str(selected_members)) - LOGGER.debug(op_type) - LOGGER.debug(_filter) - # TODO: Need to handle: for m in selected_members: _ = await group_store.add_user_to_group( user_id=m.get("id"), group_id=group_id @@ -143,7 +148,7 @@ async def _execute_group_operation(self, operation: Dict) -> Dict: _ = await group["members_store"].delete(member.get("value")) return group if op_path == "members" and op_type == "replace" and op_value: - if is_rdbms: + if is_dbs: _ = await group_store.set_group_members(users=op_value, group_id=group_id) return group else: @@ -151,7 +156,7 @@ async def _execute_group_operation(self, operation: Dict) -> Dict: if op_path == "members" and op_value: for member in op_value: member_id = member.get("value") - if is_rdbms: + if is_dbs: if op_type == "add": _ = await group_store.add_user_to_group(member.get("value"), group_id) elif op_type == "remove": diff --git a/keystone/store/__init__.py b/keystone/store/__init__.py index 37695c4..8cd7587 100644 --- a/keystone/store/__init__.py +++ b/keystone/store/__init__.py @@ -50,8 +50,7 @@ async def _sanitize(self, resource: Dict) -> Dict: return s_resource -class RDBMSStore(BaseStore, ABC): - +class DatabaseStore(BaseStore, ABC): async def remove_users_from_group(self, user_ids: List[str], group_id: str): raise NotImplementedError("Method 'remove_user_from_group' not implemented") @@ -61,5 +60,12 @@ async def add_user_to_group(self, user_id: str, group_id: str): async def set_group_members(self, users: List[Dict], group_id: str): raise NotImplementedError("Method 'set_group_members' not implemented") + +class RDBMSStore(DatabaseStore, ABC): + async def search_members(self, _filter: str, group_id: str): raise NotImplementedError("Method 'search_members' not implemented") + + +class DocumentStore(DatabaseStore, ABC): + pass diff --git a/keystone/store/cosmos_db_store.py b/keystone/store/cosmos_db_store.py index cc313f8..0980e4e 100644 --- a/keystone/store/cosmos_db_store.py +++ b/keystone/store/cosmos_db_store.py @@ -28,12 +28,12 @@ async def get_client_credentials(async_client: bool = True): # forces the usage of the main module to run aggregate queries with the # 'VALUE' keyword. The bug doesn't exist in the main module, therefore # this function can currently produce non-async credentials. - cosmos_account_key = CONFIG.get("store.cosmos_account_key") + cosmos_account_key = CONFIG.get("store.cosmos.account_key") if cosmos_account_key: return cosmos_account_key - tenant_id = CONFIG.get("store.tenant_id") - client_id = CONFIG.get("store.client_id") - client_secret = CONFIG.get("store.client_secret") + tenant_id = CONFIG.get("store.cosmos.tenant_id") + client_id = CONFIG.get("store.cosmos.client_id") + client_secret = CONFIG.get("store.cosmos.client_secret") if tenant_id and client_id and client_secret: aad_credentials = AsyncClientSecretCredential( tenant_id=tenant_id, @@ -71,7 +71,7 @@ class CosmosDbStore(BaseStore): def __init__(self, entity_name: str, key_attr: str = "id", unique_attribute: str = None): self.entity_name = entity_name self.key_attr = key_attr - self.account_uri = CONFIG.get("store.cosmos_account_uri") + self.account_uri = CONFIG.get("store.cosmos.account_uri") self.unique_attribute = unique_attribute self.container_name = f"scim2{self.entity_name}" self.init_client() @@ -80,7 +80,7 @@ async def get_by_id(self, resource_id: str): client_creds = await get_client_credentials() uri = self.account_uri async with AsyncCosmosClient(uri, credential=client_creds) as client: - database = client.get_database_client(CONFIG.get("store.cosmos_db_name")) + database = client.get_database_client(CONFIG.get("store.cosmos.db_name")) container = database.get_container_client(self.container_name) resource = await container.read_item(item=resource_id, partition_key=resource_id) return await remove_cosmos_metadata(resource) @@ -93,7 +93,7 @@ async def _get_query_count(self, query: str, params: Dict): client_creds = await get_client_credentials(async_client=False) try: c = CosmosClient(self.account_uri, credential=client_creds) - db = c.get_database_client(CONFIG.get("store.cosmos_db_name")) + db = c.get_database_client(CONFIG.get("store.cosmos.db_name")) co = db.get_container_client(self.container_name) count = 0 for res in co.query_items(query=query, parameters=params, enable_cross_partition_query=True): @@ -129,7 +129,7 @@ async def search(self, _filter: str, start_index: int = 1, count: int = 100) -> resources = [] async with AsyncCosmosClient(uri, credential=client_creds) as client: try: - database = client.get_database_client(CONFIG.get("store.cosmos_db_name")) + database = client.get_database_client(CONFIG.get("store.cosmos.db_name")) container = database.get_container_client(self.container_name) iterator = container.query_items(query=query, parameters=params, populate_query_metrics=True) async for resource in iterator: @@ -146,7 +146,7 @@ async def update(self, resource_id: str, **kwargs: Dict): client_creds = await get_client_credentials() uri = self.account_uri async with AsyncCosmosClient(uri, credential=client_creds) as client: - database = client.get_database_client(CONFIG.get("store.cosmos_db_name")) + database = client.get_database_client(CONFIG.get("store.cosmos.db_name")) container = database.get_container_client(self.container_name) try: resource = await remove_cosmos_metadata( @@ -162,7 +162,7 @@ async def create(self, resource: Dict) -> Dict: client_creds = await get_client_credentials() uri = self.account_uri async with AsyncCosmosClient(uri, credential=client_creds) as client: - database = client.get_database_client(CONFIG.get("store.cosmos_db_name")) + database = client.get_database_client(CONFIG.get("store.cosmos.db_name")) container = database.get_container_client(self.container_name) resource_id = resource.get(self.key_attr) or str(uuid.uuid4()) try: @@ -189,7 +189,7 @@ async def delete(self, resource_id: str): client_creds = await get_client_credentials() uri = self.account_uri async with AsyncCosmosClient(uri, credential=client_creds) as client: - database = client.get_database_client(CONFIG.get("store.cosmos_db_name")) + database = client.get_database_client(CONFIG.get("store.cosmos.db_name")) container = database.get_container_client(self.container_name) try: _ = await container.delete_item(item=resource_id, partition_key=resource_id) @@ -198,15 +198,15 @@ async def delete(self, resource_id: str): return def init_client(self): - account_uri = CONFIG.get("store.cosmos_account_uri") + account_uri = CONFIG.get("store.cosmos.account_uri") if not account_uri: raise ValueError( - "Could not initialize Cosmos DB store. Missing configuration: 'store.cosmos_account_uri'" + "Could not initialize Cosmos DB store. Missing configuration: 'store.cosmos.account_uri'" ) - tenant_id = CONFIG.get("store.tenant_id") - client_id = CONFIG.get("store.client_id") - client_secret = CONFIG.get("store.client_secret") - cosmos_account_key = CONFIG.get("store.cosmos_account_key") + tenant_id = CONFIG.get("store.cosmos.tenant_id") + client_id = CONFIG.get("store.cosmos.client_id") + client_secret = CONFIG.get("store.cosmos.client_secret") + cosmos_account_key = CONFIG.get("store.cosmos.account_key") if cosmos_account_key: client = CosmosClient(account_uri, credential=cosmos_account_key, consistency_level="Session") elif tenant_id and client_id and client_secret: @@ -218,7 +218,7 @@ def init_client(self): client = CosmosClient(account_uri, credential=aad_credentials, consistency_level="Session") else: client = CosmosClient(account_uri, credential=DefaultAzureCredential(), consistency_level="Session") - cosmos_db_name = CONFIG.get("store.cosmos_db_name") + cosmos_db_name = CONFIG.get("store.cosmos.db_name") try: database = client.create_database(cosmos_db_name) except (exceptions.CosmosResourceExistsError, exceptions.CosmosHttpResponseError): diff --git a/keystone/store/memory_store.py b/keystone/store/memory_store.py index 7f12bd3..464e599 100644 --- a/keystone/store/memory_store.py +++ b/keystone/store/memory_store.py @@ -180,13 +180,13 @@ async def evaluate_filter(self, parsed_filter: Dict, node: Dict): expr = parsed_filter.get("expr") f = expr.get("func") if f: - f = expr["func"] # eq() - pred = expr["pred"] # "f1aa2630-6343-41fa-bae4-384a46bc2ed3" - attr = expr["attr"].lower() # "value" - op = expr["op"].lower() # "eq" - attr_parts = attr.split(".") # ["value"] + f = expr["func"] + pred = expr["pred"] + attr = expr["attr"].lower() + op = expr["op"].lower() + attr_parts = attr.split(".") node_attr_value = node - namespace = expr.get("namespace") # "members" + namespace = expr.get("namespace") if namespace: f = self.filter_map[f"{op}_lst"] lst = node.get(namespace) diff --git a/keystone/store/mongodb_store.py b/keystone/store/mongodb_store.py new file mode 100644 index 0000000..ad04373 --- /dev/null +++ b/keystone/store/mongodb_store.py @@ -0,0 +1,272 @@ +import asyncio +import urllib.parse +from typing import Dict, List + +from bson.objectid import ObjectId +from motor.motor_asyncio import AsyncIOMotorClient +from pymongo.collation import Collation +from scim2_filter_parser import ast +from scim2_filter_parser.ast import LogExpr, Filter, AttrExpr, CompValue, AttrPath, AST +from scim2_filter_parser.lexer import SCIMLexer +from scim2_filter_parser.parser import SCIMParser + +from keystone.store import DocumentStore +from keystone.util.config import Config +from keystone.util.exc import ResourceNotFound + +CONFIG = Config() + + +def build_dsn(**kwargs): + dsn = kwargs.get("dsn", CONFIG.get("store.mongo.dsn")) + if dsn: + return dsn + host = kwargs.get("host", CONFIG.get("store.mongo.host")) + port = kwargs.get("port", CONFIG.get("store.mongo.port", 5432)) + username = kwargs.get("username", CONFIG.get("store.mongo.username")) + password = kwargs.get("password", CONFIG.get("store.mongo.password")) + tls = kwargs.get("tls", CONFIG.get("store.mongo.tls", "true")) + if type(tls) == bool: + tls = "true" if tls is True else "false" + replica_set = kwargs.get("replica_set", CONFIG.get("store.mongo.replica_set")) + cred = username + if password: + cred = f"{cred}:{urllib.parse.quote(password)}" + query_params = {} + if tls: + query_params["tls"] = tls + if replica_set: + query_params["replicaSet"] = replica_set + return f"mongodb://{cred}@{host}:{port}/?tls={tls}" + + +async def set_up(**kwargs): + client = AsyncIOMotorClient(build_dsn(**kwargs)) + db_name = kwargs.get("database", CONFIG.get("store.mongo.database")) + users_collection = client[db_name]["users"] + groups_collection = client[db_name]["groups"] + _ = await users_collection.create_index([("userName", 1)], unique=True, + collation=Collation(locale="en", strength=2)) + _ = await users_collection.create_index([("emails.value", 1)], collation=Collation(locale="en", strength=2)) + _ = await groups_collection.create_index([("displayName", 1)], unique=True, + collation=Collation(locale="en", strength=2)) + + +async def _transform_user(item: Dict) -> Dict: + item_id: ObjectId = item.get("_id") + user = {**item} + if item_id: + user["id"] = str(item_id) + del user["_id"] + return user + + +async def _transform_group(item: Dict) -> Dict: + return { + "id": str(item.get("_id")), + "schemas": item.get("schemas"), + "displayName": item.get("displayName"), + "meta": item.get("meta"), + "members": [{"value": str(m.get("_id")), "display": m.get("userName")} for m in item.get("userMembers", [])] + } + + +class MongoDbStore(DocumentStore): + client: AsyncIOMotorClient + entity_type: str + + def __init__(self, entity_type: str, **conn_args): + self.entity_type = entity_type + self.client = AsyncIOMotorClient(build_dsn(**conn_args)) + self.db_name = conn_args.get("database", CONFIG.get("store.mongo.database")) + + async def _get_group_by_id(self, group_id: ObjectId) -> Dict: + aggregate = [ + { + "$match": {"_id": group_id}, + }, + { + "$lookup": { + "from": "users", + "localField": "members", + "foreignField": "_id", + "as": "userMembers", + } + } + ] + async for group in self.collection.aggregate(aggregate, collation={"locale": "en", "strength": 2}): + return await _transform_group(group) + + raise ResourceNotFound("group", str(group_id)) + + async def _get_user_by_id(self, user_id: ObjectId) -> Dict: + resource = await self.collection.find_one({"_id": user_id}) + if resource: + return await _transform_user( + await self._sanitize(resource) + ) + raise ResourceNotFound("User", str(user_id)) + + async def get_by_id(self, resource_id: str) -> Dict: + _resource_id = ObjectId(resource_id) + if self.entity_type == "users": + return await self._get_user_by_id(_resource_id) + return await self._get_group_by_id(_resource_id) + + async def search(self, _filter: str = None, start_index: int = 1, count: int = 100) -> tuple[list[Dict], int]: + parsed_filter = {} + if _filter: + token_stream = SCIMLexer().tokenize(_filter) + ast_nodes = SCIMParser().parse(token_stream) + # We only need the root node, which contains all the references in the tree for traversal: + _, root = ast.flatten(ast_nodes)[0] + parsed_filter = await self.parse_scim_filter(root) + aggregate = [ + {"$facet": { + "data": [ + {"$match": parsed_filter}, + {"$skip": start_index - 1}, + {"$limit": count}, + ], + "totalCount": [ + {"$match": parsed_filter}, + {"$count": "count"}, + ] + }} + ] + res = [] + total = 0 + async for resource in self.collection.aggregate(aggregate, collation={"locale": "en", "strength": 2}): + res = [await _transform_user(r) for r in resource.get("data")] + total = resource.get("totalCount")[0]["count"] if len(resource.get("totalCount")) > 0 else 0 + break + + return res, total + + async def update(self, resource_id: str, **kwargs: Dict): + resource = await self.collection.find_one({"_id": ObjectId(resource_id)}) + if not resource: + ResourceNotFound("User", resource_id) + _ = await self.collection.replace_one({"_id": ObjectId(resource_id)}, kwargs, True) + return await self.get_by_id(resource_id) + + async def create(self, resource: Dict): + sanitized = await self._sanitize(resource) + if "id" in sanitized: + del sanitized["id"] + return await self._create_user(sanitized) if self.entity_type == "users" else await self._create_group( + sanitized) + + @property + def collection(self): + return self.client[self.db_name][self.entity_type] + + async def _create_user(self, user: Dict): + inserted_id = (await self.collection.insert_one(user)).inserted_id + inserted_user = await self.collection.find_one(inserted_id) + return await _transform_user(inserted_user) + + async def _create_group(self, group: Dict): + group["members"] = [ObjectId(m.get("value")) for m in group.get("members", [])] + inserted_id = (await self.collection.insert_one(group)).inserted_id + inserted_group = await self.collection.find_one(inserted_id) + return await _transform_group(inserted_group) + + async def delete(self, resource_id: str): + resource = await self.collection.find_one({"_id": ObjectId(resource_id)}) + if not resource: + raise ResourceNotFound(self.entity_type, resource_id) + _ = await self.collection.delete_one({"_id": ObjectId(resource_id)}) + return {} + + async def clean_up_store(self): + return await self.collection.drop() + + async def parse_scim_filter(self, node: AST, namespace: str = None) -> Dict: + if isinstance(node, Filter): + ns = node.namespace.attr_name if node.namespace else None + expr = await self.parse_scim_filter(node.expr, ns or namespace) + return {"$not": expr} if node.negated else expr + if isinstance(node, AttrExpr): + # Parse an atomic comparison operation: + operator = node.value.lower() + attr_path: AttrPath = node.attr_path + attr = attr_path.attr_name + if attr_path.sub_attr: + sub_attr = attr_path.sub_attr.value + attr = f"{attr}.{sub_attr}" + comp_value: CompValue = node.comp_value + value = comp_value.value if comp_value else None + if value: + if operator.lower() == "eq" and attr == "id": + return { + "_id": ObjectId(value) + } + if operator.lower() == "co" and attr.endswith("emails"): + return { + "emails.value": value + } + if operator.lower() == "sw": + operator = "regex" + value = f"^{value}" + elif operator.lower() == "ew": + operator = "regex" + value = f"{value}$" + elif operator.lower() == "co": + operator = "regex" + value = f"{value}" + if namespace: + attr = f"{namespace}.{attr}" + return { + attr: {f"${operator}": value} + } + if isinstance(node, LogExpr): + # Parse a logical expression: + operator = node.op.lower() + l_exp = await self.parse_scim_filter(node.expr1, namespace) + r_exp = await self.parse_scim_filter(node.expr2, namespace) + return { + f"${operator}": [ + l_exp, + r_exp, + ] + } + + async def _update_group(self, group_id: str, **kwargs) -> Dict: + group = await self.collection.find_one({"_id": ObjectId(group_id)}) + if not group: + raise ResourceNotFound("group", group_id) + _ = await self.collection.replace_one( + {"_id": ObjectId(group_id)}, + kwargs, + True + ) + return await _transform_group(await self.collection.find_one({"_id": ObjectId(group_id)})) + + async def remove_users_from_group(self, user_ids: List[str], group_id: str): + group = await self.collection.find_one({"_id": ObjectId(group_id)}) + if not group: + raise ResourceNotFound("group", group_id) + _ = await self.collection.update_one( + {"_id": ObjectId(group_id)}, + {"$pull": {"members": {"$in": [ObjectId(user_id) for user_id in user_ids]}}}, + ) + + async def add_user_to_group(self, user_id: str, group_id: str): + group = await self.collection.find_one({"_id": ObjectId(group_id)}) + if not group: + raise ResourceNotFound("group", group_id) + members = {str(g): None for g in group.get("members", [])} + if True or user_id not in members: + _ = await self.collection.update_one({"_id": ObjectId(group_id)}, + {"$push": {"members": ObjectId(user_id)}}) + + async def set_group_members(self, user_ids: List[str], group_id: str): + group = await self.collection.find_one({"_id": ObjectId(group_id)}) + if not group: + raise ResourceNotFound("group", group_id) + _ = await self.collection.replace_one( + {"_id": ObjectId(group_id)}, + {"members": [ObjectId(user_id) for user_id in user_ids]}, + True + ) diff --git a/keystone/store/pg_models.py b/keystone/store/pg_models.py index d0a8cfa..3fa0e12 100644 --- a/keystone/store/pg_models.py +++ b/keystone/store/pg_models.py @@ -6,7 +6,7 @@ metadata = sa.MetaData() CONFIG = Config() -_schema = CONFIG.get("store.pg_schema", "public") +_schema = CONFIG.get("store.pg.schema", "public") users = sa.Table( diff --git a/keystone/store/postgresql_store.py b/keystone/store/postgresql_store.py index f936a13..3d0846d 100644 --- a/keystone/store/postgresql_store.py +++ b/keystone/store/postgresql_store.py @@ -27,12 +27,12 @@ def build_dsn(**kwargs): - host = kwargs.get("host", CONFIG.get("store.pg_host")) - port = kwargs.get("port", CONFIG.get("store.pg_port", 5432)) - username = kwargs.get("username", CONFIG.get("store.pg_username")) - password = kwargs.get("password", CONFIG.get("store.pg_password")) - database = kwargs.get("database", CONFIG.get("store.pg_database")) - ssl_mode = kwargs.get("ssl_mode", CONFIG.get("store.pg_ssl_mode")) + host = kwargs.get("host", CONFIG.get("store.pg.host")) + port = kwargs.get("port", CONFIG.get("store.pg.port", 5432)) + username = kwargs.get("username", CONFIG.get("store.pg.username")) + password = kwargs.get("password", CONFIG.get("store.pg.password")) + database = kwargs.get("database", CONFIG.get("store.pg.database")) + ssl_mode = kwargs.get("ssl_mode", CONFIG.get("store.pg.ssl_mode")) cred = username if password: cred = f"{cred}:{urllib.parse.quote(password)}" @@ -43,7 +43,7 @@ def set_up_schema(**kwargs): conn = psycopg2.connect( dsn=build_dsn(**kwargs) ) - schema = CONFIG.get("store.pg_schema", "public") + schema = CONFIG.get("store.pg.schema", "public") cursor = conn.cursor() for q in ddl_queries: cursor.execute(q.format(schema)) @@ -104,7 +104,7 @@ class PostgresqlStore(RDBMSStore): } def __init__(self, entity_type: str, **conn_args): - self.schema = CONFIG.get("store.pg_schema") + self.schema = CONFIG.get("store.pg.schema") self.entity_type = entity_type self.conn_args = conn_args @@ -248,10 +248,11 @@ async def search(self, _filter: str, start_index: int = 1, count: int = 100) -> return await self._search_groups(_filter, start_index, count) async def update(self, resource_id: str, **kwargs: Dict): + sanitized = await self._sanitize(kwargs) if self.entity_type == "users": - return await self._update_user(resource_id, **kwargs) + return await self._update_user(resource_id, **sanitized) if self.entity_type == "groups": - return await self._update_group(resource_id, **kwargs) + return await self._update_group(resource_id, **sanitized) async def _update_user(self, user_id: str, **kwargs: Dict) -> Dict: # "id" is immutable, and "groups" are updated through the groups API: @@ -301,10 +302,11 @@ async def _update_group(self, group_id: str, **kwargs: Dict) -> Dict: return await self._get_group_by_id(group_id) async def create(self, resource: Dict): + sanitized = await self._sanitize(resource) if self.entity_type == "users": - return await self._create_user(resource) + return await self._create_user(sanitized) if self.entity_type == "groups": - return await self._create_group(resource) + return await self._create_group(sanitized) async def _create_group(self, resource: Dict) -> Dict: group_id = resource.get("id") or str(uuid.uuid4()) @@ -421,7 +423,7 @@ async def add_user_to_group(self, user_id: str, group_id: str): _ = await conn.execute(insert_q) return - async def set_group_members(self, user_ids: List[Dict], group_id: str): + async def set_group_members(self, user_ids: List[str], group_id: str): delete_q = delete(tbl.users_groups).where(tbl.users_groups.c.groupId == group_id) insert_q = insert(tbl.users_groups).values( [{"userId": uid, "groupId": group_id} for uid in user_ids] diff --git a/keystone/util/config.py b/keystone/util/config.py index 6be9320..4fd1441 100644 --- a/keystone/util/config.py +++ b/keystone/util/config.py @@ -10,21 +10,34 @@ LOGGER = logging.getLogger(__name__) SCHEMA = Schema({ Optional("store", default={}): Schema({ - Optional("type", default="CosmosDB"): str, - Optional("tenant_id"): str, - Optional("client_id"): str, - Optional("client_secret"): str, - Optional("cosmos_account_uri"): str, - Optional("cosmos_account_key"): str, - Optional("cosmos_db_name", default="scim_2_identity_pool"): str, - - Optional("pg_host", "localhost"): str, - Optional("pg_port", default=5432): int, - Optional("pg_ssl_mode", default="require"): str, - Optional("pg_username"): str, - Optional("pg_password"): str, - Optional("pg_database", "postgres"): str, - Optional("pg_schema", "public"): str, + Optional("type", default="InMemory"): str, + Optional("cosmos", default=None): Schema({ + Optional("tenant_id"): str, + Optional("client_id"): str, + Optional("client_secret"): str, + Optional("account_uri"): str, + Optional("account_key"): str, + Optional("db_name", default="scim_2_db"): str, + }), + Optional("pg", default=None): Schema({ + Optional("host"): str, + Optional("port", default=5432): int, + Optional("ssl_mode", default="require"): str, + Optional("username"): str, + Optional("password"): str, + Optional("database", default="postgres"): str, + Optional("schema", default="public"): str, + }), + Optional("mongo", default=None): Schema({ + Optional("host"): str, + Optional("port", default=27017): int, + Optional("username"): str, + Optional("password"): str, + Optional("database", default="scim2Db"): str, + Optional("tls", default=True): bool, + Optional("replica_set", default=True): str, + Optional("dsn"): str, + }) }), Optional("authentication", default={}): Schema({ Optional("akv", default={}): Schema({ diff --git a/keystone/util/store_util.py b/keystone/util/store_util.py index 3f41f75..90956bb 100644 --- a/keystone/util/store_util.py +++ b/keystone/util/store_util.py @@ -4,6 +4,7 @@ from keystone.store import BaseStore from keystone.store.memory_store import MemoryStore from keystone.store.cosmos_db_store import CosmosDbStore +from keystone.store.mongodb_store import MongoDbStore from keystone.store.postgresql_store import PostgresqlStore from keystone.util import ThreadSafeSingleton from keystone.util.config import Config @@ -27,19 +28,24 @@ def get(self, store_name: str): def init_stores(): store_type = CONFIG.get("store.type", "InMemory") store_impl: BaseStore - if store_type == "CosmosDB": - stores = Stores( - users=CosmosDbStore("users", unique_attribute="userName"), - groups=CosmosDbStore("groups", unique_attribute="displayName") - ) - elif store_type == "PostgreSQL": + if CONFIG.get("store.pg.host") is not None: user_store = PostgresqlStore("users") group_store = PostgresqlStore("groups") stores = Stores( users=user_store, groups=group_store ) - elif store_type == "InMemory": + elif CONFIG.get("store.cosmos.account_uri"): + stores = Stores( + users=CosmosDbStore("users", unique_attribute="userName"), + groups=CosmosDbStore("groups", unique_attribute="displayName") + ) + elif CONFIG.get("store.mongo.host") or CONFIG.get("store.mongo.dsn"): + stores = Stores( + users=MongoDbStore("users"), + groups=MongoDbStore("groups") + ) + else: stores = Stores( users=MemoryStore("User"), groups=MemoryStore( @@ -49,7 +55,5 @@ def init_stores(): nested_store_attr="members" ) ) - else: - raise ValueError(f"Invalid store type: '{store_type}'") LOGGER.debug("Using '%s' store for users and groups", store_type) return stores diff --git a/logo/how-it-works.png b/logo/how-it-works.png new file mode 100644 index 0000000..667dd6c Binary files /dev/null and b/logo/how-it-works.png differ diff --git a/poetry.lock b/poetry.lock index 2ce126a..94c0d91 100644 --- a/poetry.lock +++ b/poetry.lock @@ -284,7 +284,7 @@ python-versions = ">=3.6" [[package]] name = "coverage" -version = "6.4.3" +version = "6.4.4" description = "Code coverage measurement for Python" category = "dev" optional = false @@ -458,6 +458,26 @@ docs = ["sphinx (==4.5.0)", "sphinx-issues (==3.0.1)", "alabaster (==0.7.12)", " lint = ["mypy (==0.961)", "flake8 (==4.0.1)", "flake8-bugbear (==22.6.22)", "pre-commit (>=2.4,<3.0)"] tests = ["pytest", "pytz", "simplejson"] +[[package]] +name = "motor" +version = "3.0.0" +description = "Non-blocking MongoDB driver for Tornado or asyncio" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +pymongo = ">=4.1,<5" + +[package.extras] +zstd = ["pymongo[zstd] (>=4.1,<5)"] +srv = ["pymongo[srv] (>=4.1,<5)"] +snappy = ["pymongo[snappy] (>=4.1,<5)"] +ocsp = ["pymongo[ocsp] (>=4.1,<5)"] +gssapi = ["pymongo[gssapi] (>=4.1,<5)"] +encryption = ["pymongo[encryption] (>=4.1,<5)"] +aws = ["pymongo[aws] (>=4.1,<5)"] + [[package]] name = "msal" version = "1.18.0" @@ -665,6 +685,23 @@ dev = ["sphinx", "sphinx-rtd-theme", "zope.interface", "cryptography (>=3.3.1)", docs = ["sphinx", "sphinx-rtd-theme", "zope.interface"] tests = ["pytest (>=6.0.0,<7.0.0)", "coverage[toml] (==5.0.4)"] +[[package]] +name = "pymongo" +version = "4.2.0" +description = "Python driver for MongoDB " +category = "main" +optional = false +python-versions = ">=3.7" + +[package.extras] +aws = ["pymongo-auth-aws (<2.0.0)"] +encryption = ["pymongocrypt (>=1.3.0,<2.0.0)"] +gssapi = ["pykerberos"] +ocsp = ["pyopenssl (>=17.2.0)", "requests (<3.0.0)", "service-identity (>=18.1.0)", "certifi"] +snappy = ["python-snappy"] +srv = ["dnspython (>=1.16.0,<3.0.0)"] +zstd = ["zstandard"] + [[package]] name = "pyparsing" version = "3.0.9" @@ -1001,7 +1038,7 @@ multidict = ">=4.0" [metadata] lock-version = "1.1" python-versions = "^3.9" -content-hash = "131a0ddb6b5c0d9282577e00051178b696fdb69a9f48431dab741f7b70bf1893" +content-hash = "602e8b2972b0c7ded20b4ca8890896984feff7f5b2ad143089187c12f96af216" [metadata.files] aiohttp = [ @@ -1329,6 +1366,7 @@ marshmallow = [ {file = "marshmallow-3.17.0-py3-none-any.whl", hash = "sha256:00040ab5ea0c608e8787137627a8efae97fabd60552a05dc889c888f814e75eb"}, {file = "marshmallow-3.17.0.tar.gz", hash = "sha256:635fb65a3285a31a30f276f30e958070f5214c7196202caa5c7ecf28f5274bc7"}, ] +motor = [] msal = [ {file = "msal-1.18.0-py2.py3-none-any.whl", hash = "sha256:9c10e6cb32e0b6b8eaafc1c9a68bc3b2ff71505e0c5b8200799582d8b9f22947"}, {file = "msal-1.18.0.tar.gz", hash = "sha256:576af55866038b60edbcb31d831325a1bd8241ed272186e2832968fd4717d202"}, @@ -1446,6 +1484,7 @@ pyjwt = [ {file = "PyJWT-2.4.0-py3-none-any.whl", hash = "sha256:72d1d253f32dbd4f5c88eaf1fdc62f3a19f676ccbadb9dbc5d07e951b2b26daf"}, {file = "PyJWT-2.4.0.tar.gz", hash = "sha256:d42908208c699b3b973cbeb01a969ba6a96c821eefb1c5bfe4c390c01d67abba"}, ] +pymongo = [] pyparsing = [ {file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"}, {file = "pyparsing-3.0.9.tar.gz", hash = "sha256:2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb"}, diff --git a/pyproject.toml b/pyproject.toml index de86f3d..2e656a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ azure-cosmos = "^4.3.0" azure-identity = "^1.10.0" azure-keyvault-secrets = "^4.4.0" loguru = "^0.6.0" +motor = "^3.0.0" psycopg2 = "^2.9.3" pyyaml = "^6.0" python-json-logger = "^2.0.2" diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 2b7cd03..ae2ad50 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,4 +1,4 @@ -import uuid +import random from typing import Callable import asyncio @@ -7,11 +7,16 @@ from aiohttp.test_utils import BaseTestServer, TestServer, TestClient from aiohttp.web_app import Application -from keystone.store.postgresql_store import set_up_schema +from keystone.store import postgresql_store +from keystone.store import mongodb_store from keystone.util.config import Config from keystone.util.store_util import init_stores +def gen_random_hex(length: int = 24): + return f"%0{length}x" % random.randrange(16**length) + + def build_user(first_name, last_name, guid): email = f"{first_name[0].lower()}{last_name.lower()}@company.com" return { @@ -45,23 +50,26 @@ def build_user(first_name, last_name, guid): @pytest.fixture(scope="module") def initial_user(): - return build_user("Alex", "Smith", "58c08e90-fbe7-4460-970b-9b3a9840d661") + return build_user("Alex", "Smith", "6303270163fc418d32450cd9") @pytest.fixture(scope="module") def second_user(): - return build_user("Daniel", "Gonzales", "458c6efc-f3c3-4606-b189-e347120068e6") + return build_user("Daniel", "Gonzales", "63033e376d9d20f702bc0a11") @pytest.fixture(scope="module") def module_scoped_event_loop(): - loop = asyncio.get_event_loop() + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() yield loop loop.close() @pytest.fixture(scope="module") -def module_scoped_aiohttp_client(module_scoped_event_loop): # loop: asyncio.AbstractEventLoop): +def module_scoped_aiohttp_client(module_scoped_event_loop): loop = module_scoped_event_loop clients = [] @@ -112,8 +120,10 @@ def scim_api(module_scoped_aiohttp_client, module_scoped_event_loop, cfg, initia app.middlewares.append( module_scoped_event_loop.run_until_complete(get_error_handling_mw()) ) - if cfg.get("store.type") == "PostgreSQL": - set_up_schema() + if cfg.get("store.pg.host") is not None: + postgresql_store.set_up_schema() + elif cfg.get("store.mongo.host") or cfg.get("store.mongo.dsn"): + module_scoped_event_loop.run_until_complete(mongodb_store.set_up()) c = module_scoped_event_loop.run_until_complete(module_scoped_aiohttp_client(app)) module_scoped_event_loop.run_until_complete(c.post("/scim/Users", json=initial_user, headers=headers)) return c @@ -163,14 +173,14 @@ def invalid_user_by_username_response(run_async, scim_api, headers): @pytest.fixture(scope="module") def invalid_user_by_id_response(run_async, scim_api, headers): - invalid_user_id = str(uuid.uuid4()) + invalid_user_id = gen_random_hex() # str(uuid.uuid4()) url = f"/scim/Users/{invalid_user_id}" return run_async(scim_api.get(url, headers=headers)) @pytest.fixture(scope="module") def random_user_nonexistent_response(run_async, scim_api, headers): - random_username = f"{str(uuid.uuid4())}@organization.org" + random_username = f"{gen_random_hex()}@organization.org" url = f"/scim/Users?filter=userName eq \"{random_username}\"" return run_async(scim_api.get(url, headers=headers)) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 9b129dd..921db46 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,17 +1,26 @@ import random import uuid -import asyncio import names import pytest from aiohttp import web from keystone.store.memory_store import MemoryStore from keystone.store.postgresql_store import PostgresqlStore, set_up_schema +from keystone.store import mongodb_store from keystone.util.config import Config from keystone.util.store_util import init_stores +def gen_random_hex(length: int = 24): + return f"%0{length}x" % random.randrange(16**length) + + +@pytest.fixture +def rand_id(): + return gen_random_hex(24) + + @pytest.fixture def postgresql_stores(event_loop): conn_args = dict( @@ -31,6 +40,24 @@ def postgresql_stores(event_loop): event_loop.run_until_complete(group_store.term_connection()) +@pytest.fixture +def mongodb_stores(event_loop): + conn_args = dict( + host="localhost", + port=27017, + username="root", + password="example", + tls=False, + database="scim2UnitTest", + ) + event_loop.run_until_complete(mongodb_store.set_up(**conn_args)) + user_store = mongodb_store.MongoDbStore("users", **conn_args) + group_store = mongodb_store.MongoDbStore("groups", **conn_args) + yield user_store, group_store + event_loop.run_until_complete(user_store.clean_up_store()) + event_loop.run_until_complete(group_store.clean_up_store()) + + def generate_random_user(): first_name = names.get_first_name() last_name = names.get_last_name() @@ -57,7 +84,7 @@ def generate_random_user(): "schemas": [ "urn:ietf:params:scim:schemas:core:2.0:User" ], - "id": str(uuid.uuid4()), + "id": gen_random_hex(24), "userName": email, "active": True, "displayName": f"{first_name} {last_name}" diff --git a/tests/unit/store/test_mongodb_store.py b/tests/unit/store/test_mongodb_store.py new file mode 100644 index 0000000..c058798 --- /dev/null +++ b/tests/unit/store/test_mongodb_store.py @@ -0,0 +1,304 @@ +import uuid +from random import choice + +import asyncio +import pytest +from psycopg2.errors import UniqueViolation +from pymongo.errors import DuplicateKeyError + +from keystone.util.exc import ResourceNotFound + + +class TestMongoDbStore: + + @staticmethod + @pytest.mark.asyncio + async def test_get_user_by_id_fails_for_nonexistent_user(mongodb_stores, rand_id): + user_store, _ = mongodb_stores + exc_thrown = False + try: + _ = await user_store.get_by_id(rand_id) + except ResourceNotFound: + exc_thrown = True + assert exc_thrown + + @staticmethod + @pytest.mark.asyncio + async def test_delete_user_by_id_fails_for_nonexistent_user(mongodb_stores, rand_id): + user_store, _ = mongodb_stores + exc_thrown = False + try: + _ = await user_store.delete(rand_id) + except ResourceNotFound: + exc_thrown = True + assert exc_thrown + + @staticmethod + @pytest.mark.asyncio + async def test_create_user_success(mongodb_stores, single_user): + user_store, _ = mongodb_stores + returned_user = await user_store.create(single_user) + user_id = returned_user.get("id") + looked_up_user = await user_store.get_by_id(user_id) + assert looked_up_user.get("id") == user_id + assert looked_up_user.get("userName") == returned_user.get("userName") + + @staticmethod + @pytest.mark.asyncio + async def test_create_user_fails_on_duplicate_username(mongodb_stores, single_user): + user_store, _ = mongodb_stores + _ = await user_store.create(single_user) + duplicate_user = {**single_user} + del duplicate_user["id"] + exc_thrown = False + try: + _ = await user_store.create(duplicate_user) + except DuplicateKeyError: + exc_thrown = True + assert exc_thrown + + @staticmethod + @pytest.mark.asyncio + async def test_delete_user_success(mongodb_stores, single_user): + user_store, _ = mongodb_stores + user = await user_store.create(single_user) + user_id = user.get("id") + _ = await user_store.delete(user_id) + exc_thrown = False + try: + _ = await user_store.get_by_id(user_id) + except ResourceNotFound: + exc_thrown = True + assert exc_thrown + + @staticmethod + @pytest.mark.asyncio + async def test_search_user_by_username(mongodb_stores, single_user): + user_store, _ = mongodb_stores + username = single_user.get("userName") + _ = await user_store.create(single_user) + _filter = f"userName Eq \"{username}\"" + res, count = await user_store.search(_filter) + assert 1 == count == len(res) + assert res[0].get("userName") == username + + mixed_case_username = "".join(choice((str.upper, str.lower))(c) for c in username) + _filter = f"userName Eq \"{mixed_case_username}\"" + res, count = await user_store.search(_filter) + assert 1 == count == len(res) + assert res[0].get("userName") == username + + @staticmethod + @pytest.mark.asyncio + async def test_search_user_by_id(mongodb_stores, single_user): + user_store, _ = mongodb_stores + user = await user_store.create(single_user) + user_id = user.get("id") + _filter = f"id Eq \"{user_id}\"" + res, count = await user_store.search(_filter) + assert 1 == count == len(res) + assert res[0].get("userName") == single_user.get("userName") + + @staticmethod + @pytest.mark.asyncio + async def test_search_user_by_email(mongodb_stores, single_user): + user_store, _ = mongodb_stores + email = single_user.get("userName") + _ = await user_store.create(single_user) + _filter = f"emails.value Eq \"{email}\"" + res, count = await user_store.search(_filter) + assert 1 == count == len(res) + assert res[0].get("userName") == single_user.get("userName") + + _filter = f"emails Co \"{email}\"" + res, count = await user_store.search(_filter) + assert 1 == count == len(res) + assert res[0].get("userName") == single_user.get("userName") + + email_username = email.split("@")[0] + _filter = f"emails.value Sw \"{email_username}\"" + res, count = await user_store.search(_filter) + assert 1 == count == len(res) + assert res[0].get("userName") == single_user.get("userName") + + @staticmethod + @pytest.mark.asyncio + async def test_search_user_pagination(mongodb_stores, users): + user_store, _ = mongodb_stores + _ = await asyncio.gather(*[user_store.create(u) for u in users]) + email = users[0].get("userName") + email_domain = email.split("@")[1] + _filter = f"emails.value co \"{email_domain}\"" + res, count = await user_store.search(_filter, start_index=1, count=3) + assert len(users) == count + assert 3 == len(res) + + res, count = await user_store.search(_filter, start_index=4, count=3) + assert len(users) == count + assert 2 == len(res) + + @staticmethod + @pytest.mark.asyncio + async def test_update_user_success(mongodb_stores, single_user): + user_store, _ = mongodb_stores + res = await user_store.create(single_user) + user_id = res.get("id") + update_attr = { + "groups": [], # To be ignored + "id": user_id, # To be ignored + "invalidAttribute": "foo", # To be ignored + "name": { + "formatted": "John Doe", + "givenName": "Doe", + "familyName": "John" + }, + "locale": "pt-BR", + "displayName": "John Doe", + "emails": single_user.get("emails") + [{ + "value": "johndoe@emailprovider.com", + "primary": False, + "type": "home" + }], + } + updated_user = await user_store.update(user_id, **update_attr) + assert "pt-BR" == updated_user.get("locale") + assert 2 == len(updated_user.get("emails")) + assert "John Doe" == updated_user.get("displayName") == updated_user.get("name").get("formatted") + + @staticmethod + @pytest.mark.asyncio + async def test_create_group_success(mongodb_stores, single_group): + _, group_store = mongodb_stores + res = await group_store.create(single_group) + group_id = res.get("id") + + group = await group_store.get_by_id(group_id) + + assert single_group.get("displayName") == group.get("displayName") + + @staticmethod + @pytest.mark.asyncio + async def test_create_group_with_members_success(mongodb_stores, single_group, users): + user_store, group_store = mongodb_stores + user_res = await asyncio.gather(*[user_store.create(u) for u in users]) + group_payload = {**single_group, "members": [{ + "value": u.get("id"), + "display": u.get("userName"), + } for u in user_res]} + res = await group_store.create(group_payload) + group_id = res.get("id") + group = await group_store.get_by_id(group_id) + assert len(users) == len(group.get("members")) + + @staticmethod + @pytest.mark.asyncio + async def test_get_group_by_id_fails_for_nonexistent_group(mongodb_stores, rand_id): + _, group_store = mongodb_stores + exc_thrown = False + try: + _ = await group_store.get_by_id(rand_id) + except ResourceNotFound: + exc_thrown = True + assert exc_thrown + + @staticmethod + @pytest.mark.asyncio + async def test_update_group_metadata_success(mongodb_stores, single_group): + _, group_store = mongodb_stores + res = await group_store.create(single_group) + group_id = res.get("id") + update_attr = { + "id": group_id, # To be ignored + "displayName": "New Group Name" + } + group = await group_store.update(group_id, **update_attr) + assert "New Group Name" == group.get("displayName") + + @staticmethod + @pytest.mark.asyncio + async def test_delete_group_success(mongodb_stores, single_group): + _, group_store = mongodb_stores + group = await group_store.create(single_group) + group_id = group.get("id") + _ = await group_store.delete(group_id) + exc_thrown = False + try: + _ = await group_store.get_by_id(group_id) + except ResourceNotFound: + exc_thrown = True + assert exc_thrown + + @staticmethod + @pytest.mark.asyncio + async def test_delete_group_fails_for_nonexistent_group(mongodb_stores, rand_id): + _, group_store = mongodb_stores + exc_thrown = False + try: + _ = await group_store.delete(rand_id) + except ResourceNotFound: + exc_thrown = True + assert exc_thrown + + @staticmethod + @pytest.mark.asyncio + async def test_add_users_to_group(mongodb_stores, single_group, users): + user_store, group_store = mongodb_stores + _ = await asyncio.gather(*[user_store.create(u) for u in users]) + ret_users, _ = await user_store.search() + res = await group_store.create(single_group) + group_id = res.get("id") + _ = await asyncio.gather(*[group_store.add_user_to_group(u.get("id"), group_id) for u in ret_users]) + group = await group_store.get_by_id(group_id) + assert len(users) == len(group.get("members")) + + @staticmethod + @pytest.mark.asyncio + async def test_remove_users_from_group(mongodb_stores, single_group, users): + user_store, group_store = mongodb_stores + _ = await asyncio.gather(*[user_store.create(u) for u in users]) + ret_users, _ = await user_store.search() + res = await group_store.create(single_group) + group_id = res.get("id") + _ = await asyncio.gather(*[group_store.add_user_to_group( + user_id=u.get("id"), + group_id=group_id + ) for u in ret_users]) + group = await group_store.get_by_id(group_id) + assert len(users) == len(group.get("members")) + + _ = await group_store.remove_users_from_group( + user_ids=[u.get("id") for u in ret_users], + group_id=group_id + ) + group = await group_store.get_by_id(group_id) + assert 0 == len(group.get("members")) + + @staticmethod + @pytest.mark.asyncio + async def test_set_group_members(mongodb_stores, single_group, users): + user_store, group_store = mongodb_stores + uc_res = await asyncio.gather(*[user_store.create(u) for u in users]) + cohort_1 = uc_res[:2] + cohort_2 = uc_res[2:len(uc_res)] + res = await group_store.create(single_group) + group_id = res.get("id") + _ = await asyncio.gather(*[group_store.add_user_to_group(u.get("id"), group_id) for u in cohort_1]) + group = await group_store.get_by_id(group_id) + assert 2 == len(group.get("members")) + _ = await group_store.set_group_members( + user_ids=[u.get("id") for u in cohort_2], + group_id=group_id + ) + group = await group_store.get_by_id(group_id) + assert len(users) - 2 == len(group.get("members")) + + @staticmethod + @pytest.mark.asyncio + async def test_search_groups(mongodb_stores, groups): + _, group_store = mongodb_stores + _ = await asyncio.gather(*[group_store.create(g) for g in groups]) + _filter = f"displayName Eq \"Human Resources\"" + res, count = await group_store.search(_filter) + assert 1 == count + assert "Human Resources" == res[0].get("displayName") +