Skip to content

Commit

Permalink
Remove expired JWTs as needed
Browse files Browse the repository at this point in the history
  • Loading branch information
kesmit13 committed Jul 10, 2024
1 parent be4c9c8 commit ec4452f
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions singlestoredb/management/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,14 @@ def get(self, name_or_id: str, *default: Any) -> Any:
def _setup_authentication_info_handler() -> Callable[..., Dict[str, Any]]:
"""Setup authentication info event handler."""

authentication_info: List[Tuple[str, Any]] = []
authentication_info: Dict[str, Any] = {}

def handle_authentication_info(msg: Dict[str, Any]) -> None:
"""Handle authentication info events."""
nonlocal authentication_info
if msg.get('name', '') != 'singlestore.portal.authentication_updated':
return
authentication_info = list(msg.get('data', {}).items())
authentication_info = dict(msg.get('data', {}))

events.subscribe(handle_authentication_info)

Expand All @@ -145,11 +145,27 @@ def handle_connection_info(msg: Dict[str, Any]) -> None:
out['user'] = data['user']
if 'password' in data:
out['password'] = data['password']
authentication_info = list(out.items())
authentication_info = out

events.subscribe(handle_authentication_info)

def retrieve_current_authentication_info() -> List[Tuple[str, Any]]:
"""Retrieve JWT if not expired."""
nonlocal authentication_info
password = authentication_info.get('password')
if password:
expires = datetime.datetime.fromtimestamp(
jwt.decode(
password,
options={'verify_signature': False},
)['exp'],
)
if datetime.datetime.now() > expires:
authentication_info = {}
return list(authentication_info.items())

def get_env() -> List[Tuple[str, Any]]:
"""Retrieve JWT from environment."""
conn = {}
url = os.environ.get('SINGLESTOREDB_URL') or get_option('host')
if url:
Expand All @@ -170,7 +186,7 @@ def get_authentication_info(include_env: bool = True) -> Dict[str, Any]:
return dict(
itertools.chain(
(get_env() if include_env else []),
authentication_info,
retrieve_current_authentication_info(),
),
)

Expand Down

0 comments on commit ec4452f

Please sign in to comment.