-
Notifications
You must be signed in to change notification settings - Fork 364
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add comprehensive error handling across ML-CaPsule
- Loading branch information
1 parent
b2df3d0
commit 7f6a064
Showing
4 changed files
with
166 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from utils.error_handler import error_handler, DataValidationError, ModelError | ||
import pandas as pd | ||
import numpy as np | ||
|
||
@error_handler | ||
def load_data(file_path): | ||
try: | ||
df = pd.read_csv(file_path) | ||
if df.empty: | ||
raise DataValidationError("Empty dataset loaded") | ||
return df | ||
except FileNotFoundError: | ||
raise DataValidationError(f"Dataset not found at {file_path}") | ||
|
||
@error_handler | ||
def preprocess_data(df): | ||
if not isinstance(df, pd.DataFrame): | ||
raise DataValidationError("Input must be a pandas DataFrame") | ||
|
||
# Check for missing values | ||
if df.isnull().sum().any(): | ||
logging.warning("Missing values detected in the dataset") | ||
df = df.fillna(df.mean()) | ||
|
||
# Check for invalid values | ||
numeric_columns = df.select_dtypes(include=[np.number]).columns | ||
for col in numeric_columns: | ||
if (df[col] < 0).any(): | ||
raise DataValidationError(f"Negative values found in column {col}") | ||
|
||
return df | ||
|
||
@error_handler | ||
def train_model(X_train, y_train, model_type='random_forest'): | ||
if len(X_train) != len(y_train): | ||
raise DataValidationError("Feature and target dimensions do not match") | ||
|
||
try: | ||
if model_type == 'random_forest': | ||
from sklearn.ensemble import RandomForestClassifier | ||
model = RandomForestClassifier(random_state=42) | ||
else: | ||
raise ValueError(f"Unsupported model type: {model_type}") | ||
|
||
model.fit(X_train, y_train) | ||
return model | ||
except Exception as e: | ||
raise ModelError(f"Model training failed: {str(e)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from utils.error_handler import error_handler, DataValidationError, ModelError | ||
import pandas as pd | ||
import numpy as np | ||
|
||
@error_handler | ||
def validate_heart_data(df): | ||
required_columns = [ | ||
'age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', | ||
'restecg', 'thalach', 'exang', 'oldpeak', 'slope', | ||
'ca', 'thal', 'target' | ||
] | ||
|
||
# Check required columns | ||
missing_cols = [col for col in required_columns if col not in df.columns] | ||
if missing_cols: | ||
raise DataValidationError(f"Missing required columns: {missing_cols}") | ||
|
||
# Validate value ranges | ||
if not (df['age'] >= 0).all(): | ||
raise DataValidationError("Age cannot be negative") | ||
if not df['sex'].isin([0, 1]).all(): | ||
raise DataValidationError("Sex must be binary (0 or 1)") | ||
if not (df['trestbps'] > 0).all(): | ||
raise DataValidationError("Blood pressure must be positive") | ||
|
||
return True | ||
|
||
@error_handler | ||
def prepare_heart_data(df): | ||
try: | ||
validate_heart_data(df) | ||
|
||
# Handle missing values | ||
if df.isnull().sum().any(): | ||
logging.warning("Missing values found - applying mean imputation") | ||
df = df.fillna(df.mean()) | ||
|
||
# Feature scaling | ||
from sklearn.preprocessing import StandardScaler | ||
scaler = StandardScaler() | ||
numeric_cols = df.select_dtypes(include=[np.number]).columns | ||
df[numeric_cols] = scaler.fit_transform(df[numeric_cols]) | ||
|
||
return df | ||
except Exception as e: | ||
raise DataValidationError(f"Data preparation failed: {str(e)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import logging | ||
import sys | ||
from functools import wraps | ||
|
||
# Configure logging | ||
logging.basicConfig( | ||
level=logging.INFO, | ||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | ||
handlers=[ | ||
logging.FileHandler('ml_capsule.log'), | ||
logging.StreamHandler(sys.stdout) | ||
] | ||
) | ||
|
||
class MLCapsuleError(Exception): | ||
"""Base exception class for ML-CaPsule""" | ||
pass | ||
|
||
class DataValidationError(MLCapsuleError): | ||
"""Raised when data validation fails""" | ||
pass | ||
|
||
class ModelError(MLCapsuleError): | ||
"""Raised when model operations fail""" | ||
pass | ||
|
||
def error_handler(func): | ||
@wraps(func) | ||
def wrapper(*args, **kwargs): | ||
try: | ||
return func(*args, **kwargs) | ||
except MLCapsuleError as e: | ||
logging.error(f"ML-CaPsule error in {func.__name__}: {str(e)}") | ||
raise | ||
except Exception as e: | ||
logging.error(f"Unexpected error in {func.__name__}: {str(e)}") | ||
raise MLCapsuleError(f"Function {func.__name__} failed: {str(e)}") | ||
return wrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from utils.error_handler import error_handler, ModelError | ||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score | ||
import numpy as np | ||
|
||
@error_handler | ||
def evaluate_classification_model(y_true, y_pred): | ||
if len(y_true) != len(y_pred): | ||
raise ModelError("Prediction and ground truth dimensions do not match") | ||
|
||
try: | ||
metrics = { | ||
'accuracy': accuracy_score(y_true, y_pred), | ||
'precision': precision_score(y_true, y_pred, average='weighted'), | ||
'recall': recall_score(y_true, y_pred, average='weighted'), | ||
'f1': f1_score(y_true, y_pred, average='weighted') | ||
} | ||
|
||
return metrics | ||
except Exception as e: | ||
raise ModelError(f"Model evaluation failed: {str(e)}") | ||
|
||
@error_handler | ||
def cross_validate_model(model, X, y, cv=5): | ||
from sklearn.model_selection import cross_val_score | ||
|
||
try: | ||
scores = cross_val_score(model, X, y, cv=cv) | ||
return { | ||
'mean_score': np.mean(scores), | ||
'std_score': np.std(scores), | ||
'scores': scores | ||
} | ||
except Exception as e: | ||
raise ModelError(f"Cross-validation failed: {str(e)}") |