Matplotlib: Correlation Matrix Plot

This recipe includes the following topics:

  • Draw a Correlation Matrix Plot
  • Add axes tick labels
  • Add a colorbar


# import module
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

fileGitURL = 'https://raw.githubusercontent.com/andrewgurung/data-repository/master/pima-indians-diabetes.data.csv'

# define column names
cols = ['preg', 'plas', 'pres', 'skin', 'test', 'mass', 'pedi', 'age', 'class']

# load file as a Pandas DataFrame
pimaDf = pd.read_csv(fileGitURL, names=cols)

# calculate correlation between all columns
correlations = pimaDf.corr()

# create a new figure
fig = plt.figure(figsize=(10,10))

# 111: 1x1 grid, first subplot
ax = fig.add_subplot(111)

# normalize data using vmin, vmax
cax = ax.matshow(correlations, vmin=-1, vmax=1)

# add a colorbar to a plot.
fig.colorbar(cax)

# define ticks
ticks = np.arange(0,9,1)

# set x and y tick marks
ax.set_xticks(ticks)
ax.set_yticks(ticks)

# set x and y tick labels
ax.set_xticklabels(cols)
ax.set_yticklabels(cols)

# draw a matrix using the correlations data
plt.show()

Correlation Matrix Plot
Fig: Correlation Matrix Plot

One Comment

  1. Correlation matrix values are in the [0, 1] range, there’s no need to make the scale from -1 to 1.
    cax = ax.matshow(correlations, vmin=0, vmax=1)

Leave a Reply

Your email address will not be published. Required fields are marked *