From d6310e01cddcf41e21c42cc5226f1ff30d933d4c Mon Sep 17 00:00:00 2001 From: hirokimii <145586445+hirokimii@users.noreply.github.com> Date: Tue, 9 Jul 2024 11:22:48 -0400 Subject: [PATCH] classification_dataset bool --- tsfm_public/toolkit/util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tsfm_public/toolkit/util.py b/tsfm_public/toolkit/util.py index 2e11a0a8..b305c80b 100644 --- a/tsfm_public/toolkit/util.py +++ b/tsfm_public/toolkit/util.py @@ -626,6 +626,7 @@ class values. has_class_labels_tag = True class_label_list = [token.strip() for token in tokens[2:]] metadata_started = True + classification_dataset = True elif line.startswith("@targetlabel"): if data_started: raise IOError("metadata must come before data") @@ -647,7 +648,6 @@ class values. ) has_class_labels_tag = True metadata_started = True - regression = True # Check if this line contains the start of data elif line.startswith("@data"): if line != "@data": @@ -705,7 +705,7 @@ class values. # Check if we have reached a class label if line[char_num] != "(" and class_labels: class_val = line[char_num:].strip() - if not regression: + if classification_dataset: if class_val not in class_label_list: raise IOError( "the class value '"