Decision Tree Classifier, Explained: A Visual Guide with Code Examples for Beginners

CLASSIFICATION ALGORITHM

A fresh look on our favorite upside-down tree

Decision Trees are everywhere in machine learning, beloved for their intuitive output. Who doesn’t love a simple “if-then” flowchart? Despite their popularity, it’s surprising how challenging it is to find a clear, step-by-step explanation of how Decision Trees work. (I’m actually embarrassed by how long it took me to actually understand how the algorithm works.)

So, in this breakdown, I’ll be focusing on the essentials of tree construction. We’ll unpack EXACTLY what’s happening in each node and why, from root to final leaves (with visuals of course).

All visuals: Author-created using Canva Pro. Optimized for mobile; may appear oversized on desktop.

Definition

A Decision Tree classifier creates an upside-down tree to make predictions, starting at the top with a question about an important feature in your data, then branches out based on the answers. As you follow these branches down, each stop asks another question, narrowing down the possibilities. This question-and-answer game continues until you reach the bottom — a leaf node — where you get your final prediction or classification.

Decision Tree is one of the most important machine learning algorithms — it’s a series of yes or no question.

Dataset Used

Throughout this article, we’ll use this artificial golf dataset (inspired by [1]) as an example. This dataset predicts whether a person will play golf based on weather conditions.

Columns: ‘Outlook’ (already one-hot encoded to sunny, overcast, rainy), ‘Temperature’ (in Fahrenheit), ‘Humidity’ (in %), ‘Wind’ (yes/no), and ‘Play’ (target feature)
# Import libraries
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import pandas as pd
import numpy as np

# Load data
dataset_dict = {
'Outlook': ['sunny', 'sunny', 'overcast', 'rainy', 'rainy', 'rainy', 'overcast', 'sunny', 'sunny', 'rainy', 'sunny', 'overcast', 'overcast', 'rainy', 'sunny', 'overcast', 'rainy', 'sunny', 'sunny', 'rainy', 'overcast', 'rainy', 'sunny', 'overcast', 'sunny', 'overcast', 'rainy', 'overcast'],
'Temperature': [85.0, 80.0, 83.0, 70.0, 68.0, 65.0, 64.0, 72.0, 69.0, 75.0, 75.0, 72.0, 81.0, 71.0, 81.0, 74.0, 76.0, 78.0, 82.0, 67.0, 85.0, 73.0, 88.0, 77.0, 79.0, 80.0, 66.0, 84.0],
'Humidity': [85.0, 90.0, 78.0, 96.0, 80.0, 70.0, 65.0, 95.0, 70.0, 80.0, 70.0, 90.0, 75.0, 80.0, 88.0, 92.0, 85.0, 75.0, 92.0, 90.0, 85.0, 88.0, 65.0, 70.0, 60.0, 95.0, 70.0, 78.0],
'Wind': [False, True, False, False, False, True, True, False, False, False, True, True, False, True, True, False, False, True, False, True, True, False, True, False, False, True, False, False],
'Play': ['No', 'No', 'Yes', 'Yes', 'Yes', 'No', 'Yes', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'No', 'Yes', 'Yes', 'No', 'No', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'Yes']
}
df = pd.DataFrame(dataset_dict)

# Preprocess data
df = pd.get_dummies(df, columns=['Outlook'], prefix='', prefix_sep='', dtype=int)
df['Wind'] = df['Wind'].astype(int)
df['Play'] = (df['Play'] == 'Yes').astype(int)

# Reorder the columns
df = df[['sunny', 'overcast', 'rainy', 'Temperature', 'Humidity', 'Wind', 'Play']]

# Prepare features and target
X, y = df.drop(columns='Play'), df['Play']

# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.5, shuffle=False)

# Display results
print(pd.concat([X_train, y_train], axis=1), 'n')
print(pd.concat([X_test, y_test], axis=1))

Main Mechanism

