Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Dev2 #143

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
91 changes: 91 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
{
"files.associations": {
"*.ipp": "cpp",
"memory": "cpp",
"regex": "cpp",
"utility": "cpp",
"__bit_reference": "cpp",
"__config": "cpp",
"__debug": "cpp",
"__errc": "cpp",
"__functional_base": "cpp",
"__hash_table": "cpp",
"__locale": "cpp",
"__mutex_base": "cpp",
"__node_handle": "cpp",
"__nullptr": "cpp",
"__split_buffer": "cpp",
"__string": "cpp",
"__threading_support": "cpp",
"__tree": "cpp",
"__tuple": "cpp",
"algorithm": "cpp",
"array": "cpp",
"atomic": "cpp",
"bit": "cpp",
"bitset": "cpp",
"cctype": "cpp",
"chrono": "cpp",
"clocale": "cpp",
"cmath": "cpp",
"codecvt": "cpp",
"complex": "cpp",
"condition_variable": "cpp",
"csignal": "cpp",
"cstdarg": "cpp",
"cstddef": "cpp",
"cstdint": "cpp",
"cstdio": "cpp",
"cstdlib": "cpp",
"cstring": "cpp",
"ctime": "cpp",
"cwchar": "cpp",
"cwctype": "cpp",
"deque": "cpp",
"exception": "cpp",
"forward_list": "cpp",
"fstream": "cpp",
"functional": "cpp",
"future": "cpp",
"initializer_list": "cpp",
"iomanip": "cpp",
"ios": "cpp",
"iosfwd": "cpp",
"iostream": "cpp",
"istream": "cpp",
"iterator": "cpp",
"limits": "cpp",
"list": "cpp",
"locale": "cpp",
"map": "cpp",
"mutex": "cpp",
"new": "cpp",
"numeric": "cpp",
"optional": "cpp",
"ostream": "cpp",
"queue": "cpp",
"random": "cpp",
"ratio": "cpp",
"set": "cpp",
"sstream": "cpp",
"stack": "cpp",
"stdexcept": "cpp",
"streambuf": "cpp",
"string": "cpp",
"string_view": "cpp",
"system_error": "cpp",
"thread": "cpp",
"tuple": "cpp",
"type_traits": "cpp",
"typeindex": "cpp",
"typeinfo": "cpp",
"unordered_map": "cpp",
"unordered_set": "cpp",
"valarray": "cpp",
"variant": "cpp",
"vector": "cpp",
"__functional_03": "cpp",
"filesystem": "cpp"
},
"python.pythonPath": "/usr/local/Caskroom/miniconda/base/envs/moRL/bin/python"
}
Empty file modified README.md
100644 → 100755
Empty file.
6 changes: 4 additions & 2 deletions elf/ai.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace elf {

using namespace std;

template <typename S, typename A>
template <typename S, typename A> // RTSState RTSMCAction
class AI_T {
public:
using Action = A;
Expand All @@ -43,6 +43,8 @@ class AI_T {

virtual ~AI_T() { }

void Print(){std::cout<<"AIName: _name: "<<_name<<" id: "<<_id<<std::endl;}

private:
const std::string _name;
int _id;
Expand All @@ -51,7 +53,7 @@ class AI_T {
virtual void on_set_id() { }
};

template <typename S, typename A, typename AIComm>
template <typename S, typename A, typename AIComm> //带有AICommT的AI
class AIWithCommT : public AI_T<S, A> {
public:
using AI = AI_T<S, A>;
Expand Down
14 changes: 9 additions & 5 deletions elf/comm_template.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,16 @@ class CommT {
_signal->use_queue_per_group(_groups.size());
}

//std::cout<<"-------CommT CollectorsReady---------------"<<std::endl;
for (auto &p : _map) {
p.second.InitCond(_exclusive_groups.size());
}

// std::cout<<"pool_size 1: "<<_pool.size()<<std::endl; // 初始为 0
_pool.resize(_groups.size());
//std::cout<<"pool_size 2: "<<_pool.size()<<std::endl; // 设为16
for (auto &g : _groups) {
CollectorGroup *p = g.get();
_pool.push([p, this](int) { p->MainLoop(); });
_pool.push([p, this](int) { p->MainLoop(); }); // 16 个 Batch Collector
}
}

Expand Down Expand Up @@ -343,13 +345,15 @@ class ContextT {
const Options &options() const { return _options; }

void Start(GameStartFunc game_start_func) {
_comm.CollectorsReady();

std::cout<<"--------ContextT Start---------------"<<std::endl;
_comm.CollectorsReady(); //设置 BatchCollector

std::cout<<"_pool.size"<<_pool.size()<<std::endl; // 1024 - num_games
// Now we start all jobs.
for (int i = 0; i < _pool.size(); ++i) {
_pool.push([i, this, &game_start_func](int){
elf::Signal signal(_done.flag(), _prepare_stop);
game_start_func(i, _context_options, _options, signal, &_comm);
game_start_func(i, _context_options, _options, signal, &_comm); //每个线程调用一次GameStartFunc
// std::cout << "G[" << i << "] is ending" << std::endl;
_done.notify();
});
Expand Down
28 changes: 14 additions & 14 deletions elf/context_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,22 @@ def __init__(self):
call_from = self,
define_args = [
("num_games", 1024),
("batchsize", 128),
("game_multi", dict(type=int, default=None)),
("T", 6),
("eval", dict(action="store_true")),
("wait_per_group", dict(action="store_true")),
("num_collectors", 0),
("verbose_comm", dict(action="store_true")),
("verbose_collector", dict(action="store_true")),
("mcts_threads", 0),
("batchsize", 128), # 64
("game_multi", dict(type=int, default=None)),
("T", 6), # 20
("eval", dict(action="store_true")), # False
("wait_per_group", dict(action="store_true")), # False
("num_collectors", 0),
("verbose_comm", dict(action="store_true")), # False
("verbose_collector", dict(action="store_true")), # False
("mcts_threads", 0),
("mcts_rollout_per_thread", 1),
("mcts_verbose", dict(action="store_true")),
("mcts_save_tree_filename", ""),
("mcts_verbose_time", dict(action="store_true")),
("mcts_verbose", dict(action="store_true")), # False
("mcts_save_tree_filename", ""),
("mcts_verbose_time", dict(action="store_true")), # False

("mcts_use_prior", dict(action="store_true")),
("mcts_pseudo_games", 0),
("mcts_use_prior", dict(action="store_true")), # False
("mcts_pseudo_games", 0),
("mcts_pick_method", "most_visited"),
],
on_get_args = self._on_get_args
Expand Down
15 changes: 13 additions & 2 deletions elf/game_base.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,25 @@ class GameBaseT {
return res;
}

void MainLoop(const std::atomic_bool *done = nullptr) {
_state->Init();
void MainLoop(const std::atomic_bool *done = nullptr,bool isPrint = false) {
if(isPrint)
std::cout<<"-------MainLoop----------"<<std::endl;
_state->Init(isPrint); // 初始化游戏
// if(isPrint){
// std::cout<<"--------Start PleyerInfo"<<std::endl;
// std::cout<<_state->env().PrintPlayerInfo()<<std::endl;
// }
while (true) {
if (Step(done) != GAME_NORMAL) break;
if (done != nullptr && done->load()) break;
}
// Send message to AIs.
_act(false, done);
_game_end();
// if(isPrint){
// std::cout<<"--------End PleyerInfo"<<std::endl;
// std::cout<<_state->env().PrintPlayerInfo()<<std::endl;
// }
_state->Finalize();
}

Expand Down Expand Up @@ -122,6 +132,7 @@ class GameBaseT {
bot.ai->GameEnd();
}
if (_spectator != nullptr) {
std::cout<<"_spectator"<<std::endl;
_spectator->GameEnd();
}
}
Expand Down
8 changes: 4 additions & 4 deletions elf/python_options_utils_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@

struct ContextOptions {
// How many simulation threads we are running.
int num_games = 1;
int num_games = 1; // 1024

// The maximum number of threads per game
int max_num_threads = 0;
int max_num_threads = 0; //0

// History length. How long we should keep the history.
int T = 1;
int T = 1; // 20

// verbose options.
bool verbose_comm = false;
Expand All @@ -34,7 +34,7 @@ struct ContextOptions {
// Whether we wait for each group or we wait jointly.
bool wait_per_group = false;

int num_collectors = 1;
int num_collectors = 1; // 0

mcts::TSOptions mcts_options;

Expand Down
8 changes: 4 additions & 4 deletions elf/tree_search_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ using namespace std;

struct TSOptions {
int max_num_moves = 0;
int num_threads = 16;
int num_rollout_per_thread = 100;
bool verbose = false;
int num_threads = 16; // 0
int num_rollout_per_thread = 100; // 1
bool verbose = false;
bool verbose_time = false;

string save_tree_filename;
string save_tree_filename; // ""

bool persistent_tree = false;
// [TODO] Not a good design.
Expand Down
21 changes: 16 additions & 5 deletions elf/utils_elf.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -223,20 +223,22 @@ def __init__(self, GC, co, descriptions, use_numpy=False, gpu=None, params=dict(
gpu(int): gpu to use.
params(dict): additional parameters
'''


#self.isPrint = False
self._init_collectors(GC, co, descriptions, use_gpu=gpu is not None, use_numpy=use_numpy)
self.gpu = gpu
self.inputs_gpu = [ self.inputs[gids[0]].cpu2gpu(gpu=gpu) for gids in self.gpu2gid ] if gpu is not None else None
self.params = params
self._cb = { }


def _init_collectors(self, GC, co, descriptions, use_gpu=True, use_numpy=False):
num_games = co.num_games

total_batchsize = 0
for key, v in descriptions.items():
total_batchsize += v["batchsize"]

if co.num_collectors > 0:
num_recv_thread = co.num_collectors
else:
Expand Down Expand Up @@ -269,11 +271,11 @@ def _init_collectors(self, GC, co, descriptions, use_gpu=True, use_numpy=False):
for i in range(num_recv_thread):
group_id = GC.AddCollectors(batchsize, len(gpu2gid) - 1, timeout_usec, gstat)

input_batch = Batch.load(GC, "input", input, group_id, use_gpu=use_gpu, use_numpy=use_numpy)
input_batch = Batch.load(GC, "input", input, group_id, use_gpu=use_gpu, use_numpy=use_numpy) # 加载输入Batch
input_batch.batchsize = batchsize
inputs.append(input_batch)
if reply is not None:
reply_batch = Batch.load(GC, "reply", reply, group_id, use_gpu=use_gpu, use_numpy=use_numpy)
reply_batch = Batch.load(GC, "reply", reply, group_id, use_gpu=use_gpu, use_numpy=use_numpy) # 加载回复Batch
reply_batch.batchsize= batchsize
replies.append(reply_batch)
else:
Expand All @@ -298,6 +300,14 @@ def _init_collectors(self, GC, co, descriptions, use_gpu=True, use_numpy=False):
self.name2idx = name2idx
self.gid2gpu = gid2gpu
self.gpu2gid = gpu2gid
# if not self.isPrint:
# print("idx2name",self.idx2name)
# print("name2idx",self.name2idx)
# print("gid2gpu",self.gid2gpu)
# print("gpu2gid",self.gpu2gid)
# print("num_collectors: ",co.num_collectors)
# self.isPrint = True


def reg_has_callback(self, key):
return key in self.name2idx
Expand All @@ -311,6 +321,7 @@ def reg_callback_if_exists(self, key, cb):

def reg_callback(self, key, cb):
'''Set callback function for key
注册回调函数,有符合要求和数量的Batch到来时,调用对应的函数

Parameters:
key(str): the key used to register the callback function.
Expand All @@ -332,7 +343,7 @@ def _call(self, infos):
raise ValueError("info.gid[%d] is not in callback functions" % infos.gid)

if self._cb[infos.gid] is None:
return;
return

batchsize = len(infos.s)

Expand Down
9 changes: 5 additions & 4 deletions rlpytorch/sampler/sample_methods.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def sample_with_check(probs, greedy=False):
'''
num_action = probs.size(1)
if greedy:
_, actions = probs.max(1)
_, actions = probs.max(1) # 贪婪算法,每次取概率最大的动作
return actions
while True:
actions = probs.multinomial(1)[:,0]
actions = probs.multinomial(1)[:,0] # 按照概率选择一个动作
cond1 = (actions < 0).sum()
cond2 = (actions >= num_action).sum()
if cond1 == 0 and cond2 == 0:
Expand Down Expand Up @@ -74,8 +74,9 @@ def sample_eps_with_check(probs, epsilon, greedy=False):
rej_p = probs.new().resize_(2)
rej_p[0] = 1 - epsilon
rej_p[1] = epsilon
# rej 按照概率取 0 或 1(batchsize次),取到1时(epsilon)表示此次不选择该动作并随机取样
rej = rej_p.multinomial(batchsize, replacement=True).byte()

# 随机取样
uniform_p = probs.new().resize_(num_action).fill_(1.0 / num_action)
uniform_sampling = uniform_p.multinomial(batchsize, replacement=True)
actions[rej] = uniform_sampling[rej]
Expand Down Expand Up @@ -110,7 +111,7 @@ def sample_multinomial(state_curr, args, node="pi", greedy=False):
return actions
else:
probs = state_curr[node].data
return sample_eps_with_check(probs, args.epsilon, greedy=greedy)
return sample_eps_with_check(probs, args.epsilon, greedy=greedy) # probs 0 False

def epsilon_greedy(state_curr, args, node="pi"):
''' epsilon greedy sampling
Expand Down
Empty file modified rlpytorch/sampler/sampler.py
100644 → 100755
Empty file.
Loading