Skip to content

Commit

Permalink
A new implementation and unit test strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
jwnimmer-tri committed Jul 1, 2023
1 parent 9874fc2 commit e6bb249
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 75 deletions.
22 changes: 22 additions & 0 deletions solvers/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ load(
"drake_cc_package_library",
"drake_cc_test",
)
load(
"@drake//tools/skylark:drake_py.bzl",
"drake_py_unittest",
)
load(
":defs.bzl",
"drake_cc_optional_googletest",
Expand Down Expand Up @@ -1250,6 +1254,24 @@ drake_cc_googletest(
],
)

drake_py_unittest(
name = "gurobi_solver_license_retention_test",
data = [
":gurobi_solver_license_retention_test_helper",
],
tags = gurobi_test_tags(),
)

drake_cc_binary(
name = "gurobi_solver_license_retention_test_helper",
testonly = True,
srcs = ["test/gurobi_solver_license_retention_test_helper.cc"],
visibility = ["//visibility:private"],
deps = [
":gurobi_solver",
],
)

drake_cc_googletest(
name = "integer_optimization_util_test",
deps = [
Expand Down
79 changes: 41 additions & 38 deletions solvers/gurobi_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <charconv>
#include <cmath>
#include <fstream>
#include <iostream>
#include <limits>
#include <optional>
#include <stdexcept>
Expand Down Expand Up @@ -1042,7 +1041,14 @@ class GurobiSolver::License {
"Could not locate Gurobi license key file because GRB_LICENSE_FILE "
"environment variable was not set.");
}

if (const char* filename = std::getenv("GRB_LICENSE_FILE")) {
// For unit testing, we employ a hack to keep env_ uninitialized so that
// we don't need a valid license file.
if (std::string_view{filename}.find("DRAKE_UNIT_TEST_NO_LICENSE") !=
std::string_view::npos) {
return;
}
}
const int num_tries = 3;
int grb_load_env_error = 1;
for (int i = 0; grb_load_env_error && i < num_tries; ++i) {
Expand All @@ -1057,59 +1063,56 @@ class GurobiSolver::License {
"\".");
}
DRAKE_DEMAND(env_ != nullptr);

// We use the existence of the string HOSTID in the license file as
// confirmation that the license is associated with the local host.
const char* grb_license_filename = std::getenv("GRB_LICENSE_FILE");
std::ifstream grb_license_file(grb_license_filename);
if (!grb_license_file) {
throw std::runtime_error(
"Could not read Gurobi license file specified in the "
"GRB_LICENSE_FILE environment variable");
}
const std::string grb_license_file_contents(
(std::istreambuf_iterator<char>(grb_license_file)),
std::istreambuf_iterator<char>());
is_local_license_ =
grb_license_file_contents.find("HOSTID") != std::string::npos;
}

~License() {
GRBfreeenv(env_);
env_ = nullptr;
}

bool is_local_license() const { return is_local_license_; }

GRBenv* GurobiEnv() { return env_; }

private:
bool is_local_license_{false};
GRBenv* env_ = nullptr;
};

/* Gurobi recommends acquiring the license only once per program to avoid
overhead from acquiring the license (and console spew for academic license
users; see #19657). However, if users are using a shared network license from a
limited pool, then we risk them checking out the license and not giving it back
(for instance, if they are working in a jupyter notebook). As a compromise, we
hold on to the license beyond the lifetime of the GurobiSolver iff we can
confirm that the license is associated with the local host. */
std::shared_ptr<GurobiSolver::License> local_host_gurobi_license{};

