Skip to content

Commit

Permalink
bug fix hyper-tune cost-sensitive
Browse files Browse the repository at this point in the history
  • Loading branch information
kjgm committed Jan 16, 2025
1 parent bddc86d commit a3e93c3
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 8 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ Note that STreeD provides an optimal decision tree for the given binarization. T

See [examples/binarize_example.py](examples/binarize_example.py) for an example.

See also our latest paper on optimizing trees with continuous features:
* Briţa, Cătălin E., Jacobus G. M. van der Linden, and Emir Demirović. "Optimal Classification Trees for Continuous Feature Data Using Dynamic Programming with Branch-and-Bound." In _Proceedings of AAAI-25_ (2025). [pdf](https://arxiv.org/pdf/2501.07903) / [source](https://github.com/consol-Lab/contree)

## Overfitting and tuning

To prevent overfitting the size of the tree can be tuned. This can be done in the standard way using `scikit-learn` methods, see [examples/gridsearch_example.py](examples/gridsearch_example.py).
Expand Down
8 changes: 7 additions & 1 deletion include/tasks/cost_sensitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@ namespace STreeD {
static const int worst = INT32_MAX;
static const int best = 0;

CostSensitive(const ParameterHandler& parameters);
CostSensitive(const ParameterHandler& parameters) : Classification(parameters) { UpdateParameters(parameters); }
inline void UpdateParameters(const ParameterHandler& parameters) {
cost_filename = parameters.GetStringParameter("cost-file");
}
inline void CopyTaskInfoFrom(const OptimizationTask* task) {
UpdateCostSpecifier(static_cast<const CostSensitive*>(task)->cost_specifier);
}
void InformTrainData(const ADataView& train_data, const DataSummary& train_summary);
void UpdateCostSpecifier(const CostSpecifier& cost_specifier) { this->cost_specifier = cost_specifier; }

Expand Down
1 change: 1 addition & 0 deletions include/tasks/optimization_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ namespace STreeD {

// Inform the OT on updated parameters (when hypertuning)
inline void UpdateParameters(const ParameterHandler& parameters) { return; }
inline void CopyTaskInfoFrom(const OptimizationTask* task) {}

// Addition and subtraction functions
inline static int Add(const int left, const int right) { return left + right; }
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "pystreed"
version = "1.3.3"
version = "1.3.4"
requires-python = ">=3.8"
description = "Python Implementation of STreeD: Dynamic Programming Approach for Optimal Decision Trees with Separable objectives and Constraints"
license= {file = "LICENSE"}
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# Define package metadata
package_name = 'pystreed'
extension_name = 'cstreed'
__version__ = "1.3.3"
__version__ = "1.3.4"

ext_modules = [
Pybind11Extension(package_name + '.' + extension_name,
Expand Down
4 changes: 3 additions & 1 deletion src/solver/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -902,10 +902,11 @@ namespace STreeD {
}

template <class OT>
std::shared_ptr<SolverResult> Solver<OT>::HyperSolve(const ADataView& train_data) {
std::shared_ptr<SolverResult> Solver<OT>::HyperSolve(const ADataView& _train_data) {
using ScoreType = std::shared_ptr<Score>;
runtime_assert(parameters.GetBooleanParameter("hyper-tune"));
stopwatch.Initialise(parameters.GetFloatParameter("time"));
InitializeSolver(_train_data);

bool verbose = parameters.GetBooleanParameter("verbose");
const int max_num_nodes = int(parameters.GetIntegerParameter("max-num-nodes"));
Expand All @@ -929,6 +930,7 @@ namespace STreeD {
Solver<OT> solver(parameters, rng);
solver.solver_parameters.verbose = false;
solver.redundant_features = redundant_features;
solver.task->CopyTaskInfoFrom(this->task);
//ADataView::TrainTestSplitData<typename OT::LabelType>(train_data, sub_train_data, sub_test_data, rng, validation_percentage, true);
ADataView& sub_train_data = sub_train_datas[r], &sub_test_data = sub_test_datas[r];
solver.InitializeSolver(sub_train_data); // Initialize with max-depth
Expand Down
1 change: 1 addition & 0 deletions src/tasks/accuracy/cost_complex_accuracy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ namespace STreeD {
alphas.push_back(base_alpha * a);
for(double alpha = 100*base_alpha; alpha < 0.01; alpha += 0.001)
alphas.push_back(alpha);
std::sort(alphas.begin(), alphas.end(), std::greater<>());
for (auto a: alphas) {
if (a > 0.1) continue;
ParameterHandler params = default_config;
Expand Down
4 changes: 0 additions & 4 deletions src/tasks/cost_sensitive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,6 @@ namespace STreeD {
return max;
}

CostSensitive::CostSensitive(const ParameterHandler& parameters) : Classification(parameters) {
cost_filename = parameters.GetStringParameter("cost-file");
}

void CostSensitive::InformTrainData(const ADataView& train_data, const DataSummary& train_summary) {
OptimizationTask::InformTrainData(train_data, train_summary);
runtime_assert(cost_filename != "" || cost_specifier.IsInitialized());
Expand Down

0 comments on commit a3e93c3

Please sign in to comment.