Skip to content

Commit

Permalink
[solvers] Add SpecificOptions sugar for solvers to map their options
Browse files Browse the repository at this point in the history
In the near future, we anticipate changing the SolverOptions API in
support of loading and saving (i.e., serialization). That's especially
troublesome for our solver back-ends that consume its information,
given the low-level and hodge-podge ways in which they hunt for and
apply their specific options.

This commit introduces a higher-level intermediary between solver
back-ends and the program's options. The goal is that SolverOptions
will solely be a user-facing aspect of defining a program; the solver
back-ends will never touch SolverOptions anymore, instead using the
SpecificOptions sugar to map our options API into the back-end.

The new API is designed to improve uniformity of common errors, such
as unknown names or wrongly-typed values.

The new API also lays the groundwork for more efficient processing, as
future work. It removes the need for the "Merge" (copy) operation in
the hot path -- since it only provides a *view* of the options, it can
easily keep track of several dictionaries and query them in order,
with no copying.
  • Loading branch information
jwnimmer-tri committed Oct 15, 2024
1 parent a60c85e commit 10ffba6
Show file tree
Hide file tree
Showing 5 changed files with 907 additions and 0 deletions.
22 changes: 22 additions & 0 deletions solvers/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ drake_cc_package_library(
":solver_type_converter",
":sos_basis_generator",
":sparse_and_dense_matrix",
":specific_options",
":unrevised_lemke_solver",
],
)
Expand Down Expand Up @@ -314,6 +315,19 @@ drake_cc_library(
],
)

drake_cc_library(
name = "specific_options",
srcs = ["specific_options.cc"],
hdrs = ["specific_options.h"],
deps = [
":solver_id",
":solver_options",
"//common:name_value",
"//common:overloaded",
"//common:string_container",
],
)

drake_cc_library(
name = "indeterminate",
srcs = ["indeterminate.cc"],
Expand Down Expand Up @@ -2026,6 +2040,14 @@ drake_cc_googletest(
],
)

drake_cc_googletest(
name = "specific_options_test",
deps = [
":specific_options",
"//common/test_utilities:expect_throws_message",
],
)

