-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path7_Roc_Curve.py
64 lines (47 loc) · 1.82 KB
/
7_Roc_Curve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_curve, auc
# Generate synthetic data for demonstration
X, y = make_classification(
n_samples=1000, n_features=20, n_classes=2, random_state=42)
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# print(X_train.shape)
# Train two different models
logistic_model = LogisticRegression(random_state=42)
logistic_model.fit(X_train, y_train)
random_forest_model = RandomForestClassifier(n_estimators=100, random_state=42)
random_forest_model.fit(X_train, y_train)
# Generate predictions
y_pred_logistic = logistic_model.predict_proba(X_test)[:, 1]
y_pred_rf = random_forest_model.predict_proba(X_test)[:, 1]
# Create a DataFrame
test_df = pd.DataFrame(
{'True': y_test, 'Logistic': y_pred_logistic, 'RandomForest': y_pred_rf})
# print(test_df.head())
# output:
# True Logistic RandomForest
# 0 1 0.648888 0.75
# 1 1 0.867905 0.79
# 2 1 0.475407 0.25
# 3 1 0.852635 0.89
# 4 1 0.955500 0.99
# Plot ROC curve for each model
plt.figure(figsize=(7, 5))
for model in ['Logistic', 'RandomForest']:
fpr, tpr, _ = roc_curve(test_df['True'], test_df[model])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, label=f'{model} (AUC = {roc_auc:.2f})')
# Plot random guess line
plt.plot([0, 1], [0, 1], 'r--', label='Random Guess')
# Set labels and title
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curves for Two Models')
plt.legend()
plt.show()