Skip to content

Commit

Permalink
add multi thread agent parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
masanorihirano committed Oct 23, 2024
1 parent 28cbb86 commit 8b1a175
Show file tree
Hide file tree
Showing 5 changed files with 241 additions and 51 deletions.
3 changes: 2 additions & 1 deletion docs/source/user_guide/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ Json config
["MarketName1", "MarketName2", float], # fundamentalVolatility is required in both markets
...
]
}
},
"numParallel": int (Optional; default 1; only for MultiThreadedRunner),
},
"FundamentalPriceShock": {
"class": "FundamentalPriceShock",
Expand Down
1 change: 1 addition & 0 deletions pams/runners/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .agent_parallel import MultiThreadAgentParallelRuner
from .base import Runner
from .sequential import SequentialRunner
90 changes: 90 additions & 0 deletions pams/runners/agent_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import os
import random
import warnings
from concurrent.futures import ThreadPoolExecutor
from io import TextIOWrapper
from multiprocessing import cpu_count
from typing import Dict
from typing import List
from typing import Optional
from typing import Type
from typing import Union

from pams.logs.base import Logger
from pams.simulator import Simulator

from ..order import Cancel
from ..order import Order
from ..session import Session
from .sequential import SequentialRunner


class MultiThreadAgentParallelRuner(SequentialRunner):
"""Multi Thread Agent Parallel runner class. This is experimental.
In this runner, only normal agents are parallelized in each steps.
This means that the number of agents that can be parallelized is limited by MAX_NORMAL_ORDERS.
"""

def __init__(
self,
settings: Union[Dict, TextIOWrapper, os.PathLike, str],
prng: Optional[random.Random] = None,
logger: Optional[Logger] = None,
simulator_class: Type[Simulator] = Simulator,
):
super().__init__(settings, prng, logger, simulator_class)
warnings.warn(
"MultiThreadRuner is experimental. Future changes may occur disruptively."
)
self.num_parallel = max(cpu_count() - 1, 1)

def _setup(self) -> None:
super()._setup()
if "numParallel" in self.settings["simulation"]:
self.num_parallel = self.settings["simulation"]["numParallel"]
max_notmal_orders = max(
session.max_normal_orders for session in self.simulator.sessions
)
if self.num_parallel > max_notmal_orders:
warnings.warn(
f"When MultiThreadAgentParallelRuner is used, the maximum number of parallel agents"
f" is limited by max_normal_orders ({max_notmal_orders}) evne if numParallel"
f" ({self.num_parallel}) is set to a larger value."
)
self.thread_pool = ThreadPoolExecutor(max_workers=self.num_parallel)

def _collect_orders_from_normal_agents(
self, session: Session
) -> List[List[Union[Order, Cancel]]]:
"""collect orders from normal_agents. (Internal method)
orders are corrected until the total number of orders reaches max_normal_orders
Args:
session (Session): session.
Returns:
List[List[Union[Order, Cancel]]]: orders lists.
"""
agents = self.simulator.normal_frequency_agents
agents = self._prng.sample(agents, len(agents))
all_orders: List[List[Union[Order, Cancel]]] = []
# TODO: currently the original impl is used for order counting.
# See more in the SequentialRunner class.
futures = []
for agent in agents[: session.max_normal_orders]:
future = self.thread_pool.submit(
agent.submit_orders, self.simulator.markets
)
futures.append((future, agent))
for future, agent in futures:
orders = future.result()
if len(orders) > 0:
if not session.with_order_placement:
raise AssertionError("currently order is not accepted")
if sum([order.agent_id != agent.agent_id for order in orders]) > 0:
raise ValueError(
"spoofing order is not allowed. please check agent_id in order"
)
all_orders.append(orders)
return all_orders
98 changes: 98 additions & 0 deletions tests/pams/runners/test_agent_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import time
from typing import Dict, List, cast
from typing import Type

from pams.agents.fcn_agent import FCNAgent
from pams.market import Market
from pams.order import Cancel, Order
from pams.runners import MultiThreadAgentParallelRuner
from pams.runners import Runner
from pams.runners.sequential import SequentialRunner
from tests.pams.runners.test_base import TestRunner


class TestMultiThreadAgentParallelRuner(TestRunner):
runner_class: Type[Runner] = MultiThreadAgentParallelRuner
default_setting: Dict = {
"simulation": {
"markets": ["Market"],
"agents": ["FCNAgents"],
"sessions": [
{
"sessionName": 0,
"iterationSteps": 5,
"withOrderPlacement": True,
"withOrderExecution": True,
"withPrint": True,
"events": ["FundamentalPriceShock"],
"maxNormalOrders": 3,
}
],
"numParallel": 3,
},
"Market": {"class": "Market", "tickSize": 0.00001, "marketPrice": 300.0},
"FCNAgents": {
"class": "FCNAgent",
"numAgents": 10,
"markets": ["Market"],
"assetVolume": 50,
"cashAmount": 10000,
"fundamentalWeight": {"expon": [1.0]},
"chartWeight": {"expon": [0.0]},
"noiseWeight": {"expon": [1.0]},
"meanReversionTime": {"uniform": [50, 100]},
"noiseScale": 0.001,
"timeWindowSize": [100, 200],
"orderMargin": [0.0, 0.1],
},
"FundamentalPriceShock": {
"class": "FundamentalPriceShock",
"target": "Market",
"triggerTime": 0,
"priceChangeRate": -0.1,
"shockTimeLength": 1,
"enabled": True,
},
}

def test_parallel_efficiency(self) -> None:
wait_time = 0.2 # seconds

class FCNDelayAgent(FCNAgent):
def submit_orders(self, markets: List[Market]) -> List[Order | Cancel]:
time.sleep(wait_time) # Simulate a delay
return super().submit_orders(markets)

setting = self.default_setting.copy()
setting["FCNAgents"]["class"] = "FCNDelayAgent" # Use the delayed agent

runner_class_dummy = self.runner_class
self.runner_class = SequentialRunner # Temporarily set to SequentialRunner
sequential_runner = cast(
SequentialRunner,
self.test__init__(
setting_mode="dict", logger=None, simulator_class=None, setting=setting
),
)
self.runner_class = runner_class_dummy
parallel_runner = cast(
self.runner_class,
self.test__init__(
setting_mode="dict", logger=None, simulator_class=None, setting=setting
),
)

sequential_runner.class_register(cls=FCNDelayAgent)
parallel_runner.class_register(cls=FCNDelayAgent)
start_time = time.time()
sequential_runner.main()
end_time = time.time()
elps_time_sequential = end_time - start_time
start_time = time.time()
parallel_runner.main()
end_time = time.time()
elps_time_parallel = end_time - start_time
assert elps_time_sequential < wait_time * 15 + 1
assert elps_time_sequential > wait_time * 15
assert elps_time_parallel < wait_time * 5 + 1
assert elps_time_parallel > wait_time * 5
Loading

0 comments on commit 8b1a175

Please sign in to comment.