Skip to content

Commit

Permalink
[solvers] Add SpecificOptions sugar for solvers to map their options (R…
Browse files Browse the repository at this point in the history
…obotLocomotion#22043)

In the near future, we anticipate changing the SolverOptions API in
support of loading and saving (i.e., serialization). That's especially
troublesome for our solver back-ends that consume its information,
given the low-level and hodge-podge ways in which they hunt for and
apply their specific options.

This commit introduces a higher-level intermediary between solver
back-ends and the program's options. The goal is that SolverOptions
will solely be a user-facing aspect of defining a program; the solver
back-ends will never touch SolverOptions anymore, instead using the
SpecificOptions sugar to map our options API into the back-end.

The new API is designed to improve uniformity of common errors, such
as unknown names or wrongly-typed values.

The new API also lays the groundwork for more efficient processing, as
future work. It removes the need for the "Merge" (copy) operation in
the hot path -- since it only provides a *view* of the options, it can
easily keep track of several dictionaries and query them in order,
with no copying.
  • Loading branch information
jwnimmer-tri authored and RussTedrake committed Dec 15, 2024
1 parent 4e2ac66 commit 609f7db
Show file tree
Hide file tree
Showing 5 changed files with 884 additions and 0 deletions.
22 changes: 22 additions & 0 deletions solvers/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ drake_cc_package_library(
":solver_type_converter",
":sos_basis_generator",
":sparse_and_dense_matrix",
":specific_options",
":unrevised_lemke_solver",
],
)
Expand Down Expand Up @@ -314,6 +315,19 @@ drake_cc_library(
],
)

drake_cc_library(
name = "specific_options",
srcs = ["specific_options.cc"],
hdrs = ["specific_options.h"],
deps = [
":solver_id",
":solver_options",
"//common:name_value",
"//common:overloaded",
"//common:string_container",
],
)

drake_cc_library(
name = "indeterminate",
srcs = ["indeterminate.cc"],
Expand Down Expand Up @@ -2027,6 +2041,14 @@ drake_cc_googletest(
],
)

drake_cc_googletest(
name = "specific_options_test",
deps = [
":specific_options",
"//common/test_utilities:expect_throws_message",
],
)