The Decision Tree classifier operates by recursively splitting the data based on the most informative features. Here’s how it works:

  1. Start with the entire dataset at the root node.
  2. Select the best feature to split the data (based on measures like Gini impurity).
  3. Create child nodes for each possible value of the selected feature.
  4. Repeat steps 2–3 for each child node until a stopping criterion is met (e.g., maximum depth reached, minimum samples per leaf, or pure leaf nodes).
  5. Assign the majority class to each leaf node.

Training Steps

In scikit-learn, the decision tree algorithm is called CART (Classification and Regression Trees). It builds binary trees and typically follows these steps:

  1. Start with all training samples in the root node.
Starting with the root node containing all 14 training samples, we will figure out the best way feature and the best point to split the data to start building the tree.

2.For each feature:
a. Sort the feature values.
b. Consider all possible thresholds between adjacent values as potential split points.

In this root node, there are 23 split points to check. Binary columns only has one split point.
def potential_split_points(attr_name, attr_values):
sorted_attr = np.sort(attr_values)
unique_values = np.unique(sorted_attr)
split_points = [(unique_values[i] + unique_values[i+1]) / 2 for i in range(len(unique_values) - 1)]
return {attr_name: split_points}

# Calculate and display potential split points for all columns
for column in X_train.columns:
splits = potential_split_points(column, X_train[column])
for attr, points in splits.items():
print(f"{attr:11}: {points}")

3. For each potential split point:
a. Calculate the impurity (e.g, Gini impurity) of the current node.
b. Calculate the weighted average of impurities.

For example, for feature “sunny” with split point 0.5, the impurity (like “Gini Impurity”) is calculated for both part of the dataset.
Another example, same process can be done to continuous features like “Temperature” as well.
def gini_impurity(y):
p = np.bincount(y) / len(y)
return 1 - np.sum(p**2)

def weighted_average_impurity(y, split_index):
n = len(y)
left_impurity = gini_impurity(y[:split_index])
right_impurity = gini_impurity(y[split_index:])
return (split_index * left_impurity + (n - split_index) * right_impurity) / n

# Sort 'sunny' feature and corresponding labels
sunny = X_train['sunny']
sorted_indices = np.argsort(sunny)
sorted_sunny = sunny.iloc[sorted_indices]
sorted_labels = y_train.iloc[sorted_indices]

# Find split index for 0.5
split_index = np.searchsorted(sorted_sunny, 0.5, side='right')

# Calculate impurity
impurity = weighted_average_impurity(sorted_labels, split_index)

print(f"Weighted average impurity for 'sunny' at split point 0.5: {impurity:.3f}")

4. After calculating all impurity for all features and split points, choose the lowest one.

The feature “overcast” with split point 0.5 gives the lowest impurity. This means the split will be the purest out of all the other split points!
def calculate_split_impurities(X, y):
split_data = []

for feature in X.columns:
sorted_indices = np.argsort(X[feature])
sorted_feature = X[feature].iloc[sorted_indices]
sorted_y = y.iloc[sorted_indices]

unique_values = sorted_feature.unique()
split_points = (unique_values[1:] + unique_values[:-1]) / 2

for split in split_points:
split_index = np.searchsorted(sorted_feature, split, side='right')
impurity = weighted_average_impurity(sorted_y, split_index)
split_data.append({
'feature': feature,
'split_point': split,
'weighted_avg_impurity': impurity
})

return pd.DataFrame(split_data)

# Calculate split impurities for all features
calculate_split_impurities(X_train, y_train).round(3)

5. Create two child nodes based on the chosen feature and split point:
– Left child: samples with feature value <= split point
– Right child: samples with feature value > split point

The selected split point split the data into two parts. As one part already pure (the right side! That’s why it’s impurity is low!), we only need to continue the tree on the left node.

6. Recursively repeat steps 2–5 for each child node. You can also stop until a stopping criterion is met (e.g., maximum depth reached, minimum number of samples per leaf node, or minimum impurity decrease).

