AW Dev Rethought

Programs must be written for people to read, and only incidentally for machines to execute - Harold Abelson

🧠 AI with Python – 🔥 Feature Interaction Heatmaps


Description:

Understanding how individual features affect predictions is useful. But in many real-world machine learning problems, the real complexity comes from interactions between features.

Some features may not be powerful alone — yet when combined with another feature, they significantly influence the model’s output.

In this project, we explore Feature Interaction Heatmaps, a visual technique to understand how two features jointly impact model predictions.


Understanding the Problem

Single-feature interpretability methods like:

  • Feature Importance
  • SHAP values
  • Partial Dependence Plots

help explain global or local behaviour.

However, they often treat features independently.

But models — especially tree-based models — frequently learn non-linear interactions between variables.

Key question:

How does the model behave when two features change together?

That’s where interaction heatmaps help.


What Is a Feature Interaction Heatmap?

A Feature Interaction Heatmap:

  • Selects two features
  • Varies both across a range of values
  • Keeps all other features fixed
  • Visualises predicted output as a 2D heatmap

This allows us to see:

  • Synergy effects
  • Threshold regions
  • Non-linear interaction zones

1. Training a Model

We begin by training a classification model.

from sklearn.ensemble import RandomForestClassifier

model = RandomForestClassifier(
    n_estimators=200,
    random_state=42
)
model.fit(X_train, y_train)

Tree-based models are particularly good at capturing interaction effects.


2. Creating a Feature Grid

We select two features and create a grid of value combinations.

feature_1 = "mean radius"
feature_2 = "mean texture"

We generate combinations of values across their ranges while holding other features constant (using mean values).


3. Predicting Over the Grid

We compute model predictions for every pair in the grid.

probs = model.predict_proba(grid)[:, 1]

These predictions are reshaped into a matrix for visualisation.


4. Visualising the Interaction

sns.heatmap(heatmap_data, cmap="viridis")

The heatmap color intensity represents predicted probability.

Darker or brighter regions indicate:

  • Higher predicted probability
  • Stronger interaction effects

How to Interpret the Heatmap

  • Smooth gradient → weak interaction
  • Sharp transitions → threshold behaviour
  • Diagonal patterns → joint influence
  • Irregular zones → complex non-linear interaction

This gives a deeper understanding than single-feature plots.


Why Interaction Heatmaps Matter

Feature interactions:

  • Often drive predictive performance
  • Explain model decisions better
  • Reveal hidden patterns in data
  • Support advanced model diagnostics

They complement:

  • SHAP values
  • PDP
  • Permutation importance

Key Takeaways


  1. Models often rely on feature interactions.
  2. Heatmaps visualise joint feature influence.
  3. Useful for detecting non-linear and threshold behaviour.
  4. Best applied with tree-based models.
  5. A powerful interpretability technique in advanced ML workflows.

Conclusion

Feature Interaction Heatmaps provide a powerful way to uncover how pairs of features jointly influence model predictions. By moving beyond independent feature analysis, we gain deeper insights into model behaviour and uncover hidden relationships that drive performance.

This technique strengthens the Advanced Visualisation & Interpretability module within the AI with Python series and helps build more transparent and trustworthy machine learning systems.


Code Snippet:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier


data = load_breast_cancer()
X = pd.DataFrame(data.data, columns=data.feature_names)
y = data.target


X_train, X_test, y_train, y_test = train_test_split(
    X,
    y,
    test_size=0.3,
    random_state=42,
    stratify=y
)


model = RandomForestClassifier(
    n_estimators=200,
    random_state=42
)
model.fit(X_train, y_train)


feature_1 = "mean radius"
feature_2 = "mean texture"

# Create grid values
f1_range = np.linspace(X_test[feature_1].min(), X_test[feature_1].max(), 50)
f2_range = np.linspace(X_test[feature_2].min(), X_test[feature_2].max(), 50)

grid = pd.DataFrame([
    {feature_1: f1, feature_2: f2}
    for f1 in f1_range
    for f2 in f2_range
])

# Fill other features with mean values
for col in X.columns:
    if col not in [feature_1, feature_2]:
        grid[col] = X_train[col].mean()

# Predict probabilities
probs = model.predict_proba(grid)[:, 1]

# Reshape predictions for heatmap
heatmap_data = probs.reshape(len(f1_range), len(f2_range))

plt.figure(figsize=(8, 6))
sns.heatmap(
    heatmap_data,
    xticklabels=np.round(f2_range, 2),
    yticklabels=np.round(f1_range, 2),
    cmap="viridis"
)

plt.xlabel(feature_2)
plt.ylabel(feature_1)
plt.title("Feature Interaction Heatmap")
plt.show()

Link copied!

Comments

Add Your Comment

Comment Added!