From 026885a8a65139a40ffaa1624f7d2098f8941c59 Mon Sep 17 00:00:00 2001 From: Jeremy Fix Date: Sat, 29 Sep 2018 08:47:24 +0200 Subject: [PATCH] a bug in rl softmax...forgot to use the values for discrete_distribution --- examples/example-000-003-agents.cc | 27 +++++++++++++++++++++++++++ src/rlAlgo.hpp | 11 +++++------ 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/examples/example-000-003-agents.cc b/examples/example-000-003-agents.cc index 6bcd0fc..e07bf27 100644 --- a/examples/example-000-003-agents.cc +++ b/examples/example-000-003-agents.cc @@ -70,6 +70,32 @@ class Q { // Let us define functions for plotting histograms +template + void plotQ(std::string title, const QFUNCTION q_function, + const ACTION_ITERATOR& a_begin, + const ACTION_ITERATOR& a_end, + std::string filename) { + std::ofstream file; + + file.open(filename.c_str()); + if(!file) { + std::cerr << "Cannot open \"" << filename << "\". Aborting"; + return; + } + file << "set title '" << title << "';" << std::endl + << "set xrange [0:" << NB_ARMS-1 << "];" << std::endl + << "set yrange [0:1];" << std::endl + << "set xlabel 'Actions'" << std::endl + << "plot '-' with lines notitle" << std::endl; + S dummy; + for(auto ait = a_begin; ait != a_end ; ++ait) + file << *ait << ' ' << q_function(dummy, *ait) << std::endl; + + file.close(); + std::cout << "\"" << filename << "\" generated." << std::endl; + }; + #define HISTO_NB_SAMPLES 20000 template @@ -176,6 +202,7 @@ int main(int argc, char* argv[]) { x = a/(double)NB_ARMS; q.tabular_values[a] = (1-.2*x)*pow(sin(5*(x+.15)),2); } + plotQ("Q values", q, a_begin, a_end, "Qvalues.plot"); // Let us plot histograms of policys. plot1D("Random policy choices",random_policy,"RandomPolicy.plot"); diff --git a/src/rlAlgo.hpp b/src/rlAlgo.hpp index e1a3668..6efb85f 100644 --- a/src/rlAlgo.hpp +++ b/src/rlAlgo.hpp @@ -190,12 +190,12 @@ namespace rl { RANDOM_DEVICE& rd) -> decltype(*begin) { auto size = end-begin; - std::vector cum(size); + std::vector fvalues(size); auto iter = begin; - auto citer = cum.begin(); - for(; iter != end; ++iter, ++citer) - *citer = f(*iter); - std::discrete_distribution d; + auto fvaluesiter = fvalues.begin(); + for(; iter != end; ++iter, ++fvaluesiter) + *fvaluesiter = f(*iter); + std::discrete_distribution d(fvalues.begin(), fvalues.end()); return *(begin + d(rd)); } @@ -220,7 +220,6 @@ namespace rl { fmax = std::max(fmax, f_values[*it]); } - auto shifted_exp_values = [&temperature, &f_values, &fmax](const decltype(*begin)& a) -> double { return exp((f_values[a] - fmax)/temperature); };