Skip to content

Commit

Permalink
Create a hashing state store (#329)
Browse files Browse the repository at this point in the history
Create a state store which hashes the content and stores that, instead
of a low/high watermark like the old version.

This will be useful for things like generic RAW rows and data models,
where we don't always have an incremental field to use as a high
watermark.
  • Loading branch information
mathialo authored May 22, 2024
1 parent 0f98d81 commit 349fa23
Show file tree
Hide file tree
Showing 7 changed files with 404 additions and 72 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ Changes are grouped as follows
called at all if the cancellation token was set. This resulted in errors
with for example upload queues.

## Next

### Added

* A new type of state store that stores hashes of ingested items. This can be used to detect changed RAW rows or data model instances.

## 7.1.6

### Changed
Expand Down
12 changes: 12 additions & 0 deletions cognite/extractorutils/statestore/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from .hashing import AbstractHashStateStore, LocalHashStateStore, RawHashStateStore
from .watermark import AbstractStateStore, LocalStateStore, NoStateStore, RawStateStore

__all__ = [
"AbstractStateStore",
"RawStateStore",
"LocalStateStore",
"NoStateStore",
"AbstractHashStateStore",
"RawHashStateStore",
"LocalHashStateStore",
]
81 changes: 81 additions & 0 deletions cognite/extractorutils/statestore/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import logging
import threading
from abc import ABC, abstractmethod
from typing import Optional

from cognite.extractorutils._inner_util import _resolve_log_level
from cognite.extractorutils.threading import CancellationToken

RETRY_BACKOFF_FACTOR = 1.5
RETRY_MAX_DELAY = 60
RETRY_DELAY = 1
RETRIES = 10


class _BaseStateStore(ABC):
def __init__(
self,
save_interval: Optional[int] = None,
trigger_log_level: str = "DEBUG",
thread_name: Optional[str] = None,
cancellation_token: Optional[CancellationToken] = None,
) -> None:
self._initialized = False

self.logger = logging.getLogger(__name__)
self.trigger_log_level = _resolve_log_level(trigger_log_level)
self.save_interval = save_interval

self.thread = threading.Thread(target=self._run, daemon=cancellation_token is None, name=thread_name)
self.lock = threading.RLock()
self.cancellation_token = cancellation_token.create_child_token() if cancellation_token else CancellationToken()

def start(self, initialize: bool = True) -> None:
"""
Start saving state periodically if save_interval is set.
This calls the synchronize method every save_interval seconds.
"""
if initialize and not self._initialized:
self.initialize()
if self.save_interval is not None:
self.thread.start()

def stop(self, ensure_synchronize: bool = True) -> None:
"""
Stop synchronize thread if running, and ensure state is saved if ensure_synchronize is True.
Args:
ensure_synchronize (bool): (Optional). Call synchronize one last time after shutting down thread.
"""
self.cancellation_token.cancel()
if ensure_synchronize:
self.synchronize()

def _run(self) -> None:
"""
Internal run method for synchronize thread
"""
self.initialize()
while not self.cancellation_token.wait(timeout=self.save_interval):
try:
self.logger.log(self.trigger_log_level, "Triggering scheduled state store synchronization")
self.synchronize()
except Exception as e:
self.logger.error("Unexpected error while synchronizing state store: %s.", str(e))

# trigger stop event explicitly to drain the queue
self.stop(ensure_synchronize=True)

@abstractmethod
def initialize(self, force: bool = False) -> None:
"""
Get states from remote store
"""
pass

@abstractmethod
def synchronize(self) -> None:
"""
Upload states to remote store
"""
pass
258 changes: 258 additions & 0 deletions cognite/extractorutils/statestore/hashing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
import hashlib
import json
from abc import ABC
from types import TracebackType
from typing import Any, Dict, Iterable, Iterator, Optional, Set, Type

import orjson

from cognite.client import CogniteClient
from cognite.client.data_classes import Row
from cognite.client.exceptions import CogniteAPIError
from cognite.extractorutils._inner_util import _DecimalDecoder, _DecimalEncoder
from cognite.extractorutils.threading import CancellationToken
from cognite.extractorutils.util import cognite_exceptions, retry

from ._base import RETRIES, RETRY_BACKOFF_FACTOR, RETRY_DELAY, RETRY_MAX_DELAY, _BaseStateStore


class AbstractHashStateStore(_BaseStateStore, ABC):
def __init__(
self,
save_interval: Optional[int] = None,
trigger_log_level: str = "DEBUG",
thread_name: Optional[str] = None,
cancellation_token: Optional[CancellationToken] = None,
) -> None:
super().__init__(
save_interval=save_interval,
trigger_log_level=trigger_log_level,
thread_name=thread_name,
cancellation_token=cancellation_token,
)

self._local_state: Dict[str, Dict[str, str]] = {}
self._seen: Set[str] = set()

def get_state(self, external_id: str) -> Optional[str]:
with self.lock:
return self._local_state.get(external_id, {}).get("digest")

def _hash_row(self, data: Dict[str, Any]) -> str:
return hashlib.sha256(orjson.dumps(data, option=orjson.OPT_SORT_KEYS)).hexdigest()

def set_state(self, external_id: str, data: Dict[str, Any]) -> None:
with self.lock:
self._local_state[external_id] = {"digest": self._hash_row(data)}

def has_changed(self, external_id: str, data: Dict[str, Any]) -> bool:
with self.lock:
if external_id not in self._local_state:
return True

return self._hash_row(data) != self._local_state[external_id]["digest"]

def __getitem__(self, external_id: str) -> Optional[str]:
return self.get_state(external_id)

def __setitem__(self, key: str, value: Dict[str, Any]) -> None:
self.set_state(external_id=key, data=value)

def __contains__(self, external_id: str) -> bool:
return external_id in self._local_state

def __len__(self) -> int:
return len(self._local_state)

def __iter__(self) -> Iterator[str]:
with self.lock:
for key in self._local_state:
yield key


class RawHashStateStore(AbstractHashStateStore):
def __init__(
self,
cdf_client: CogniteClient,
database: str,
table: str,
save_interval: Optional[int] = None,
trigger_log_level: str = "DEBUG",
thread_name: Optional[str] = None,
cancellation_token: Optional[CancellationToken] = None,
) -> None:
super().__init__(
save_interval=save_interval,
trigger_log_level=trigger_log_level,
thread_name=thread_name,
cancellation_token=cancellation_token,
)
self._cdf_client = cdf_client
self.database = database
self.table = table

def synchronize(self) -> None:
@retry(
exceptions=cognite_exceptions(),
cancellation_token=self.cancellation_token,
tries=RETRIES,
delay=RETRY_DELAY,
max_delay=RETRY_MAX_DELAY,
backoff=RETRY_BACKOFF_FACTOR,
)
def impl() -> None:
"""
Upload local state store to CDF
"""
with self.lock:
self._cdf_client.raw.rows.insert(
db_name=self.database,
table_name=self.table,
row=self._local_state,
ensure_parent=True,
)

impl()

def initialize(self, force: bool = False) -> None:
@retry(
exceptions=cognite_exceptions(),
cancellation_token=self.cancellation_token,
tries=RETRIES,
delay=RETRY_DELAY,
max_delay=RETRY_MAX_DELAY,
backoff=RETRY_BACKOFF_FACTOR,
)
def impl() -> None:
"""
Get all known states.
Args:
force: Enable re-initialization, ie overwrite when called multiple times
"""
if self._initialized and not force:
return

rows: Iterable[Row]
try:
rows = self._cdf_client.raw.rows.list(db_name=self.database, table_name=self.table, limit=None)
except CogniteAPIError as e:
if e.code == 404:
rows = []
else:
raise e

with self.lock:
self._local_state.clear()
for row in rows:
if row.key is None or row.columns is None:
self.logger.warning(f"None encountered in row: {str(row)}")
# should never happen, but type from sdk is optional
continue
state = row.columns.get("digest")
if state:
self._local_state[row.key] = {"digest": state}

self._initialized = True

impl()

def __enter__(self) -> "RawHashStateStore":
"""
Wraps around start method, for use as context manager
Returns:
self
"""
self.start()
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
"""
Wraps around stop method, for use as context manager
Args:
exc_type: Exception type
exc_val: Exception value
exc_tb: Traceback
"""
self.stop()


class LocalHashStateStore(AbstractHashStateStore):
def __init__(
self,
file_path: str,
save_interval: Optional[int] = None,
trigger_log_level: str = "DEBUG",
thread_name: Optional[str] = None,
cancellation_token: Optional[CancellationToken] = None,
) -> None:
super().__init__(
save_interval=save_interval,
trigger_log_level=trigger_log_level,
thread_name=thread_name,
cancellation_token=cancellation_token,
)

self._file_path = file_path

def initialize(self, force: bool = False) -> None:
"""
Load states from specified JSON file
Args:
force: Enable re-initialization, ie overwrite when called multiple times
"""
if self._initialized and not force:
return

with self.lock:
try:
with open(self._file_path, "r") as f:
self._local_state = json.load(f, cls=_DecimalDecoder)
except FileNotFoundError:
pass
except json.decoder.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in state store file: {str(e)}") from e

self._initialized = True

def synchronize(self) -> None:
"""
Save states to specified JSON file
"""
with self.lock:
with open(self._file_path, "w") as f:
json.dump(self._local_state, f, cls=_DecimalEncoder)

def __enter__(self) -> "LocalHashStateStore":
"""
Wraps around start method, for use as context manager
Returns:
self
"""
self.start()
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
"""
Wraps around stop method, for use as context manager
Args:
exc_type: Exception type
exc_val: Exception value
exc_tb: Traceback
"""
self.stop()
Loading

0 comments on commit 349fa23

Please sign in to comment.