Skip to content

Commit

Permalink
Add some more ergonomics to instance syncing along with sqlite example (
Browse files Browse the repository at this point in the history
  • Loading branch information
erlendvollset authored Sep 8, 2023
1 parent c43e55b commit 097177f
Show file tree
Hide file tree
Showing 9 changed files with 382 additions and 38 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ Changes are grouped as follows
- `Fixed` for any bug fixes.
- `Security` in case of vulnerabilities.

## [6.22.0] - 2023-09-08
### Added
- `client.data_modeling.instances.subscribe` which lets you subscribe to a given
data modeling query and receive updates through a provided callback.
- Example on how to use the subscribe method to sync nodes to a local sqlite db.

## [6.21.1] - 2023-09-07
### Fixed
- Concurrent usage of the `CogniteClient` could result in API calls being made with the wrong value for `api_subversion`.
Expand Down
116 changes: 115 additions & 1 deletion cognite/client/_api/data_modeling/instances.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
from __future__ import annotations

import json
import logging
import random
import time
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Iterator, List, Literal, Sequence, Union, cast, overload
from datetime import datetime, timezone
from threading import Thread
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterator,
List,
Literal,
Sequence,
Union,
cast,
overload,
)

from cognite.client._api_client import APIClient
from cognite.client._constants import DEFAULT_LIMIT_READ
Expand Down Expand Up @@ -37,6 +53,7 @@
NodeApplyResult,
NodeApplyResultList,
NodeList,
SubscriptionContext,
)
from cognite.client.data_classes.data_modeling.query import (
Query,
Expand All @@ -45,6 +62,8 @@
from cognite.client.data_classes.data_modeling.views import View
from cognite.client.data_classes.filters import Filter, _validate_filter
from cognite.client.utils._identifier import DataModelingIdentifierSequence
from cognite.client.utils._retry import Backoff
from cognite.client.utils._text import random_string

from ._data_modeling_executor import get_data_modeling_executor

Expand All @@ -70,6 +89,8 @@
}
)

_LOGGER = logging.getLogger(__name__)


