Skip to content

Latest commit

 

History

History
57 lines (42 loc) · 1.82 KB

random_forest_classifier.md

File metadata and controls

57 lines (42 loc) · 1.82 KB
<script src="../turi/js/recview.js"></script>

Random Forest Classifier

A Random Forest classifier is one of the most effective machine learning models for predictive analytics. Refer to the chapter on random forest regression for background on random forests.

Introductory Example
import graphlab as gl

# Load the data
# The data can be downloaded using
data =  gl.SFrame.read_csv('https://static.turi.com/datasets/xgboost/mushroom.csv')

# Label 'c' is edible
data['label'] = data['label'] == 'c'

# Make a train-test split
train_data, test_data = data.random_split(0.8)

# Create a model.
model = gl.random_forest_classifier.create(train_data, target='label',
                                           max_iterations=2,
                                           max_depth = 3)

# Save predictions to an SArray. 
predictions = model.predict(test_data)

# Evaluate the model and save the results into a dictionary
results = model.evaluate(test_data)

We can visualize the models using

model.show(view="Tree", tree_id=0)
model.show(view="Tree", tree_id=1)

Alt text Alt text

See the chapter on random forest regression for additional tips and tricks of using the random forest classifier model.

Advanced Features

Refer to the earlier chapters for the following features: