diff --git a/docker-compose.yaml b/docker-compose.yaml index 5bbfac5..f293940 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -115,11 +115,13 @@ services: condition: service_healthy mysql: condition: service_healthy + wait_for_api: image: alpine:latest container_name: wait_for_api depends_on: public_api: condition: service_healthy + networks: botdetector-network: diff --git a/mysql/docker-entrypoint-initdb.d/01_tables.sql b/mysql/docker-entrypoint-initdb.d/01_tables.sql index b61b79c..b73c3f4 100644 --- a/mysql/docker-entrypoint-initdb.d/01_tables.sql +++ b/mysql/docker-entrypoint-initdb.d/01_tables.sql @@ -3,7 +3,7 @@ USE playerdata; CREATE TABLE Players ( id INT PRIMARY KEY AUTO_INCREMENT, name TEXT, - created_at TIMESTAMP, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP, possible_ban BOOLEAN, confirmed_ban BOOLEAN, diff --git a/requirements.txt b/requirements.txt index 3ace0c9..afb045d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,6 @@ anyio==4.3.0 async-timeout==4.0.3 asyncmy==0.2.9 attrs==23.2.0 -black==24.4.2 certifi==2024.2.2 cffi==1.16.0 cfgv==3.4.0 @@ -50,14 +49,11 @@ pydantic-settings==2.2.1 pydantic_core==2.18.2 Pygments==2.18.0 PyMySQL==1.1.1 -pytest==8.2.1 -pytest-asyncio==0.23.7 python-dotenv==1.0.1 python-multipart==0.0.9 PyYAML==6.0.1 requests==2.32.2 rich==13.7.1 -ruff==0.4.5 shellingham==1.5.4 sniffio==1.3.1 sortedcontainers==2.4.0 diff --git a/src/api/v2/report.py b/src/api/v2/report.py index 83097a8..8c3019d 100644 --- a/src/api/v2/report.py +++ b/src/api/v2/report.py @@ -39,10 +39,23 @@ async def post_reports( _data = [] for d in data: _d = d.model_dump() + # get reported_id from name reported = player_repo.sanitize_name(_d.pop("reported")) + reported_id = players.get(reported) + + # get reporter_id from name reporter = player_repo.sanitize_name(_d.pop("reporter")) - _d["reported_id"] = players.get(reported) - _d["reporter_id"] = players.get(reporter) + reporter_id = players.get(reporter) + + # some validation + if reporter_id is None or reported_id is None: + logger.warning(msg=f"{reported_id=}, {reporter_id=}, {d}") + raise HTTPException( + status.HTTP_400_BAD_REQUEST, detail="something went wrong" + ) + _d["reported_id"] = reported_id + _d["reporter_id"] = reporter_id + _data.append(ParsedDetection(**_d)) await report_repo.send_to_kafka(data=_data) return Ok() diff --git a/src/app/repositories/player.py b/src/app/repositories/player.py index 350de34..84d30da 100644 --- a/src/app/repositories/player.py +++ b/src/app/repositories/player.py @@ -108,14 +108,14 @@ async def get_prediction(self, player_names: list[str]): return jsonable_encoder(result) async def get(self, player_name: str) -> PlayerInDB: + assert isinstance(player_name, str) player_name = self.sanitize_name(player_name) sql = sqla.select(dbPlayer).where(dbPlayer.name == player_name) result = await self.session.execute(sql) data = result.scalars().all() - - return PlayerInDB(**model_to_dict(data[0])) if data else None + return PlayerInDB(**model_to_dict(data[0])) if len(data) > 0 else None async def get_cache(self, player_name: str) -> PlayerInDB: player_name = self.sanitize_name(player_name) @@ -136,6 +136,7 @@ async def insert(self, player: PlayerCreate) -> PlayerInDB: player.name = self.sanitize_name(player.name) sql = sqla.insert(dbPlayer).values(player.model_dump()).prefix_with("IGNORE") await self.session.execute(sql) + # await self.session.commit() return await self.get(player_name=player.name) async def get_or_insert(self, player_name: str, cached=True) -> PlayerInDB: diff --git a/src/app/views/player.py b/src/app/views/player.py index 8dad8fc..482b971 100644 --- a/src/app/views/player.py +++ b/src/app/views/player.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import Optional -from pydantic import BaseModel +from pydantic import BaseModel, field_validator class PlayerCreate(BaseModel): @@ -11,9 +11,9 @@ class PlayerCreate(BaseModel): confirmed_player: Optional[bool] = 0 label_id: Optional[int] = 0 label_jagex: Optional[int] = 0 - ironman: Optional[bool] = None - hardcore_ironman: Optional[bool] = None - ultimate_ironman: Optional[bool] = None + ironman: Optional[int] = None + hardcore_ironman: Optional[int] = None + ultimate_ironman: Optional[int] = None normalized_name: Optional[str] = None @@ -24,9 +24,9 @@ class PlayerUpdate(BaseModel): confirmed_player: Optional[bool] = None label_id: Optional[int] = None label_jagex: Optional[int] = None - ironman: Optional[bool] = None - hardcore_ironman: Optional[bool] = None - ultimate_ironman: Optional[bool] = None + ironman: Optional[int] = None + hardcore_ironman: Optional[int] = None + ultimate_ironman: Optional[int] = None normalized_name: Optional[str] = None @@ -35,6 +35,14 @@ class PlayerInDB(PlayerCreate): created_at: datetime updated_at: datetime | None + @field_validator("created_at", mode="before") + def parse_created_at(cls, value): + if isinstance(value, str): + return datetime.fromisoformat(value) + if value is None: + raise ValueError("created_at cannot be None") + return value + class Player(PlayerInDB): pass diff --git a/tests/test_report.py b/tests/test_report.py index 5d09ab9..78ddecb 100644 --- a/tests/test_report.py +++ b/tests/test_report.py @@ -67,6 +67,40 @@ async def test_valid_report_weird_name(custom_client): assert response.status_code == 201 +@pytest.mark.asyncio +async def test_valid_report_unkown_reported(custom_client): + global example_data + endpoint = "/v2/report" + _data = example_data.copy() + _data["ts"] = int(time.time()) + _data["reported"] = "new reported" + + # Example of a valid detection data + detection_data = [_data] + + async with custom_client as client: + client: AsyncClient + response = await client.post(endpoint, json=detection_data) + assert response.status_code == 201 + + +@pytest.mark.asyncio +async def test_valid_report_unkown_reporter(custom_client): + global example_data + endpoint = "/v2/report" + _data = example_data.copy() + _data["ts"] = int(time.time()) + _data["reported"] = "new reporter" + + # Example of a valid detection data + detection_data = [_data] + + async with custom_client as client: + client: AsyncClient + response = await client.post(endpoint, json=detection_data) + assert response.status_code == 201 + + @pytest.mark.asyncio async def test_invalid_ts_high_report(custom_client): global example_data