diff --git a/leaves.go b/leaves.go index c6d5a86..06b335b 100644 --- a/leaves.go +++ b/leaves.go @@ -19,6 +19,7 @@ type ensembleBaseInterface interface { Name() string adjustNEstimators(nEstimators int) int predictInner(fvals []float64, nEstimators int, predictions []float64, startIndex int) + predictLeaves(fvals []float64, predictions []int) error resetFVals(fvals []float64) } @@ -75,6 +76,15 @@ func (e *Ensemble) Predict(fvals []float64, nEstimators int, predictions []float return nil } +func (e *Ensemble) PredictLeaves(fvals []float64, predictions []int) error { + if len(predictions) != e.NRawOutputGroups()*e.NEstimators() { + return fmt.Errorf("predictions slice wrong size (expected %d)", e.NRawOutputGroups()*e.NEstimators()) + } else if len(fvals) != e.NFeatures() { + return fmt.Errorf("incorrect number of features (%d)", len(fvals)) + } + return e.predictLeaves(fvals, predictions) +} + // PredictCSR calculates predictions from ensemble. `indptr`, `cols`, `vals` // represent data structures from Compressed Sparse Row Matrix format (see // CSRMat). Only `nEstimators` first estimators (trees) will be used. diff --git a/lgensemble.go b/lgensemble.go index fa5143b..a977b4a 100644 --- a/lgensemble.go +++ b/lgensemble.go @@ -52,6 +52,13 @@ func (e *lgEnsemble) predictInner(fvals []float64, nEstimators int, predictions } } +func (e *lgEnsemble) predictLeaves(fvals []float64, predictions []int) error { + for i, t := range e.Trees { + predictions[i] = t.predictLeaf(fvals) + } + return nil +} + func (e *lgEnsemble) adjustNEstimators(nEstimators int) int { if nEstimators > 0 { nEstimators = util.MinInt(nEstimators, e.NEstimators()) diff --git a/lgtree.go b/lgtree.go index 9d290e4..fb4f45b 100644 --- a/lgtree.go +++ b/lgtree.go @@ -71,9 +71,9 @@ func (t *lgTree) decision(node *lgNode, fval float64) bool { return t.numericalDecision(node, fval) } -func (t *lgTree) predict(fvals []float64) float64 { +func (t *lgTree) predictLeaf(fvals []float64) int { if len(t.nodes) == 0 { - return t.leafValues[0] + return 0 } idx := uint32(0) for { @@ -81,18 +81,22 @@ func (t *lgTree) predict(fvals []float64) float64 { left := t.decision(node, fvals[node.Feature]) if left { if node.Flags&leftLeaf > 0 { - return t.leafValues[node.Left] + return int(node.Left) } idx = node.Left } else { if node.Flags&rightLeaf > 0 { - return t.leafValues[node.Right] + return int(node.Right) } idx++ } } } +func (t *lgTree) predict(fvals []float64) float64 { + return t.leafValues[t.predictLeaf(fvals)] +} + func (t *lgTree) findInBitset(idx uint32, pos uint32) bool { i1 := pos / 32 idxS := t.catBoundaries[idx] diff --git a/xgblinear.go b/xgblinear.go index dd03bda..98041e5 100644 --- a/xgblinear.go +++ b/xgblinear.go @@ -1,5 +1,7 @@ package leaves +import "fmt" + // xgLinear is XGBoost model (gblinear) type xgLinear struct { NumFeature int @@ -38,6 +40,10 @@ func (e *xgLinear) predictInner(fvals []float64, nIterations int, predictions [] } } +func (e *xgLinear) predictLeaves(fvals []float64, predictions []int) error { + return fmt.Errorf("leaf prediction not supported for xgboost linear models") +} + func (e *xgLinear) resetFVals(fvals []float64) { for j := 0; j < len(fvals); j++ { fvals[j] = 0.0 diff --git a/xgensemble.go b/xgensemble.go index 06873e6..8d208c6 100644 --- a/xgensemble.go +++ b/xgensemble.go @@ -57,6 +57,13 @@ func (e *xgEnsemble) predictInner(fvals []float64, nEstimators int, predictions } } +func (e *xgEnsemble) predictLeaves(fvals []float64, predictions []int) error { + for i, t := range e.Trees { + predictions[i] = t.predictLeaf(fvals) + } + return nil +} + func (e *xgEnsemble) resetFVals(fvals []float64) { for j := 0; j < len(fvals); j++ { fvals[j] = math.NaN()