From b10bed7e7c3e91e474bf4f743267e929590fc1e6 Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Fri, 22 Nov 2024 11:59:19 +0100 Subject: [PATCH] feat (WMS): Improve caching performance of Limiter --- .../Client/Limiter.py | 134 ++++++++++++++++-- 1 file changed, 121 insertions(+), 13 deletions(-) diff --git a/src/DIRAC/WorkloadManagementSystem/Client/Limiter.py b/src/DIRAC/WorkloadManagementSystem/Client/Limiter.py index 9037a4aa669..d0c005613b7 100644 --- a/src/DIRAC/WorkloadManagementSystem/Client/Limiter.py +++ b/src/DIRAC/WorkloadManagementSystem/Client/Limiter.py @@ -2,6 +2,15 @@ Utilities and classes here are used by the Matcher """ +import threading +from collections import defaultdict +from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor, wait, Future +from functools import partial +from typing import Any + +from cachetools import TTLCache + from DIRAC import S_OK, S_ERROR from DIRAC import gLogger @@ -12,10 +21,109 @@ from DIRAC.WorkloadManagementSystem.Client import JobStatus +class TwoLevelCache: + """A two-level caching system with soft and hard time-to-live (TTL) expiration. + + This cache implements a two-tier caching mechanism to allow for background refresh + of cached values. It uses a soft TTL for quick access and a hard TTL as a fallback, + which helps in reducing latency and maintaining data freshness. + + Attributes: + soft_cache (TTLCache): A cache with a shorter TTL for quick access. + hard_cache (TTLCache): A cache with a longer TTL as a fallback. + locks (defaultdict): Thread-safe locks for each cache key. + futures (dict): Stores ongoing asynchronous population tasks. + pool (ThreadPoolExecutor): Thread pool for executing cache population tasks. + + Args: + soft_ttl (int): Time-to-live in seconds for the soft cache. + hard_ttl (int): Time-to-live in seconds for the hard cache. + max_workers (int): Maximum number of workers in the thread pool. + max_items (int): Maximum number of items in the cache. + + Example: + >>> cache = TwoLevelCache(soft_ttl=60, hard_ttl=300) + >>> def populate_func(): + ... return "cached_value" + >>> value = cache.get("key", populate_func) + + Note: + The cache uses a ThreadPoolExecutor with a maximum of 10 workers to + handle concurrent cache population requests. + """ + + def __init__(self, soft_ttl: int, hard_ttl: int, *, max_workers: int = 10, max_items: int = 1_000_000): + """Initialize the TwoLevelCache with specified TTLs.""" + self.soft_cache = TTLCache(max_items, soft_ttl) + self.hard_cache = TTLCache(max_items, hard_ttl) + self.locks = defaultdict(threading.Lock) + self.futures: dict[str, Future] = {} + self.pool = ThreadPoolExecutor(max_workers=max_workers) + + def get(self, key: str, populate_func: Callable[[], Any]): + """Retrieve a value from the cache, populating it if necessary. + + This method first checks the soft cache for the key. If not found, + it checks the hard cache while initiating a background refresh. + If the key is not in either cache, it waits for the populate_func + to complete and stores the result in both caches. + + Locks are used to ensure there is never more than one concurrent + population task for a given key. + + Args: + key (str): The cache key to retrieve or populate. + populate_func (Callable[[], Any]): A function to call to populate the cache + if the key is not found. + + Returns: + Any: The cached value associated with the key. + + Note: + This method is thread-safe and handles concurrent requests for the same key. + """ + if result := self.soft_cache.get(key): + return result + with self.locks[key]: + if key not in self.futures: + self.futures[key] = self.pool.submit(self._work, key, populate_func) + if result := self.hard_cache.get(key): + self.soft_cache[key] = result + return result + # It is critical that ``future`` is waited for outside of the lock as + # _work aquires the lock before filling the caches. This also means + # we can gaurentee that the future has not yet been removed from the + # futures dict. + future = self.futures[key] + wait([future]) + return self.hard_cache[key] + + def _work(self, key: str, populate_func: Callable[[], Any]) -> None: + """Internal method to execute the populate_func and update caches. + + This method is intended to be run in a separate thread. It calls the + populate_func, stores the result in both caches, and cleans up the + associated future. + + Args: + key (str): The cache key to populate. + populate_func (Callable[[], Any]): The function to call to get the value. + + Note: + This method is not intended to be called directly by users of the class. + """ + result = populate_func() + with self.locks[key]: + self.futures.pop(key) + self.hard_cache[key] = result + self.soft_cache[key] = result + + class Limiter: # static variables shared between all instances of this class csDictCache = DictCache() condCache = DictCache() + newCache = TwoLevelCache(10, 300) delayMem = {} def __init__(self, jobDB=None, opsHelper=None, pilotRef=None): @@ -177,19 +285,7 @@ def __getRunningCondition(self, siteName, gridCE=None): if attName not in self.jobDB.jobAttributeNames: self.log.error("Attribute does not exist", f"({attName}). Check the job limits") continue - cK = f"Running:{siteName}:{attName}" - data = self.condCache.get(cK) - if not data: - result = self.jobDB.getCounters( - "Jobs", - [attName], - {"Site": siteName, "Status": [JobStatus.RUNNING, JobStatus.MATCHED, JobStatus.STALLED]}, - ) - if not result["OK"]: - return result - data = result["Value"] - data = {k[0][attName]: k[1] for k in data} - self.condCache.add(cK, 10, data) + data = self.newCache.get(f"Running:{siteName}:{attName}", partial(self._countsByJobType, siteName, attName)) for attValue in limitsDict[attName]: limit = limitsDict[attName][attValue] running = data.get(attValue, 0) @@ -249,3 +345,15 @@ def __getDelayCondition(self, siteName): negCond[attName] = [] negCond[attName].append(attValue) return S_OK(negCond) + + def _countsByJobType(self, siteName, attName): + result = self.jobDB.getCounters( + "Jobs", + [attName], + {"Site": siteName, "Status": [JobStatus.RUNNING, JobStatus.MATCHED, JobStatus.STALLED]}, + ) + if not result["OK"]: + return result + data = result["Value"] + data = {k[0][attName]: k[1] for k in data} + return data