Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TPS limiter to mini-ann #2

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions aerospike/aerospikehdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@
import statistics
import json

from enum import Flag, auto
from typing import Iterable, List, Union, Any
from importlib.metadata import version
from typing import List, Union, Any
from math import sqrt

from aerospike_vector_search import types as vectorTypes, Client as vectorSyncClient
from aerospike_vector_search import types as vectorTypes
from aerospike_vector_search.aio import AdminClient as vectorASyncAdminClient, Client as vectorASyncClient
from aerospike_vector_search.shared.proto_generated.types_pb2_grpc import grpc as vectorResultCodes

from dynamic_throttle import DynamicThrottle
from baseaerospike import BaseAerospike, _distanceNameToAerospikeType as DistanceMaps, _distanceAerospikeTypeToAnn as DistanceMapsAnn, OperationActions
from datasets import DATASETS, load_and_transform_dataset, get_dataset_fn
from metrics import all_metrics as METRICS, DummyMetric
Expand Down Expand Up @@ -239,6 +238,12 @@ def parse_arguments_query(parser: argparse.ArgumentParser) -> None:
help="Don't adjust the distance returned by Aerospike based on the distance type (e.g., Square-Euclidean)",
action='store_true'
)
parser.add_argument(
"--target-tps",
help="Target TPS for query",
default=0,
type=int
)

BaseAerospike.parse_arguments(parser)

Expand Down Expand Up @@ -304,6 +309,14 @@ def __init__(self, runtimeArgs: argparse.Namespace, actions: OperationActions):
self._idx_name = runtimeArgs.idxname
self._query_runs = runtimeArgs.runs
self._query_parallel = runtimeArgs.parallel
self._target_tps = runtimeArgs.target_tps

self._throttler: list[DynamicThrottle] = []

num_threads = self._query_runs if self._query_parallel else 1
for thread in range(self._query_runs):
self._throttler.append(DynamicThrottle(self._target_tps, num_threads))

self._query_check = runtimeArgs.check
self._query_nbrlimit = runtimeArgs.limit
self._query_metric = METRICS[runtimeArgs.metric]
Expand Down Expand Up @@ -590,7 +603,7 @@ async def populate(self) -> None:
self.prometheus_status()

async with vectorASyncAdminClient(seeds=self._host,
listener_name=self._listern,
listener_name=self._listener,
is_loadbalancer=self._useloadbalancer
) as adminClient:

Expand Down Expand Up @@ -624,7 +637,7 @@ async def populate(self) -> None:
await self.create_index(adminClient)

async with vectorASyncClient(seeds=self._host,
listener_name=self._listern,
listener_name=self._listener,
is_loadbalancer=self._useloadbalancer
) as client:

Expand Down Expand Up @@ -791,7 +804,7 @@ async def query(self) -> None:
self.prometheus_status()

async with vectorASyncAdminClient(seeds=self._host,
listener_name=self._listern,
listener_name=self._listener,
is_loadbalancer=self._useloadbalancer
) as adminClient:

Expand Down Expand Up @@ -823,14 +836,13 @@ async def query(self) -> None:
distancemetric = DISTANCES[self._query_distancecalc]

async with vectorASyncClient(seeds=self._host,
listener_name=self._listern,
listener_name=self._listener,
is_loadbalancer=self._useloadbalancer
) as client:

self._heartbeat_stage = 3
self.prometheus_status()

s = time.time()
totalquerytime : float = 0.0
taskPuts = []
queries = []
Expand Down Expand Up @@ -887,7 +899,6 @@ async def query(self) -> None:
totalquerytime += sum(times)
i += 1

