Py: Decision Tree Classification#

This notebook fits a decision tree on a spam dataset to identify spam emails based on features about the email.


The spam dataset is sourced from the University of California, Irvine Machine Learning Repository: Hopkins, M., Reeber, E., Forman, G., and Suermondt, J. (1999). Spambase Data Set [Dataset].

This dataset contains the following:

  • 4,601 observations, each representing an email originally collected from a Hewlett-Packard email server, of which 1,813 (39%) were identified as spam;

  • 57 continuous features:

    • 48 features of type ‘word_freq_WORD’ that represent the percentage (0 to 100) of words in the email that match ‘WORD’;

    • 6 features of type ‘char_freq_CHAR’ that represent the percentage (0 to 100) of characters in the email that match ‘CHAR’;

    • 1 feature, ‘capital_run_length_average’, that is the average length of uninterrupted sequences of capital letters in the email;

    • 1 feature, ‘capital_run_length_longest’, that is the length of the longest uninterrupted sequence of capital letters in the email; and

    • 1 feature, ‘capital_run_length_total’, that is the total number of capital letters in the email; and

  • a binary response variable that takes on a value 0 if the email is not spam and 1 if the email is spam.


This section imports the packages that will be required for this exercise/case study.

import pandas as pd
import matplotlib.pyplot as plt

# Used to build the decision tree and evaluate it
from sklearn import tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import sklearn.metrics as metrics

# Used to graph the decision tree
import graphviz
from sklearn.tree import export_graphviz
from IPython.display import Image


This section:

  • imports the data that will be used in the modelling; and

  • explores the data.

Import data#

# Create a list of headings for the data.
namearray = [
'Spam_fl' ]

# Read in the data from the Stanford website.
spam = pd.read_csv("", delim_whitespace=True, 

Explore data (EDA)#

# Check the dimensions of the data.

# Print the first 10 observations from the data.
This section:

  • fits a model; and

  • evaluates the fitted model.

Fit model#

# Build a decision tree to classify an email as spam or non-spam.

X = spam.iloc[:,:-1] # Drops the last column of the dataframe that contains the spam indicator (response).
Y = spam['Spam_fl'].values
X_train,X_test,Y_train,Y_test = train_test_split(X,Y,test_size = 0.25, random_state=42)
        # This separates the data into training (75%) and test (25%) datasets.

alpha = 0.001 # Experiment with changing this value of alpha.
              # What happens to the size of the decision tree and its accuracy?

criteria = 'gini' # Experiment with changing this to an 'entropy' criterion
                  # What happens to the size of the decision tree and its accuracy?        

classifier = tree.DecisionTreeClassifier(random_state=0, ccp_alpha=alpha, min_samples_split=20, criterion=criteria, min_samples_leaf=5), Y_train)

Y_predict = classifier.predict(X_test)

Evaluate model#

# Calculate the confusion matrix and test accuracy.
print('Test confusion matrix:')
print(confusion_matrix(Y_test, Y_predict))
print('Test accuracy is {:.4f}.'.format(accuracy_score(Y_test, Y_predict)))

# Calculate the AUC and ROC.
Y_prob = classifier.predict_proba(X_test)
fpr, tpr, threshold = metrics.roc_curve(Y_test, Y_prob[:,1:2])
roc_auc = metrics.auc(fpr, tpr)
print('AUC is {:.4f}.'.format(roc_auc))

# Calculate the number of terminal nodes in the tree.
num_nodes = (1+classifier.tree_.node_count)/2
print('The number of terminal nodes is {:.0f}.'.format(num_nodes))
Test confusion matrix:
[[636  40]
 [ 61 414]]
Test accuracy is 0.9123.
AUC is 0.9572.
The number of terminal nodes is 37.
# Print the tree in a format that is easier to visualise.
clf2 = tree.DecisionTreeClassifier(random_state=0, ccp_alpha=0.01, min_samples_split=20, criterion="gini",min_samples_leaf=5), Y_train)

export_graphviz(clf2, '', feature_names = X_train.columns,class_names=['Not Spam', 'Spam'],filled=True, rounded=True,special_characters=True)
! dot -Tpng -o DAA_M05_Fig11.png