class _NodeOrEdgeList(CogniteResourceList):
_RESOURCE = (Node, Edge) # type: ignore[assignment]
Expand Down Expand Up @@ -394,6 +415,99 @@ def delete(
edge_ids = [EdgeId.load(item) for item in deleted_instances if item["instanceType"] == "edge"]
return InstancesDeleteResult(node_ids, edge_ids)

def subscribe(
self,
query: Query,
callback: Callable[[QueryResult], None],
poll_delay_seconds: float = 30,
throttle_seconds: float = 1,
) -> SubscriptionContext:
"""Subscribe to a query and get updates when the result set changes. This invokes the sync() method in a loop
in a background thread, and only invokes the callback when there are actual changes to the result set(s).
We do not support chaining result sets when subscribing to a query.
Args:
query (Query): The query to subscribe to.
callback (Callable[[QueryResult], None]): The callback function to call when the result set changes.
poll_delay_seconds (float): The time to wait between polls when no data is present. Defaults to 30 seconds.
throttle_seconds (float): The time to wait between polls despite data being present.
Returns:
SubscriptionContext: An object that can be used to cancel the subscription.
Examples:
Subscrie to a given query and print the changed data:
>>> from cognite.client import CogniteClient
>>> from cognite.client.data_classes.data_modeling.query import Query, QueryResult, NodeResultSetExpression, Select, SourceSelector
>>> from cognite.client.data_classes.data_modeling import ViewId
>>> from cognite.client.data_classes.filters import Range
>>>
>>> c = CogniteClient()
>>> def just_print_the_result(result: QueryResult) -> None:
... print(result)
...
>>> view_id = ViewId("someSpace", "someView", "v1")
>>> filter = Range(view_id.as_property_ref("releaseYear"), lt=2000)
>>> query = Query(
... with_={"movies": NodeResultSetExpression(filter=filter)},
... select={"movies": Select([SourceSelector(view_id, ["releaseYear"])])}
... )
>>> subscription_context = c.data_modeling.instances.subscribe(query, just_print_the_result)
>>> subscription_context.cancel()
"""
for result_set_expression in query.with_.values():
if result_set_expression.from_ is not None:
raise ValueError("Cannot chain result sets when subscribing to a query")

subscription_context = SubscriptionContext()

def _poll_delay(seconds: float) -> None:
if not hasattr(_poll_delay, "has_been_invoked"):
# smear if first invocation
delay = random.uniform(0, poll_delay_seconds)
setattr(_poll_delay, "has_been_invoked", True)
else:
delay = seconds
_LOGGER.debug(f"Waiting {delay} seconds before polling sync endpoint again...")
time.sleep(delay)

def _do_subscribe() -> None:
cursors = query.cursors
error_backoff = Backoff(max_wait=30)
while not subscription_context._canceled:
# No need to resync if we encountered an error in the callback last iteration
if not error_backoff.has_progressed():
query.cursors = cursors
result = self.sync(query)
subscription_context.last_successful_sync = datetime.now(tz=timezone.utc)

try:
callback(result)
except Exception:
_LOGGER.exception("Unhandled exception in sync subscriber callback. Backing off and retrying...")
time.sleep(next(error_backoff))
continue

subscription_context.last_successful_callback = datetime.now(tz=timezone.utc)
# only progress the cursor if the callback executed successfully
cursors = result.cursors

data_is_present = any(len(instances) > 0 for instances in result.data.values())
if data_is_present:
_poll_delay(throttle_seconds)
else:
_poll_delay(poll_delay_seconds)

error_backoff.reset()

thread_name = f"instances-sync-subscriber-{random_string(10)}"
thread = Thread(target=_do_subscribe, name=thread_name, daemon=True)
thread.start()

return subscription_context

@classmethod
def _create_other_params(
cls,
Expand Down
2 changes: 1 addition & 1 deletion cognite/client/_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import annotations

__version__ = "6.21.1"
__version__ = "6.22.0"
__api_subversion__ = "V20220125"
23 changes: 19 additions & 4 deletions cognite/client/data_classes/data_modeling/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from abc import abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -691,9 +692,11 @@ def as_ids(self) -> list[NodeId]:


class NodeListWithCursor(NodeList):
def __init__(self, resources: Collection[Any], cognite_client: CogniteClient | None = None) -> None:
def __init__(
self, resources: Collection[Any], cursor: str | None, cognite_client: CogniteClient | None = None
) -> None:
super().__init__(resources, cognite_client)
self.cursor: str | None = None
self.cursor = cursor


class EdgeApplyResultList(CogniteResourceList[EdgeApplyResult]):
Expand Down Expand Up @@ -736,9 +739,11 @@ def as_ids(self) -> list[EdgeId]:


class EdgeListWithCursor(EdgeList):
def __init__(self, resources: Collection[Any], cognite_client: CogniteClient | None = None) -> None:
def __init__(
self, resources: Collection[Any], cursor: str | None, cognite_client: CogniteClient | None = None
) -> None:
super().__init__(resources, cognite_client)
self.cursor: str | None = None
self.cursor = cursor


@dataclass
Expand Down Expand Up @@ -809,3 +814,13 @@ class InstancesDeleteResult:

nodes: list[NodeId]
edges: list[EdgeId]


@dataclass
class SubscriptionContext:
last_successful_sync: datetime | None = None
last_successful_callback: datetime | None = None
_canceled: bool = False

def cancel(self) -> None:
self._canceled = True
46 changes: 23 additions & 23 deletions cognite/client/data_classes/data_modeling/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

from cognite.client.data_classes.data_modeling.ids import ViewId
from cognite.client.data_classes.data_modeling.instances import (
EdgeList,
Edge,
EdgeListWithCursor,
InstanceSort,
NodeList,
Node,
NodeListWithCursor,
PropertyValue,
)
Expand Down Expand Up @@ -95,8 +95,11 @@ def __init__(
self.parameters = parameters
self.cursors = cursors or {k: None for k in select}

def instance_type_by_result_expression(self) -> dict[str, type[NodeList] | type[EdgeList]]:
return {k: NodeList if isinstance(v, NodeResultSetExpression) else EdgeList for k, v in self.with_.items()}
def instance_type_by_result_expression(self) -> dict[str, type[NodeListWithCursor] | type[EdgeListWithCursor]]:
return {
k: NodeListWithCursor if isinstance(v, NodeResultSetExpression) else EdgeListWithCursor
for k, v in self.with_.items()
}

def dump(self, camel_case: bool = False) -> dict[str, Any]:
output: dict[str, Any] = {
Expand Down Expand Up @@ -137,6 +140,12 @@ def __eq__(self, other: Any) -> bool:


class ResultSetExpression(ABC):
def __init__(self, from_: str | None, filter: Filter | None, limit: int | None, sort: list[InstanceSort] | None):
self.from_ = from_
self.filter = filter
self.limit = limit
self.sort = sort

@abstractmethod
def dump(self, camel_case: bool = False) -> dict[str, Any]:
...
Expand Down Expand Up @@ -184,10 +193,7 @@ def __init__(
sort: list[InstanceSort] | None = None,
limit: int | None = None,
):
self.from_ = from_
self.filter = filter
self.sort = sort
self.limit = limit
super().__init__(from_=from_, filter=filter, limit=limit, sort=sort)

def dump(self, camel_case: bool = False) -> dict[str, Any]:
output: dict[str, Any] = {"nodes": {}}
Expand Down Expand Up @@ -219,16 +225,13 @@ def __init__(
post_sort: list[InstanceSort] | None = None,
limit: int | None = None,
):
self.from_ = from_
super().__init__(from_=from_, filter=filter, limit=limit, sort=sort)
self.max_distance = max_distance
self.direction = direction
self.filter = filter
self.node_filter = node_filter
self.termination_filter = termination_filter
self.limit_each = limit_each
self.sort = sort
self.post_sort = post_sort
self.limit = limit

def dump(self, camel_case: bool = False) -> dict[str, Any]:
output: dict[str, Any] = {"edges": {}}
Expand Down Expand Up @@ -270,22 +273,19 @@ def cursors(self) -> dict[str, str]:
def load(
cls,
data: dict[str, Any] | str,
default_by_reference: dict[str, type[NodeList] | type[EdgeList]],
cursors: dict[str, Any] | None = None,
instance_list_type_by_result_expression_name: dict[str, type[NodeListWithCursor] | type[EdgeListWithCursor]],
cursors: dict[str, Any],
) -> QueryResult:
data = json.loads(data) if isinstance(data, str) else data
instance = cls()
for key, values in data.items():
cursor = cursors.get(key)
if not values:
instance[key] = default_by_reference[key]([])
elif values[0].get("instanceType") == "node":
instance[key] = NodeListWithCursor._load(values)
if cursors:
instance[key].cursor = cursors.get(key)
elif values[0].get("instanceType") == "edge":
instance[key] = EdgeListWithCursor._load(values)
if cursors:
instance[key].cursor = cursors.get(key)
instance[key] = instance_list_type_by_result_expression_name[key]([], cursor)
elif values[0]["instanceType"] == "node":
instance[key] = NodeListWithCursor([Node._load(node) for node in values], cursor)
elif values[0]["instanceType"] == "edge":
instance[key] = EdgeListWithCursor([Edge._load(edge) for edge in values], cursor)
else:
raise ValueError(f"Unexpected instance type {values[0].get('instanceType')}")

Expand Down
33 changes: 33 additions & 0 deletions cognite/client/utils/_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from __future__ import annotations

from random import uniform
from typing import Iterator


class Backoff(Iterator[float]):
"""Iterator that emits how long to wait, according to the "Full Jitter" approach
described in this post: https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
Args:
multiplier (float): No description.
max_wait (float): No description.
base (int): No description."""

def __init__(self, multiplier: float = 0.5, max_wait: float = 60.0, base: int = 2) -> None:
self._multiplier = multiplier
self._max_wait = max_wait
self._base = base
self._past_attempts = 0

def __next__(self) -> float:
# 100 is an arbitrary limit at which point most sensible parameters are likely to
# be capped by max anyway.
wait = uniform(0, min(self._multiplier * (self._base ** min(100, self._past_attempts)), self._max_wait))
self._past_attempts += 1
return wait

def reset(self) -> None:
self._past_attempts = 0

def has_progressed(self) -> bool:
return self._past_attempts > 0
Loading

0 comments on commit 097177f

Please sign in to comment.