std::shared_ptr<GurobiSolver::License> GurobiSolver::AcquireLicense() {
if (local_host_gurobi_license) {
return local_host_gurobi_license;
namespace {
bool IsGrbLicenseFileLocalHost() {
// We use the existence of the string HOSTID in the license file as
// confirmation that the license is associated with the local host.
const char* grb_license_file = std::getenv("GRB_LICENSE_FILE");
if (grb_license_file == nullptr) {
return false;
}
auto license = GetScopedSingleton<GurobiSolver::License>();
if (license->is_local_license()) {
local_host_gurobi_license = license;
std::ifstream stream{grb_license_file};
const std::string contents{std::istreambuf_iterator<char>{stream},
std::istreambuf_iterator<char>{}};
if (stream.fail()) {
return false;
}
return license;
return contents.find("HOSTID") != std::string::npos;
}
} // namespace

bool GurobiSolver::has_acquired_local_license() const {
return license_ && license_->is_local_license();
std::shared_ptr<GurobiSolver::License> GurobiSolver::AcquireLicense() {
// Gurobi recommends acquiring the license only once per program to avoid
// overhead from acquiring the license (and console spew for academic license
// users; see #19657). However, if users are using a shared network license
// from a limited pool, then we risk them checking out the license and not
// giving it back (e.g., if they are working in a jupyter notebook). As a
// compromise, we extend license beyond the lifetime of the GurobiSolver iff
// we can confirm that the license is associated with the local host.
//
// The first time the anyone calls GurobiSolver::AcquireLicense, we check
// whether the license is local. If yes, the local_host_holder keeps the
// license's use_count lower bounded to 1. If no, the local_hold_holder is
// null and the usual GetScopedSingleton workflow applies.
static never_destroyed<std::shared_ptr<void>> local_host_holder{[]() {
return IsGrbLicenseFileLocalHost()
? GetScopedSingleton<GurobiSolver::License>()
: nullptr;
}()};
return GetScopedSingleton<GurobiSolver::License>();
}

// TODO([email protected]): break this large DoSolve function to smaller
Expand Down
7 changes: 0 additions & 7 deletions solvers/gurobi_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,6 @@ class GurobiSolver final : public SolverBase {
static std::string UnsatisfiedProgramAttributes(const MathematicalProgram&);
//@}

// Returns true if this solver has acquired a license during a Solve(), and
// this license was confirmed to be tied to the local host. If the license can
// be confirmed to be local, then it will never be destroyed. If the license
// cannot be confirmed to be local, then the license will stay valid as long
// as at least one shared_ptr returned by AcquireLicense() is alive.
bool has_acquired_local_license() const;

// A using-declaration adds these methods into our class's Doxygen.
using SolverBase::Solve;

Expand Down
30 changes: 0 additions & 30 deletions solvers/test/gurobi_solver_grb_license_file_test.cc
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
#include <cstdlib>
#include <fstream>
#include <optional>
#include <stdexcept>
#include <string>

#include <gtest/gtest.h>

#include "drake/common/temp_directory.h"
#include "drake/common/test_utilities/expect_no_throw.h"
#include "drake/common/test_utilities/expect_throws_message.h"
#include "drake/solvers/gurobi_solver.h"
Expand Down Expand Up @@ -41,16 +39,6 @@ class GrbLicenseFileTest : public ::testing::Test {
prog_.NewContinuousVariables<1>();
}

void WriteTempLicenseFile(std::string contents) {
std::string test_license = temp_directory() + "/test.lic";
std::ofstream file(test_license);
ASSERT_TRUE(file);
file << contents;
const int setenv_result =
::setenv("GRB_LICENSE_FILE", test_license.c_str(), 1);
ASSERT_EQ(setenv_result, 0);
}

void TearDown() override {
if (orig_grb_license_file_) {
const int setenv_result =
Expand Down Expand Up @@ -80,24 +68,6 @@ TEST_F(GrbLicenseFileTest, GrbLicenseFileUnset) {
".*GurobiSolver has not been properly configured.*");
}


TEST_F(GrbLicenseFileTest, LocalLicenseFile) {
EXPECT_EQ(solver_.enabled(), true);
WriteTempLicenseFile(
"license file contents that contains the string HOSTID and perhaps some "
"other info.");
solver_.Solve(prog_);
EXPECT_TRUE(solver_.has_acquired_local_license());
}

TEST_F(GrbLicenseFileTest, ServerLicenseFile) {
EXPECT_EQ(solver_.enabled(), true);
WriteTempLicenseFile(
"license file contents without the magic keyword.");
solver_.Solve(prog_);
EXPECT_FALSE(solver_.has_acquired_local_license());
}

} // namespace
} // namespace solvers
} // namespace drake
43 changes: 43 additions & 0 deletions solvers/test/gurobi_solver_license_retention_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import copy
import os
from pathlib import Path
import subprocess
import unittest


class TestGurobiSolverLicenseRetention(unittest.TestCase):

def _subprocess_license_use_count(self, license_file_content):
"""Sets GRB_LICENSE_FILE to a temp file with the given content, runs
our test helper, and then returns the license pointer use_count.
"""
# Create a dummy license file. Note that the license filename is magic.
# The License code in gurobi_solver.cc treats this filename specially.
tmpdir = Path(os.environ["TEST_TMPDIR"])
license_file = tmpdir / "DRAKE_UNIT_TEST_NO_LICENSE.lic"
with open(license_file, "w", encoding="utf-8") as f:
f.write(license_file_content)

# Override the built-in license file.
env = copy.copy(os.environ)
env["GRB_LICENSE_FILE"] = str(license_file)

# Run the helper and return the poitner use_count.
output = subprocess.check_output(
["solvers/gurobi_solver_license_retention_test_helper"])
return int(output)

def test_local_license(self):
"""When the file named by GRB_LICENSE_FILE contains 'HOSTID', the
license object is held in two places: the test helper main(), and
a global variable within GurobiSolver::AcquireLicense.
"""
content = "HOSTID=foobar\n"
self.assertEqual(self._subprocess_license_use_count(content), 2)

def test_nonlocal_license(self):
"""When the file named by GRB_LICENSE_FILE doesn't contain 'HOSTID',
the license object is only held by main(), not any global variable.
"""
content = "TOKENSERVER=foobar.invalid.\n"
self.assertEqual(self._subprocess_license_use_count(content), 1)
11 changes: 11 additions & 0 deletions solvers/test/gurobi_solver_license_retention_test_helper.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#include "drake/solvers/gurobi_solver.h"

using drake::solvers::GurobiSolver;

/* Acquire a license and report the overall use_count. */
int main() {
std::shared_ptr<GurobiSolver::License> license =
GurobiSolver::AcquireLicense();
fmt::print("{}\n", license.use_count());
return 0;
}

0 comments on commit e6bb249

Please sign in to comment.