Training a Naive Bayes classifier using sklearn

This is the second post in a series, in this post we will look at how to apply the Naive Bayes to train and solve a classification problem on the iris data-set.  Here’s the link to the first post naive bayes a primer in case you missed it, in it we breakdown the mechanics of the algorithm.

Okay so lets get down to business and get an overview of what is going to be covered in this post. We will get our hands dirty by creating a naive bayes model using the scikit-learn python framework.

Before we dive in, lets look at the software prerequisites  to execute the code.

  1. Python 2.7 or higher
  2. install the SciPy package  –
  3. install the scikit sklearn package –

About scikit-learn 

Scikit-Learn is a Open Source Machine learning library for Python. The library is simple to use and contains tools for data analysis and data mining, not to mention several machine learning algorithms. The framework is built on the NumPy, SciPy and matplotlib packages.

The Dataset

In the program we will be using the iris dataset that is provided with the sklearn library. The dataset contains a total of 150 observations, which is made up of 3 classes of 50 instances each, where each class refers to a type of iris plant.

Please refer to this link  for more reading on the dataset. Each row of the observations is made of 4  feature attributes and 1 class attribute which is the predicted attribute.

  1. Sepal Length in cm
  2. Sepal Width in cm
  3. Petal Length in cm
  4. Petal Width in cm
  5. Class: The class labels are
  • Setosa
  • Versicolour
  • Virginica

Exploring the dataset

Before we get to the code, it is vital to  get a better understanding of the relationship between the features, one way to visualize the data is generate a scatter matrix plot as shown below

iris_data-scatter-plot-1 The attribute in the row represent the y axis and the attribute in the column is the variable on the x axis. So the first plot , sepal length is plotted on the y axis and sepal width on the x axis. From the plot we can conclude that there is a linear relationship between Sepal Length and Petal Length, Sepal Length and Petal Width.

Another method to find the linear relationships between features is to use a bivariate statistic measure called correlation coefficient r , where the range of values can be between -1 and 1, where 1 indicates a very strong positive linear relationship and -1 indicating the a negative linear relationship. Here’s a plot of the correlation coefficient ;


Train a Naive Bayes Classifier

Now that we’ve developed an intuition of the data, let’s write a application to train and use a naive bayes classifier and have it predict the class outcomes.

But first let’s break the problem into smaller steps

  • We will first load the features and the Class in two separate variables called X and y respectively, and then we will randomly divide the dataset into a training and test set.
  • The training set will be used to train the dataset, the test set will be used to get the classifier predict the class for each of the outcomes.
  • We will then measure the accuracy of the predictions by comparing the predicted outcomes to that of the true values of the class for the test set.

Now that we’ve defined the problem let’s code the solution:


import matplotlib.pyplot as plt

from sklearn import datasets
from sklearn.naive_bayes import  GaussianNB
from sklearn.metrics import accuracy_score,confusion_matrix
from sklearn.cross_validation import train_test_split
#import pandas as pd
import  numpy as np
from StringIO import StringIO

iris = datasets.load_iris()
X =
y =

# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# Create the Naive Bayes Classifier
clf = GaussianNB()

# Train the classifier using the fit method,y_train)

# Generate predictions i.e. class names on the test data set
y_predict = clf.predict(X_test)

score = accuracy_score(y_test,y_predict,normalize=False)

print("Total number of correctly classified observations: {0} out of {2} observations, Accuracy of the predictions: {1}").format(score,score/float(len(y_test)),len(y_test))

def plot_confusion_matrix(cm, title='Confusion matrix',
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    tick_marks = np.arange(len(iris.target_names))
    plt.xticks(tick_marks, iris.target_names, rotation=45)
    plt.yticks(tick_marks, iris.target_names)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

#Compute confusion matrix
cm = confusion_matrix(y_test,y_predict)
print('Confusion matrix, without normalization')


# Normalize the confusion matrix by row (i.e by the number of samples
# in each class)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print('Normalized confusion matrix')
plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix')

All the code is available on github @


Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s