Skip to content

Commit

Permalink
added a singleton coordinator so that ships and markets are each othe…
Browse files Browse the repository at this point in the history
…r instead of divergent clones.
  • Loading branch information
Ctri-The-Third committed Apr 1, 2024
1 parent afe2381 commit 2df8671
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 8 deletions.
28 changes: 21 additions & 7 deletions straders_sdk/client_mediator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .utils import get_and_validate, get_and_validate_paginated, post_and_validate, _url
from .utils import ApiConfig, _log_response, waypoint_to_system
from .utils import ApiConfig, _log_response, waypoint_to_system, get_name_from_token
from .client_interface import SpaceTradersInteractive, SpaceTradersClient
import time
from .responses import SpaceTradersResponse
Expand All @@ -16,8 +16,8 @@
JumpGate,
)
import psycopg2
from .models_misc import Shipyard, System
from .models_ship import Ship
from .models_misc import Shipyard, System, SingletonMarkets
from .models_ship import Ship, SingletonShips
from .client_api import SpaceTradersApiClient
from .client_stub import SpaceTradersStubClient
from .client_postgres import SpaceTradersPostgresClient
Expand Down Expand Up @@ -60,7 +60,10 @@ def __init__(
self.logger = logging.getLogger(__name__)

self.token = token
self.current_agent = current_agent_symbol
if token and not current_agent_symbol:
self.current_agent_symbol = get_name_from_token(token)
else:
self.current_agent_symbol = current_agent_symbol
if db_host and db_name and db_user and db_pass:
self.db_client = SpaceTradersPostgresClient(
db_host=db_host,
Expand Down Expand Up @@ -224,15 +227,21 @@ def ships_view(self, force=False) -> dict[str, Ship] or SpaceTradersResponse:
resp = self.db_client.ships_view()
if resp:
self.ships = self.ships | resp
for ship in self.ships.values():

ship = SingletonShips().add_ship(ship)

return resp
start = datetime.now()
resp = self.api_client.ships_view()
self.logging_client.ships_view(resp, (datetime.now() - start).total_seconds())
if resp:

new_ships = resp
self.ships = self.ships | new_ships
for ship in self.ships.values():
ship: Ship
ship = SingletonShips().add_ship(ship)
ship.dirty = True # force a refresh of the ship into the DB
self.db_client.update(ship)
return new_ships
Expand All @@ -249,7 +258,9 @@ def ships_view_one(self, symbol: str, force=False):
resp = self.db_client.ships_view_one(symbol)
if resp:
resp: Ship
self.ships[symbol] = resp
SingletonShips().add_ship(resp)


return resp
start = datetime.now()
resp = self.api_client.ships_view_one(symbol)
Expand All @@ -259,7 +270,8 @@ def ships_view_one(self, symbol: str, force=False):
if resp:
resp: Ship
resp.dirty = True
self.ships[symbol] = resp
SingletonShips().add_ship(resp)

self.db_client.update(resp)
return resp

Expand All @@ -277,6 +289,7 @@ def ships_purchase(
self.set_connections()
start = datetime.now()
resp = self.api_client.ships_purchase(ship_type, waypoint)
resp[0] = SingletonShips().add_ship(resp[0])
self.logging_client.ships_purchase(
ship_type, waypoint, resp, (datetime.now() - start).total_seconds()
)
Expand Down Expand Up @@ -369,7 +382,7 @@ def update(self, json_data):
self.surveys[json_data.signature] = json_data
self.db_client.update(json_data)
if isinstance(json_data, Ship):
self.ships[json_data.name] = json_data
json_data = SingletonShips().add_ship(json_data)
self.db_client.update(json_data)
if isinstance(json_data, Waypoint):
self.waypoints[json_data.symbol] = json_data
Expand Down Expand Up @@ -597,6 +610,7 @@ def system_market(
resp = self.api_client.system_market(wp)
self.logging_client.system_market(wp, (datetime.now() - start).total_seconds())
if bool(resp):
resp = SingletonMarkets().add_market(resp)
self.db_client.update(resp)
return resp
return resp
Expand Down
34 changes: 34 additions & 0 deletions straders_sdk/models_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,36 @@ class MarketTradeGood:
description: str


class SingletonMarkets:
"A singleton dict of markets, keyed by symbol."

def __new__(cls):
if not hasattr(cls, "instance"):
cls.instance = super(SingletonMarkets, cls).__new__(cls)

pass
return cls.instance

def __init__(self):
if not hasattr(self, "markets"):
self.markets = {}
pass

def add_market(self, market: "Market"):
"""Add a market to the singleton dict. If the market already exists, merge the data.
The returned object should be the authoritative market object."""
if market.symbol not in self.markets:
self.markets[market.symbol] = market
return market
else:
self.markets[market.symbol].merge(market)
return self.markets[market.symbol]
pass

def get_market(self, symbol: str) -> "Market":
return self.markets.get(symbol, None)


@dataclass
class Market:
symbol: str
Expand All @@ -864,6 +894,10 @@ def from_json(cls, json_data: dict):
]
return cls(json_data["symbol"], exports, imports, exchange, listings)

def merge(self, other_market: "Market"):
self.__dict__.update(other_market.__dict__)
return self

def is_stale(self, age: int = 60):
if not self.listings:
return True
Expand Down
32 changes: 32 additions & 0 deletions straders_sdk/models_ship.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,33 @@ def from_json(cls, json_data: dict):
return cls(*json_data.values())


class SingletonShips:

def __new__(cls):
if not hasattr(cls, "_instance"):
cls._instance = super(SingletonShips, cls).__new__(cls)
pass
return cls._instance

def __init__(self) -> None:
if not hasattr(self, "ships"):
self.ships = {}

def get_ship(self, ship_name: str):
if ship_name in self.ships:
return self.ships[ship_name]
else:
return None

def add_ship(self, ship: "Ship"):
"use this method to guarantee the ship object you're using is the same as the one in the dictionary."
if ship.name not in self.ships:
self.ships[ship.name] = ship
return ship
else:
return self.ships[ship.name].merge(ship)


class Ship:
name: str
role: str
Expand Down Expand Up @@ -285,6 +312,11 @@ def receive_cargo(self, trade_symbol, units):
)
)

def merge(self, new_ship_data: "Ship"):
"Merge the data from a new ship object into this one."
self.__dict__.update(new_ship_data.__dict__)
return self

def update(self, json_data: dict):
"Update the ship with the contents of a response object"
if json_data is None:
Expand Down
55 changes: 54 additions & 1 deletion tests/test_mediator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from straders_sdk import SpaceTraders
from straders_sdk.utils import try_execute_select
from straders_sdk.models_misc import Waypoint, Market
from straders_sdk.models_misc import Waypoint, Market, SingletonMarkets
from straders_sdk.models_misc import JumpGate, JumpGateConnection
from straders_sdk.models_misc import ConstructionSite, ConstructionSiteMaterial
from straders_sdk.models import Ship, SingletonShips
import pytest

ST_HOST = os.getenv("ST_DB_HOST")
Expand All @@ -27,3 +28,55 @@ def test_init_possible():
db_user=ST_USER,
db_port=ST_PORT,
)


