Classification tree#

Iris dataset#

This is an example from sklearn.

Download and visualize Iris dataset:

import seaborn as sns
from sklearn.datasets import load_iris
iris = sns.load_dataset("iris")
sns.pairplot(iris, hue="species");
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[1], line 1
----> 1 import seaborn as sns
      2 from sklearn.datasets import load_iris
      3 iris = sns.load_dataset("iris")

ModuleNotFoundError: No module named 'seaborn'

Fit decision tree classifier:

from sklearn import tree

y = iris['species']
X = iris.drop("species", axis=1)
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, y)
clf.score(X, y)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[2], line 3
      1 from sklearn import tree
----> 3 y = iris['species']
      4 X = iris.drop("species", axis=1)
      5 clf = tree.DecisionTreeClassifier()

NameError: name 'iris' is not defined

Plot the tree:

tree.plot_tree(clf, filled=True);
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[3], line 1
----> 1 tree.plot_tree(clf, filled=True);

NameError: name 'clf' is not defined

A prettier tree can be drawn by graphviz:

import graphviz

dot_data = tree.export_graphviz(clf, out_file=None, 
                     feature_names=iris.columns[:-1],  
                     class_names=['setosa', 'versicolor', 'virginica'],  
                     filled=True, rounded=True,  
                     special_characters=True)  
graph = graphviz.Source(dot_data)  
graph 
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[4], line 1
----> 1 import graphviz
      3 dot_data = tree.export_graphviz(clf, out_file=None, 
      4                      feature_names=iris.columns[:-1],  
      5                      class_names=['setosa', 'versicolor', 'virginica'],  
      6                      filled=True, rounded=True,  
      7                      special_characters=True)  
      8 graph = graphviz.Source(dot_data)  

ModuleNotFoundError: No module named 'graphviz'

MNIST#

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split

%config InlineBackend.figure_format = 'svg'

X, Y = fetch_openml('mnist_784', return_X_y=True, parser='auto')

X = X.astype(float).values / 255
Y = Y.astype(int).values
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[5], line 3
      1 import numpy as np
      2 import matplotlib.pyplot as plt
----> 3 import seaborn as sns
      4 from sklearn.metrics import accuracy_score, confusion_matrix
      5 from sklearn.datasets import fetch_openml

ModuleNotFoundError: No module named 'seaborn'

Visualize data:

---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[6], line 19
     16         plt.title(title, size=20)
     17     plt.show()
---> 19 plot_digits(X, Y, random_state=11)

NameError: name 'X' is not defined

Split into train and test:

X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=10000)
X_train.shape, X_test.shape, y_train.shape, y_test.shape
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[7], line 1
----> 1 X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=10000)
      2 X_train.shape, X_test.shape, y_train.shape, y_test.shape

NameError: name 'train_test_split' is not defined

Fit a decision tree model:

from sklearn.tree import DecisionTreeClassifier

DT = DecisionTreeClassifier()
DT.fit(X_train, y_train)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[8], line 4
      1 from sklearn.tree import DecisionTreeClassifier
      3 DT = DecisionTreeClassifier()
----> 4 DT.fit(X_train, y_train)

NameError: name 'X_train' is not defined

Test accuracy:

DT.score(X_test, y_test)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[9], line 1
----> 1 DT.score(X_test, y_test)

NameError: name 'X_test' is not defined

Splitting conditions#

Each non-terminal node contains splitting condition, depending on which we are going to the left or to the right subtree. The splitting condition usually consists in comparing value of some feature \(x_j\) with a threshold \(t\):

\[ \mathbb I[x_j \leqslant t], \quad 1\leqslant j \leqslant d. \]

According to the splitting condition, the training sample \(X\) is split into two subsamples \(X_l\) and \(X_r\), \(X = X_l \cup X_r\).

ChatGPT suggestions#

  1. Node Splitting

At each internal node, the tree algorithm selects a feature and a splitting criterion to divide the data into two or more child nodes. The goal is to create splits that maximize the purity or homogeneity of the class labels within each node.

  1. Leaf Nodes

The leaf nodes are the terminal nodes of the tree. Each leaf node contains a predicted class label, representing the majority class of the training samples in that node.

  1. Predictive Modeling

To make predictions for new data, you traverse the tree from the root to a leaf node based on the feature values of the new data point. The class label in the selected leaf node is the predicted class for that data point.

  1. Recursive Partitioning

The process of building a classification tree is recursive. The algorithm starts with the entire dataset and recursively splits it into subsets by choosing the best feature and split criterion at each node, continuing until a stopping condition is met.

  1. Stopping Criteria

Stopping criteria are used to determine when to stop growing the tree. Common stopping criteria include limiting the tree depth, setting a minimum number of samples per leaf, or using a minimum impurity reduction threshold.

  1. Impurity Measures

In classification trees, impurity measures such as Gini impurity, entropy, or misclassification rate are used to evaluate how well a split increases the purity or homogeneity of class labels. The split that minimizes impurity is selected.

  1. Pruning

After building a classification tree, it may be pruned to reduce overfitting. Pruning involves removing nodes that do not significantly improve the tree’s performance on a validation dataset.

  1. Visualization

Classification trees can be visualized graphically, making it easy to interpret and understand the model’s decision-making process.

  1. Ensemble Methods

Classification trees are often used as building blocks in ensemble methods like Random Forests and Gradient Boosting, which combine multiple trees to improve predictive accuracy and reduce overfitting.

  1. Advantages

Classification trees are interpretable, and their decision-making process is easy to understand. They can capture complex decision boundaries and interactions between features.

  1. Limitations

They can be prone to overfitting, especially if the tree is allowed to grow deep. Single trees may not generalize well on certain types of data. Ensembling methods can mitigate these limitations.

Classification trees are widely used in various domains, including healthcare, finance, and natural language processing, for tasks such as spam email detection, disease diagnosis, and sentiment analysis. Proper tuning of hyperparameters and consideration of potential overfitting are essential when working with classification trees.