drake_cc_googletest(
name = "augmented_lagrangian_test",
deps = [
Expand Down
14 changes: 14 additions & 0 deletions solvers/common_solver_option.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#pragma once

#include <optional>
#include <ostream>
#include <string>

#include "drake/common/fmt_ostream.h"

Expand Down Expand Up @@ -56,6 +58,18 @@ enum class CommonSolverOption {

std::ostream& operator<<(std::ostream& os,
CommonSolverOption common_solver_option);

namespace internal {

/* Aggregated values for CommonSolverOption, for Drake-internal use only. */
struct CommonSolverOptionValues {
std::string print_file_name;
bool print_to_console{false};
std::string standalone_reproduction_file_name;
std::optional<int> max_threads;
};

} // namespace internal
} // namespace solvers
} // namespace drake

Expand Down
308 changes: 308 additions & 0 deletions solvers/specific_options.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
#include "drake/solvers/specific_options.h"

#include <algorithm>
#include <limits>
#include <unordered_map>
#include <vector>

#include "drake/common/overloaded.h"

namespace drake {
namespace solvers {
namespace internal {

using OptionValue = SolverOptions::OptionValue;

SpecificOptions::SpecificOptions(const SolverId* id,
const SolverOptions* all_options)
: id_{id}, all_options_{all_options} {
DRAKE_DEMAND(id != nullptr);
DRAKE_DEMAND(all_options != nullptr);
}

SpecificOptions::~SpecificOptions() = default;

void SpecificOptions::Respell(
const std::function<void(const CommonSolverOptionValues&,
string_unordered_map<OptionValue>*)>& respell) {
DRAKE_DEMAND(respell != nullptr);
DRAKE_DEMAND(respelled_.empty());
respell(
CommonSolverOptionValues{
.print_file_name = all_options_->get_print_file_name(),
.print_to_console = all_options_->get_print_to_console(),
.standalone_reproduction_file_name =
all_options_->get_standalone_reproduction_file_name(),
.max_threads = all_options_->get_max_threads(),
},
&respelled_);
}

template <typename Result>
std::optional<Result> SpecificOptions::Pop(std::string_view key) {
if (popped_.contains(key)) {
return std::nullopt;
}
const auto& typed_options = all_options_->template GetOptions<Result>(*id_);
// TODO(jwnimmer-tri) Nix this string copy after we fix SolverOptions to use
// sensible representation choices.
if (auto iter = typed_options.find(std::string{key});
iter != typed_options.end()) {
popped_.emplace(key);
return iter->second;
}
if (auto iter = respelled_.find(key); iter != respelled_.end()) {
const OptionValue& value = iter->second;
if (std::holds_alternative<Result>(value)) {
popped_.emplace(key);
return std::get<Result>(value);
}
throw std::logic_error(fmt::format(
"{}: internal error: option {} was respelled to the wrong type",
id_->name(), key));
}
return {};
}

template std::optional<double> SpecificOptions::Pop(std::string_view);
template std::optional<int> SpecificOptions::Pop(std::string_view);
template std::optional<std::string> SpecificOptions::Pop(std::string_view);

void SpecificOptions::CopyToCallbacks(
const std::function<void(const std::string& key, double)>& set_double,
const std::function<void(const std::string& key, int)>& set_int,
const std::function<void(const std::string& key, const std::string&)>&
set_string) const {
// Wrap the solver's set_{type} callbacks with error-reporting sugar, and
// logic to promote integers to doubles.
auto on_double = [this, &set_double](const std::string& key, double value) {
if (popped_.contains(key)) {
return;
}
if (set_double != nullptr) {
set_double(key, value);
return;
}
throw std::logic_error(fmt::format(
"{}: floating-point options are not supported; the option {}={} is "
"invalid",
id_->name(), key, value));
};
auto on_int = [this, &set_int, &set_double](const std::string& key,
int value) {
if (popped_.contains(key)) {
return;
}
if (set_int != nullptr) {
set_int(key, value);
return;
}
if (set_double != nullptr) {
set_double(key, value);
return;
}
throw std::logic_error(fmt::format(
"{}: integer and floating-point options are not supported; the option "
"{}={} is invalid",
id_->name(), key, value));
};
auto on_string = [this, &set_string](const std::string& key,
const std::string& value) {
if (popped_.contains(key)) {
return;
}
if (set_string != nullptr) {
set_string(key, value);
return;
}
throw std::logic_error(fmt::format(
"{}: string options are not supported; the option {}='{}' is invalid",
id_->name(), key, value));
};

// Handle solver-specific options.
const std::unordered_map<std::string, double>& options_double =
all_options_->GetOptionsDouble(*id_);
const std::unordered_map<std::string, int>& options_int =
all_options_->GetOptionsInt(*id_);
const std::unordered_map<std::string, std::string>& options_str =
all_options_->GetOptionsStr(*id_);
for (const auto& [key, value] : options_double) {
on_double(key, value);
}
for (const auto& [key, value] : options_int) {
on_int(key, value);
}
for (const auto& [key, value] : options_str) {
on_string(key, value);
}

// Handle any respelled options, being careful not to set anything that has
// already been set.
for (const auto& [respelled_key, boxed_value] : respelled_) {
// Pedantically, lambdas cannot capture a structured binding so we need to
// make a local variable that we can capture.
const auto& key = respelled_key;
std::visit(
overloaded{[&key, &on_double, &options_double](double value) {
if (!options_double.contains(key)) {
on_double(key, value);
}
},
[&key, &on_int, &options_int](int value) {
if (!options_int.contains(key)) {
on_int(key, value);
}
},
[&key, &on_string, &options_str](const std::string& value) {
if (!options_str.contains(key)) {
on_string(key, value);
}
}},
boxed_value);
}
}

void SpecificOptions::InitializePending() {
pending_keys_.clear();
for (const auto& [key, _] : all_options_->GetOptionsDouble(*id_)) {
pending_keys_.insert(key);
}
for (const auto& [key, _] : all_options_->GetOptionsInt(*id_)) {
pending_keys_.insert(key);
}
for (const auto& [key, _] : all_options_->GetOptionsStr(*id_)) {
pending_keys_.insert(key);
}
for (const auto& [key, _] : respelled_) {
pending_keys_.insert(key);
}
for (const auto& key : popped_) {
pending_keys_.erase(key);
}
}

void SpecificOptions::CheckNoPending() const {
// Identify any unsupported names (i.e., leftovers in `pending_`).
if (!pending_keys_.empty()) {
std::vector<std::string_view> unknown_names;
for (const auto& name : pending_keys_) {
unknown_names.push_back(name);
}
std::sort(unknown_names.begin(), unknown_names.end());
throw std::logic_error(fmt::format(
"{}: the following solver option names were not recognized: {}",
id_->name(), fmt::join(unknown_names, ", ")));
}
}

std::optional<OptionValue> SpecificOptions::PrepareToCopy(const char* name) {
DRAKE_DEMAND(name != nullptr);
const std::unordered_map<std::string, double>& options_double =
all_options_->GetOptionsDouble(*id_);
// TODO(jwnimmer-tri) Nix these string copies after we fix SolverOptions to
// use sensible representation choices.
if (auto iter = options_double.find(std::string{name});
iter != options_double.end()) {
pending_keys_.erase(iter->first);
return iter->second;
}
const std::unordered_map<std::string, int>& options_int =
all_options_->GetOptionsInt(*id_);
if (auto iter = options_int.find(std::string{name});
iter != options_int.end()) {
pending_keys_.erase(iter->first);
return iter->second;
}
const std::unordered_map<std::string, std::string>& options_str =
all_options_->GetOptionsStr(*id_);
if (auto iter = options_str.find(std::string{name});
iter != options_str.end()) {
pending_keys_.erase(iter->first);
return iter->second;
}
if (auto iter = respelled_.find(name); iter != respelled_.end()) {
pending_keys_.erase(iter->first);
return iter->second;
}
return {};
}

template <typename T>
void SpecificOptions::CopyFloatingPointOption(const char* name, T* output) {
DRAKE_DEMAND(output != nullptr);
if (auto boxed_value = PrepareToCopy(name)) {
if (std::holds_alternative<double>(*boxed_value)) {
*output = std::get<double>(*boxed_value);
return;
}
if (std::holds_alternative<int>(*boxed_value)) {
*output = std::get<int>(*boxed_value);
return;
}
throw std::logic_error(
fmt::format("{}: Expected a floating-point value for option {}",
id_->name(), name));
}
}
template void SpecificOptions::CopyFloatingPointOption(const char*, double*);
template void SpecificOptions::CopyFloatingPointOption(const char*, float*);

template <typename T>
void SpecificOptions::CopyIntegralOption(const char* name, T* output) {
DRAKE_DEMAND(output != nullptr);
if (auto boxed_value = PrepareToCopy(name)) {
if (std::holds_alternative<int>(*boxed_value)) {
const int value = std::get<int>(*boxed_value);
if constexpr (std::is_same_v<T, int>) {
*output = value;
} else if constexpr (std::is_same_v<T, bool>) {
if (!(value == 0 || value == 1)) {
throw std::logic_error(fmt::format(
"{}: Expected a boolean value (0 or 1) for int option {}={}",
id_->name(), name, value));
}
*output = value;
} else {
static_assert(std::is_same_v<T, uint32_t>);
if (value < 0) {
throw std::logic_error(fmt::format(
"{}: Expected a non-negative value for unsigned int option {}={}",
id_->name(), name, value));
}
// In practice it's unlikely that sizeof(int) > 4, but better safe than
// sorry. This also serves as a reminder to sanity check other casts if
// we add more template instantiations than just uint32_t.
if (static_cast<int64_t>(value) >
static_cast<int64_t>(std::numeric_limits<uint32_t>::max())) {
throw std::logic_error(
fmt::format("{}: Too-large value for uint32 option {}={}",
id_->name(), name, value));
}
*output = value;
}
return;
}
throw std::logic_error(fmt::format(
"{}: Expected an integer value for option {}", id_->name(), name));
}
}
template void SpecificOptions::CopyIntegralOption(const char*, int*);
template void SpecificOptions::CopyIntegralOption(const char*, bool*);
template void SpecificOptions::CopyIntegralOption(const char*, uint32_t*);

void SpecificOptions::CopyStringOption(const char* name, std::string* output) {
DRAKE_DEMAND(output != nullptr);
if (auto boxed_value = PrepareToCopy(name)) {
if (std::holds_alternative<std::string>(*boxed_value)) {
*output = std::get<std::string>(*boxed_value);
return;
}
throw std::logic_error(fmt::format(
"{}: Expected a string value for option {}", id_->name(), name));
}
}

} // namespace internal
} // namespace solvers
} // namespace drake
Loading

0 comments on commit 609f7db

Please sign in to comment.