forked from RobotLocomotion/drake
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
A new implementation and unit test strategy
- Loading branch information
1 parent
9874fc2
commit e6bb249
Showing
6 changed files
with
117 additions
and
75 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,6 @@ | |
#include <charconv> | ||
#include <cmath> | ||
#include <fstream> | ||
#include <iostream> | ||
#include <limits> | ||
#include <optional> | ||
#include <stdexcept> | ||
|
@@ -1042,7 +1041,14 @@ class GurobiSolver::License { | |
"Could not locate Gurobi license key file because GRB_LICENSE_FILE " | ||
"environment variable was not set."); | ||
} | ||
|
||
if (const char* filename = std::getenv("GRB_LICENSE_FILE")) { | ||
// For unit testing, we employ a hack to keep env_ uninitialized so that | ||
// we don't need a valid license file. | ||
if (std::string_view{filename}.find("DRAKE_UNIT_TEST_NO_LICENSE") != | ||
std::string_view::npos) { | ||
return; | ||
} | ||
} | ||
const int num_tries = 3; | ||
int grb_load_env_error = 1; | ||
for (int i = 0; grb_load_env_error && i < num_tries; ++i) { | ||
|
@@ -1057,59 +1063,56 @@ class GurobiSolver::License { | |
"\"."); | ||
} | ||
DRAKE_DEMAND(env_ != nullptr); | ||
|
||
// We use the existence of the string HOSTID in the license file as | ||
// confirmation that the license is associated with the local host. | ||
const char* grb_license_filename = std::getenv("GRB_LICENSE_FILE"); | ||
std::ifstream grb_license_file(grb_license_filename); | ||
if (!grb_license_file) { | ||
throw std::runtime_error( | ||
"Could not read Gurobi license file specified in the " | ||
"GRB_LICENSE_FILE environment variable"); | ||
} | ||
const std::string grb_license_file_contents( | ||
(std::istreambuf_iterator<char>(grb_license_file)), | ||
std::istreambuf_iterator<char>()); | ||
is_local_license_ = | ||
grb_license_file_contents.find("HOSTID") != std::string::npos; | ||
} | ||
|
||
~License() { | ||
GRBfreeenv(env_); | ||
env_ = nullptr; | ||
} | ||
|
||
bool is_local_license() const { return is_local_license_; } | ||
|
||
GRBenv* GurobiEnv() { return env_; } | ||
|
||
private: | ||
bool is_local_license_{false}; | ||
GRBenv* env_ = nullptr; | ||
}; | ||
|
||
/* Gurobi recommends acquiring the license only once per program to avoid | ||
overhead from acquiring the license (and console spew for academic license | ||
users; see #19657). However, if users are using a shared network license from a | ||
limited pool, then we risk them checking out the license and not giving it back | ||
(for instance, if they are working in a jupyter notebook). As a compromise, we | ||
hold on to the license beyond the lifetime of the GurobiSolver iff we can | ||
confirm that the license is associated with the local host. */ | ||
std::shared_ptr<GurobiSolver::License> local_host_gurobi_license{}; | ||
|
||
std::shared_ptr<GurobiSolver::License> GurobiSolver::AcquireLicense() { | ||
if (local_host_gurobi_license) { | ||
return local_host_gurobi_license; | ||
namespace { | ||
bool IsGrbLicenseFileLocalHost() { | ||
// We use the existence of the string HOSTID in the license file as | ||
// confirmation that the license is associated with the local host. | ||
const char* grb_license_file = std::getenv("GRB_LICENSE_FILE"); | ||
if (grb_license_file == nullptr) { | ||
return false; | ||
} | ||
auto license = GetScopedSingleton<GurobiSolver::License>(); | ||
if (license->is_local_license()) { | ||
local_host_gurobi_license = license; | ||
std::ifstream stream{grb_license_file}; | ||
const std::string contents{std::istreambuf_iterator<char>{stream}, | ||
std::istreambuf_iterator<char>{}}; | ||
if (stream.fail()) { | ||
return false; | ||
} | ||
return license; | ||
return contents.find("HOSTID") != std::string::npos; | ||
} | ||
} // namespace | ||
|
||
bool GurobiSolver::has_acquired_local_license() const { | ||
return license_ && license_->is_local_license(); | ||
std::shared_ptr<GurobiSolver::License> GurobiSolver::AcquireLicense() { | ||
// Gurobi recommends acquiring the license only once per program to avoid | ||
// overhead from acquiring the license (and console spew for academic license | ||
// users; see #19657). However, if users are using a shared network license | ||
// from a limited pool, then we risk them checking out the license and not | ||
// giving it back (e.g., if they are working in a jupyter notebook). As a | ||
// compromise, we extend license beyond the lifetime of the GurobiSolver iff | ||
// we can confirm that the license is associated with the local host. | ||
// | ||
// The first time the anyone calls GurobiSolver::AcquireLicense, we check | ||
// whether the license is local. If yes, the local_host_holder keeps the | ||
// license's use_count lower bounded to 1. If no, the local_hold_holder is | ||
// null and the usual GetScopedSingleton workflow applies. | ||
static never_destroyed<std::shared_ptr<void>> local_host_holder{[]() { | ||
return IsGrbLicenseFileLocalHost() | ||
? GetScopedSingleton<GurobiSolver::License>() | ||
: nullptr; | ||
}()}; | ||
return GetScopedSingleton<GurobiSolver::License>(); | ||
} | ||
|
||
// TODO([email protected]): break this large DoSolve function to smaller | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import copy | ||
import os | ||
from pathlib import Path | ||
import subprocess | ||
import unittest | ||
|
||
|
||
class TestGurobiSolverLicenseRetention(unittest.TestCase): | ||
|
||
def _subprocess_license_use_count(self, license_file_content): | ||
"""Sets GRB_LICENSE_FILE to a temp file with the given content, runs | ||
our test helper, and then returns the license pointer use_count. | ||
""" | ||
# Create a dummy license file. Note that the license filename is magic. | ||
# The License code in gurobi_solver.cc treats this filename specially. | ||
tmpdir = Path(os.environ["TEST_TMPDIR"]) | ||
license_file = tmpdir / "DRAKE_UNIT_TEST_NO_LICENSE.lic" | ||
with open(license_file, "w", encoding="utf-8") as f: | ||
f.write(license_file_content) | ||
|
||
# Override the built-in license file. | ||
env = copy.copy(os.environ) | ||
env["GRB_LICENSE_FILE"] = str(license_file) | ||
|
||
# Run the helper and return the poitner use_count. | ||
output = subprocess.check_output( | ||
["solvers/gurobi_solver_license_retention_test_helper"]) | ||
return int(output) | ||
|
||
def test_local_license(self): | ||
"""When the file named by GRB_LICENSE_FILE contains 'HOSTID', the | ||
license object is held in two places: the test helper main(), and | ||
a global variable within GurobiSolver::AcquireLicense. | ||
""" | ||
content = "HOSTID=foobar\n" | ||
self.assertEqual(self._subprocess_license_use_count(content), 2) | ||
|
||
def test_nonlocal_license(self): | ||
"""When the file named by GRB_LICENSE_FILE doesn't contain 'HOSTID', | ||
the license object is only held by main(), not any global variable. | ||
""" | ||
content = "TOKENSERVER=foobar.invalid.\n" | ||
self.assertEqual(self._subprocess_license_use_count(content), 1) |
11 changes: 11 additions & 0 deletions
11
solvers/test/gurobi_solver_license_retention_test_helper.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#include "drake/solvers/gurobi_solver.h" | ||
|
||
using drake::solvers::GurobiSolver; | ||
|
||
/* Acquire a license and report the overall use_count. */ | ||
int main() { | ||
std::shared_ptr<GurobiSolver::License> license = | ||
GurobiSolver::AcquireLicense(); | ||
fmt::print("{}\n", license.use_count()); | ||
return 0; | ||
} |