def test_market_singleton():
"""Test if the Market singleton is working."""

markets = SingletonMarkets()
market = Market("X1-MARKET-1", [], [], [])
market.exchange = "PLACEHOLDER"
market = markets.add_market(market)
market2 = Market("X1-MARKET-1", [], [], [])
market2 = markets.add_market(market2)
assert market2.exchange != "PLACEHOLDER"


def test_ship_singleton():
"""Test if the Ship singleton is working."""

ships = SingletonShips()
ship = Ship()
ship.name = "X1-SHIP-1"
ship = ships.add_ship(ship)
ship2 = Ship()
ship2.name = "X1-SHIP-1"
ship2.cargo_units_used = 30

ship2 = ships.add_ship(ship2)
assert ship.cargo_units_used == 30


def test_ships():
"""test all the ships we're getting out of the DB are being singletonned"""

st = SpaceTraders(
"",
db_host=ST_HOST,
db_name=ST_NAME,
db_pass=ST_PASS,
db_user=ST_USER,
db_port=ST_PORT,
)
ships = st.ships_view()
if (not ships) or len(ships) == 0:
# skip the test
pytest.skip("No ships in the DB")
for ship in ships:
ship: Ship
assert ship == SingletonShips().get_ship(ship.name)

other_ship = st.ships_view_one(ship.name)
assert ship == other_ship
other_ship.cargo_capacity = 5
assert ship.cargo_capacity == 5

0 comments on commit 2df8671

Please sign in to comment.