diff --git a/xgensemble_io.go b/xgensemble_io.go index 192ec71..42df681 100644 --- a/xgensemble_io.go +++ b/xgensemble_io.go @@ -4,6 +4,7 @@ import ( "bufio" "fmt" "os" + "strings" "github.com/dmitryikh/leaves/internal/xgbin" "github.com/dmitryikh/leaves/transformation" @@ -212,6 +213,8 @@ func XGEnsembleFromReader(reader *bufio.Reader, loadTransformation bool) (*Ensem if loadTransformation { if header.NameObj == "binary:logistic" { transform = &transformation.TransformLogistic{} + } else if strings.HasPrefix(header.NameObj, "multi:soft") { + transform = &transformation.TransformSoftmax{NClasses: e.nRawOutputGroups} } else { return nil, fmt.Errorf("unknown transformation function '%s'", header.NameObj) }