Skip to content

Commit

Permalink
a bug in rl softmax...forgot to use the values for discrete_distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremyfix committed Sep 29, 2018
1 parent 80f51db commit 026885a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 6 deletions.
27 changes: 27 additions & 0 deletions examples/example-000-003-agents.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,32 @@ class Q {

// Let us define functions for plotting histograms

template<typename QFUNCTION,
typename ACTION_ITERATOR>
void plotQ(std::string title, const QFUNCTION q_function,
const ACTION_ITERATOR& a_begin,
const ACTION_ITERATOR& a_end,
std::string filename) {
std::ofstream file;

file.open(filename.c_str());
if(!file) {
std::cerr << "Cannot open \"" << filename << "\". Aborting";
return;
}
file << "set title '" << title << "';" << std::endl
<< "set xrange [0:" << NB_ARMS-1 << "];" << std::endl
<< "set yrange [0:1];" << std::endl
<< "set xlabel 'Actions'" << std::endl
<< "plot '-' with lines notitle" << std::endl;
S dummy;
for(auto ait = a_begin; ait != a_end ; ++ait)
file << *ait << ' ' << q_function(dummy, *ait) << std::endl;

file.close();
std::cout << "\"" << filename << "\" generated." << std::endl;
};

#define HISTO_NB_SAMPLES 20000

template<typename POLICY>
Expand Down Expand Up @@ -176,6 +202,7 @@ int main(int argc, char* argv[]) {
x = a/(double)NB_ARMS;
q.tabular_values[a] = (1-.2*x)*pow(sin(5*(x+.15)),2);
}
plotQ("Q values", q, a_begin, a_end, "Qvalues.plot");

// Let us plot histograms of policys.
plot1D("Random policy choices",random_policy,"RandomPolicy.plot");
Expand Down
11 changes: 5 additions & 6 deletions src/rlAlgo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,12 @@ namespace rl {
RANDOM_DEVICE& rd)
-> decltype(*begin) {
auto size = end-begin;
std::vector<double> cum(size);
std::vector<double> fvalues(size);
auto iter = begin;
auto citer = cum.begin();
for(; iter != end; ++iter, ++citer)
*citer = f(*iter);
std::discrete_distribution<decltype(end - begin)> d;
auto fvaluesiter = fvalues.begin();
for(; iter != end; ++iter, ++fvaluesiter)
*fvaluesiter = f(*iter);
std::discrete_distribution<decltype(end - begin)> d(fvalues.begin(), fvalues.end());
return *(begin + d(rd));

}
Expand All @@ -220,7 +220,6 @@ namespace rl {
fmax = std::max(fmax, f_values[*it]);
}


auto shifted_exp_values = [&temperature, &f_values, &fmax](const decltype(*begin)& a) -> double {
return exp((f_values[a] - fmax)/temperature);
};
Expand Down

0 comments on commit 026885a

Please sign in to comment.