diff --git a/code/classification/run_classifier.py b/code/classification/run_classifier.py index 414e0ce5..48631f73 100644 --- a/code/classification/run_classifier.py +++ b/code/classification/run_classifier.py @@ -10,12 +10,14 @@ import argparse, pickle from sklearn.dummy import DummyClassifier + from sklearn.metrics import accuracy_score, cohen_kappa_score from sklearn.preprocessing import StandardScaler from sklearn.neighbors import KNeighborsClassifier from sklearn.pipeline import make_pipeline from mlflow import log_metric, log_param, set_tracking_uri + # setting up CLI parser = argparse.ArgumentParser(description = "Classifier") parser.add_argument("input_file", help = "path to the input pickle file") @@ -26,8 +28,12 @@ parser.add_argument("-f", "--frequency", action = "store_true", help = "label frequency classifier") parser.add_argument("--knn", type = int, help = "k nearest neighbor classifier with the specified value of k", default = None) parser.add_argument("-a", "--accuracy", action = "store_true", help = "evaluate using accuracy") + + + parser.add_argument("-k", "--kappa", action = "store_true", help = "evaluate using Cohen's kappa") parser.add_argument("--log_folder", help = "where to log the mlflow results", default = "data/classification/mlflow") + args = parser.parse_args() # load data @@ -83,9 +89,17 @@ evaluation_metrics = [] if args.accuracy: evaluation_metrics.append(("accuracy", accuracy_score)) + +if args.area: + evaluation_metrics.append(("area_under_curve", roc_auc_score)) + +if args.cohen: + evaluation_metrics.append(("cohen", cohen_kappa_score)) +======= if args.kappa: evaluation_metrics.append(("Cohen_kappa", cohen_kappa_score)) + # compute and print them for metric_name, metric in evaluation_metrics: metric_value = metric(data["labels"], prediction)