Classification Algorithm: k-Nearest Neighbors

k-Nearest Neighbors(KNN) uses distance metric to find the k most similar instances(the neighbors) and takes mean outcome(for regression problem) or mode(for classification problem) as the prediction.

Note: k-Nearest Neighbors is a non-linear machine learning(ML) algorithm. KNN can require a lot of memory to perform calculation and suggests to only include the most relevant input variables.

Medium Post: Top 10 algorithms for ML newbies

  • Load classification problem dataset (Pima Indians) from github
  • Split columns into the usual feature columns(X) and target column(Y)
  • Set k-fold count to 10
  • Set seed to reproduce the same random data each time
  • Split data using KFold() class
  • Instantiate the classification algorithm: KNeighborsClassifier
  • Call cross_val_score() to run cross validation
  • Calculate mean estimated accuracy from scores returned by cross_val_score()

# import modules
import pandas as pd
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import KFold
from sklearn.model_selection import cross_val_score

# read data file from github
# dataframe: pimaDf
gitFileURL = ''
cols = ['preg', 'plas', 'pres', 'skin', 'test', 'mass', 'pedi', 'age', 'class']
pimaDf = pd.read_csv(gitFileURL, names = cols)

# convert into numpy array for scikit-learn
pimaArr = pimaDf.values

# Let's split columns into the usual feature columns(X) and target column(Y)
# Y represents the target 'class' column whose value is either '0' or '1'
X = pimaArr[:, 0:8]
Y = pimaArr[:, 8]

# set k-fold count
folds = 10

# set seed to reproduce the same random data each time
seed = 7

# split data using KFold
kfold = KFold(n_splits=folds, random_state=seed)

# instantiate the classification algorithm
model = KNeighborsClassifier()

# call cross_val_score() to run cross validation
resultArr = cross_val_score(model, X, Y, cv=kfold)

# calculate mean of scores for all folds
meanAccuracy = resultArr.mean() * 100

# display mean estimated accuracy
print("Mean estimated accuracy: %.3f%%" % meanAccuracy)
Mean estimated accuracy: 72.656%

Leave a Reply

Your email address will not be published. Required fields are marked *