Skip to content

Commit

Permalink
fixd bug in healthcheck, added suspensionhandler and more
Browse files Browse the repository at this point in the history
  • Loading branch information
Deutscher775 committed Oct 19, 2024
1 parent dbed987 commit fd86c7f
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 17 deletions.
8 changes: 4 additions & 4 deletions src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ async def get_endpoint(endpoint: int,
if token is not None:
if token == data_token or token == Bot.config.MASTER_TOKEN:
try:
return fastapi.responses.JSONResponse(status_code=200, content=await astroidapi.surrealdb_handler.get_endpoint(endpoint))
return fastapi.responses.JSONResponse(status_code=200, content=await astroidapi.surrealdb_handler.get_endpoint(endpoint, __file__))
except astroidapi.errors.SurrealDBHandler.EndpointNotFoundError as e:
return fastapi.responses.JSONResponse(status_code=404, content={"message": f"Endpoint {endpoint} not found."})
except astroidapi.errors.SurrealDBHandler.GetEndpointError as e:
Expand Down Expand Up @@ -284,7 +284,7 @@ async def get_bridges(endpoint: int,
if token is not None:
if token == data_token or token == Bot.config.MASTER_TOKEN:
try:
bridges_json = await astroidapi.surrealdb_handler.get_endpoint(endpoint)
bridges_json = await astroidapi.surrealdb_handler.get_endpoint(endpoint, __file__)
bridges_discord = []
bridges_guilded = []
bridges_revolt = []
Expand Down Expand Up @@ -410,7 +410,7 @@ async def post_endpoint(
beta=beta,
only_check=only_check,
)
return await astroidapi.surrealdb_handler.get_endpoint(endpoint)
return await astroidapi.surrealdb_handler.get_endpoint(endpoint, __file__)


@api.patch("/sync", description="Sync the local files with the database.")
Expand Down Expand Up @@ -627,7 +627,7 @@ async def delete_enpoint_data(endpoint: int,
if token is not None:
if token == data_token or token == Bot.config.MASTER_TOKEN:
try:
json_data = await astroidapi.surrealdb_handler.get_endpoint(endpoint)
json_data = await astroidapi.surrealdb_handler.get_endpoint(endpoint, __file__)
if webhook_discord:
json_data["config"]["webhooks"]["discord"].pop(json_data["config"]["webhooks"]["discord"].index(webhook_discord))
if webhook_guilded:
Expand Down
6 changes: 3 additions & 3 deletions src/astroidapi/endpoint_update_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def update_endpoint(

if token is not None:
if token == config.MASTER_TOKEN or token == data_token:
endpoint_data = await surrealdb_handler.get_endpoint(endpoint)
endpoint_data = await surrealdb_handler.get_endpoint(endpoint, __file__)
if endpoint_data:
if only_check:
if endpoint_data["meta"]["read"]["discord"] and endpoint_data["meta"]["read"]["guilded"] and endpoint_data["meta"]["read"]["revolt"]:
Expand Down Expand Up @@ -234,7 +234,7 @@ async def update_endpoint(
waiting_secs = 0
max_secs = 10
while True:
check_json = await surrealdb_handler.get_endpoint(endpoint)
check_json = await surrealdb_handler.get_endpoint(endpoint, __file__)
if (check_json["meta"]["read"]["discord"] == True
and check_json["meta"]["read"]["guilded"] == True
and check_json["meta"]["read"]["revolt"] == True
Expand Down Expand Up @@ -291,7 +291,7 @@ async def update_endpoint(
waiting_secs = 0
max_secs = 10
while True:
check_json = await surrealdb_handler.get_endpoint(endpoint)
check_json = await surrealdb_handler.get_endpoint(endpoint, __file__)
if check_json["meta"]["trigger"] is False and check_json["meta"]["message"]["content"] is None:
return fastapi.responses.JSONResponse(status_code=200, content=check_json)
if (check_json["meta"]["read"]["discord"] == True
Expand Down
6 changes: 3 additions & 3 deletions src/astroidapi/health_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
},
"blacklist": "config.blacklist",
"allowed-ids": "config.`allowed-ids`",
"isbeta": "config.isbeta`"
"isbeta": "config.isbeta"
},
"meta": {
"sender-channel": "meta.`sender-channel`",
Expand Down Expand Up @@ -122,7 +122,7 @@ async def check(cls, endpoint):
}
}
try:
endpoint_data = await surrealdb_handler.get_endpoint(endpoint)
endpoint_data = await surrealdb_handler.get_endpoint(endpoint, __file__)
for key in healthy_endpoint_data["config"].keys():
if key not in endpoint_data["config"]:
raise errors.HealtCheckError.EndpointCheckError.EndpointConfigError(f"'{key}' not found in endpoint config '{endpoint}'")
Expand Down Expand Up @@ -189,7 +189,7 @@ async def repair_structure(cls, endpoint):
}
}
try:
endpoint_data = await surrealdb_handler.get_endpoint(endpoint)
endpoint_data = await surrealdb_handler.get_endpoint(endpoint, __file__)
summary = []
try:
self_user = endpoint_data["config"]["self-user"]
Expand Down
4 changes: 2 additions & 2 deletions src/astroidapi/read_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ class ReadHandler:
async def mark_read(cls, endpoint, platform):
try:
print(f"Marking {platform} read")
endpoint_data = await surrealdb_handler.get_endpoint(endpoint)
endpoint_data = await surrealdb_handler.get_endpoint(endpoint, __file__)
if endpoint_data is None:
raise errors.SurrealDBHandler.EndpointNotFoundError(f"'{endpoint}' not found")
if await cls.check_read(endpoint, platform):
Expand All @@ -30,7 +30,7 @@ async def check_read(cls, endpoint, platform = "all", data: dict = None):
raise errors.ReadHandlerError.InvalidPlatformError(f"Invalid platform '{platform}'")
read = read[platform]
return read
endpoint = await surrealdb_handler.get_endpoint(endpoint)
endpoint = await surrealdb_handler.get_endpoint(endpoint, __file__)
if endpoint is None:
raise errors.SurrealDBHandler.EndpointNotFoundError(f"'{endpoint}' not found")
read = endpoint["meta"]["read"]
Expand Down
41 changes: 37 additions & 4 deletions src/astroidapi/surrealdb_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ async def sync_server_relations():



async def get_endpoint(endpoint: int):
async def get_endpoint(endpoint: int, caller: str):
try:
print(f"{endpoint} called by {__file__}")
print(f"{endpoint} called by {caller}")
async with Surreal(config.SDB_URL) as db:
await db.signin({"user": config.SDB_USER, "pass": config.SDB_PASS})
await db.use(config.SDB_NAMESPACE, config.SDB_DATABASE)
Expand Down Expand Up @@ -400,7 +400,8 @@ async def unsuspend(cls, endpoint_id):
async with Surreal(config.SDB_URL) as db:
await db.signin({"user": config.SDB_USER, "pass": config.SDB_PASS})
await db.use(config.SDB_NAMESPACE, config.SDB_DATABASE)
await db.delete(f"suspensions:`{endpoint_id}`")
await db.delete(endpoint_id)
print(f"Endpoint {endpoint_id} has been unsuspended")
return True
except Exception as e:
raise errors.SurrealDBHandler.UnsuspendEndpointError(e)
Expand All @@ -425,4 +426,36 @@ async def update(cls, endpoint_id, reason: str = None, suspended_by: int = None,
await db.update(f"suspensions:`{endpoint_id}`", current_data)
return await db.select(f"suspensions:`{endpoint_id}`")
except Exception as e:
raise errors.SurrealDBHandler.SuspendEndpointError(e)
raise errors.SurrealDBHandler.SuspendEndpointError(e)

@staticmethod
async def _checkexpireloop():
try:
async with Surreal(config.SDB_URL) as db:
await db.signin({"user": config.SDB_USER, "pass": config.SDB_PASS})
await db.use(config.SDB_NAMESPACE, config.SDB_DATABASE)
data = await db.select("suspensions")
for suspension in data:
print(f"Checking {suspension['id']}")
if datetime.datetime.now() >= datetime.datetime.fromtimestamp(suspension["suspendedAt"]):
print(f"Endpoint {suspension['id']} has expired")
await Suspension.Endpoints.unsuspend(suspension["id"])
return True
except Exception as e:
raise errors.SurrealDBHandler.SuspensionHandlerError(e)

@staticmethod
async def _checkendpointdatadeletionloop():
try:
async with Surreal(config.SDB_URL) as db:
await db.signin({"user": config.SDB_USER, "pass": config.SDB_PASS})
await db.use(config.SDB_NAMESPACE, config.SDB_DATABASE)
data = await db.select("endpoints")
for endpoint in data:
print(f"Checking {endpoint['id']}")
if not endpoint["expireAt"] and datetime.datetime.now() <= datetime.datetime.fromtimestamp(endpoint["suspendetAt"]) + datetime.timedelta(weeks=1):
await delete(endpoint["id"])
return True
except Exception as e:
raise errors.SurrealDBHandler.SuspensionHandlerError(e)

36 changes: 35 additions & 1 deletion src/astroidapi/suspension_handler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import astroidapi.surrealdb_handler as surrealdb_handler
import astroidapi.errors as errors
import asyncio
import threading


class Endpoint():
def __init__(self, endpoint_id):
self.endpoint_id = endpoint_id

stop_event = asyncio.Event()

@classmethod
async def is_suspended(cls, endpoint_id):
try:
Expand All @@ -32,4 +36,34 @@ async def unsuspend(cls, endpoint_id):
try:
await surrealdb_handler.Suspension.Endpoints.unsuspend(endpoint_id)
except errors.SurrealDBHandler.UnsuspendEndpointError as e:
raise errors.SuspensionHandlerError.UnsuspendEndpointError(e)
raise errors.SuspensionHandlerError.UnsuspendEndpointError(e)


@classmethod
async def check_expirations(cls):
while not cls.stop_event.is_set():
print("[Suspension handler] Checking expirations...")
await surrealdb_handler.Suspension.Endpoints._checkexpireloop()
await asyncio.sleep(60 * 10) # 10 minutes

print("Stopping expiration checks...")

@classmethod
def stop_check_expirations(cls):
cls.stop_event.set()


def run_async_in_thread(coro):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(coro)
except KeyboardInterrupt:
pass
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()


thread = threading.Thread(target=run_async_in_thread, args=(Endpoint.check_expirations()))
thread.start()

0 comments on commit fd86c7f

Please sign in to comment.