-
Notifications
You must be signed in to change notification settings - Fork 0
/
random_forest_synthetic.py
64 lines (50 loc) · 2.09 KB
/
random_forest_synthetic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
from imblearn.over_sampling import SMOTE
import matplotlib.pyplot as plt
import seaborn as sns
# Load your dataset
df = pd.read_csv('/Users/meeps360/cs158-final-project/week_approach_maskedID_timeseries.csv')
# Separate the dataset into injured and uninjured groups
injured_samples = df[df['injury'] == 1]
uninjured_samples = df[df['injury'] == 0]
# Use SMOTE to generate synthetic samples for the minority class
x = df.drop(columns=['injury', 'Athlete ID', 'Date'])
y = df['injury']
smote = SMOTE(random_state=42)
X_resampled, y_resampled = smote.fit_resample(x, y)
# Split the resampled data
X_train, X_test, y_train, y_test = train_test_split(X_resampled, y_resampled, test_size=0.2, random_state=42)
# Train a model
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)
#testing original dataset
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)
# Make predictions
y_pred = clf.predict(X_test)
# Evaluate the model
print(classification_report(y_test, y_pred))
#create a table showing trends in what makes a player more likely to be injured
#showing the top 10 features with the highest feature importance and label columns
feature_imp = pd.Series(clf.feature_importances_, index=X_train.columns).sort_values(ascending=False)
#label the columns
feature_imp.columns = ['Feature', 'Importance']
print(feature_imp.head(10))
#what was the least important feature?
print(feature_imp.tail(1))
#export the table to a csv
feature_imp.to_csv('feature_imp.csv')
# #create a graph showing the top 10 features with the highest feature importance
# # Creating a bar plot
# sns.barplot(x=feature_imp[:10], y=feature_imp[:10].index)
# # Add labels to your graph
# plt.xlabel('Feature Importance Score')
# plt.ylabel('Features')
# plt.title("Visualizing Important Features")
# plt.legend()
# plt.show()
# plt.savefig('feature_imp.png')
#what are all the features?
print(feature_imp.index)