Skip to content

Commit

Permalink
dmitryikh#51 obtain the leaf index of gbdt tree
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Grasevski committed Aug 16, 2019
1 parent 3fa8b3b commit 01175c6
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 4 deletions.
10 changes: 10 additions & 0 deletions leaves.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type ensembleBaseInterface interface {
Name() string
adjustNEstimators(nEstimators int) int
predictInner(fvals []float64, nEstimators int, predictions []float64, startIndex int)
predictLeaves(fvals []float64, predictions []int) error
resetFVals(fvals []float64)
}

Expand Down Expand Up @@ -75,6 +76,15 @@ func (e *Ensemble) Predict(fvals []float64, nEstimators int, predictions []float
return nil
}

func (e *Ensemble) PredictLeaves(fvals []float64, predictions []int) error {
if len(predictions) != e.NRawOutputGroups()*e.NEstimators() {
return fmt.Errorf("predictions slice wrong size (expected %d)", e.NRawOutputGroups()*e.NEstimators())
} else if len(fvals) != e.NFeatures() {
return fmt.Errorf("incorrect number of features (%d)", len(fvals))
}
return e.predictLeaves(fvals, predictions)
}

// PredictCSR calculates predictions from ensemble. `indptr`, `cols`, `vals`
// represent data structures from Compressed Sparse Row Matrix format (see
// CSRMat). Only `nEstimators` first estimators (trees) will be used.
Expand Down
7 changes: 7 additions & 0 deletions lgensemble.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ func (e *lgEnsemble) predictInner(fvals []float64, nEstimators int, predictions
}
}

func (e *lgEnsemble) predictLeaves(fvals []float64, predictions []int) error {
for i, t := range e.Trees {
predictions[i] = t.predictLeaf(fvals)
}
return nil
}

func (e *lgEnsemble) adjustNEstimators(nEstimators int) int {
if nEstimators > 0 {
nEstimators = util.MinInt(nEstimators, e.NEstimators())
Expand Down
12 changes: 8 additions & 4 deletions lgtree.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,28 +71,32 @@ func (t *lgTree) decision(node *lgNode, fval float64) bool {
return t.numericalDecision(node, fval)
}

func (t *lgTree) predict(fvals []float64) float64 {
func (t *lgTree) predictLeaf(fvals []float64) int {
if len(t.nodes) == 0 {
return t.leafValues[0]
return 0
}
idx := uint32(0)
for {
node := &t.nodes[idx]
left := t.decision(node, fvals[node.Feature])
if left {
if node.Flags&leftLeaf > 0 {
return t.leafValues[node.Left]
return int(node.Left)
}
idx = node.Left
} else {
if node.Flags&rightLeaf > 0 {
return t.leafValues[node.Right]
return int(node.Right)
}
idx++
}
}
}

func (t *lgTree) predict(fvals []float64) float64 {
return t.leafValues[t.predictLeaf(fvals)]
}

func (t *lgTree) findInBitset(idx uint32, pos uint32) bool {
i1 := pos / 32
idxS := t.catBoundaries[idx]
Expand Down
6 changes: 6 additions & 0 deletions xgblinear.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package leaves

import "fmt"

// xgLinear is XGBoost model (gblinear)
type xgLinear struct {
NumFeature int
Expand Down Expand Up @@ -38,6 +40,10 @@ func (e *xgLinear) predictInner(fvals []float64, nIterations int, predictions []
}
}

func (e *xgLinear) predictLeaves(fvals []float64, predictions []int) error {
return fmt.Errorf("leaf prediction not supported for xgboost linear models")
}

func (e *xgLinear) resetFVals(fvals []float64) {
for j := 0; j < len(fvals); j++ {
fvals[j] = 0.0
Expand Down
7 changes: 7 additions & 0 deletions xgensemble.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ func (e *xgEnsemble) predictInner(fvals []float64, nEstimators int, predictions
}
}

func (e *xgEnsemble) predictLeaves(fvals []float64, predictions []int) error {
for i, t := range e.Trees {
predictions[i] = t.predictLeaf(fvals)
}
return nil
}

func (e *xgEnsemble) resetFVals(fvals []float64) {
for j := 0; j < len(fvals); j++ {
fvals[j] = math.NaN()
Expand Down

0 comments on commit 01175c6

Please sign in to comment.