t = time.time()
self._query_metric_value = statistics.mean(metricValues)
self._aerospike_metric_value = statistics.mean(metricValuesAS)
self._query_metric_big_value = statistics.mean(metricValuesBig)
Expand Down Expand Up @@ -1032,6 +1043,9 @@ async def vector_search(self, client:vectorASyncClient, query:List[float], runNb
latency = (t-s)*math.pow(10,-6)
self._query_counter.add(1, {"type": "Vector Search","ns":self._idx_namespace,"idx":self._idx_name, "run": runNbr})
self._query_histogram.record(latency, {"ns":self._idx_namespace,"idx": self._idx_name, "run": runNbr})

await self._throttler[runNbr-1].throttle()

except vectorTypes.AVSServerError as e:
if "unknown vector" in e.rpc_error.details():
self._exception_counter.add(1, {"exception_type":f"vector_search: {e.rpc_error.details()}", "handled_by_user":True,"ns":self._idx_namespace,"set":self._idx_name,"run":runNbr})
Expand Down
12 changes: 5 additions & 7 deletions aerospike/baseaerospike.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import os
import numpy as np
import time
import logging
Expand All @@ -23,8 +22,7 @@
from aerospike_vector_search import types as vectorTypes
from aerospike_vector_search import AdminClient as vectorAdminClient

from metrics import all_metrics as METRICS
from helpers import set_hnsw_params_attrs, hnswstr
from helpers import hnswstr
from dsiterator import DSIterator

_distanceNameToAerospikeType: Dict[str, vectorTypes.VectorDistanceMetric] = {
Expand Down Expand Up @@ -168,7 +166,7 @@ def __init__(self, runtimeArgs: argparse.Namespace, logger: logging.Logger):
from dshdfiterator import DSHDFIterator
DSHDFIterator.set_storage_threshold(runtimeArgs.storagethreshold)

self._listern = None
self._listener = None
self._useloadbalancer = runtimeArgs.vectorloadbalancer

self._namespace : str = None
Expand Down Expand Up @@ -451,7 +449,7 @@ def vector_queue_status(self, adminclient : vectorAdminClient, queryapi:bool = T
try:
self._vector_queue_depth = adminclient.index_get_status(namespace=self._namespace,
name=self._idx_name,
timeout=2)
timeout=2).unmerged_record_count

except vectorTypes.AVSServerError as avse:
self._vector_queue_depth = None
Expand All @@ -477,7 +475,7 @@ def _vector_queue_heartbeat(self) -> None:

try:
with vectorAdminClient(seeds=self._host,
listener_name=self._listern,
listener_name=self._listener,
is_loadbalancer=self._useloadbalancer
) as adminClient:
self._logger.debug(f"Vector Heartbeating Start")
Expand Down Expand Up @@ -597,4 +595,4 @@ def __str__(self):
else:
fullName = f"{self._namespace}.{self._setName}.{self._idx_namespace}.{self._idx_name}"

return f"{fullName}({self._datasetname})"
return f"{fullName}({self._datasetname})"
82 changes: 82 additions & 0 deletions aerospike/dynamic_throttle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import asyncio
import time


class DynamicThrottle:
_throttle_startup_count: int = 20
_throttle_alpha: float = 1.0 / _throttle_startup_count

def __init__(self, tps: float, num_threads: int = 1) -> None:
"""
Initialize a DynamicThrottle instance with a target TPS (transactions per second).

:param tps: Target transactions per second (0 means no throttling)
:param num_threads: Number of concurrent threads (default: 1)
"""
if tps == 0:
# No throttling
self.target_period: float = 0
else:
# Calculate the target period in seconds
self.target_period: float = num_threads / tps

self.avg_fn_delay: float = 0.0
self.n_records: int = 0
self.last_record_timestamp: float = 0.0

@staticmethod
def ramp(value: float) -> float:
"""
Ensure non-negative value for pause.

:param value: The value to validate.
:return: Non-negative value.
"""
return max(0.0, value)

def pause_for_duration(self) -> float:
"""
Calculate pause duration based on the current record timestamp.

:return: Pause duration in seconds.
"""

# Get the current time in seconds
current_record_timestamp: float = time.time()

if self.n_records < self._throttle_startup_count:
# During initial calls
if self.n_records == 0:
pause_for = self.target_period
else:
alpha = 1.0 / self.n_records
avg = self.avg_fn_delay
avg = (1 - alpha) * avg + alpha * (current_record_timestamp - self.last_record_timestamp)
self.avg_fn_delay = avg
pause_for = self.target_period - avg
else:
# After sufficient records have been logged
avg = self.avg_fn_delay
avg = (1 - self._throttle_alpha) * avg + self._throttle_alpha * (current_record_timestamp - self.last_record_timestamp)
self.avg_fn_delay = avg
pause_for = self.target_period - avg

# Ensure non-negative pause
pause_for = self.ramp(pause_for)

# Update last record and record count
self.last_record_timestamp = current_record_timestamp + pause_for
self.n_records += 1
return pause_for

async def throttle(self) -> None:
"""
Throttle execution to maintain the target period.
"""
if self.target_period == 0:
return

pause_duration: float = self.pause_for_duration()

# Sleep for the calculated duration in seconds
await asyncio.sleep(pause_duration)
4 changes: 2 additions & 2 deletions aerospike/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ def set_hnsw_params_attrs(__obj :object, __dict: dict) -> object:
def hnswstr(hnswparams : vectorTypes.HnswParams) -> str:
if hnswparams is None:
return ''
if hnswparams.batching_params is None:
if not hasattr(hnswparams.batching_params, 'max_records') or hnswparams.batching_params is None:
batchingparams = ''
else:
batchingparams = f"maxrecs:{hnswparams.batching_params.max_records}, interval:{hnswparams.batching_params.interval}"
if hnswparams.caching_params is None:
if not hasattr(hnswparams, 'caching_params') or hnswparams.caching_params is None:
cachingparams = ''
else:
cachingparams = f"max_entries:{hnswparams.caching_params.max_entries}, expiry:{hnswparams.caching_params.expiry}"
Expand Down
Loading