diff --git a/.github/workflows/unit_tests.yaml b/.github/workflows/unit_tests.yaml
index f4019fd..2b1c655 100644
--- a/.github/workflows/unit_tests.yaml
+++ b/.github/workflows/unit_tests.yaml
@@ -35,7 +35,14 @@ jobs:
poetry-version: 1.1.13
- name: Install dependencies
run: poetry install
+ - name: Start MongoDB
+ uses: supercharge/mongodb-github-action@1.7.0
+ with:
+ mongodb-version: 5.0
+ mongodb-username: root
+ mongodb-password: example
+ mongodb-db: scim2UnitTest
+ mongodb-port: 27017
- name: Unit tests
run: |
- ISTORE_PG_SCHEMA_IGNORE_THIS=unit_test_$(date +%s) \
poetry run pytest tests/unit -p no:warnings --asyncio-mode=strict
diff --git a/Makefile b/Makefile
index 8bf556b..d2ae61d 100644
--- a/Makefile
+++ b/Makefile
@@ -31,6 +31,13 @@ integration-tests-cosmos-store:
$(POETRY_BIN) run pytest tests/integration -p no:warnings --verbose --asyncio-mode=strict ; \
$(POETRY_BIN) run tests/integration/scripts/cleanup.py
+.PHONY: integration-tests-mongo-store
+integration-tests-mongo-store: export CONFIG_PATH=./config/integration-tests-mongo-store.yaml
+integration-tests-mongo-store: export STORE_MONGO_DATABASE=scim_int_tsts_$(shell date +%s)
+integration-tests-mongo-store:
+ $(POETRY_BIN) run pytest tests/integration -p no:warnings --verbose --asyncio-mode=strict ; \
+ $(POETRY_BIN) run tests/integration/scripts/cleanup.py
+
.PHONY: security-tests
security-tests:
$(POETRY_BIN) run bandit -r ./keystone
diff --git a/README.md b/README.md
index d7dea81..aefd45a 100644
--- a/README.md
+++ b/README.md
@@ -49,20 +49,23 @@ operations with an identity manager that supports user provisioning (e.g., Azure
you can use **Keystone** to persist directory changes. Keystone v0.1.0 supports two
persistence layers: PostgreSQL and Azure Cosmos DB.
+
+

+
+
+
Key features:
* A compliant [SCIM 2.0 REST API](https://datatracker.ietf.org/doc/html/rfc7644)
implementation for Users and Groups.
-* Stateless container - deploy it anywhere you want (e.g., Kubernetes).
+* Stateless container - deploy it anywhere you want (e.g., Kubernetes) and bring your own storage.
* Pluggable store for users and groups. Current supported storage technologies:
* [Azure Cosmos DB](https://docs.microsoft.com/en-us/azure/cosmos-db/introduction)
* [PostgreSQL](https://www.postgresql.org) (version 10 or higher)
-
+ * [MongoDB](https://www.mongodb.com/docs/) (version 3.6 or higher)
* Azure Key Vault bearer token retrieval.
-* Extensible stores.
-
-Can't use Cosmos DB or PostgreSQL? Open an issue and/or consider
-[becoming a contributor](./CONTRIBUTING.md).
+* Extensible store: Can't use MongoDB, Cosmos DB, or PostgreSQL? Open an issue and/or consider
+ [becoming a contributor](./CONTRIBUTING.md) by implementing your own data store.
## Configure the API
diff --git a/config/README.md b/config/README.md
deleted file mode 100644
index 32f0a54..0000000
--- a/config/README.md
+++ /dev/null
@@ -1,53 +0,0 @@
-# Keystone Configuration
-
-This page outlines the possible configurations that
-a Keystone container supports.
-
-## YAML File or Environment Variables?
-
-The short answer: You can use either a YAML file, environment variables,
-or **a combination of both**.
-
-You can populate some, all, or none of the configuration keys using environment
-variables. All configuration keys can be represented by an environment variable by
-the capitalizing the entire key name and replacing the nesting dot (`.`) annotation with
-an underscore (`_`).
-
-For example, `store.cosmos_account_key` can be populated with the
-`STORE_COSMOS_ACCOUNT_KEY` environment variable in the container
-the API is running in.
-
-## Configure Keystore with Environment Variables
-
-| **VARIABLE** | **Type** | **Description** | **Default Value** |
-|-----------------------------------------------------------------------------------|----------|--------------------------------------------------------------------------------------|------------------------|
-| store.
type | string | The persistence layer type. Supported values: `CosmosDB`, `InMemory` | `CosmosDB` |
-| store.
tenant_id | string | Azure Tenant ID, if using a Cosmos DB store with Client Secret Credentials auth. | - |
-| store.
client_id | string | Azure Client ID, if using a Cosmos DB store with Client Secret Credentials auth. | - |
-| store.
secret | string | Azure Client Secret, if using a Cosmos DB store with Client Secret Credentials auth. | - |
-| store.
cosmos_account_uri | string | Cosmos Account URI, if using a Cosmos DB store | - |
-| store.
cosmos_account_key | string | Cosmos DB account key, if using a Cosmos DB store with Account Key auth. | - |
-| store.
cosmos_db_name | string | Cosmos DB database name, if using a Cosmos DB store | `scim_2_identity_pool` |
-| authentication.
secret | string | Plain secret bearer token | - |
-| authentication.
akv.
vault_name | string | AKV name, if bearer token is stored in AKV. | - |
-| authentication.
akv.
secret_name | string | AKV secret name, if bearer token is stored in AKV. | `scim-2-api-token` |
-| authentication.
akv.
credentials_client | string | Credentials client type, if bearer token is stored in AKV. | `default` |
-| authentication.
akv.
force_create | bool | Try to create an AKV secret on startup, if bearer token to be stored in AKV. | `false` |
-
-## Configure Keystone with a YAML File
-
-
-| **Key** | **Type** | **Description** | **Default Value** |
-|-----------------------------------------------------------------------------------|----------|--------------------------------------------------------------------------------------|------------------------|
-| store.
type | string | The persistence layer type. Supported values: `CosmosDB`, `InMemory` | `CosmosDB` |
-| store.
tenant_id | string | Azure Tenant ID, if using a Cosmos DB store with Client Secret Credentials auth. | - |
-| store.
client_id | string | Azure Client ID, if using a Cosmos DB store with Client Secret Credentials auth. | - |
-| store.
secret | string | Azure Client Secret, if using a Cosmos DB store with Client Secret Credentials auth. | - |
-| store.
cosmos_account_uri | string | Cosmos Account URI, if using a Cosmos DB store | - |
-| store.
cosmos_account_key | string | Cosmos DB account key, if using a Cosmos DB store with Account Key auth. | - |
-| store.
cosmos_db_name | string | Cosmos DB database name, if using a Cosmos DB store | `scim_2_identity_pool` |
-| authentication.
secret | string | Plain secret bearer token | - |
-| authentication.
akv.
vault_name | string | AKV name, if bearer token is stored in AKV. | - |
-| authentication.
akv.
secret_name | string | AKV secret name, if bearer token is stored in AKV. | `scim-2-api-token` |
-| authentication.
akv.
credentials_client | string | Credentials client type, if bearer token is stored in AKV. | `default` |
-| authentication.
akv.
force_create | bool | Try to create an AKV secret on startup, if bearer token to be stored in AKV. | `false` |
\ No newline at end of file
diff --git a/config/dev-cosmos.yaml b/config/dev-cosmos.yaml
index e253143..cea66c9 100644
--- a/config/dev-cosmos.yaml
+++ b/config/dev-cosmos.yaml
@@ -1,6 +1,7 @@
store:
type: CosmosDB
- cosmos_account_uri:
- cosmos_account_key:
+ cosmos:
+ account_uri:
+ account_key:
authentication:
secret: not-so-secret
diff --git a/config/integration-tests-mongo-store.yaml b/config/integration-tests-mongo-store.yaml
new file mode 100644
index 0000000..ab03e56
--- /dev/null
+++ b/config/integration-tests-mongo-store.yaml
@@ -0,0 +1,8 @@
+store:
+ type: MongoDb
+ mongo:
+ host: localhost
+ port: 27017
+ tls: false
+authentication:
+ secret: not-so-secret
diff --git a/config/integration-tests-pg-store.yaml b/config/integration-tests-pg-store.yaml
index 7aee0b3..a23997a 100644
--- a/config/integration-tests-pg-store.yaml
+++ b/config/integration-tests-pg-store.yaml
@@ -1,5 +1,6 @@
store:
type: PostgreSQL
- pg_schema: public
+ pg:
+ schema: public
authentication:
secret: not-so-secret
diff --git a/docker-compose.yaml b/docker-compose.yaml
new file mode 100644
index 0000000..2cefa2d
--- /dev/null
+++ b/docker-compose.yaml
@@ -0,0 +1,23 @@
+# Use root/example as user/password credentials
+version: '3.1'
+
+services:
+
+ mongo:
+ image: mongo
+ restart: always
+ ports:
+ - 27017:27017
+ environment:
+ MONGO_INITDB_ROOT_USERNAME: root
+ MONGO_INITDB_ROOT_PASSWORD: example
+
+ mongo-express:
+ image: mongo-express
+ restart: always
+ ports:
+ - 8081:8081
+ environment:
+ ME_CONFIG_MONGODB_ADMINUSERNAME: root
+ ME_CONFIG_MONGODB_ADMINPASSWORD: example
+ ME_CONFIG_MONGODB_URL: mongodb://root:example@mongo:27017/
diff --git a/keystone/cmd.py b/keystone/cmd.py
index f0de761..27fa5ad 100644
--- a/keystone/cmd.py
+++ b/keystone/cmd.py
@@ -3,6 +3,7 @@
import os
from keystone import VERSION, LOGO, InterceptHandler
+from keystone.store.mongodb_store import set_up
from keystone.store.postgresql_store import set_up_schema
from keystone.util.logger import get_log_handler
@@ -46,8 +47,10 @@ async def print_logo(_logger):
async def serve(host: str = "0.0.0.0", port: int = 5001):
- if CONFIG.get("store.type") == "PostgreSQL":
+ if CONFIG.get("store.pg.host") is not None:
set_up_schema()
+ elif CONFIG.get("store.mongo.host") or CONFIG.get("store.mongo.dsn"):
+ await set_up()
error_handling_mw = await get_error_handling_mw()
diff --git a/keystone/rest/__init__.py b/keystone/rest/__init__.py
index 3824a93..c885adb 100644
--- a/keystone/rest/__init__.py
+++ b/keystone/rest/__init__.py
@@ -1,6 +1,7 @@
from aiohttp_catcher import Catcher, canned, catch
from azure.cosmos.exceptions import CosmosResourceNotFoundError
from psycopg2.errors import UniqueViolation
+from pymongo.errors import DuplicateKeyError
from keystone.models import DEFAULT_ERROR_SCHEMA
from keystone.util.exc import ResourceNotFound, ResourceAlreadyExists, UnauthorizedRequest
@@ -9,14 +10,17 @@
async def get_error_handling_mw():
catcher = Catcher(code="status", envelope="detail")
err_schemas = {"schemas": [DEFAULT_ERROR_SCHEMA]}
- await catcher.add_scenarios(*[sc.with_additional_fields(err_schemas) for sc in canned.AIOHTTP_SCENARIOS])
await catcher.add_scenarios(
+ *[sc.with_additional_fields(err_schemas) for sc in canned.AIOHTTP_SCENARIOS],
+
catch(ResourceNotFound).with_status_code(404).and_stringify().with_additional_fields(err_schemas),
+
catch(CosmosResourceNotFoundError).with_status_code(404).and_return(
"Resource not found").with_additional_fields(err_schemas),
- catch(UniqueViolation).with_status_code(409).and_return(
+
+ catch(UniqueViolation, DuplicateKeyError, ResourceAlreadyExists).with_status_code(409).and_return(
"Resource already exists").with_additional_fields(err_schemas),
- catch(ResourceAlreadyExists).with_status_code(409).and_stringify().with_additional_fields(err_schemas),
+
catch(UnauthorizedRequest).with_status_code(401).and_return("Unauthorized request").with_additional_fields(
err_schemas)
)
diff --git a/keystone/rest/group.py b/keystone/rest/group.py
index 7e2dcb7..e58f4c3 100644
--- a/keystone/rest/group.py
+++ b/keystone/rest/group.py
@@ -12,7 +12,7 @@
from keystone.models import ListQueryParams, ErrorResponse, DEFAULT_LIST_SCHEMA
from keystone.models.group import Group, PatchGroupOp, ListGroupsResponse
-from keystone.store import BaseStore, RDBMSStore
+from keystone.store import BaseStore, RDBMSStore, DocumentStore, DatabaseStore
from keystone.util.store_util import Stores
LOGGER = logging.getLogger(__name__)
@@ -97,7 +97,9 @@ async def _execute_group_operation(self, operation: Dict) -> Dict:
]
}
"""
- is_rdbms = isinstance(group_store, RDBMSStore)
+ is_dbs = isinstance(group_store, DatabaseStore)
+ is_rdbmss = isinstance(group_store, RDBMSStore)
+ is_docs = isinstance(group_store, DocumentStore)
group_id = self.request.match_info["group_id"]
if hasattr(group_store, "resource_db"):
group = group_store.resource_db[group_id]
@@ -113,20 +115,23 @@ async def _execute_group_operation(self, operation: Dict) -> Dict:
if op_path and op_path.startswith("members[") and not op_value:
# Remove members with a path:
_filter = op_path.strip("members[").strip("]")
- if is_rdbms:
+ if is_dbs:
if op_type == "remove":
- selected_members = await group_store.search_members(
- _filter=_filter, group_id=group_id
- )
- _ = await group_store.remove_users_from_group(
- user_ids=[m.get("id") for m in selected_members], group_id=group_id
- )
+ if is_rdbmss:
+ selected_members = await group_store.search_members(
+ _filter=_filter, group_id=group_id
+ )
+ _ = await group_store.remove_users_from_group(
+ user_ids=[m.get("id") for m in selected_members], group_id=group_id
+ )
+ else:
+ _filter = _filter.replace("value", "id").replace("display", "userName")
+ u, _ = await user_store.search(_filter=_filter)
+ _ = await group_store.remove_users_from_group(
+ user_ids=[m.get("id") for m in u], group_id=group_id
+ )
elif op_type == "add":
selected_members, _ = await user_store.search(_filter=_filter.replace("value", "id"))
- LOGGER.debug(str(selected_members))
- LOGGER.debug(op_type)
- LOGGER.debug(_filter)
- # TODO: Need to handle:
for m in selected_members:
_ = await group_store.add_user_to_group(
user_id=m.get("id"), group_id=group_id
@@ -143,7 +148,7 @@ async def _execute_group_operation(self, operation: Dict) -> Dict:
_ = await group["members_store"].delete(member.get("value"))
return group
if op_path == "members" and op_type == "replace" and op_value:
- if is_rdbms:
+ if is_dbs:
_ = await group_store.set_group_members(users=op_value, group_id=group_id)
return group
else:
@@ -151,7 +156,7 @@ async def _execute_group_operation(self, operation: Dict) -> Dict:
if op_path == "members" and op_value:
for member in op_value:
member_id = member.get("value")
- if is_rdbms:
+ if is_dbs:
if op_type == "add":
_ = await group_store.add_user_to_group(member.get("value"), group_id)
elif op_type == "remove":
diff --git a/keystone/store/__init__.py b/keystone/store/__init__.py
index 37695c4..8cd7587 100644
--- a/keystone/store/__init__.py
+++ b/keystone/store/__init__.py
@@ -50,8 +50,7 @@ async def _sanitize(self, resource: Dict) -> Dict:
return s_resource
-class RDBMSStore(BaseStore, ABC):
-
+class DatabaseStore(BaseStore, ABC):
async def remove_users_from_group(self, user_ids: List[str], group_id: str):
raise NotImplementedError("Method 'remove_user_from_group' not implemented")
@@ -61,5 +60,12 @@ async def add_user_to_group(self, user_id: str, group_id: str):
async def set_group_members(self, users: List[Dict], group_id: str):
raise NotImplementedError("Method 'set_group_members' not implemented")
+
+class RDBMSStore(DatabaseStore, ABC):
+
async def search_members(self, _filter: str, group_id: str):
raise NotImplementedError("Method 'search_members' not implemented")
+
+
+class DocumentStore(DatabaseStore, ABC):
+ pass
diff --git a/keystone/store/cosmos_db_store.py b/keystone/store/cosmos_db_store.py
index cc313f8..0980e4e 100644
--- a/keystone/store/cosmos_db_store.py
+++ b/keystone/store/cosmos_db_store.py
@@ -28,12 +28,12 @@ async def get_client_credentials(async_client: bool = True):
# forces the usage of the main module to run aggregate queries with the
# 'VALUE' keyword. The bug doesn't exist in the main module, therefore
# this function can currently produce non-async credentials.
- cosmos_account_key = CONFIG.get("store.cosmos_account_key")
+ cosmos_account_key = CONFIG.get("store.cosmos.account_key")
if cosmos_account_key:
return cosmos_account_key
- tenant_id = CONFIG.get("store.tenant_id")
- client_id = CONFIG.get("store.client_id")
- client_secret = CONFIG.get("store.client_secret")
+ tenant_id = CONFIG.get("store.cosmos.tenant_id")
+ client_id = CONFIG.get("store.cosmos.client_id")
+ client_secret = CONFIG.get("store.cosmos.client_secret")
if tenant_id and client_id and client_secret:
aad_credentials = AsyncClientSecretCredential(
tenant_id=tenant_id,
@@ -71,7 +71,7 @@ class CosmosDbStore(BaseStore):
def __init__(self, entity_name: str, key_attr: str = "id", unique_attribute: str = None):
self.entity_name = entity_name
self.key_attr = key_attr
- self.account_uri = CONFIG.get("store.cosmos_account_uri")
+ self.account_uri = CONFIG.get("store.cosmos.account_uri")
self.unique_attribute = unique_attribute
self.container_name = f"scim2{self.entity_name}"
self.init_client()
@@ -80,7 +80,7 @@ async def get_by_id(self, resource_id: str):
client_creds = await get_client_credentials()
uri = self.account_uri
async with AsyncCosmosClient(uri, credential=client_creds) as client:
- database = client.get_database_client(CONFIG.get("store.cosmos_db_name"))
+ database = client.get_database_client(CONFIG.get("store.cosmos.db_name"))
container = database.get_container_client(self.container_name)
resource = await container.read_item(item=resource_id, partition_key=resource_id)
return await remove_cosmos_metadata(resource)
@@ -93,7 +93,7 @@ async def _get_query_count(self, query: str, params: Dict):
client_creds = await get_client_credentials(async_client=False)
try:
c = CosmosClient(self.account_uri, credential=client_creds)
- db = c.get_database_client(CONFIG.get("store.cosmos_db_name"))
+ db = c.get_database_client(CONFIG.get("store.cosmos.db_name"))
co = db.get_container_client(self.container_name)
count = 0
for res in co.query_items(query=query, parameters=params, enable_cross_partition_query=True):
@@ -129,7 +129,7 @@ async def search(self, _filter: str, start_index: int = 1, count: int = 100) ->
resources = []
async with AsyncCosmosClient(uri, credential=client_creds) as client:
try:
- database = client.get_database_client(CONFIG.get("store.cosmos_db_name"))
+ database = client.get_database_client(CONFIG.get("store.cosmos.db_name"))
container = database.get_container_client(self.container_name)
iterator = container.query_items(query=query, parameters=params, populate_query_metrics=True)
async for resource in iterator:
@@ -146,7 +146,7 @@ async def update(self, resource_id: str, **kwargs: Dict):
client_creds = await get_client_credentials()
uri = self.account_uri
async with AsyncCosmosClient(uri, credential=client_creds) as client:
- database = client.get_database_client(CONFIG.get("store.cosmos_db_name"))
+ database = client.get_database_client(CONFIG.get("store.cosmos.db_name"))
container = database.get_container_client(self.container_name)
try:
resource = await remove_cosmos_metadata(
@@ -162,7 +162,7 @@ async def create(self, resource: Dict) -> Dict:
client_creds = await get_client_credentials()
uri = self.account_uri
async with AsyncCosmosClient(uri, credential=client_creds) as client:
- database = client.get_database_client(CONFIG.get("store.cosmos_db_name"))
+ database = client.get_database_client(CONFIG.get("store.cosmos.db_name"))
container = database.get_container_client(self.container_name)
resource_id = resource.get(self.key_attr) or str(uuid.uuid4())
try:
@@ -189,7 +189,7 @@ async def delete(self, resource_id: str):
client_creds = await get_client_credentials()
uri = self.account_uri
async with AsyncCosmosClient(uri, credential=client_creds) as client:
- database = client.get_database_client(CONFIG.get("store.cosmos_db_name"))
+ database = client.get_database_client(CONFIG.get("store.cosmos.db_name"))
container = database.get_container_client(self.container_name)
try:
_ = await container.delete_item(item=resource_id, partition_key=resource_id)
@@ -198,15 +198,15 @@ async def delete(self, resource_id: str):
return
def init_client(self):
- account_uri = CONFIG.get("store.cosmos_account_uri")
+ account_uri = CONFIG.get("store.cosmos.account_uri")
if not account_uri:
raise ValueError(
- "Could not initialize Cosmos DB store. Missing configuration: 'store.cosmos_account_uri'"
+ "Could not initialize Cosmos DB store. Missing configuration: 'store.cosmos.account_uri'"
)
- tenant_id = CONFIG.get("store.tenant_id")
- client_id = CONFIG.get("store.client_id")
- client_secret = CONFIG.get("store.client_secret")
- cosmos_account_key = CONFIG.get("store.cosmos_account_key")
+ tenant_id = CONFIG.get("store.cosmos.tenant_id")
+ client_id = CONFIG.get("store.cosmos.client_id")
+ client_secret = CONFIG.get("store.cosmos.client_secret")
+ cosmos_account_key = CONFIG.get("store.cosmos.account_key")
if cosmos_account_key:
client = CosmosClient(account_uri, credential=cosmos_account_key, consistency_level="Session")
elif tenant_id and client_id and client_secret:
@@ -218,7 +218,7 @@ def init_client(self):
client = CosmosClient(account_uri, credential=aad_credentials, consistency_level="Session")
else:
client = CosmosClient(account_uri, credential=DefaultAzureCredential(), consistency_level="Session")
- cosmos_db_name = CONFIG.get("store.cosmos_db_name")
+ cosmos_db_name = CONFIG.get("store.cosmos.db_name")
try:
database = client.create_database(cosmos_db_name)
except (exceptions.CosmosResourceExistsError, exceptions.CosmosHttpResponseError):
diff --git a/keystone/store/memory_store.py b/keystone/store/memory_store.py
index 7f12bd3..464e599 100644
--- a/keystone/store/memory_store.py
+++ b/keystone/store/memory_store.py
@@ -180,13 +180,13 @@ async def evaluate_filter(self, parsed_filter: Dict, node: Dict):
expr = parsed_filter.get("expr")
f = expr.get("func")
if f:
- f = expr["func"] # eq()
- pred = expr["pred"] # "f1aa2630-6343-41fa-bae4-384a46bc2ed3"
- attr = expr["attr"].lower() # "value"
- op = expr["op"].lower() # "eq"
- attr_parts = attr.split(".") # ["value"]
+ f = expr["func"]
+ pred = expr["pred"]
+ attr = expr["attr"].lower()
+ op = expr["op"].lower()
+ attr_parts = attr.split(".")
node_attr_value = node
- namespace = expr.get("namespace") # "members"
+ namespace = expr.get("namespace")
if namespace:
f = self.filter_map[f"{op}_lst"]
lst = node.get(namespace)
diff --git a/keystone/store/mongodb_store.py b/keystone/store/mongodb_store.py
new file mode 100644
index 0000000..ad04373
--- /dev/null
+++ b/keystone/store/mongodb_store.py
@@ -0,0 +1,272 @@
+import asyncio
+import urllib.parse
+from typing import Dict, List
+
+from bson.objectid import ObjectId
+from motor.motor_asyncio import AsyncIOMotorClient
+from pymongo.collation import Collation
+from scim2_filter_parser import ast
+from scim2_filter_parser.ast import LogExpr, Filter, AttrExpr, CompValue, AttrPath, AST
+from scim2_filter_parser.lexer import SCIMLexer
+from scim2_filter_parser.parser import SCIMParser
+
+from keystone.store import DocumentStore
+from keystone.util.config import Config
+from keystone.util.exc import ResourceNotFound
+
+CONFIG = Config()
+
+
+def build_dsn(**kwargs):
+ dsn = kwargs.get("dsn", CONFIG.get("store.mongo.dsn"))
+ if dsn:
+ return dsn
+ host = kwargs.get("host", CONFIG.get("store.mongo.host"))
+ port = kwargs.get("port", CONFIG.get("store.mongo.port", 5432))
+ username = kwargs.get("username", CONFIG.get("store.mongo.username"))
+ password = kwargs.get("password", CONFIG.get("store.mongo.password"))
+ tls = kwargs.get("tls", CONFIG.get("store.mongo.tls", "true"))
+ if type(tls) == bool:
+ tls = "true" if tls is True else "false"
+ replica_set = kwargs.get("replica_set", CONFIG.get("store.mongo.replica_set"))
+ cred = username
+ if password:
+ cred = f"{cred}:{urllib.parse.quote(password)}"
+ query_params = {}
+ if tls:
+ query_params["tls"] = tls
+ if replica_set:
+ query_params["replicaSet"] = replica_set
+ return f"mongodb://{cred}@{host}:{port}/?tls={tls}"
+
+
+async def set_up(**kwargs):
+ client = AsyncIOMotorClient(build_dsn(**kwargs))
+ db_name = kwargs.get("database", CONFIG.get("store.mongo.database"))
+ users_collection = client[db_name]["users"]
+ groups_collection = client[db_name]["groups"]
+ _ = await users_collection.create_index([("userName", 1)], unique=True,
+ collation=Collation(locale="en", strength=2))
+ _ = await users_collection.create_index([("emails.value", 1)], collation=Collation(locale="en", strength=2))
+ _ = await groups_collection.create_index([("displayName", 1)], unique=True,
+ collation=Collation(locale="en", strength=2))
+
+
+async def _transform_user(item: Dict) -> Dict:
+ item_id: ObjectId = item.get("_id")
+ user = {**item}
+ if item_id:
+ user["id"] = str(item_id)
+ del user["_id"]
+ return user
+
+
+async def _transform_group(item: Dict) -> Dict:
+ return {
+ "id": str(item.get("_id")),
+ "schemas": item.get("schemas"),
+ "displayName": item.get("displayName"),
+ "meta": item.get("meta"),
+ "members": [{"value": str(m.get("_id")), "display": m.get("userName")} for m in item.get("userMembers", [])]
+ }
+
+
+class MongoDbStore(DocumentStore):
+ client: AsyncIOMotorClient
+ entity_type: str
+
+ def __init__(self, entity_type: str, **conn_args):
+ self.entity_type = entity_type
+ self.client = AsyncIOMotorClient(build_dsn(**conn_args))
+ self.db_name = conn_args.get("database", CONFIG.get("store.mongo.database"))
+
+ async def _get_group_by_id(self, group_id: ObjectId) -> Dict:
+ aggregate = [
+ {
+ "$match": {"_id": group_id},
+ },
+ {
+ "$lookup": {
+ "from": "users",
+ "localField": "members",
+ "foreignField": "_id",
+ "as": "userMembers",
+ }
+ }
+ ]
+ async for group in self.collection.aggregate(aggregate, collation={"locale": "en", "strength": 2}):
+ return await _transform_group(group)
+
+ raise ResourceNotFound("group", str(group_id))
+
+ async def _get_user_by_id(self, user_id: ObjectId) -> Dict:
+ resource = await self.collection.find_one({"_id": user_id})
+ if resource:
+ return await _transform_user(
+ await self._sanitize(resource)
+ )
+ raise ResourceNotFound("User", str(user_id))
+
+ async def get_by_id(self, resource_id: str) -> Dict:
+ _resource_id = ObjectId(resource_id)
+ if self.entity_type == "users":
+ return await self._get_user_by_id(_resource_id)
+ return await self._get_group_by_id(_resource_id)
+
+ async def search(self, _filter: str = None, start_index: int = 1, count: int = 100) -> tuple[list[Dict], int]:
+ parsed_filter = {}
+ if _filter:
+ token_stream = SCIMLexer().tokenize(_filter)
+ ast_nodes = SCIMParser().parse(token_stream)
+ # We only need the root node, which contains all the references in the tree for traversal:
+ _, root = ast.flatten(ast_nodes)[0]
+ parsed_filter = await self.parse_scim_filter(root)
+ aggregate = [
+ {"$facet": {
+ "data": [
+ {"$match": parsed_filter},
+ {"$skip": start_index - 1},
+ {"$limit": count},
+ ],
+ "totalCount": [
+ {"$match": parsed_filter},
+ {"$count": "count"},
+ ]
+ }}
+ ]
+ res = []
+ total = 0
+ async for resource in self.collection.aggregate(aggregate, collation={"locale": "en", "strength": 2}):
+ res = [await _transform_user(r) for r in resource.get("data")]
+ total = resource.get("totalCount")[0]["count"] if len(resource.get("totalCount")) > 0 else 0
+ break
+
+ return res, total
+
+ async def update(self, resource_id: str, **kwargs: Dict):
+ resource = await self.collection.find_one({"_id": ObjectId(resource_id)})
+ if not resource:
+ ResourceNotFound("User", resource_id)
+ _ = await self.collection.replace_one({"_id": ObjectId(resource_id)}, kwargs, True)
+ return await self.get_by_id(resource_id)
+
+ async def create(self, resource: Dict):
+ sanitized = await self._sanitize(resource)
+ if "id" in sanitized:
+ del sanitized["id"]
+ return await self._create_user(sanitized) if self.entity_type == "users" else await self._create_group(
+ sanitized)
+
+ @property
+ def collection(self):
+ return self.client[self.db_name][self.entity_type]
+
+ async def _create_user(self, user: Dict):
+ inserted_id = (await self.collection.insert_one(user)).inserted_id
+ inserted_user = await self.collection.find_one(inserted_id)
+ return await _transform_user(inserted_user)
+
+ async def _create_group(self, group: Dict):
+ group["members"] = [ObjectId(m.get("value")) for m in group.get("members", [])]
+ inserted_id = (await self.collection.insert_one(group)).inserted_id
+ inserted_group = await self.collection.find_one(inserted_id)
+ return await _transform_group(inserted_group)
+
+ async def delete(self, resource_id: str):
+ resource = await self.collection.find_one({"_id": ObjectId(resource_id)})
+ if not resource:
+ raise ResourceNotFound(self.entity_type, resource_id)
+ _ = await self.collection.delete_one({"_id": ObjectId(resource_id)})
+ return {}
+
+ async def clean_up_store(self):
+ return await self.collection.drop()
+
+ async def parse_scim_filter(self, node: AST, namespace: str = None) -> Dict:
+ if isinstance(node, Filter):
+ ns = node.namespace.attr_name if node.namespace else None
+ expr = await self.parse_scim_filter(node.expr, ns or namespace)
+ return {"$not": expr} if node.negated else expr
+ if isinstance(node, AttrExpr):
+ # Parse an atomic comparison operation:
+ operator = node.value.lower()
+ attr_path: AttrPath = node.attr_path
+ attr = attr_path.attr_name
+ if attr_path.sub_attr:
+ sub_attr = attr_path.sub_attr.value
+ attr = f"{attr}.{sub_attr}"
+ comp_value: CompValue = node.comp_value
+ value = comp_value.value if comp_value else None
+ if value:
+ if operator.lower() == "eq" and attr == "id":
+ return {
+ "_id": ObjectId(value)
+ }
+ if operator.lower() == "co" and attr.endswith("emails"):
+ return {
+ "emails.value": value
+ }
+ if operator.lower() == "sw":
+ operator = "regex"
+ value = f"^{value}"
+ elif operator.lower() == "ew":
+ operator = "regex"
+ value = f"{value}$"
+ elif operator.lower() == "co":
+ operator = "regex"
+ value = f"{value}"
+ if namespace:
+ attr = f"{namespace}.{attr}"
+ return {
+ attr: {f"${operator}": value}
+ }
+ if isinstance(node, LogExpr):
+ # Parse a logical expression:
+ operator = node.op.lower()
+ l_exp = await self.parse_scim_filter(node.expr1, namespace)
+ r_exp = await self.parse_scim_filter(node.expr2, namespace)
+ return {
+ f"${operator}": [
+ l_exp,
+ r_exp,
+ ]
+ }
+
+ async def _update_group(self, group_id: str, **kwargs) -> Dict:
+ group = await self.collection.find_one({"_id": ObjectId(group_id)})
+ if not group:
+ raise ResourceNotFound("group", group_id)
+ _ = await self.collection.replace_one(
+ {"_id": ObjectId(group_id)},
+ kwargs,
+ True
+ )
+ return await _transform_group(await self.collection.find_one({"_id": ObjectId(group_id)}))
+
+ async def remove_users_from_group(self, user_ids: List[str], group_id: str):
+ group = await self.collection.find_one({"_id": ObjectId(group_id)})
+ if not group:
+ raise ResourceNotFound("group", group_id)
+ _ = await self.collection.update_one(
+ {"_id": ObjectId(group_id)},
+ {"$pull": {"members": {"$in": [ObjectId(user_id) for user_id in user_ids]}}},
+ )
+
+ async def add_user_to_group(self, user_id: str, group_id: str):
+ group = await self.collection.find_one({"_id": ObjectId(group_id)})
+ if not group:
+ raise ResourceNotFound("group", group_id)
+ members = {str(g): None for g in group.get("members", [])}
+ if True or user_id not in members:
+ _ = await self.collection.update_one({"_id": ObjectId(group_id)},
+ {"$push": {"members": ObjectId(user_id)}})
+
+ async def set_group_members(self, user_ids: List[str], group_id: str):
+ group = await self.collection.find_one({"_id": ObjectId(group_id)})
+ if not group:
+ raise ResourceNotFound("group", group_id)
+ _ = await self.collection.replace_one(
+ {"_id": ObjectId(group_id)},
+ {"members": [ObjectId(user_id) for user_id in user_ids]},
+ True
+ )
diff --git a/keystone/store/pg_models.py b/keystone/store/pg_models.py
index d0a8cfa..3fa0e12 100644
--- a/keystone/store/pg_models.py
+++ b/keystone/store/pg_models.py
@@ -6,7 +6,7 @@
metadata = sa.MetaData()
CONFIG = Config()
-_schema = CONFIG.get("store.pg_schema", "public")
+_schema = CONFIG.get("store.pg.schema", "public")
users = sa.Table(
diff --git a/keystone/store/postgresql_store.py b/keystone/store/postgresql_store.py
index f936a13..3d0846d 100644
--- a/keystone/store/postgresql_store.py
+++ b/keystone/store/postgresql_store.py
@@ -27,12 +27,12 @@
def build_dsn(**kwargs):
- host = kwargs.get("host", CONFIG.get("store.pg_host"))
- port = kwargs.get("port", CONFIG.get("store.pg_port", 5432))
- username = kwargs.get("username", CONFIG.get("store.pg_username"))
- password = kwargs.get("password", CONFIG.get("store.pg_password"))
- database = kwargs.get("database", CONFIG.get("store.pg_database"))
- ssl_mode = kwargs.get("ssl_mode", CONFIG.get("store.pg_ssl_mode"))
+ host = kwargs.get("host", CONFIG.get("store.pg.host"))
+ port = kwargs.get("port", CONFIG.get("store.pg.port", 5432))
+ username = kwargs.get("username", CONFIG.get("store.pg.username"))
+ password = kwargs.get("password", CONFIG.get("store.pg.password"))
+ database = kwargs.get("database", CONFIG.get("store.pg.database"))
+ ssl_mode = kwargs.get("ssl_mode", CONFIG.get("store.pg.ssl_mode"))
cred = username
if password:
cred = f"{cred}:{urllib.parse.quote(password)}"
@@ -43,7 +43,7 @@ def set_up_schema(**kwargs):
conn = psycopg2.connect(
dsn=build_dsn(**kwargs)
)
- schema = CONFIG.get("store.pg_schema", "public")
+ schema = CONFIG.get("store.pg.schema", "public")
cursor = conn.cursor()
for q in ddl_queries:
cursor.execute(q.format(schema))
@@ -104,7 +104,7 @@ class PostgresqlStore(RDBMSStore):
}
def __init__(self, entity_type: str, **conn_args):
- self.schema = CONFIG.get("store.pg_schema")
+ self.schema = CONFIG.get("store.pg.schema")
self.entity_type = entity_type
self.conn_args = conn_args
@@ -248,10 +248,11 @@ async def search(self, _filter: str, start_index: int = 1, count: int = 100) ->
return await self._search_groups(_filter, start_index, count)
async def update(self, resource_id: str, **kwargs: Dict):
+ sanitized = await self._sanitize(kwargs)
if self.entity_type == "users":
- return await self._update_user(resource_id, **kwargs)
+ return await self._update_user(resource_id, **sanitized)
if self.entity_type == "groups":
- return await self._update_group(resource_id, **kwargs)
+ return await self._update_group(resource_id, **sanitized)
async def _update_user(self, user_id: str, **kwargs: Dict) -> Dict:
# "id" is immutable, and "groups" are updated through the groups API:
@@ -301,10 +302,11 @@ async def _update_group(self, group_id: str, **kwargs: Dict) -> Dict:
return await self._get_group_by_id(group_id)
async def create(self, resource: Dict):
+ sanitized = await self._sanitize(resource)
if self.entity_type == "users":
- return await self._create_user(resource)
+ return await self._create_user(sanitized)
if self.entity_type == "groups":
- return await self._create_group(resource)
+ return await self._create_group(sanitized)
async def _create_group(self, resource: Dict) -> Dict:
group_id = resource.get("id") or str(uuid.uuid4())
@@ -421,7 +423,7 @@ async def add_user_to_group(self, user_id: str, group_id: str):
_ = await conn.execute(insert_q)
return
- async def set_group_members(self, user_ids: List[Dict], group_id: str):
+ async def set_group_members(self, user_ids: List[str], group_id: str):
delete_q = delete(tbl.users_groups).where(tbl.users_groups.c.groupId == group_id)
insert_q = insert(tbl.users_groups).values(
[{"userId": uid, "groupId": group_id} for uid in user_ids]
diff --git a/keystone/util/config.py b/keystone/util/config.py
index 6be9320..4fd1441 100644
--- a/keystone/util/config.py
+++ b/keystone/util/config.py
@@ -10,21 +10,34 @@
LOGGER = logging.getLogger(__name__)
SCHEMA = Schema({
Optional("store", default={}): Schema({
- Optional("type", default="CosmosDB"): str,
- Optional("tenant_id"): str,
- Optional("client_id"): str,
- Optional("client_secret"): str,
- Optional("cosmos_account_uri"): str,
- Optional("cosmos_account_key"): str,
- Optional("cosmos_db_name", default="scim_2_identity_pool"): str,
-
- Optional("pg_host", "localhost"): str,
- Optional("pg_port", default=5432): int,
- Optional("pg_ssl_mode", default="require"): str,
- Optional("pg_username"): str,
- Optional("pg_password"): str,
- Optional("pg_database", "postgres"): str,
- Optional("pg_schema", "public"): str,
+ Optional("type", default="InMemory"): str,
+ Optional("cosmos", default=None): Schema({
+ Optional("tenant_id"): str,
+ Optional("client_id"): str,
+ Optional("client_secret"): str,
+ Optional("account_uri"): str,
+ Optional("account_key"): str,
+ Optional("db_name", default="scim_2_db"): str,
+ }),
+ Optional("pg", default=None): Schema({
+ Optional("host"): str,
+ Optional("port", default=5432): int,
+ Optional("ssl_mode", default="require"): str,
+ Optional("username"): str,
+ Optional("password"): str,
+ Optional("database", default="postgres"): str,
+ Optional("schema", default="public"): str,
+ }),
+ Optional("mongo", default=None): Schema({
+ Optional("host"): str,
+ Optional("port", default=27017): int,
+ Optional("username"): str,
+ Optional("password"): str,
+ Optional("database", default="scim2Db"): str,
+ Optional("tls", default=True): bool,
+ Optional("replica_set", default=True): str,
+ Optional("dsn"): str,
+ })
}),
Optional("authentication", default={}): Schema({
Optional("akv", default={}): Schema({
diff --git a/keystone/util/store_util.py b/keystone/util/store_util.py
index 3f41f75..90956bb 100644
--- a/keystone/util/store_util.py
+++ b/keystone/util/store_util.py
@@ -4,6 +4,7 @@
from keystone.store import BaseStore
from keystone.store.memory_store import MemoryStore
from keystone.store.cosmos_db_store import CosmosDbStore
+from keystone.store.mongodb_store import MongoDbStore
from keystone.store.postgresql_store import PostgresqlStore
from keystone.util import ThreadSafeSingleton
from keystone.util.config import Config
@@ -27,19 +28,24 @@ def get(self, store_name: str):
def init_stores():
store_type = CONFIG.get("store.type", "InMemory")
store_impl: BaseStore
- if store_type == "CosmosDB":
- stores = Stores(
- users=CosmosDbStore("users", unique_attribute="userName"),
- groups=CosmosDbStore("groups", unique_attribute="displayName")
- )
- elif store_type == "PostgreSQL":
+ if CONFIG.get("store.pg.host") is not None:
user_store = PostgresqlStore("users")
group_store = PostgresqlStore("groups")
stores = Stores(
users=user_store,
groups=group_store
)
- elif store_type == "InMemory":
+ elif CONFIG.get("store.cosmos.account_uri"):
+ stores = Stores(
+ users=CosmosDbStore("users", unique_attribute="userName"),
+ groups=CosmosDbStore("groups", unique_attribute="displayName")
+ )
+ elif CONFIG.get("store.mongo.host") or CONFIG.get("store.mongo.dsn"):
+ stores = Stores(
+ users=MongoDbStore("users"),
+ groups=MongoDbStore("groups")
+ )
+ else:
stores = Stores(
users=MemoryStore("User"),
groups=MemoryStore(
@@ -49,7 +55,5 @@ def init_stores():
nested_store_attr="members"
)
)
- else:
- raise ValueError(f"Invalid store type: '{store_type}'")
LOGGER.debug("Using '%s' store for users and groups", store_type)
return stores
diff --git a/logo/how-it-works.png b/logo/how-it-works.png
new file mode 100644
index 0000000..667dd6c
Binary files /dev/null and b/logo/how-it-works.png differ
diff --git a/poetry.lock b/poetry.lock
index 2ce126a..94c0d91 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -284,7 +284,7 @@ python-versions = ">=3.6"
[[package]]
name = "coverage"
-version = "6.4.3"
+version = "6.4.4"
description = "Code coverage measurement for Python"
category = "dev"
optional = false
@@ -458,6 +458,26 @@ docs = ["sphinx (==4.5.0)", "sphinx-issues (==3.0.1)", "alabaster (==0.7.12)", "
lint = ["mypy (==0.961)", "flake8 (==4.0.1)", "flake8-bugbear (==22.6.22)", "pre-commit (>=2.4,<3.0)"]
tests = ["pytest", "pytz", "simplejson"]
+[[package]]
+name = "motor"
+version = "3.0.0"
+description = "Non-blocking MongoDB driver for Tornado or asyncio"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+pymongo = ">=4.1,<5"
+
+[package.extras]
+zstd = ["pymongo[zstd] (>=4.1,<5)"]
+srv = ["pymongo[srv] (>=4.1,<5)"]
+snappy = ["pymongo[snappy] (>=4.1,<5)"]
+ocsp = ["pymongo[ocsp] (>=4.1,<5)"]
+gssapi = ["pymongo[gssapi] (>=4.1,<5)"]
+encryption = ["pymongo[encryption] (>=4.1,<5)"]
+aws = ["pymongo[aws] (>=4.1,<5)"]
+
[[package]]
name = "msal"
version = "1.18.0"
@@ -665,6 +685,23 @@ dev = ["sphinx", "sphinx-rtd-theme", "zope.interface", "cryptography (>=3.3.1)",
docs = ["sphinx", "sphinx-rtd-theme", "zope.interface"]
tests = ["pytest (>=6.0.0,<7.0.0)", "coverage[toml] (==5.0.4)"]
+[[package]]
+name = "pymongo"
+version = "4.2.0"
+description = "Python driver for MongoDB "
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.extras]
+aws = ["pymongo-auth-aws (<2.0.0)"]
+encryption = ["pymongocrypt (>=1.3.0,<2.0.0)"]
+gssapi = ["pykerberos"]
+ocsp = ["pyopenssl (>=17.2.0)", "requests (<3.0.0)", "service-identity (>=18.1.0)", "certifi"]
+snappy = ["python-snappy"]
+srv = ["dnspython (>=1.16.0,<3.0.0)"]
+zstd = ["zstandard"]
+
[[package]]
name = "pyparsing"
version = "3.0.9"
@@ -1001,7 +1038,7 @@ multidict = ">=4.0"
[metadata]
lock-version = "1.1"
python-versions = "^3.9"
-content-hash = "131a0ddb6b5c0d9282577e00051178b696fdb69a9f48431dab741f7b70bf1893"
+content-hash = "602e8b2972b0c7ded20b4ca8890896984feff7f5b2ad143089187c12f96af216"
[metadata.files]
aiohttp = [
@@ -1329,6 +1366,7 @@ marshmallow = [
{file = "marshmallow-3.17.0-py3-none-any.whl", hash = "sha256:00040ab5ea0c608e8787137627a8efae97fabd60552a05dc889c888f814e75eb"},
{file = "marshmallow-3.17.0.tar.gz", hash = "sha256:635fb65a3285a31a30f276f30e958070f5214c7196202caa5c7ecf28f5274bc7"},
]
+motor = []
msal = [
{file = "msal-1.18.0-py2.py3-none-any.whl", hash = "sha256:9c10e6cb32e0b6b8eaafc1c9a68bc3b2ff71505e0c5b8200799582d8b9f22947"},
{file = "msal-1.18.0.tar.gz", hash = "sha256:576af55866038b60edbcb31d831325a1bd8241ed272186e2832968fd4717d202"},
@@ -1446,6 +1484,7 @@ pyjwt = [
{file = "PyJWT-2.4.0-py3-none-any.whl", hash = "sha256:72d1d253f32dbd4f5c88eaf1fdc62f3a19f676ccbadb9dbc5d07e951b2b26daf"},
{file = "PyJWT-2.4.0.tar.gz", hash = "sha256:d42908208c699b3b973cbeb01a969ba6a96c821eefb1c5bfe4c390c01d67abba"},
]
+pymongo = []
pyparsing = [
{file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"},
{file = "pyparsing-3.0.9.tar.gz", hash = "sha256:2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb"},
diff --git a/pyproject.toml b/pyproject.toml
index de86f3d..2e656a1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -19,6 +19,7 @@ azure-cosmos = "^4.3.0"
azure-identity = "^1.10.0"
azure-keyvault-secrets = "^4.4.0"
loguru = "^0.6.0"
+motor = "^3.0.0"
psycopg2 = "^2.9.3"
pyyaml = "^6.0"
python-json-logger = "^2.0.2"
diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py
index 2b7cd03..ae2ad50 100644
--- a/tests/integration/conftest.py
+++ b/tests/integration/conftest.py
@@ -1,4 +1,4 @@
-import uuid
+import random
from typing import Callable
import asyncio
@@ -7,11 +7,16 @@
from aiohttp.test_utils import BaseTestServer, TestServer, TestClient
from aiohttp.web_app import Application
-from keystone.store.postgresql_store import set_up_schema
+from keystone.store import postgresql_store
+from keystone.store import mongodb_store
from keystone.util.config import Config
from keystone.util.store_util import init_stores
+def gen_random_hex(length: int = 24):
+ return f"%0{length}x" % random.randrange(16**length)
+
+
def build_user(first_name, last_name, guid):
email = f"{first_name[0].lower()}{last_name.lower()}@company.com"
return {
@@ -45,23 +50,26 @@ def build_user(first_name, last_name, guid):
@pytest.fixture(scope="module")
def initial_user():
- return build_user("Alex", "Smith", "58c08e90-fbe7-4460-970b-9b3a9840d661")
+ return build_user("Alex", "Smith", "6303270163fc418d32450cd9")
@pytest.fixture(scope="module")
def second_user():
- return build_user("Daniel", "Gonzales", "458c6efc-f3c3-4606-b189-e347120068e6")
+ return build_user("Daniel", "Gonzales", "63033e376d9d20f702bc0a11")
@pytest.fixture(scope="module")
def module_scoped_event_loop():
- loop = asyncio.get_event_loop()
+ try:
+ loop = asyncio.get_running_loop()
+ except RuntimeError:
+ loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="module")
-def module_scoped_aiohttp_client(module_scoped_event_loop): # loop: asyncio.AbstractEventLoop):
+def module_scoped_aiohttp_client(module_scoped_event_loop):
loop = module_scoped_event_loop
clients = []
@@ -112,8 +120,10 @@ def scim_api(module_scoped_aiohttp_client, module_scoped_event_loop, cfg, initia
app.middlewares.append(
module_scoped_event_loop.run_until_complete(get_error_handling_mw())
)
- if cfg.get("store.type") == "PostgreSQL":
- set_up_schema()
+ if cfg.get("store.pg.host") is not None:
+ postgresql_store.set_up_schema()
+ elif cfg.get("store.mongo.host") or cfg.get("store.mongo.dsn"):
+ module_scoped_event_loop.run_until_complete(mongodb_store.set_up())
c = module_scoped_event_loop.run_until_complete(module_scoped_aiohttp_client(app))
module_scoped_event_loop.run_until_complete(c.post("/scim/Users", json=initial_user, headers=headers))
return c
@@ -163,14 +173,14 @@ def invalid_user_by_username_response(run_async, scim_api, headers):
@pytest.fixture(scope="module")
def invalid_user_by_id_response(run_async, scim_api, headers):
- invalid_user_id = str(uuid.uuid4())
+ invalid_user_id = gen_random_hex() # str(uuid.uuid4())
url = f"/scim/Users/{invalid_user_id}"
return run_async(scim_api.get(url, headers=headers))
@pytest.fixture(scope="module")
def random_user_nonexistent_response(run_async, scim_api, headers):
- random_username = f"{str(uuid.uuid4())}@organization.org"
+ random_username = f"{gen_random_hex()}@organization.org"
url = f"/scim/Users?filter=userName eq \"{random_username}\""
return run_async(scim_api.get(url, headers=headers))
diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py
index 9b129dd..921db46 100644
--- a/tests/unit/conftest.py
+++ b/tests/unit/conftest.py
@@ -1,17 +1,26 @@
import random
import uuid
-import asyncio
import names
import pytest
from aiohttp import web
from keystone.store.memory_store import MemoryStore
from keystone.store.postgresql_store import PostgresqlStore, set_up_schema
+from keystone.store import mongodb_store
from keystone.util.config import Config
from keystone.util.store_util import init_stores
+def gen_random_hex(length: int = 24):
+ return f"%0{length}x" % random.randrange(16**length)
+
+
+@pytest.fixture
+def rand_id():
+ return gen_random_hex(24)
+
+
@pytest.fixture
def postgresql_stores(event_loop):
conn_args = dict(
@@ -31,6 +40,24 @@ def postgresql_stores(event_loop):
event_loop.run_until_complete(group_store.term_connection())
+@pytest.fixture
+def mongodb_stores(event_loop):
+ conn_args = dict(
+ host="localhost",
+ port=27017,
+ username="root",
+ password="example",
+ tls=False,
+ database="scim2UnitTest",
+ )
+ event_loop.run_until_complete(mongodb_store.set_up(**conn_args))
+ user_store = mongodb_store.MongoDbStore("users", **conn_args)
+ group_store = mongodb_store.MongoDbStore("groups", **conn_args)
+ yield user_store, group_store
+ event_loop.run_until_complete(user_store.clean_up_store())
+ event_loop.run_until_complete(group_store.clean_up_store())
+
+
def generate_random_user():
first_name = names.get_first_name()
last_name = names.get_last_name()
@@ -57,7 +84,7 @@ def generate_random_user():
"schemas": [
"urn:ietf:params:scim:schemas:core:2.0:User"
],
- "id": str(uuid.uuid4()),
+ "id": gen_random_hex(24),
"userName": email,
"active": True,
"displayName": f"{first_name} {last_name}"
diff --git a/tests/unit/store/test_mongodb_store.py b/tests/unit/store/test_mongodb_store.py
new file mode 100644
index 0000000..c058798
--- /dev/null
+++ b/tests/unit/store/test_mongodb_store.py
@@ -0,0 +1,304 @@
+import uuid
+from random import choice
+
+import asyncio
+import pytest
+from psycopg2.errors import UniqueViolation
+from pymongo.errors import DuplicateKeyError
+
+from keystone.util.exc import ResourceNotFound
+
+
+class TestMongoDbStore:
+
+ @staticmethod
+ @pytest.mark.asyncio
+ async def test_get_user_by_id_fails_for_nonexistent_user(mongodb_stores, rand_id):
+ user_store, _ = mongodb_stores
+ exc_thrown = False
+ try:
+ _ = await user_store.get_by_id(rand_id)
+ except ResourceNotFound:
+ exc_thrown = True
+ assert exc_thrown
+
+ @staticmethod
+ @pytest.mark.asyncio
+ async def test_delete_user_by_id_fails_for_nonexistent_user(mongodb_stores, rand_id):
+ user_store, _ = mongodb_stores
+ exc_thrown = False
+ try:
+ _ = await user_store.delete(rand_id)
+ except ResourceNotFound:
+ exc_thrown = True
+ assert exc_thrown
+
+ @staticmethod
+ @pytest.mark.asyncio
+ async def test_create_user_success(mongodb_stores, single_user):
+ user_store, _ = mongodb_stores
+ returned_user = await user_store.create(single_user)
+ user_id = returned_user.get("id")
+ looked_up_user = await user_store.get_by_id(user_id)
+ assert looked_up_user.get("id") == user_id
+ assert looked_up_user.get("userName") == returned_user.get("userName")
+
+ @staticmethod
+ @pytest.mark.asyncio
+ async def test_create_user_fails_on_duplicate_username(mongodb_stores, single_user):
+ user_store, _ = mongodb_stores
+ _ = await user_store.create(single_user)
+ duplicate_user = {**single_user}
+ del duplicate_user["id"]
+ exc_thrown = False
+ try:
+ _ = await user_store.create(duplicate_user)
+ except DuplicateKeyError:
+ exc_thrown = True
+ assert exc_thrown
+
+ @staticmethod
+ @pytest.mark.asyncio
+ async def test_delete_user_success(mongodb_stores, single_user):
+ user_store, _ = mongodb_stores
+ user = await user_store.create(single_user)
+ user_id = user.get("id")
+ _ = await user_store.delete(user_id)
+ exc_thrown = False
+ try:
+ _ = await user_store.get_by_id(user_id)
+ except ResourceNotFound:
+ exc_thrown = True
+ assert exc_thrown
+
+ @staticmethod
+ @pytest.mark.asyncio
+ async def test_search_user_by_username(mongodb_stores, single_user):
+ user_store, _ = mongodb_stores
+ username = single_user.get("userName")
+ _ = await user_store.create(single_user)
+ _filter = f"userName Eq \"{username}\""
+ res, count = await user_store.search(_filter)
+ assert 1 == count == len(res)
+ assert res[0].get("userName") == username
+
+ mixed_case_username = "".join(choice((str.upper, str.lower))(c) for c in username)
+ _filter = f"userName Eq \"{mixed_case_username}\""
+ res, count = await user_store.search(_filter)
+ assert 1 == count == len(res)
+ assert res[0].get("userName") == username
+
+ @staticmethod
+ @pytest.mark.asyncio
+ async def test_search_user_by_id(mongodb_stores, single_user):
+ user_store, _ = mongodb_stores
+ user = await user_store.create(single_user)
+ user_id = user.get("id")
+ _filter = f"id Eq \"{user_id}\""
+ res, count = await user_store.search(_filter)
+ assert 1 == count == len(res)
+ assert res[0].get("userName") == single_user.get("userName")
+
+ @staticmethod
+ @pytest.mark.asyncio
+ async def test_search_user_by_email(mongodb_stores, single_user):
+ user_store, _ = mongodb_stores
+ email = single_user.get("userName")
+ _ = await user_store.create(single_user)
+ _filter = f"emails.value Eq \"{email}\""
+ res, count = await user_store.search(_filter)
+ assert 1 == count == len(res)
+ assert res[0].get("userName") == single_user.get("userName")
+
+ _filter = f"emails Co \"{email}\""
+ res, count = await user_store.search(_filter)
+ assert 1 == count == len(res)
+ assert res[0].get("userName") == single_user.get("userName")
+
+ email_username = email.split("@")[0]
+ _filter = f"emails.value Sw \"{email_username}\""
+ res, count = await user_store.search(_filter)
+ assert 1 == count == len(res)
+ assert res[0].get("userName") == single_user.get("userName")
+
+ @staticmethod
+ @pytest.mark.asyncio
+ async def test_search_user_pagination(mongodb_stores, users):
+ user_store, _ = mongodb_stores
+ _ = await asyncio.gather(*[user_store.create(u) for u in users])
+ email = users[0].get("userName")
+ email_domain = email.split("@")[1]
+ _filter = f"emails.value co \"{email_domain}\""
+ res, count = await user_store.search(_filter, start_index=1, count=3)
+ assert len(users) == count
+ assert 3 == len(res)
+
+ res, count = await user_store.search(_filter, start_index=4, count=3)
+ assert len(users) == count
+ assert 2 == len(res)
+
+ @staticmethod
+ @pytest.mark.asyncio
+ async def test_update_user_success(mongodb_stores, single_user):
+ user_store, _ = mongodb_stores
+ res = await user_store.create(single_user)
+ user_id = res.get("id")
+ update_attr = {
+ "groups": [], # To be ignored
+ "id": user_id, # To be ignored
+ "invalidAttribute": "foo", # To be ignored
+ "name": {
+ "formatted": "John Doe",
+ "givenName": "Doe",
+ "familyName": "John"
+ },
+ "locale": "pt-BR",
+ "displayName": "John Doe",
+ "emails": single_user.get("emails") + [{
+ "value": "johndoe@emailprovider.com",
+ "primary": False,
+ "type": "home"
+ }],
+ }
+ updated_user = await user_store.update(user_id, **update_attr)
+ assert "pt-BR" == updated_user.get("locale")
+ assert 2 == len(updated_user.get("emails"))
+ assert "John Doe" == updated_user.get("displayName") == updated_user.get("name").get("formatted")
+
+ @staticmethod
+ @pytest.mark.asyncio
+ async def test_create_group_success(mongodb_stores, single_group):
+ _, group_store = mongodb_stores
+ res = await group_store.create(single_group)
+ group_id = res.get("id")
+
+ group = await group_store.get_by_id(group_id)
+
+ assert single_group.get("displayName") == group.get("displayName")
+
+ @staticmethod
+ @pytest.mark.asyncio
+ async def test_create_group_with_members_success(mongodb_stores, single_group, users):
+ user_store, group_store = mongodb_stores
+ user_res = await asyncio.gather(*[user_store.create(u) for u in users])
+ group_payload = {**single_group, "members": [{
+ "value": u.get("id"),
+ "display": u.get("userName"),
+ } for u in user_res]}
+ res = await group_store.create(group_payload)
+ group_id = res.get("id")
+ group = await group_store.get_by_id(group_id)
+ assert len(users) == len(group.get("members"))
+
+ @staticmethod
+ @pytest.mark.asyncio
+ async def test_get_group_by_id_fails_for_nonexistent_group(mongodb_stores, rand_id):
+ _, group_store = mongodb_stores
+ exc_thrown = False
+ try:
+ _ = await group_store.get_by_id(rand_id)
+ except ResourceNotFound:
+ exc_thrown = True
+ assert exc_thrown
+
+ @staticmethod
+ @pytest.mark.asyncio
+ async def test_update_group_metadata_success(mongodb_stores, single_group):
+ _, group_store = mongodb_stores
+ res = await group_store.create(single_group)
+ group_id = res.get("id")
+ update_attr = {
+ "id": group_id, # To be ignored
+ "displayName": "New Group Name"
+ }
+ group = await group_store.update(group_id, **update_attr)
+ assert "New Group Name" == group.get("displayName")
+
+ @staticmethod
+ @pytest.mark.asyncio
+ async def test_delete_group_success(mongodb_stores, single_group):
+ _, group_store = mongodb_stores
+ group = await group_store.create(single_group)
+ group_id = group.get("id")
+ _ = await group_store.delete(group_id)
+ exc_thrown = False
+ try:
+ _ = await group_store.get_by_id(group_id)
+ except ResourceNotFound:
+ exc_thrown = True
+ assert exc_thrown
+
+ @staticmethod
+ @pytest.mark.asyncio
+ async def test_delete_group_fails_for_nonexistent_group(mongodb_stores, rand_id):
+ _, group_store = mongodb_stores
+ exc_thrown = False
+ try:
+ _ = await group_store.delete(rand_id)
+ except ResourceNotFound:
+ exc_thrown = True
+ assert exc_thrown
+
+ @staticmethod
+ @pytest.mark.asyncio
+ async def test_add_users_to_group(mongodb_stores, single_group, users):
+ user_store, group_store = mongodb_stores
+ _ = await asyncio.gather(*[user_store.create(u) for u in users])
+ ret_users, _ = await user_store.search()
+ res = await group_store.create(single_group)
+ group_id = res.get("id")
+ _ = await asyncio.gather(*[group_store.add_user_to_group(u.get("id"), group_id) for u in ret_users])
+ group = await group_store.get_by_id(group_id)
+ assert len(users) == len(group.get("members"))
+
+ @staticmethod
+ @pytest.mark.asyncio
+ async def test_remove_users_from_group(mongodb_stores, single_group, users):
+ user_store, group_store = mongodb_stores
+ _ = await asyncio.gather(*[user_store.create(u) for u in users])
+ ret_users, _ = await user_store.search()
+ res = await group_store.create(single_group)
+ group_id = res.get("id")
+ _ = await asyncio.gather(*[group_store.add_user_to_group(
+ user_id=u.get("id"),
+ group_id=group_id
+ ) for u in ret_users])
+ group = await group_store.get_by_id(group_id)
+ assert len(users) == len(group.get("members"))
+
+ _ = await group_store.remove_users_from_group(
+ user_ids=[u.get("id") for u in ret_users],
+ group_id=group_id
+ )
+ group = await group_store.get_by_id(group_id)
+ assert 0 == len(group.get("members"))
+
+ @staticmethod
+ @pytest.mark.asyncio
+ async def test_set_group_members(mongodb_stores, single_group, users):
+ user_store, group_store = mongodb_stores
+ uc_res = await asyncio.gather(*[user_store.create(u) for u in users])
+ cohort_1 = uc_res[:2]
+ cohort_2 = uc_res[2:len(uc_res)]
+ res = await group_store.create(single_group)
+ group_id = res.get("id")
+ _ = await asyncio.gather(*[group_store.add_user_to_group(u.get("id"), group_id) for u in cohort_1])
+ group = await group_store.get_by_id(group_id)
+ assert 2 == len(group.get("members"))
+ _ = await group_store.set_group_members(
+ user_ids=[u.get("id") for u in cohort_2],
+ group_id=group_id
+ )
+ group = await group_store.get_by_id(group_id)
+ assert len(users) - 2 == len(group.get("members"))
+
+ @staticmethod
+ @pytest.mark.asyncio
+ async def test_search_groups(mongodb_stores, groups):
+ _, group_store = mongodb_stores
+ _ = await asyncio.gather(*[group_store.create(g) for g in groups])
+ _filter = f"displayName Eq \"Human Resources\""
+ res, count = await group_store.search(_filter)
+ assert 1 == count
+ assert "Human Resources" == res[0].get("displayName")
+