Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
nairbenrekia committed Dec 11, 2024
1 parent df90b0e commit 747d9c4
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 2 deletions.
50 changes: 50 additions & 0 deletions khiops/samples/samples_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,53 @@ def khiops_classifier():
# kh.visualize_report("report.khj")


def khiops_classifier_float_target():
"""Trains a `.KhiopsClassifier` on a monotable dataframe
with a float target"""
# Imports
import os
import pandas as pd
from khiops import core as kh
from khiops.sklearn import KhiopsClassifier
from sklearn.model_selection import train_test_split

# Load the dataset into a pandas dataframe
adult_path = os.path.join(kh.get_samples_dir(), "Adult", "Adult.txt")
adult_df = pd.read_csv(adult_path, sep="\t")
adult_df["class"] = adult_df["class"].replace({"less": 0.0, "more": 1.0})

# Split the whole dataframe into train and test (70%-30%)
adult_train_df, adult_test_df = train_test_split(
adult_df, test_size=0.3, random_state=1
)

# Split the dataset into:
# - the X feature table
# - the y target vector ("class" column)
X_train = adult_train_df.drop("class", axis=1)
X_test = adult_test_df.drop("class", axis=1)
y_train = adult_train_df["class"]

# Create the classifier object
khc = KhiopsClassifier()

# Train the classifier
khc.fit(X_train, y_train)

# Predict the classes on the test dataset
y_test_pred = khc.predict(X_test)
print("Predicted classes (first 10):")
print(y_test_pred[0:10])
print("---")

# Predict the class probabilities on the test dataset
y_test_probas = khc.predict_proba(X_test)
print(f"Class order: {khc.classes_}")
print("Predicted class probabilities (first 10):")
print(y_test_probas[0:10])
print("---")


def khiops_classifier_multiclass():
"""Trains a multiclass `.KhiopsClassifier` on a monotable dataframe"""
# Imports
Expand Down Expand Up @@ -1061,6 +1108,8 @@ def khiops_classifier_multitable_star_file():
print(f"Test auc = {test_auc}")


exported_samples = [khiops_classifier_float_target]
"""
exported_samples = [
khiops_classifier,
khiops_classifier_multiclass,
Expand All @@ -1080,6 +1129,7 @@ def khiops_classifier_multitable_star_file():
khiops_classifier_multitable_list,
khiops_classifier_multitable_star_file,
]
"""


def execute_samples(args):
Expand Down
6 changes: 4 additions & 2 deletions khiops/sklearn/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2259,13 +2259,11 @@ def predict(self, X):
"""
# Call the parent's method
y_pred = super().predict(X)

# Adjust the data type according to the original target type
# Note: String is coerced explictly because astype does not work as expected
if isinstance(y_pred, pd.DataFrame):
# Transform to numpy.ndarray
y_pred = y_pred.to_numpy(copy=False).ravel()

# If integer and string just transform
if pd.api.types.is_integer_dtype(self._original_target_dtype):
y_pred = y_pred.astype(self._original_target_dtype)
Expand All @@ -2275,6 +2273,10 @@ def predict(self, X):
self._original_target_dtype
):
y_pred = y_pred.astype(str, copy=False)
elif pd.api.types.is_float_dtype(self._original_target_type):
print(self._original_target_type)
y_pred = y_pred.astype(str, copy=False)
print(y_pred)
# If category first coerce the type to the categories' type
else:
assert isinstance(self._original_target_dtype, pd.CategoricalDtype), (
Expand Down

0 comments on commit 747d9c4

Please sign in to comment.