From 3b08105663688d9d12a272f66e85a437ede8f8e6 Mon Sep 17 00:00:00 2001 From: martynia Date: Tue, 21 May 2024 18:18:57 +0100 Subject: [PATCH 1/3] feat: enable remote pilot logging system --- diracx-core/src/diracx/core/exceptions.py | 8 ++ diracx-db/pyproject.toml | 1 + diracx-db/src/diracx/db/os/__init__.py | 6 +- diracx-db/src/diracx/db/os/pilot_logs.py | 21 ++++ diracx-db/src/diracx/db/os/utils.py | 19 +++ diracx-db/src/diracx/db/sql/job/db.py | 19 +-- .../src/diracx/db/sql/pilot_agents/db.py | 54 +++++++- diracx-db/src/diracx/db/sql/utils/__init__.py | 2 + diracx-db/src/diracx/db/sql/utils/base.py | 11 ++ diracx-routers/pyproject.toml | 2 + .../src/diracx/routers/dependencies.py | 8 +- .../src/diracx/routers/pilots/__init__.py | 11 ++ .../diracx/routers/pilots/access_policies.py | 89 +++++++++++++ .../src/diracx/routers/pilots/logging.py | 119 ++++++++++++++++++ .../tests/pilots/test_pilot_logger.py | 51 ++++++++ .../src/diracx/testing/mock_osdb.py | 25 +++- 16 files changed, 424 insertions(+), 22 deletions(-) create mode 100644 diracx-db/src/diracx/db/os/pilot_logs.py create mode 100644 diracx-routers/src/diracx/routers/pilots/__init__.py create mode 100644 diracx-routers/src/diracx/routers/pilots/access_policies.py create mode 100644 diracx-routers/src/diracx/routers/pilots/logging.py create mode 100644 diracx-routers/tests/pilots/test_pilot_logger.py diff --git a/diracx-core/src/diracx/core/exceptions.py b/diracx-core/src/diracx/core/exceptions.py index 04b70192..2291f987 100644 --- a/diracx-core/src/diracx/core/exceptions.py +++ b/diracx-core/src/diracx/core/exceptions.py @@ -46,6 +46,14 @@ def __init__(self, job_id: int, detail: str | None = None): super().__init__(f"Job {job_id} not found" + (" ({detail})" if detail else "")) +class PilotNotFoundError(Exception): + def __init__(self, pilot_stamp: str, detail: str | None = None): + self.pilot_stamp: str = pilot_stamp + super().__init__( + f"Pilot (stamp) {pilot_stamp} not found" + (" ({detail})" if detail else "") + ) + + class SandboxNotFoundError(Exception): def __init__(self, pfn: str, se_name: str, detail: str | None = None): self.pfn: str = pfn diff --git a/diracx-db/pyproject.toml b/diracx-db/pyproject.toml index fc4ec487..1ebcefc1 100644 --- a/diracx-db/pyproject.toml +++ b/diracx-db/pyproject.toml @@ -37,6 +37,7 @@ TaskQueueDB = "diracx.db.sql:TaskQueueDB" [project.entry-points."diracx.db.os"] JobParametersDB = "diracx.db.os:JobParametersDB" +PilotLogsDB = "diracx.db.os:PilotLogsDB" [tool.setuptools.packages.find] where = ["src"] diff --git a/diracx-db/src/diracx/db/os/__init__.py b/diracx-db/src/diracx/db/os/__init__.py index 535e2a95..c1ce89bc 100644 --- a/diracx-db/src/diracx/db/os/__init__.py +++ b/diracx-db/src/diracx/db/os/__init__.py @@ -1,5 +1,9 @@ from __future__ import annotations -__all__ = ("JobParametersDB",) +__all__ = ( + "JobParametersDB", + "PilotLogsDB", +) from .job_parameters import JobParametersDB +from .pilot_logs import PilotLogsDB diff --git a/diracx-db/src/diracx/db/os/pilot_logs.py b/diracx-db/src/diracx/db/os/pilot_logs.py new file mode 100644 index 00000000..5c901191 --- /dev/null +++ b/diracx-db/src/diracx/db/os/pilot_logs.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from diracx.db.os.utils import BaseOSDB + + +class PilotLogsDB(BaseOSDB): + fields = { + "PilotStamp": {"type": "keyword"}, + "PilotID": {"type": "long"}, + "SubmissionTime": {"type": "date"}, + "LineNumber": {"type": "long"}, + "Message": {"type": "text"}, + "VO": {"type": "keyword"}, + "timestamp": {"type": "date"}, + } + index_prefix = "pilot_logs" + + def index_name(self, doc_id: int) -> str: + # TODO decide how to define the index name + # use pilot ID + return f"{self.index_prefix}_{doc_id // 1e6:.0f}" diff --git a/diracx-db/src/diracx/db/os/utils.py b/diracx-db/src/diracx/db/os/utils.py index 431cceaa..8191954f 100644 --- a/diracx-db/src/diracx/db/os/utils.py +++ b/diracx-db/src/diracx/db/os/utils.py @@ -13,6 +13,7 @@ from typing import Any, Self from opensearchpy import AsyncOpenSearch +from opensearchpy.helpers import async_bulk from diracx.core.exceptions import InvalidQueryError from diracx.core.extensions import select_from_extension @@ -190,6 +191,13 @@ async def upsert(self, doc_id, document) -> None: ) print(f"{response=}") + async def bulk_insert(self, index_name: str, docs: list[dict[str, Any]]) -> None: + """Bulk inserting to database.""" + n_inserted = await async_bulk( + self.client, actions=[doc | {"_index": index_name} for doc in docs] + ) + logger.info("Inserted %d documents to %r", n_inserted, index_name) + async def search( self, parameters, search, sorts, *, per_page: int = 100, page: int | None = None ) -> list[dict[str, Any]]: @@ -231,6 +239,17 @@ async def search( return hits + async def delete(self, query: list[dict[str, Any]]) -> dict: + """Delete multiple documents by query.""" + body = {} + res = {} + if query: + body["query"] = apply_search_filters(self.fields, query) + res = await self.client.delete_by_query( + body=body, index=f"{self.index_prefix}*" + ) + return res + def require_type(operator, field_name, field_type, allowed_types): if field_type not in allowed_types: diff --git a/diracx-db/src/diracx/db/sql/job/db.py b/diracx-db/src/diracx/db/sql/job/db.py index 7b918157..b87ddc5b 100644 --- a/diracx-db/src/diracx/db/sql/job/db.py +++ b/diracx-db/src/diracx/db/sql/job/db.py @@ -16,7 +16,7 @@ SortSpec, ) -from ..utils import BaseSQLDB, apply_search_filters, apply_sort_constraints +from ..utils import BaseSQLDB, apply_search_filters, apply_sort_constraints, get_columns from .schema import ( InputData, JobCommands, @@ -26,17 +26,6 @@ ) -def _get_columns(table, parameters): - columns = [x for x in table.columns] - if parameters: - if unrecognised_parameters := set(parameters) - set(table.columns.keys()): - raise InvalidQueryError( - f"Unrecognised parameters requested {unrecognised_parameters}" - ) - columns = [c for c in columns if c.name in parameters] - return columns - - class JobDB(BaseSQLDB): metadata = JobDBBase.metadata @@ -46,7 +35,7 @@ class JobDB(BaseSQLDB): jdl_2_db_parameters = ["JobName", "JobType", "JobGroup"] async def summary(self, group_by, search) -> list[dict[str, str | int]]: - columns = _get_columns(Jobs.__table__, group_by) + columns = get_columns(Jobs.__table__, group_by) stmt = select(*columns, func.count(Jobs.job_id).label("count")) stmt = apply_search_filters(Jobs.__table__.columns.__getitem__, stmt, search) @@ -70,7 +59,7 @@ async def search( page: int | None = None, ) -> tuple[int, list[dict[Any, Any]]]: # Find which columns to select - columns = _get_columns(Jobs.__table__, parameters) + columns = get_columns(Jobs.__table__, parameters) stmt = select(*columns) @@ -328,7 +317,7 @@ async def set_properties( required_parameters = list(required_parameters_set)[0] update_parameters = [{"job_id": k, **v} for k, v in properties.items()] - columns = _get_columns(Jobs.__table__, required_parameters) + columns = get_columns(Jobs.__table__, required_parameters) values: dict[str, BindParameter[Any] | datetime] = { c.name: bindparam(c.name) for c in columns } diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/db.py b/diracx-db/src/diracx/db/sql/pilot_agents/db.py index b4f801b7..51c0ce0f 100644 --- a/diracx-db/src/diracx/db/sql/pilot_agents/db.py +++ b/diracx-db/src/diracx/db/sql/pilot_agents/db.py @@ -1,10 +1,17 @@ from __future__ import annotations from datetime import datetime, timezone +from typing import Any -from sqlalchemy import insert +from sqlalchemy import func, insert, select -from ..utils import BaseSQLDB +from diracx.core.exceptions import InvalidQueryError +from diracx.core.models import ( + SearchSpec, + SortSpec, +) + +from ..utils import BaseSQLDB, apply_search_filters, apply_sort_constraints, get_columns from .schema import PilotAgents, PilotAgentsDBBase @@ -44,3 +51,46 @@ async def add_pilot_references( stmt = insert(PilotAgents).values(values) await self.conn.execute(stmt) return + + async def search( + self, + parameters: list[str] | None, + search: list[SearchSpec], + sorts: list[SortSpec], + *, + distinct: bool = False, + per_page: int = 100, + page: int | None = None, + ) -> tuple[int, list[dict[Any, Any]]]: + # Find which columns to select + columns = get_columns(PilotAgents.__table__, parameters) + + stmt = select(*columns) + + stmt = apply_search_filters( + PilotAgents.__table__.columns.__getitem__, stmt, search + ) + stmt = apply_sort_constraints( + PilotAgents.__table__.columns.__getitem__, stmt, sorts + ) + + if distinct: + stmt = stmt.distinct() + + # Calculate total count before applying pagination + total_count_subquery = stmt.alias() + total_count_stmt = select(func.count()).select_from(total_count_subquery) + total = (await self.conn.execute(total_count_stmt)).scalar_one() + + # Apply pagination + if page is not None: + if page < 1: + raise InvalidQueryError("Page must be a positive integer") + if per_page < 1: + raise InvalidQueryError("Per page must be a positive integer") + stmt = stmt.offset((page - 1) * per_page).limit(per_page) + + # Execute the query + return total, [ + dict(row._mapping) async for row in (await self.conn.stream(stmt)) + ] diff --git a/diracx-db/src/diracx/db/sql/utils/__init__.py b/diracx-db/src/diracx/db/sql/utils/__init__.py index cd82d3c7..c48e95b8 100644 --- a/diracx-db/src/diracx/db/sql/utils/__init__.py +++ b/diracx-db/src/diracx/db/sql/utils/__init__.py @@ -5,6 +5,7 @@ SQLDBUnavailableError, apply_search_filters, apply_sort_constraints, + get_columns, ) from .functions import substract_date, utcnow from .types import Column, DateNowColumn, EnumBackedBool, EnumColumn, NullColumn @@ -19,6 +20,7 @@ "EnumColumn", "apply_search_filters", "apply_sort_constraints", + "get_columns", "substract_date", "SQLDBUnavailableError", ) diff --git a/diracx-db/src/diracx/db/sql/utils/base.py b/diracx-db/src/diracx/db/sql/utils/base.py index dfe6baa8..6fa73d76 100644 --- a/diracx-db/src/diracx/db/sql/utils/base.py +++ b/diracx-db/src/diracx/db/sql/utils/base.py @@ -258,6 +258,17 @@ def find_time_resolution(value): raise InvalidQueryError(f"Cannot parse {value=}") +def get_columns(table, parameters): + columns = [x for x in table.columns] + if parameters: + if unrecognised_parameters := set(parameters) - set(table.columns.keys()): + raise InvalidQueryError( + f"Unrecognised parameters requested {unrecognised_parameters}" + ) + columns = [c for c in columns if c.name in parameters] + return columns + + def apply_search_filters(column_mapping, stmt, search): for query in search: try: diff --git a/diracx-routers/pyproject.toml b/diracx-routers/pyproject.toml index d99a51cc..41f0e942 100644 --- a/diracx-routers/pyproject.toml +++ b/diracx-routers/pyproject.toml @@ -46,6 +46,7 @@ types = [ ] [project.entry-points."diracx.services"] +pilots = "diracx.routers.pilots:router" jobs = "diracx.routers.jobs:router" config = "diracx.routers.configuration:router" auth = "diracx.routers.auth:router" @@ -54,6 +55,7 @@ auth = "diracx.routers.auth:router" [project.entry-points."diracx.access_policies"] WMSAccessPolicy = "diracx.routers.jobs.access_policies:WMSAccessPolicy" SandboxAccessPolicy = "diracx.routers.jobs.access_policies:SandboxAccessPolicy" +PilotLogsAccessPolicy = "diracx.routers.pilots.access_policies:PilotLogsAccessPolicy" # Minimum version of the client supported [project.entry-points."diracx.min_client_version"] diff --git a/diracx-routers/src/diracx/routers/dependencies.py b/diracx-routers/src/diracx/routers/dependencies.py index ab40190b..73d4c420 100644 --- a/diracx-routers/src/diracx/routers/dependencies.py +++ b/diracx-routers/src/diracx/routers/dependencies.py @@ -8,6 +8,7 @@ "SandboxMetadataDB", "TaskQueueDB", "PilotAgentsDB", + "PilotLogsDB", "add_settings_annotation", "AvailableSecurityProperties", ) @@ -21,6 +22,7 @@ from diracx.core.properties import SecurityProperty from diracx.core.settings import DevelopmentSettings as _DevelopmentSettings from diracx.db.os import JobParametersDB as _JobParametersDB +from diracx.db.os import PilotLogsDB as _PilotLogsDB from diracx.db.sql import AuthDB as _AuthDB from diracx.db.sql import JobDB as _JobDB from diracx.db.sql import JobLoggingDB as _JobLoggingDB @@ -36,7 +38,7 @@ def add_settings_annotation(cls: T) -> T: return Annotated[cls, Depends(cls.create)] # type: ignore -# Databases +# SQL Databases AuthDB = Annotated[_AuthDB, Depends(_AuthDB.transaction)] JobDB = Annotated[_JobDB, Depends(_JobDB.transaction)] JobLoggingDB = Annotated[_JobLoggingDB, Depends(_JobLoggingDB.transaction)] @@ -46,9 +48,9 @@ def add_settings_annotation(cls: T) -> T: ] TaskQueueDB = Annotated[_TaskQueueDB, Depends(_TaskQueueDB.transaction)] -# Opensearch databases +# OpenSearch Databases JobParametersDB = Annotated[_JobParametersDB, Depends(_JobParametersDB.session)] - +PilotLogsDB = Annotated[_PilotLogsDB, Depends(_PilotLogsDB.session)] # Miscellaneous Config = Annotated[_Config, Depends(ConfigSource.create)] diff --git a/diracx-routers/src/diracx/routers/pilots/__init__.py b/diracx-routers/src/diracx/routers/pilots/__init__.py new file mode 100644 index 00000000..3e9084bc --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/__init__.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from logging import getLogger + +from ..fastapi_classes import DiracxRouter +from .logging import router as logging_router + +logger = getLogger(__name__) + +router = DiracxRouter() +router.include_router(logging_router) diff --git a/diracx-routers/src/diracx/routers/pilots/access_policies.py b/diracx-routers/src/diracx/routers/pilots/access_policies.py new file mode 100644 index 00000000..ea2b053b --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/access_policies.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import logging +from enum import StrEnum, auto +from typing import Annotated, Callable + +from fastapi import Depends, HTTPException, status + +from diracx.core.models import ScalarSearchOperator, ScalarSearchSpec +from diracx.core.properties import ( + NORMAL_USER, +) +from diracx.routers.access_policies import BaseAccessPolicy + +from ..dependencies import PilotAgentsDB +from ..utils.users import AuthorizedUserInfo + +logger = logging.getLogger(__name__) + + +class ActionType(StrEnum): + #: Create/update pilot log records + CREATE = auto() + #: Search + QUERY = auto() + + +class PilotLogsAccessPolicy(BaseAccessPolicy): + """Rules: + Only NORMAL_USER in a correct VO and a diracAdmin VO member can query log records. + All other actions and users are explicitly denied access. + """ + + @staticmethod + async def policy( + policy_name: str, + user_info: AuthorizedUserInfo, + /, + *, + action: ActionType | None = None, + pilot_agents_db: PilotAgentsDB | None = None, + pilot_id: int | None = None, + ): + assert pilot_agents_db + if action is None: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, detail="Action is a mandatory argument" + ) + elif action == ActionType.QUERY: + if pilot_id is None: + logger.error("Pilot ID value is not provided (None)") + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail=f"PilotID not provided: {pilot_id}", + ) + search_params = ScalarSearchSpec( + parameter="PilotID", + operator=ScalarSearchOperator.EQUAL, + value=pilot_id, + ) + + total, result = await pilot_agents_db.search(["VO"], [search_params], []) + # we expect exactly one row. + if total != 1: + logger.error( + "Cannot determine VO for requested PilotID: %d, found %d candidates.", + pilot_id, + total, + ) + raise HTTPException( + status.HTTP_400_BAD_REQUEST, detail=f"PilotID not found: {pilot_id}" + ) + vo = result[0]["VO"] + + if user_info.vo == "diracAdmin": + return + + if NORMAL_USER in user_info.properties and user_info.vo == vo: + return + + raise HTTPException( + status.HTTP_403_FORBIDDEN, + detail="You don't have permission to access this pilot's log.", + ) + else: + raise NotImplementedError(action) + + +CheckPilotLogsPolicyCallable = Annotated[Callable, Depends(PilotLogsAccessPolicy.check)] diff --git a/diracx-routers/src/diracx/routers/pilots/logging.py b/diracx-routers/src/diracx/routers/pilots/logging.py new file mode 100644 index 00000000..ecddc041 --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/logging.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import logging + +from fastapi import HTTPException, status +from pydantic import BaseModel + +from diracx.core.models import ScalarSearchOperator, ScalarSearchSpec + +from ..access_policies import open_access +from ..dependencies import PilotAgentsDB, PilotLogsDB +from ..fastapi_classes import DiracxRouter +from .access_policies import ActionType, CheckPilotLogsPolicyCallable + +logger = logging.getLogger(__name__) +router = DiracxRouter() + + +class LogLine(BaseModel): + line_no: int + line: str + + +class LogMessage(BaseModel): + pilot_stamp: str + lines: list[LogLine] + vo: str + + +class DateRange(BaseModel): + min: str | None = None # expects a string in ISO 8601 ("%Y-%m-%dT%H:%M:%S.%f%z") + max: str | None = None # expects a string in ISO 8601 ("%Y-%m-%dT%H:%M:%S.%f%z") + + +@open_access +@router.post("/") +async def send_message( + data: LogMessage, + pilot_logs_db: PilotLogsDB, + pilot_agents_db: PilotAgentsDB, + # user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], +) -> int: + + # expecting exactly one row: + search_params = ScalarSearchSpec( + parameter="PilotStamp", + operator=ScalarSearchOperator.EQUAL, + value=data.pilot_stamp, + ) + + total, result = await pilot_agents_db.search( + ["PilotID", "VO", "SubmissionTime"], [search_params], [] + ) + if total != 1: + logger.error( + "Cannot determine PilotID for requested PilotStamp: %r, (%d candidates)", + data.pilot_stamp, + total, + ) + raise HTTPException( + status.HTTP_400_BAD_REQUEST, detail=f"Number of rows !=1: {total}" + ) + pilot_id, vo, submission_time = ( + result[0]["PilotID"], + result[0]["VO"], + result[0]["SubmissionTime"], + ) + + # await check_permissions(action=ActionType.CREATE, pilot_agent_db, pilot_id), + + docs = [] + for line in data.lines: + docs.append( + { + "PilotStamp": data.pilot_stamp, + "PilotID": pilot_id, + "SubmissionTime": submission_time, + "VO": vo, + "LineNumber": line.line_no, + "Message": line.line, + } + ) + await pilot_logs_db.bulk_insert(pilot_logs_db.index_name(pilot_id), docs) + """ + search_params = [{"parameter": "PilotID", "operator": "eq", "value": pilot_id}] + + result = await pilot_logs_db.search( + ["Message"], + search_params, + [{"parameter": "LineNumber", "direction": "asc"}], + ) + """ + return pilot_id + + +@router.get("/logs") +async def get_logs( + pilot_id: int, + db: PilotLogsDB, + pilot_agents_db: PilotAgentsDB, + check_permissions: CheckPilotLogsPolicyCallable, +) -> list[dict]: + + logger.debug("Retrieving logs for pilot ID %d", pilot_id) + # users will only see logs from their own VO if enforced by a policy: + await check_permissions( + action=ActionType.QUERY, pilot_agents_db=pilot_agents_db, pilot_id=pilot_id + ) + + search_params = [{"parameter": "PilotID", "operator": "eq", "value": pilot_id}] + + result = await db.search( + ["Message"], + search_params, + [{"parameter": "LineNumber", "direction": "asc"}], + ) + if not result: + return [{"Message": f"No logs for pilot ID = {pilot_id}"}] + return result diff --git a/diracx-routers/tests/pilots/test_pilot_logger.py b/diracx-routers/tests/pilots/test_pilot_logger.py new file mode 100644 index 00000000..e69e7db8 --- /dev/null +++ b/diracx-routers/tests/pilots/test_pilot_logger.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import pytest +from fastapi.testclient import TestClient + +from diracx.routers.utils.users import AuthSettings + +pytestmark = pytest.mark.enabled_dependencies( + ["AuthSettings", "PilotAgentsDB", "PilotLogsDB"] +) + + +@pytest.fixture +def normal_user_client(client_factory): + with client_factory.normal_user() as client: + yield client + + +async def test_send_and_retrieve_logs( + normal_user_client: TestClient, test_auth_settings: AuthSettings +): + + from diracx.db.sql import PilotAgentsDB + + # Add a pilot reference + upper_limit = 6 + refs = [f"ref_{i}" for i in range(1, upper_limit)] + stamps = [f"stamp_{i}" for i in range(1, upper_limit)] + stamp_dict = dict(zip(refs, stamps)) + + db = normal_user_client.app.dependency_overrides[PilotAgentsDB.transaction].args[0] + + async with db: + await db.add_pilot_references( + refs, "test_vo", grid_type="DIRAC", pilot_stamps=stamp_dict + ) + + msg = ( + "2022-02-26 13:48:35.123456 UTC DEBUG [PilotParams] JSON file loaded: pilot.json\n" + "2022-02-26 13:48:36.123456 UTC DEBUG [PilotParams] JSON file analysed: pilot.json" + ) + # message dict + lines = [] + for i, line in enumerate(msg.split("\n")): + lines.append({"line_no": i, "line": line}) + msg_dict = {"lines": lines, "pilot_stamp": "stamp_1", "vo": "diracAdmin"} + + # send message + r = normal_user_client.post("/api/pilots/", json=msg_dict) + + assert r.status_code == 200, r.text diff --git a/diracx-testing/src/diracx/testing/mock_osdb.py b/diracx-testing/src/diracx/testing/mock_osdb.py index 5b482102..ae4007c2 100644 --- a/diracx-testing/src/diracx/testing/mock_osdb.py +++ b/diracx-testing/src/diracx/testing/mock_osdb.py @@ -10,7 +10,7 @@ from functools import partial from typing import Any, AsyncIterator -from sqlalchemy import select +from sqlalchemy import delete, select from sqlalchemy.dialects.sqlite import insert as sqlite_insert from diracx.core.models import SearchSpec, SortSpec @@ -100,6 +100,21 @@ async def upsert(self, doc_id, document) -> None: stmt = stmt.on_conflict_do_update(index_elements=["doc_id"], set_=values) await self._sql_db.conn.execute(stmt) + async def bulk_insert(self, index_name: str, docs: list[dict[str, Any]]) -> None: + async with self._sql_db: + rows = [] + for doc in docs: + # don't use doc_id column explicitly. This ensures that doc_id is unique. + values = {} + for key, value in doc.items(): + if key in self.fields: + values[key] = value + else: + values.setdefault("extra", {})[key] = value + rows.append(values) + stmt = sqlite_insert(self._table).values(rows) + await self._sql_db.conn.execute(stmt) + async def search( self, parameters: list[str] | None, @@ -153,6 +168,14 @@ async def search( results.append(result) return results + async def delete(self, query: list[dict[str, Any]]) -> None: + async with self._sql_db: + stmt = delete(self._table) + stmt = sql_utils.apply_search_filters( + self._table.columns.__getitem__, stmt, query + ) + await self._sql_db.conn.execute(stmt) + async def ping(self): async with self._sql_db: return await self._sql_db.ping() From ead27e53e390babe84c2ed47257acf8ba0e1cfa7 Mon Sep 17 00:00:00 2001 From: martynia Date: Thu, 16 Jan 2025 16:16:52 +0100 Subject: [PATCH 2/3] feat: Restructured pilot logger + test send_message and get_logs --- diracx-routers/tests/pilots/test_pilot_logger.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/diracx-routers/tests/pilots/test_pilot_logger.py b/diracx-routers/tests/pilots/test_pilot_logger.py index e69e7db8..40319466 100644 --- a/diracx-routers/tests/pilots/test_pilot_logger.py +++ b/diracx-routers/tests/pilots/test_pilot_logger.py @@ -6,7 +6,13 @@ from diracx.routers.utils.users import AuthSettings pytestmark = pytest.mark.enabled_dependencies( - ["AuthSettings", "PilotAgentsDB", "PilotLogsDB"] + [ + "AuthSettings", + "PilotAgentsDB", + "PilotLogsDB", + "PilotLogsAccessPolicy", + "DevelopmentSettings", + ] ) @@ -49,3 +55,9 @@ async def test_send_and_retrieve_logs( r = normal_user_client.post("/api/pilots/", json=msg_dict) assert r.status_code == 200, r.text + # it just returns the pilot id corresponding for pilot stamp. + assert r.json() == 1 + # get the message back: + r = normal_user_client.get("/api/pilots/logs?pilot_id=1") + assert r.status_code == 200, r.text + assert [next(iter(d.values())) for d in r.json()] == msg.split("\n") From be6c26a7ed8dabbef02881354838ca89c1710bcd Mon Sep 17 00:00:00 2001 From: martynia Date: Fri, 17 Jan 2025 10:33:22 +0100 Subject: [PATCH 3/3] fix: remove unused PilotNotFoundError exception --- diracx-core/src/diracx/core/exceptions.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/diracx-core/src/diracx/core/exceptions.py b/diracx-core/src/diracx/core/exceptions.py index 2291f987..04b70192 100644 --- a/diracx-core/src/diracx/core/exceptions.py +++ b/diracx-core/src/diracx/core/exceptions.py @@ -46,14 +46,6 @@ def __init__(self, job_id: int, detail: str | None = None): super().__init__(f"Job {job_id} not found" + (" ({detail})" if detail else "")) -class PilotNotFoundError(Exception): - def __init__(self, pilot_stamp: str, detail: str | None = None): - self.pilot_stamp: str = pilot_stamp - super().__init__( - f"Pilot (stamp) {pilot_stamp} not found" + (" ({detail})" if detail else "") - ) - - class SandboxNotFoundError(Exception): def __init__(self, pfn: str, se_name: str, detail: str | None = None): self.pfn: str = pfn