# Calculate split impurities forselected index
selected_index = [4,8,3,13,7,9,10] # Change it depending on which indices you want to check
calculate_split_impurities(X_train.iloc[selected_index], y_train.iloc[selected_index]).round(3)
from sklearn.tree import DecisionTreeClassifier

# The whole Training Phase above is done inside sklearn like this
dt_clf = DecisionTreeClassifier()
dt_clf.fit(X_train, y_train)

Final Complete Tree

The class label of a leaf node is the majority class of the training samples that reached that node.

The right one is the final tree that will be used for classification. We do not need the samples anymore at this point.
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
# Plot the decision tree
plt.figure(figsize=(20, 10))
plot_tree(dt_clf, filled=True, feature_names=X.columns, class_names=['Not Play', 'Play'])
plt.show()
In this scikit-learn output, the information of the non-leaf node is also stored such as number of samples and number of each class in the node (value).

Classification Step

Here’s how the prediction process works once the decision tree has been trained:

  1. Start at the root node of the trained decision tree.
  2. Evaluate the feature and split condition at the current node.
  3. Repeat step 2 at each subsequent node until reaching a leaf node.
  4. The class label of the leaf node becomes the prediction for the new instance.
We only need the columns that is asked by the tree. Other than “overcast” and “Temperature”, other values does not matter in making the prediction.
# Make predictions
y_pred = dt_clf.predict(X_test)
print(y_pred)

Evaluation Step

The decision tree gives an adequate accuracy. As our tree only checks two features, it might not capture the test set characteristic well.
# Evaluate the classifier
print(f"Accuracy: {accuracy_score(y_test, y_pred)}")

Key Parameters

Decision Trees have several important parameters that control their growth and complexity:

1 . Max Depth: This sets the maximum depth of the tree, which can be a valuable tool in preventing overfitting.

👍 Helpful Tip: Consider starting with a shallow tree (perhaps 3–5 levels deep) and gradually increasing the depth.

Start with a shallow tree (e.g., depth of 3–5) and gradually increase until you find the optimal balance between model complexity and performance on validation data.

2. Min Samples Split: This parameter determines the minimum number of samples needed to split an internal node.

👍 Helpful Tip: Setting this to a higher value (around 5–10% of your training data) can help prevent the tree from creating too many small, specific splits that might not generalize well to new data.

3. Min Samples Leaf: This specifies the minimum number of samples required at a leaf node.

👍 Helpful Tip: Choose a value that ensures each leaf represents a meaningful subset of your data (approximately 1–5% of your training data). This can help avoid overly specific predictions.

4. Criterion: The function used to measure the quality of a split (usually “gini” for Gini impurity or “entropy” for information gain).

👍 Helpful Tip: While Gini is generally simpler and faster to compute, entropy often performs better for multi-class problems. That said, they frequently give similar results.

Example of Entropy calculation for ‘sunny’ with split point 0.5.

Pros & Cons

Like any algorithm in machine learning, Decision Trees have their strengths and limitations.

Pros:

  1. Interpretability: Easy to understand and visualize the decision-making process.
  2. No Feature Scaling: Can handle both numerical and categorical data without normalization.
  3. Handles Non-linear Relationships: Can capture complex patterns in the data.
  4. Feature Importance: Provides a clear indication of which features are most important for prediction.

Cons:

  1. Overfitting: Prone to creating overly complex trees that don’t generalize well, especially with small datasets.
  2. Instability: Small changes in the data can result in a completely different tree being generated.
  3. Biased with Imbalanced Datasets: Can be biased towards dominant classes.
  4. Inability to Extrapolate: Cannot make predictions beyond the range of the training data.

In our golf example, a Decision Tree might create very accurate and interpretable rules for deciding whether to play golf based on weather conditions. However, it might overfit to specific combinations of conditions if not properly pruned or if the dataset is small.

Final Remarks

