-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
28cbb86
commit 8b1a175
Showing
5 changed files
with
241 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .agent_parallel import MultiThreadAgentParallelRuner | ||
from .base import Runner | ||
from .sequential import SequentialRunner |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import os | ||
import random | ||
import warnings | ||
from concurrent.futures import ThreadPoolExecutor | ||
from io import TextIOWrapper | ||
from multiprocessing import cpu_count | ||
from typing import Dict | ||
from typing import List | ||
from typing import Optional | ||
from typing import Type | ||
from typing import Union | ||
|
||
from pams.logs.base import Logger | ||
from pams.simulator import Simulator | ||
|
||
from ..order import Cancel | ||
from ..order import Order | ||
from ..session import Session | ||
from .sequential import SequentialRunner | ||
|
||
|
||
class MultiThreadAgentParallelRuner(SequentialRunner): | ||
"""Multi Thread Agent Parallel runner class. This is experimental. | ||
In this runner, only normal agents are parallelized in each steps. | ||
This means that the number of agents that can be parallelized is limited by MAX_NORMAL_ORDERS. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
settings: Union[Dict, TextIOWrapper, os.PathLike, str], | ||
prng: Optional[random.Random] = None, | ||
logger: Optional[Logger] = None, | ||
simulator_class: Type[Simulator] = Simulator, | ||
): | ||
super().__init__(settings, prng, logger, simulator_class) | ||
warnings.warn( | ||
"MultiThreadRuner is experimental. Future changes may occur disruptively." | ||
) | ||
self.num_parallel = max(cpu_count() - 1, 1) | ||
|
||
def _setup(self) -> None: | ||
super()._setup() | ||
if "numParallel" in self.settings["simulation"]: | ||
self.num_parallel = self.settings["simulation"]["numParallel"] | ||
max_notmal_orders = max( | ||
session.max_normal_orders for session in self.simulator.sessions | ||
) | ||
if self.num_parallel > max_notmal_orders: | ||
warnings.warn( | ||
f"When MultiThreadAgentParallelRuner is used, the maximum number of parallel agents" | ||
f" is limited by max_normal_orders ({max_notmal_orders}) evne if numParallel" | ||
f" ({self.num_parallel}) is set to a larger value." | ||
) | ||
self.thread_pool = ThreadPoolExecutor(max_workers=self.num_parallel) | ||
|
||
def _collect_orders_from_normal_agents( | ||
self, session: Session | ||
) -> List[List[Union[Order, Cancel]]]: | ||
"""collect orders from normal_agents. (Internal method) | ||
orders are corrected until the total number of orders reaches max_normal_orders | ||
Args: | ||
session (Session): session. | ||
Returns: | ||
List[List[Union[Order, Cancel]]]: orders lists. | ||
""" | ||
agents = self.simulator.normal_frequency_agents | ||
agents = self._prng.sample(agents, len(agents)) | ||
all_orders: List[List[Union[Order, Cancel]]] = [] | ||
# TODO: currently the original impl is used for order counting. | ||
# See more in the SequentialRunner class. | ||
futures = [] | ||
for agent in agents[: session.max_normal_orders]: | ||
future = self.thread_pool.submit( | ||
agent.submit_orders, self.simulator.markets | ||
) | ||
futures.append((future, agent)) | ||
for future, agent in futures: | ||
orders = future.result() | ||
if len(orders) > 0: | ||
if not session.with_order_placement: | ||
raise AssertionError("currently order is not accepted") | ||
if sum([order.agent_id != agent.agent_id for order in orders]) > 0: | ||
raise ValueError( | ||
"spoofing order is not allowed. please check agent_id in order" | ||
) | ||
all_orders.append(orders) | ||
return all_orders |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import time | ||
from typing import Dict, List, cast | ||
from typing import Type | ||
|
||
from pams.agents.fcn_agent import FCNAgent | ||
from pams.market import Market | ||
from pams.order import Cancel, Order | ||
from pams.runners import MultiThreadAgentParallelRuner | ||
from pams.runners import Runner | ||
from pams.runners.sequential import SequentialRunner | ||
from tests.pams.runners.test_base import TestRunner | ||
|
||
|
||
class TestMultiThreadAgentParallelRuner(TestRunner): | ||
runner_class: Type[Runner] = MultiThreadAgentParallelRuner | ||
default_setting: Dict = { | ||
"simulation": { | ||
"markets": ["Market"], | ||
"agents": ["FCNAgents"], | ||
"sessions": [ | ||
{ | ||
"sessionName": 0, | ||
"iterationSteps": 5, | ||
"withOrderPlacement": True, | ||
"withOrderExecution": True, | ||
"withPrint": True, | ||
"events": ["FundamentalPriceShock"], | ||
"maxNormalOrders": 3, | ||
} | ||
], | ||
"numParallel": 3, | ||
}, | ||
"Market": {"class": "Market", "tickSize": 0.00001, "marketPrice": 300.0}, | ||
"FCNAgents": { | ||
"class": "FCNAgent", | ||
"numAgents": 10, | ||
"markets": ["Market"], | ||
"assetVolume": 50, | ||
"cashAmount": 10000, | ||
"fundamentalWeight": {"expon": [1.0]}, | ||
"chartWeight": {"expon": [0.0]}, | ||
"noiseWeight": {"expon": [1.0]}, | ||
"meanReversionTime": {"uniform": [50, 100]}, | ||
"noiseScale": 0.001, | ||
"timeWindowSize": [100, 200], | ||
"orderMargin": [0.0, 0.1], | ||
}, | ||
"FundamentalPriceShock": { | ||
"class": "FundamentalPriceShock", | ||
"target": "Market", | ||
"triggerTime": 0, | ||
"priceChangeRate": -0.1, | ||
"shockTimeLength": 1, | ||
"enabled": True, | ||
}, | ||
} | ||
|
||
def test_parallel_efficiency(self) -> None: | ||
wait_time = 0.2 # seconds | ||
|
||
class FCNDelayAgent(FCNAgent): | ||
def submit_orders(self, markets: List[Market]) -> List[Order | Cancel]: | ||
time.sleep(wait_time) # Simulate a delay | ||
return super().submit_orders(markets) | ||
|
||
setting = self.default_setting.copy() | ||
setting["FCNAgents"]["class"] = "FCNDelayAgent" # Use the delayed agent | ||
|
||
runner_class_dummy = self.runner_class | ||
self.runner_class = SequentialRunner # Temporarily set to SequentialRunner | ||
sequential_runner = cast( | ||
SequentialRunner, | ||
self.test__init__( | ||
setting_mode="dict", logger=None, simulator_class=None, setting=setting | ||
), | ||
) | ||
self.runner_class = runner_class_dummy | ||
parallel_runner = cast( | ||
self.runner_class, | ||
self.test__init__( | ||
setting_mode="dict", logger=None, simulator_class=None, setting=setting | ||
), | ||
) | ||
|
||
sequential_runner.class_register(cls=FCNDelayAgent) | ||
parallel_runner.class_register(cls=FCNDelayAgent) | ||
start_time = time.time() | ||
sequential_runner.main() | ||
end_time = time.time() | ||
elps_time_sequential = end_time - start_time | ||
start_time = time.time() | ||
parallel_runner.main() | ||
end_time = time.time() | ||
elps_time_parallel = end_time - start_time | ||
assert elps_time_sequential < wait_time * 15 + 1 | ||
assert elps_time_sequential > wait_time * 15 | ||
assert elps_time_parallel < wait_time * 5 + 1 | ||
assert elps_time_parallel > wait_time * 5 |
Oops, something went wrong.