diff --git a/aerospike/aerospikehdf.py b/aerospike/aerospikehdf.py index 1db32b8e3..a2bfec35c 100644 --- a/aerospike/aerospikehdf.py +++ b/aerospike/aerospikehdf.py @@ -6,15 +6,14 @@ import statistics import json -from enum import Flag, auto -from typing import Iterable, List, Union, Any -from importlib.metadata import version +from typing import List, Union, Any from math import sqrt -from aerospike_vector_search import types as vectorTypes, Client as vectorSyncClient +from aerospike_vector_search import types as vectorTypes from aerospike_vector_search.aio import AdminClient as vectorASyncAdminClient, Client as vectorASyncClient from aerospike_vector_search.shared.proto_generated.types_pb2_grpc import grpc as vectorResultCodes +from dynamic_throttle import DynamicThrottle from baseaerospike import BaseAerospike, _distanceNameToAerospikeType as DistanceMaps, _distanceAerospikeTypeToAnn as DistanceMapsAnn, OperationActions from datasets import DATASETS, load_and_transform_dataset, get_dataset_fn from metrics import all_metrics as METRICS, DummyMetric @@ -239,6 +238,12 @@ def parse_arguments_query(parser: argparse.ArgumentParser) -> None: help="Don't adjust the distance returned by Aerospike based on the distance type (e.g., Square-Euclidean)", action='store_true' ) + parser.add_argument( + "--target-tps", + help="Target TPS for query", + default=0, + type=int + ) BaseAerospike.parse_arguments(parser) @@ -304,6 +309,14 @@ def __init__(self, runtimeArgs: argparse.Namespace, actions: OperationActions): self._idx_name = runtimeArgs.idxname self._query_runs = runtimeArgs.runs self._query_parallel = runtimeArgs.parallel + self._target_tps = runtimeArgs.target_tps + + self._throttler: list[DynamicThrottle] = [] + + num_threads = self._query_runs if self._query_parallel else 1 + for thread in range(self._query_runs): + self._throttler.append(DynamicThrottle(self._target_tps, num_threads)) + self._query_check = runtimeArgs.check self._query_nbrlimit = runtimeArgs.limit self._query_metric = METRICS[runtimeArgs.metric] @@ -590,7 +603,7 @@ async def populate(self) -> None: self.prometheus_status() async with vectorASyncAdminClient(seeds=self._host, - listener_name=self._listern, + listener_name=self._listener, is_loadbalancer=self._useloadbalancer ) as adminClient: @@ -624,7 +637,7 @@ async def populate(self) -> None: await self.create_index(adminClient) async with vectorASyncClient(seeds=self._host, - listener_name=self._listern, + listener_name=self._listener, is_loadbalancer=self._useloadbalancer ) as client: @@ -791,7 +804,7 @@ async def query(self) -> None: self.prometheus_status() async with vectorASyncAdminClient(seeds=self._host, - listener_name=self._listern, + listener_name=self._listener, is_loadbalancer=self._useloadbalancer ) as adminClient: @@ -823,14 +836,13 @@ async def query(self) -> None: distancemetric = DISTANCES[self._query_distancecalc] async with vectorASyncClient(seeds=self._host, - listener_name=self._listern, + listener_name=self._listener, is_loadbalancer=self._useloadbalancer ) as client: self._heartbeat_stage = 3 self.prometheus_status() - s = time.time() totalquerytime : float = 0.0 taskPuts = [] queries = [] @@ -887,7 +899,6 @@ async def query(self) -> None: totalquerytime += sum(times) i += 1 - t = time.time() self._query_metric_value = statistics.mean(metricValues) self._aerospike_metric_value = statistics.mean(metricValuesAS) self._query_metric_big_value = statistics.mean(metricValuesBig) @@ -1032,6 +1043,9 @@ async def vector_search(self, client:vectorASyncClient, query:List[float], runNb latency = (t-s)*math.pow(10,-6) self._query_counter.add(1, {"type": "Vector Search","ns":self._idx_namespace,"idx":self._idx_name, "run": runNbr}) self._query_histogram.record(latency, {"ns":self._idx_namespace,"idx": self._idx_name, "run": runNbr}) + + await self._throttler[runNbr-1].throttle() + except vectorTypes.AVSServerError as e: if "unknown vector" in e.rpc_error.details(): self._exception_counter.add(1, {"exception_type":f"vector_search: {e.rpc_error.details()}", "handled_by_user":True,"ns":self._idx_namespace,"set":self._idx_name,"run":runNbr}) diff --git a/aerospike/baseaerospike.py b/aerospike/baseaerospike.py index d5389163d..8d38054f6 100644 --- a/aerospike/baseaerospike.py +++ b/aerospike/baseaerospike.py @@ -1,5 +1,4 @@ import asyncio -import os import numpy as np import time import logging @@ -23,8 +22,7 @@ from aerospike_vector_search import types as vectorTypes from aerospike_vector_search import AdminClient as vectorAdminClient -from metrics import all_metrics as METRICS -from helpers import set_hnsw_params_attrs, hnswstr +from helpers import hnswstr from dsiterator import DSIterator _distanceNameToAerospikeType: Dict[str, vectorTypes.VectorDistanceMetric] = { @@ -168,7 +166,7 @@ def __init__(self, runtimeArgs: argparse.Namespace, logger: logging.Logger): from dshdfiterator import DSHDFIterator DSHDFIterator.set_storage_threshold(runtimeArgs.storagethreshold) - self._listern = None + self._listener = None self._useloadbalancer = runtimeArgs.vectorloadbalancer self._namespace : str = None @@ -451,7 +449,7 @@ def vector_queue_status(self, adminclient : vectorAdminClient, queryapi:bool = T try: self._vector_queue_depth = adminclient.index_get_status(namespace=self._namespace, name=self._idx_name, - timeout=2) + timeout=2).unmerged_record_count except vectorTypes.AVSServerError as avse: self._vector_queue_depth = None @@ -477,7 +475,7 @@ def _vector_queue_heartbeat(self) -> None: try: with vectorAdminClient(seeds=self._host, - listener_name=self._listern, + listener_name=self._listener, is_loadbalancer=self._useloadbalancer ) as adminClient: self._logger.debug(f"Vector Heartbeating Start") @@ -597,4 +595,4 @@ def __str__(self): else: fullName = f"{self._namespace}.{self._setName}.{self._idx_namespace}.{self._idx_name}" - return f"{fullName}({self._datasetname})" \ No newline at end of file + return f"{fullName}({self._datasetname})" diff --git a/aerospike/dynamic_throttle.py b/aerospike/dynamic_throttle.py new file mode 100644 index 000000000..e25d72a94 --- /dev/null +++ b/aerospike/dynamic_throttle.py @@ -0,0 +1,82 @@ +import asyncio +import time + + +class DynamicThrottle: + _throttle_startup_count: int = 20 + _throttle_alpha: float = 1.0 / _throttle_startup_count + + def __init__(self, tps: float, num_threads: int = 1) -> None: + """ + Initialize a DynamicThrottle instance with a target TPS (transactions per second). + + :param tps: Target transactions per second (0 means no throttling) + :param num_threads: Number of concurrent threads (default: 1) + """ + if tps == 0: + # No throttling + self.target_period: float = 0 + else: + # Calculate the target period in seconds + self.target_period: float = num_threads / tps + + self.avg_fn_delay: float = 0.0 + self.n_records: int = 0 + self.last_record_timestamp: float = 0.0 + + @staticmethod + def ramp(value: float) -> float: + """ + Ensure non-negative value for pause. + + :param value: The value to validate. + :return: Non-negative value. + """ + return max(0.0, value) + + def pause_for_duration(self) -> float: + """ + Calculate pause duration based on the current record timestamp. + + :return: Pause duration in seconds. + """ + + # Get the current time in seconds + current_record_timestamp: float = time.time() + + if self.n_records < self._throttle_startup_count: + # During initial calls + if self.n_records == 0: + pause_for = self.target_period + else: + alpha = 1.0 / self.n_records + avg = self.avg_fn_delay + avg = (1 - alpha) * avg + alpha * (current_record_timestamp - self.last_record_timestamp) + self.avg_fn_delay = avg + pause_for = self.target_period - avg + else: + # After sufficient records have been logged + avg = self.avg_fn_delay + avg = (1 - self._throttle_alpha) * avg + self._throttle_alpha * (current_record_timestamp - self.last_record_timestamp) + self.avg_fn_delay = avg + pause_for = self.target_period - avg + + # Ensure non-negative pause + pause_for = self.ramp(pause_for) + + # Update last record and record count + self.last_record_timestamp = current_record_timestamp + pause_for + self.n_records += 1 + return pause_for + + async def throttle(self) -> None: + """ + Throttle execution to maintain the target period. + """ + if self.target_period == 0: + return + + pause_duration: float = self.pause_for_duration() + + # Sleep for the calculated duration in seconds + await asyncio.sleep(pause_duration) diff --git a/aerospike/helpers.py b/aerospike/helpers.py index 21695e4a8..4208f044d 100644 --- a/aerospike/helpers.py +++ b/aerospike/helpers.py @@ -50,11 +50,11 @@ def set_hnsw_params_attrs(__obj :object, __dict: dict) -> object: def hnswstr(hnswparams : vectorTypes.HnswParams) -> str: if hnswparams is None: return '' - if hnswparams.batching_params is None: + if not hasattr(hnswparams.batching_params, 'max_records') or hnswparams.batching_params is None: batchingparams = '' else: batchingparams = f"maxrecs:{hnswparams.batching_params.max_records}, interval:{hnswparams.batching_params.interval}" - if hnswparams.caching_params is None: + if not hasattr(hnswparams, 'caching_params') or hnswparams.caching_params is None: cachingparams = '' else: cachingparams = f"max_entries:{hnswparams.caching_params.max_entries}, expiry:{hnswparams.caching_params.expiry}"