Decision Tree Classifiers are a great tool for solving many types of problems in machine learning. They’re easy to understand, can handle complex data, and show us how they make decisions. This makes them useful in many areas, from business to medicine. While Decision Trees are powerful and interpretable, they’re often used as building blocks for more advanced ensemble methods like Random Forests or Gradient Boosting Machines.

🌟 Decision Tree Classifier Simplified

# Import libraries
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sklearn.tree import plot_tree, DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# Load data
dataset_dict = {
'Outlook': ['sunny', 'sunny', 'overcast', 'rainy', 'rainy', 'rainy', 'overcast', 'sunny', 'sunny', 'rainy', 'sunny', 'overcast', 'overcast', 'rainy', 'sunny', 'overcast', 'rainy', 'sunny', 'sunny', 'rainy', 'overcast', 'rainy', 'sunny', 'overcast', 'sunny', 'overcast', 'rainy', 'overcast'],
'Temperature': [85.0, 80.0, 83.0, 70.0, 68.0, 65.0, 64.0, 72.0, 69.0, 75.0, 75.0, 72.0, 81.0, 71.0, 81.0, 74.0, 76.0, 78.0, 82.0, 67.0, 85.0, 73.0, 88.0, 77.0, 79.0, 80.0, 66.0, 84.0],
'Humidity': [85.0, 90.0, 78.0, 96.0, 80.0, 70.0, 65.0, 95.0, 70.0, 80.0, 70.0, 90.0, 75.0, 80.0, 88.0, 92.0, 85.0, 75.0, 92.0, 90.0, 85.0, 88.0, 65.0, 70.0, 60.0, 95.0, 70.0, 78.0],
'Wind': [False, True, False, False, False, True, True, False, False, False, True, True, False, True, True, False, False, True, False, True, True, False, True, False, False, True, False, False],
'Play': ['No', 'No', 'Yes', 'Yes', 'Yes', 'No', 'Yes', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'No', 'Yes', 'Yes', 'No', 'No', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'Yes']
}
df = pd.DataFrame(dataset_dict)

# Prepare data
df = pd.get_dummies(df, columns=['Outlook'], prefix='', prefix_sep='', dtype=int)
df['Wind'] = df['Wind'].astype(int)
df['Play'] = (df['Play'] == 'Yes').astype(int)

# Split data
X, y = df.drop(columns='Play'), df['Play']
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.5, shuffle=False)

# Train model
dt_clf = DecisionTreeClassifier(
max_depth=None, # Maximum depth of the tree
min_samples_split=2, # Minimum number of samples required to split an internal node
min_samples_leaf=1, # Minimum number of samples required to be at a leaf node
criterion='gini' # Function to measure the quality of a split
)
dt_clf.fit(X_train, y_train)

# Make predictions
y_pred = dt_clf.predict(X_test)

# Evaluate model
print(f"Accuracy: {accuracy_score(y_test, y_pred)}")

# Visualize tree
plt.figure(figsize=(20, 10))
plot_tree(dt_clf, filled=True, feature_names=X.columns,
class_names=['Not Play', 'Play'], impurity=False)
plt.show()

Further Reading

For a detailed explanation of the Decision Tree Classifier and its implementation in scikit-learn, readers can refer to the official documentation [2], which provides comprehensive information on its usage and parameters.

Technical Environment

This article uses Python 3.7 and scikit-learn 1.5. While the concepts discussed are generally applicable, specific code implementations may vary slightly with different versions.

About the Illustrations

Unless otherwise noted, all images are created by the author, incorporating licensed design elements from Canva Pro.

For a concise visual summary of Decision Tree Classifier, check out the companion Instagram post.

Reference

[1] T. M. Mitchell, Machine Learning (1997), McGraw-Hill Science/Engineering/Math, pp. 59

[2] F. Pedregosa et al., Scikit-learn: Machine Learning in Python, Journal of Machine Learning Research, vol. 12, pp. 2825–2830, 2011. [Online]. Available: https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html

Your applause helps more data enthusiasts find this guide. If it was useful to you, why not share the love? 👏❤️