diff --git a/src/main/scala/nak/core/Classifier.scala b/src/main/scala/nak/core/Classifier.scala index 887e8a5..4930819 100644 --- a/src/main/scala/nak/core/Classifier.scala +++ b/src/main/scala/nak/core/Classifier.scala @@ -102,15 +102,17 @@ trait LiblinearClassifier extends IndexedClassifier[String] { /** * Implement the apply method of Classifier by transforming the tuples into - * Liblinear Features and then calling Linear.predictProbability. - * - * TODO: This should be made more general so that the SVM solvers can be used - * by Nak. + * Liblinear Features and then calling Linear.predictProbability or + * Linear.predictValues, depending if the model supports probability or not. */ def apply(context: Array[(Int,Double)]): Array[Double] = { val ctxt = context.map(c=>new FeatureNode(c._1,c._2).asInstanceOf[LiblinearFeature]) val labelScores = Array.fill(numLabels)(0.0) - Linear.predictProbability(model, ctxt, labelScores) + if (model.isProbabilityModel) { + Linear.predictProbability(classifier.model, ctxt, labelScores) + } else { + Linear.predictValues(model, ctxt, labelScores) + } labelScores }