Classification Algorithm: Classification and Regression Trees

Classification and Regression Trees(CART) also known as decision trees construct a binary tree from training data. CART model can be constructed using DecisionTreeClassifier class.

Note: Classification and Regression Trees is a non-linear machine learning(ML) algorithm.

Medium Post: Top 10 algorithms for ML newbies


This recipe includes the following topics:

  • Load classification problem dataset (Pima Indians) from github
  • Split columns into the usual feature columns(X) and target column(Y)
  • Set k-fold count to 10
  • Set seed to reproduce the same random data each time
  • Split data using KFold() class
  • Instantiate the classification algorithm: DecisionTreeClassifier
  • Call cross_val_score() to run cross validation
  • Calculate mean estimated accuracy from scores returned by cross_val_score()


# import modules
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import KFold
from sklearn.model_selection import cross_val_score

# read data file from github
# dataframe: pimaDf
gitFileURL = 'https://raw.githubusercontent.com/andrewgurung/data-repository/master/pima-indians-diabetes.data.csv'
cols = ['preg', 'plas', 'pres', 'skin', 'test', 'mass', 'pedi', 'age', 'class']
pimaDf = pd.read_csv(gitFileURL, names = cols)

# convert into numpy array for scikit-learn
pimaArr = pimaDf.values

# Let's split columns into the usual feature columns(X) and target column(Y)
# Y represents the target 'class' column whose value is either '0' or '1'
X = pimaArr[:, 0:8]
Y = pimaArr[:, 8]

# set k-fold count
folds = 10

# set seed to reproduce the same random data each time
seed = 7

# split data using KFold
kfold = KFold(n_splits=folds, random_state=seed)

# instantiate the classification algorithm
model = DecisionTreeClassifier()

# call cross_val_score() to run cross validation
resultArr = cross_val_score(model, X, Y, cv=kfold)

# calculate mean of scores for all folds
meanAccuracy = resultArr.mean() * 100

# display mean estimated accuracy
print("Mean estimated accuracy: %.3f%%" % meanAccuracy)
Mean estimated accuracy: 69.265%

Leave a Reply

Your email address will not be published. Required fields are marked *