drake_cc_googletest(
name = "augmented_lagrangian_test",
deps = [
Expand Down
14 changes: 14 additions & 0 deletions solvers/common_solver_option.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#pragma once

#include <optional>
#include <ostream>
#include <string>

#include "drake/common/fmt_ostream.h"

Expand Down Expand Up @@ -56,6 +58,18 @@ enum class CommonSolverOption {

std::ostream& operator<<(std::ostream& os,
CommonSolverOption common_solver_option);

namespace internal {

/* Aggregated values for CommonSolverOption, for Drake-internal use only. */
struct CommonSolverOptionValues {
std::string print_file_name;
bool print_to_console{false};
std::string standalone_reproduction_file_name;
std::optional<int> max_threads;
};

} // namespace internal
} // namespace solvers
} // namespace drake

Expand Down
311 changes: 311 additions & 0 deletions solvers/specific_options.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
#include "drake/solvers/specific_options.h"

#include <algorithm>
#include <limits>
#include <unordered_map>
#include <vector>

#include "drake/common/overloaded.h"

namespace drake {
namespace solvers {
namespace internal {

using OptionValue = SolverOptions::OptionValue;

SpecificOptions::SpecificOptions(const SolverId* id,
const SolverOptions* all_options)
: id_{id}, all_options_{all_options} {
DRAKE_DEMAND(id != nullptr);
DRAKE_DEMAND(all_options != nullptr);
}

SpecificOptions::~SpecificOptions() = default;

void SpecificOptions::Respell(
const std::function<void(const CommonSolverOptionValues&,
string_unordered_map<OptionValue>*)>& respell) {
DRAKE_DEMAND(respell != nullptr);
DRAKE_DEMAND(respelled_.empty());
respell(
CommonSolverOptionValues{
.print_file_name = all_options_->get_print_file_name(),
.print_to_console = all_options_->get_print_to_console(),
.standalone_reproduction_file_name =
all_options_->get_standalone_reproduction_file_name(),
.max_threads = all_options_->get_max_threads(),
},
&respelled_);
}

template <typename Result>
std::optional<Result> SpecificOptions::Pop(std::string_view key) {
if (popped_.contains(key)) {
return std::nullopt;
}
const auto& typed_options = all_options_->template GetOptions<Result>(*id_);
// TODO(jwnimmer-tri) Nix this string copy after we fix SolverOptions to use
// sensible representation choices.
if (auto iter = typed_options.find(std::string{key});
iter != typed_options.end()) {
popped_.emplace(key);
return iter->second;
}
if (auto iter = respelled_.find(key); iter != respelled_.end()) {
const OptionValue& value = iter->second;
if (std::holds_alternative<Result>(value)) {
popped_.emplace(key);
return std::get<Result>(value);
}
throw std::logic_error(fmt::format(
"{}: internal error: option {} was respelled to the wrong type",
id_->name(), key));
}
return {};
}

template std::optional<double> SpecificOptions::Pop(std::string_view);
template std::optional<int> SpecificOptions::Pop(std::string_view);
template std::optional<std::string> SpecificOptions::Pop(std::string_view);

void SpecificOptions::CopyToCallbacks(
const std::function<void(const std::string& key, double)>& set_double,
const std::function<void(const std::string& key, int)>& set_int,
const std::function<void(const std::string& key, const std::string&)>&
set_string) const {
// Bail out early when we have no options at all for this solver.
const std::unordered_map<std::string, double>& options_double =
all_options_->GetOptionsDouble(*id_);
const std::unordered_map<std::string, int>& options_int =
all_options_->GetOptionsInt(*id_);
const std::unordered_map<std::string, std::string>& options_str =
all_options_->GetOptionsStr(*id_);
if (options_double.empty() && options_int.empty() && options_str.empty() &&
respelled_.empty()) {
return;
}

// Wrap the solver's set_{type} callbacks with error-reporting sugar, and
// logic to promote integers to doubles.
auto on_double = [this, &set_double](const std::string& key, double value) {
if (popped_.contains(key)) {
return;
}
if (set_double != nullptr) {
set_double(key, value);
return;
}
throw std::logic_error(fmt::format(
"{}: floating-point options are not supported; the option {}={} is "
"invalid",
id_->name(), key, value));
};
auto on_int = [this, &set_int, &set_double](const std::string& key,
int value) {
if (popped_.contains(key)) {
return;
}
if (set_int != nullptr) {
set_int(key, value);
return;
}
if (set_double != nullptr) {
set_double(key, value);
return;
}
throw std::logic_error(fmt::format(
"{}: integer and floating-point options are not supported; the option "
"{}={} is invalid",
id_->name(), key, value));
};
auto on_string = [this, &set_string](const std::string& key,
const std::string& value) {
if (popped_.contains(key)) {
return;
}
if (set_string != nullptr) {
set_string(key, value);
return;
}
throw std::logic_error(fmt::format(
"{}: string options are not supported; the option {}='{}' is invalid",
id_->name(), key, value));
};

// Handle solver-specific options.
for (const auto& [key, value] : options_double) {
on_double(key, value);
}
for (const auto& [key, value] : options_int) {
on_int(key, value);
}
for (const auto& [key, value] : options_str) {
on_string(key, value);
}

// Handle any respelled options, being careful not to set anything that has
// already been set.
for (const auto& [respelled_key, boxed_value] : respelled_) {
// Pedantially, lambdas cannot capture a structured binding so we need to
// make a local variable that we can capture.
const auto& key = respelled_key;
std::visit(
overloaded{[&key, &on_double, &options_double](double value) {
if (!options_double.contains(key)) {
on_double(key, value);
}
},
[&key, &on_int, &options_int](int value) {
if (!options_int.contains(key)) {
on_int(key, value);
}
},
[&key, &on_string, &options_str](const std::string& value) {
if (!options_str.contains(key)) {
on_string(key, value);
}
}},
boxed_value);
}
}

void SpecificOptions::InitializePending() {
pending_keys_.clear();
for (const auto& [key, _] : all_options_->GetOptionsDouble(*id_)) {
pending_keys_.insert(key);
}
for (const auto& [key, _] : all_options_->GetOptionsInt(*id_)) {
pending_keys_.insert(key);
}
for (const auto& [key, _] : all_options_->GetOptionsStr(*id_)) {
pending_keys_.insert(key);
}
for (const auto& [key, _] : respelled_) {
pending_keys_.insert(key);
}
for (const auto& key : popped_) {
pending_keys_.erase(key);
}
}

void SpecificOptions::CheckNoPending() const {
// Identify any unsupported names (i.e., leftovers in `pending_`).
if (!pending_keys_.empty()) {
std::vector<std::string_view> unknown_names;
for (const auto& name : pending_keys_) {
unknown_names.push_back(name);
}
std::sort(unknown_names.begin(), unknown_names.end());
throw std::logic_error(fmt::format(
"{}: the following solver option names were not recognized: {}",
id_->name(), fmt::join(unknown_names, ", ")));
}
}

std::optional<OptionValue> SpecificOptions::PrepareToCopy(const char* name) {
DRAKE_DEMAND(name != nullptr);
const std::unordered_map<std::string, double>& options_double =
all_options_->GetOptionsDouble(*id_);
// TODO(jwnimmer-tri) Nix these string copies after we fix SolverOptions to
// use sensible representation choices.
if (auto iter = options_double.find(std::string{name});
iter != options_double.end()) {
pending_keys_.erase(iter->first);
return iter->second;
}
const std::unordered_map<std::string, int>& options_int =
all_options_->GetOptionsInt(*id_);
if (auto iter = options_int.find(std::string{name});
iter != options_int.end()) {
pending_keys_.erase(iter->first);
return iter->second;
}
const std::unordered_map<std::string, std::string>& options_str =
all_options_->GetOptionsStr(*id_);
if (auto iter = options_str.find(std::string{name});
iter != options_str.end()) {
pending_keys_.erase(iter->first);
return iter->second;
}
if (auto iter = respelled_.find(name); iter != respelled_.end()) {
pending_keys_.erase(iter->first);
return iter->second;
}
return {};
}

template <typename T>
void SpecificOptions::CopyFloatingPointOption(const char* name, T* output) {
DRAKE_DEMAND(output != nullptr);
if (auto boxed_value = PrepareToCopy(name)) {
if (std::holds_alternative<double>(*boxed_value)) {
*output = std::get<double>(*boxed_value);
return;
}
if (std::holds_alternative<int>(*boxed_value)) {
*output = std::get<int>(*boxed_value);
return;
}
throw std::logic_error(
fmt::format("{}: Expected a floating-point value for option {}",
id_->name(), name));
}
}
template void SpecificOptions::CopyFloatingPointOption(const char*, double*);
template void SpecificOptions::CopyFloatingPointOption(const char*, float*);

template <typename T>
void SpecificOptions::CopyIntegralOption(const char* name, T* output) {
DRAKE_DEMAND(output != nullptr);
if (auto boxed_value = PrepareToCopy(name)) {
if (std::holds_alternative<int>(*boxed_value)) {
const int value = std::get<int>(*boxed_value);
if constexpr (std::is_same_v<T, int>) {
*output = value;
} else if constexpr (std::is_same_v<T, bool>) {
if (!(value == 0 || value == 1)) {
throw std::logic_error(fmt::format(
"{}: Expected a boolean value (0 or 1) for int option {}={}",
id_->name(), name, value));
}
*output = value;
} else {
static_assert(std::is_same_v<T, uint32_t>);
if (value < 0) {
throw std::logic_error(fmt::format(
"{}: Expected a non-negative value for unsigned int option {}={}",
id_->name(), name, value));
}
if (static_cast<int64_t>(value) >
static_cast<int64_t>(std::numeric_limits<uint32_t>::max())) {
throw std::logic_error(fmt::format(
"{}: Too-large value for uint32 option {}={}",
id_->name(), name, value));
}
*output = value;
}
return;
}
throw std::logic_error(fmt::format(
"{}: Expected an integer value for option {}", id_->name(), name));
}
}
template void SpecificOptions::CopyIntegralOption(const char*, int*);
template void SpecificOptions::CopyIntegralOption(const char*, bool*);
template void SpecificOptions::CopyIntegralOption(const char*, uint32_t*);

void SpecificOptions::CopyStringOption(const char* name, std::string* output) {
DRAKE_DEMAND(output != nullptr);
if (auto boxed_value = PrepareToCopy(name)) {
if (std::holds_alternative<std::string>(*boxed_value)) {
*output = std::get<std::string>(*boxed_value);
return;
}
throw std::logic_error(fmt::format(
"{}: Expected a string value for option {}", id_->name(), name));
}
}

} // namespace internal
} // namespace solvers
} // namespace drake
Loading

0 comments on commit 10ffba6

Please sign in to comment.