AW Dev Rethought

🌟 The best way to predict the future is to invent it - Alan Kay

🧠 AI with Python – 📉 Feature Drift Detection


Description:

A machine learning model can perform exceptionally well during development but gradually degrade in production without any visible errors. One of the most common reasons for this is feature drift — when the input data distribution changes over time.

The model itself hasn’t changed, but the data it receives has.

In this project, we explore how to detect feature drift using simple visualization techniques and statistical tests.


Understanding the Problem

Machine learning models are trained on historical data. They learn patterns based on that data distribution.

However, in production:

  • User behavior changes
  • Market conditions shift
  • Data collection pipelines evolve
  • New patterns emerge

This causes the distribution of features to shift, leading to unreliable predictions.


What Is Feature Drift?

Feature drift occurs when the statistical properties of input features change between:

  • Training data (reference)
  • Production data (live input)

Example:

  • Average income increases over time
  • Age distribution shifts due to new user segments

Even if the model is correct, the input space has changed.


1. Simulating Training and Production Data

We create two datasets to represent:

  • Training data
  • Drifted production data
train_data = pd.DataFrame({
    "age": np.random.normal(35, 10, 1000),
    "income": np.random.normal(50000, 15000, 1000)
})

production_data = pd.DataFrame({
    "age": np.random.normal(42, 12, 1000),
    "income": np.random.normal(65000, 20000, 1000)
})

Here, we intentionally shift the mean to simulate drift.


2. Visualizing Distribution Changes

We compare feature distributions using histograms.

plt.hist(train_data[col], alpha=0.5)
plt.hist(production_data[col], alpha=0.5)

Visual inspection helps identify:

  • shifts in mean
  • changes in variance
  • distribution overlap

3. Detecting Drift Using Statistical Tests

We use the Kolmogorov–Smirnov (KS) test to quantify drift.

from scipy.stats import ks_2samp

stat, p_value = ks_2samp(train_data[col], production_data[col])

The KS test compares two distributions and returns:

  • statistic → difference magnitude
  • p-value → significance of difference

4. Interpreting Results

if p_value < 0.05:
    print("Drift detected")
  • Low p-value → distributions differ significantly
  • High p-value → no strong evidence of drift

This provides a systematic way to monitor changes.


Why Feature Drift Matters

Feature drift can lead to:

  • degraded model accuracy
  • incorrect predictions
  • business impact (financial loss, poor decisions)
  • reduced trust in ML systems

The dangerous part is that drift happens silently — without system errors.


Key Takeaways

  1. Feature drift occurs when input data distribution changes.
  2. It can silently degrade model performance.
  3. Visualizations help identify drift intuitively.
  4. KS test provides statistical confirmation.
  5. Essential for production ML monitoring systems.

Conclusion

Feature drift detection is a foundational component of production machine learning systems. By continuously monitoring how incoming data differs from training data, we can detect issues early and take corrective actions — such as retraining models or updating pipelines.

This is a critical step in building reliable ML systems within the Production ML track of the AI with Python series.


Code Snippet:

# 📦 Import Required Libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from scipy.stats import ks_2samp


# 🧩 Create Training Data (Reference Data)
np.random.seed(42)

train_data = pd.DataFrame({
    "age": np.random.normal(35, 10, 1000),
    "income": np.random.normal(50000, 15000, 1000)
})


# 🚀 Simulate Production Data (Drifted Data)
production_data = pd.DataFrame({
    "age": np.random.normal(42, 12, 1000),
    "income": np.random.normal(65000, 20000, 1000)
})


# 📊 Visualize Distribution Shift
for col in train_data.columns:
    plt.figure(figsize=(6, 4))

    plt.hist(train_data[col], bins=30, alpha=0.5, label="Train")
    plt.hist(production_data[col], bins=30, alpha=0.5, label="Production")

    plt.title(f"Distribution Shift – {col}")
    plt.xlabel(col)
    plt.ylabel("Count")
    plt.legend()
    plt.tight_layout()
    plt.show()


# 🔍 Statistical Drift Detection (KS Test)
drift_results = {}

for col in train_data.columns:
    stat, p_value = ks_2samp(train_data[col], production_data[col])

    drift_results[col] = {
        "ks_stat": stat,
        "p_value": p_value
    }

    print(f"{col} → KS Stat: {stat:.4f}, p-value: {p_value:.6f}")


# 🚨 Interpret Drift Results
print("\nDrift Detection Summary:\n")

for col, values in drift_results.items():
    if values["p_value"] < 0.05:
        print(f"⚠️ Drift detected in feature: {col}")
    else:
        print(f"✅ No significant drift in feature: {col}")

Link copied!

Comments

Add Your Comment

Comment Added!