diff --git a/bubuku/amazon.py b/bubuku/amazon.py index 2cfbd54..fec8c79 100644 --- a/bubuku/amazon.py +++ b/bubuku/amazon.py @@ -35,7 +35,7 @@ def get_own_ip(self) -> str: doc = self._get_document() return doc['privateIp'] if doc else '127.0.0.1' - def get_addresses_by_lb_name(self, lb_name): + def get_addresses_by_lb_name(self, lb_name) -> list: region = self.get_aws_region() private_ips = [] diff --git a/bubuku/daemon.py b/bubuku/daemon.py index 29af1f9..eb8d637 100644 --- a/bubuku/daemon.py +++ b/bubuku/daemon.py @@ -16,6 +16,7 @@ from bubuku.id_generator import get_broker_id_policy from bubuku.utils import CmdHelper from bubuku.zookeeper import BukuExhibitor, load_exhibitor_proxy +from bubuku.zookeeper.exhibior import AWSExhibitorAddressProvider _LOG = logging.getLogger('bubuku.main') @@ -51,8 +52,10 @@ def main(): amazon = Amazon() + address_provider = AWSExhibitorAddressProvider(amazon, config.zk_stack_name) + _LOG.info("Loading exhibitor configuration") - buku_proxy = load_exhibitor_proxy(amazon.get_addresses_by_lb_name(config.zk_stack_name), config.zk_prefix) + buku_proxy = load_exhibitor_proxy(address_provider, config.zk_prefix) _LOG.info("Loading broker_id policy") broker_id_manager = get_broker_id_policy(config.id_policy, buku_proxy, kafka_properties, amazon) diff --git a/bubuku/zookeeper.py b/bubuku/zookeeper/__init__.py similarity index 70% rename from bubuku/zookeeper.py rename to bubuku/zookeeper/__init__.py index c6b2573..4d85454 100644 --- a/bubuku/zookeeper.py +++ b/bubuku/zookeeper/__init__.py @@ -1,68 +1,14 @@ import json import logging -import random import threading import time -import requests from kazoo.client import KazooClient -from kazoo.exceptions import NoNodeError, NodeExistsError, ConnectionLossException -from requests.exceptions import RequestException +from kazoo.exceptions import NodeExistsError, NoNodeError, ConnectionLossException _LOG = logging.getLogger('bubuku.exhibitor') -class ExhibitorEnsembleProvider: - TIMEOUT = 3.1 - - def __init__(self, hosts, port, uri_path='/exhibitor/v1/cluster/list', poll_interval=300): - self._exhibitor_port = port - self._uri_path = uri_path - self._poll_interval = poll_interval - self._exhibitors = hosts - self._master_exhibitors = hosts - self._zookeeper_hosts = '' - self._next_poll = None - while not self.poll(): - _LOG.info('waiting on exhibitor') - time.sleep(5) - - def poll(self): - if self._next_poll and self._next_poll > time.time(): - return False - - json_ = self._query_exhibitors(self._exhibitors) - if not json_: - json_ = self._query_exhibitors(self._master_exhibitors) - - if isinstance(json_, dict) and 'servers' in json_ and 'port' in json_: - self._next_poll = time.time() + self._poll_interval - zookeeper_hosts = ','.join([h + ':' + str(json_['port']) for h in sorted(json_['servers'])]) - if self._zookeeper_hosts != zookeeper_hosts: - _LOG.info('ZooKeeper connection string has changed: %s => %s', self._zookeeper_hosts, zookeeper_hosts) - self._zookeeper_hosts = zookeeper_hosts - self._exhibitors = json_['servers'] - return True - return False - - def _query_exhibitors(self, exhibitors): - if exhibitors == [None]: - return {'servers': ['localhost'], 'port': 2181} - random.shuffle(exhibitors) - for host in exhibitors: - uri = 'http://{}:{}{}'.format(host, self._exhibitor_port, self._uri_path) - try: - response = requests.get(uri, timeout=self.TIMEOUT) - return response.json() - except RequestException: - pass - return None - - @property - def zookeeper_hosts(self): - return self._zookeeper_hosts - - class WaitingCounter(object): def __init__(self, limit=100): self.limit = limit @@ -81,34 +27,84 @@ def decrement(self): self.cv.notify() -class _Exhibitor: - def __init__(self, hosts, port, prefix): - self.prefix = prefix +class SlowlyUpdatedCache(object): + def __init__(self, load_func, update_func, refresh_timeout, delay): + self.load_func = load_func + self.update_func = update_func + self.refresh_timeout = refresh_timeout + self.delay = delay + self.value = None + self.last_check = None + self.next_apply = None + self.force = True + + def __str__(self): + return 'SlowCache(refresh={}, delay={}, last_check={}, next_apply={})'.format( + self.refresh_timeout, self.delay, self.last_check, self.next_apply) + + def touch(self): + now = time.time() + if self.last_check is None or (now - self.last_check) > self.refresh_timeout: + value = None + if self.force: + while value is None: + value = self.load_func() + self.force = False + else: + value = self.load_func() + if value is not None and value != self.value: + self.value = value + self.next_apply = (now + self.delay) if self.last_check is not None else now + self.last_check = now + if self.next_apply is not None and self.next_apply - now <= 0: + self.update_func(self.value) + self.next_apply = None + + +class AddressListProvider(object): + def get_latest_address(self) -> (list, int): + """ + Loads current address list from service. Can return None if value can't be refreshed at the moment + :return: tuple of hosts, port for zookeeper + """ + raise NotImplementedError + + +class _ZookeeperProxy(object): + def __init__(self, address_provider: AddressListProvider, prefix: str): + self.address_provider = address_provider self.async_counter = WaitingCounter(limit=100) - self.exhibitor = ExhibitorEnsembleProvider(hosts, port, poll_interval=30) - self.client = KazooClient(hosts=self.exhibitor.zookeeper_hosts + self.prefix, - command_retry={ - 'deadline': 10, - 'max_delay': 1, - 'max_tries': -1}, - connection_retry={'max_delay': 1, 'max_tries': -1}) - self.client.add_listener(self.session_listener) + self.conn_str = None + self.client = None + self.prefix = prefix + self.hosts_cache = SlowlyUpdatedCache( + self.address_provider.get_latest_address, + self._update_hosts, + 30, # Refresh every 30 seconds + 3 * 60) # Update only after 180 seconds of stability + + def _update_hosts(self, value): + hosts, port = value + self.conn_str = ','.join(['{}:{}'.format(h, port) for h in hosts]) + self.prefix + + if self.client is None: + self.client = KazooClient(hosts=self.conn_str, + command_retry={'deadline': 10, 'max_delay': 1, 'max_tries': -1}, + connection_retry={'max_delay': 1, 'max_tries': -1}) + self.client.add_listener(self.session_listener) + else: + self.client.stop() + self.client.set_hosts(self.conn_str) self.client.start() def session_listener(self, state): pass def get_conn_str(self): - return self.exhibitor.zookeeper_hosts + self.prefix - - def _poll_exhibitor(self): - if self.exhibitor.poll(): - self.client.stop() - self.client.set_hosts(self.get_conn_str()) - self.client.start() + return self.conn_str def get(self, *params): - self._poll_exhibitor() + self.hosts_cache.touch() return self.client.retry(self.client.get, *params) def get_async(self, *params): @@ -123,19 +119,19 @@ def get_async(self, *params): raise e def set(self, *args, **kwargs): - self._poll_exhibitor() + self.hosts_cache.touch() return self.client.retry(self.client.set, *args, **kwargs) def create(self, *args, **kwargs): - self._poll_exhibitor() + self.hosts_cache.touch() return self.client.retry(self.client.create, *args, **kwargs) def delete(self, *args, **kwargs): - self._poll_exhibitor() + self.hosts_cache.touch() return self.client.retry(self.client.delete, *args, **kwargs) def get_children(self, *params): - self._poll_exhibitor() + self.hosts_cache.touch() try: return self.client.retry(self.client.get_children, *params) except NoNodeError: @@ -144,14 +140,14 @@ def get_children(self, *params): def take_lock(self, *args, **kwargs): while True: try: - self._poll_exhibitor() + self.hosts_cache.touch() return self.client.Lock(*args, **kwargs) except Exception as e: _LOG.error('Failed to obtain lock for exhibitor, retrying', exc_info=e) class BukuExhibitor(object): - def __init__(self, exhibitor: _Exhibitor, async=True): + def __init__(self, exhibitor: _ZookeeperProxy, async=True): self.exhibitor = exhibitor self.async = async try: @@ -288,5 +284,6 @@ def unregister_change(self, name): self.exhibitor.delete('/bubuku/changes/{}'.format(name), recursive=True) -def load_exhibitor_proxy(initial_hosts: list, zookeeper_prefix) -> BukuExhibitor: - return BukuExhibitor(_Exhibitor(initial_hosts, 8181, zookeeper_prefix)) +def load_exhibitor_proxy(address_provider: AddressListProvider, prefix: str) -> BukuExhibitor: + proxy = _ZookeeperProxy(address_provider, prefix) + return BukuExhibitor(proxy) diff --git a/bubuku/zookeeper/exhibior.py b/bubuku/zookeeper/exhibior.py new file mode 100644 index 0000000..25fe489 --- /dev/null +++ b/bubuku/zookeeper/exhibior.py @@ -0,0 +1,36 @@ +import logging +import random + +import requests +from requests import RequestException + +from bubuku.amazon import Amazon +from bubuku.zookeeper import AddressListProvider + +_LOG = logging.getLogger('bubuku.zookeeper.exhibitor') + + +class AWSExhibitorAddressProvider(AddressListProvider): + def __init__(self, amazon: Amazon, zk_stack_name: str): + self.master_exhibitors = amazon.get_addresses_by_lb_name(zk_stack_name) + self.exhibitors = list(self.master_exhibitors) + + def get_latest_address(self) -> (list, int): + json_ = self._query_exhibitors(self.exhibitors) + if not json_: + json_ = self._query_exhibitors(self.master_exhibitors) + if isinstance(json_, dict) and 'servers' in json_ and 'port' in json_: + self.exhibitors = json_['servers'] + return json_['servers'], int(json_['port']) + return None + + def _query_exhibitors(self, exhibitors): + random.shuffle(exhibitors) + for host in exhibitors: + url = 'http://{}:{}{}'.format(host, 8181, '/exhibitor/v1/cluster/list') + try: + response = requests.get(url, timeout=3.1) + return response.json() + except RequestException as e: + _LOG.warn('Failed to query zookeeper list information from {}'.format(url), exc_info=e) + return None diff --git a/tests/test_exhibitor.py b/tests/test_zookeeper.py similarity index 69% rename from tests/test_exhibitor.py rename to tests/test_zookeeper.py index 806124a..e9a3f2d 100644 --- a/tests/test_exhibitor.py +++ b/tests/test_zookeeper.py @@ -1,10 +1,13 @@ import json +import math import re +import time +import unittest from unittest.mock import MagicMock from kazoo.exceptions import NoNodeError, NodeExistsError -from bubuku.zookeeper import BukuExhibitor +from bubuku.zookeeper import BukuExhibitor, SlowlyUpdatedCache def test_get_broker_ids(): @@ -197,3 +200,91 @@ def _create(path, value=None, **kwargs): assert buku.reallocate_partition('t01', 0, [1, 2, 3]) # Node exists assert not buku.reallocate_partition('t01', 0, [1, 2, 3]) + + +class SlowlyUpdatedCacheTest(unittest.TestCase): + def test_initial_update_fast(self): + result = [None] + + def _update(value_): + result[0] = value_ + + cache = SlowlyUpdatedCache(lambda: (['test'], 1), _update, 0, 0) + + cache.touch() + assert result[0] == (['test'], 1) + + def test_initial_update_slow(self): + result = [None] + call_count = [0] + + def _load(): + call_count[0] += 1 + if call_count[0] == 100: + return ['test'], 1 + return None + + def _update(value_): + result[0] = value_ + + cache = SlowlyUpdatedCache(_load, _update, 0, 0) + + cache.touch() + assert call_count[0] == 100 + assert result[0] == (['test'], 1) + + def test_delays_illegal(self): + result = [None] + load_calls = [] + update_calls = [] + + def _load(): + load_calls.append(time.time()) + return ['test'], 0 if len(load_calls) > 1 else 1 + + def _update(value_): + update_calls.append(time.time()) + result[0] = value_ + + # refresh every 1 second, delay 0.5 second + cache = SlowlyUpdatedCache(_load, _update, 0.5, 0.25) + + while len(update_calls) != 2: + time.sleep(0.1) + cache.touch() + print(cache) + + assert math.fabs(update_calls[0] - load_calls[0]) <= 0.15 # 0.1 + 0.1/2 + # Verify that load calls were made one by another + assert math.fabs(load_calls[1] - load_calls[0] - .5) <= 0.15 + # Verity that update call was made in correct interval + + assert load_calls[1] + 0.25 <= update_calls[1] <= load_calls[1] + 0.25 + 0.15 + + def test_delays_legal(self): + result = [None] + main_call = [] + load_calls = [] + update_calls = [] + + def _load(): + load_calls.append(time.time()) + if len(load_calls) == 5: + main_call.append(time.time()) + return ['test'], 0 if len(load_calls) >= 5 else len(load_calls) + + def _update(value_): + update_calls.append(time.time()) + result[0] = value_ + + # refresh every 1 second, delay 5 second - in case where situation is constantly changing - wait for + # last stable update + cache = SlowlyUpdatedCache(_load, _update, 0.5, 3) + + while len(update_calls) != 2: + time.sleep(0.1) + cache.touch() + print(cache) + + assert len(main_call) == 1 + assert main_call[0] + 3 - .15 < update_calls[1] < main_call[0] + 3 + .15