Skip to content

Commit

Permalink
Cli funcs init and shutdown ray
Browse files Browse the repository at this point in the history
  • Loading branch information
rusu24edward committed Nov 7, 2023
1 parent aee532e commit 8ed4dad
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 11 deletions.
3 changes: 3 additions & 0 deletions abmarl/scripts/analyze_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,7 @@ def run(full_trained_directory, full_subscript, parameters):
from abmarl.tools import utils as adu
params = adu.find_params_from_output_dir(full_trained_directory)
analysis_func = adu.custom_import_module(full_subscript).run
import ray
ray.init()
analyze(params, analysis_func, **parameters)
ray.shutdown()
3 changes: 3 additions & 0 deletions abmarl/scripts/train_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,7 @@ def run(full_config_path):
shutil.copy(full_config_path, output_dir)

# Train the policy
import ray
ray.init()
_train_rllib(params)
ray.shutdown()
3 changes: 3 additions & 0 deletions abmarl/scripts/visualize_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,7 @@ def run(full_trained_directory, parameters):
from abmarl.tools import utils as adu
from abmarl.stage import visualize
params = adu.find_params_from_output_dir(full_trained_directory)
import ray
ray.init()
visualize(params, **parameters)
ray.shutdown()
7 changes: 0 additions & 7 deletions abmarl/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import ray
import ray.rllib
from ray.rllib.env import MultiAgentEnv
from ray.tune.registry import get_trainable_cls

Expand All @@ -14,7 +12,6 @@


def _stage_setup(params, seed=None, checkpoint=None):
# adu.register_env_from_params(params)
full_trained_directory = params['ray_tune']['local_dir']
# Modify the number of workers in the configuration
params['ray_tune']['config']['num_workers'] = 1
Expand Down Expand Up @@ -62,8 +59,6 @@ def analyze(
# Run the analysis function
analysis_func(sim, trainer)

# ray.shutdown()


def visualize(
params,
Expand Down Expand Up @@ -165,5 +160,3 @@ def animate(i):
while not all_done:
plt.pause(1)
plt.close(fig)

# ray.shutdown()
4 changes: 0 additions & 4 deletions abmarl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,8 @@ def _train_rllib(params):
"""
Train MARL policies with RLlib using parameters dictionary.
"""
# adu.register_env_from_params(params)
# import ray
from ray import tune
# ray.init()
tune.run(**params['ray_tune'])
# ray.shutdown()


def train(params):
Expand Down

0 comments on commit 8ed4dad

Please sign in to comment.