{ "cells": [ { "cell_type": "markdown", "id": "2aaab578", "metadata": {}, "source": [ "\n", "# Strict GUIDE Variable Importance\n", "\n", "This notebook demonstrates the **Strict GUIDE Variable Importance** algorithm (Loh & Zhou, 2021).\n", "\n", "Unlike standard impurity-based importance scores (which are biased towards high-cardinality features), Strict GUIDE scores are:\n", "1. **Unbiased:** Derived from Chi-square tests of independence.\n", "2. **Normalized:** A score of **1.0** represents the importance of a random noise variable.\n", "3. **Robust:** Includes bias correction via permutation tests.\n", "\n", "## Synthetic Data Example\n", "\n", "We will generate a dataset with:\n", "- **Signal Variables:** `x0` (linear), `x1` & `x2` (interaction).\n", "- **Noise Variables:** `x3` (high cardinality categorical), `x4` (continuous).\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b32e770c", "metadata": {}, "outputs": [], "source": [ "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "\n", "from pyguide import GuideTreeClassifier\n", "\n", "# Reproducibility\n", "rng = np.random.default_rng(42)\n", "n_samples = 1000\n", "\n", "# 1. Main signal: x0\n", "x0 = rng.uniform(0, 1, n_samples)\n", "\n", "# 2. Interaction signal: x1 and x2 (XOR-like)\n", "x1 = rng.uniform(0, 1, n_samples)\n", "x2 = rng.uniform(0, 1, n_samples)\n", "\n", "# 3. High-cardinality noise: x3 (50 levels)\n", "x3 = rng.choice([f\"cat_{i}\" for i in range(50)], n_samples)\n", "\n", "# 4. Continuous noise: x4\n", "x4 = rng.uniform(0, 1, n_samples)\n", "\n", "# Target: depends on x0 and interaction(x1, x2)\n", "# y = 1 if (x0 > 0.5) OR (x1 > 0.5 XOR x2 > 0.5)\n", "y = ((x0 > 0.5) | ((x1 > 0.5) ^ (x2 > 0.5))).astype(int)\n", "\n", "df = pd.DataFrame({\n", " 'signal_main (x0)': x0,\n", " 'signal_int_1 (x1)': x1,\n", " 'signal_int_2 (x2)': x2,\n", " 'noise_high_card (x3)': x3,\n", " 'noise_cont (x4)': x4\n", "})\n", "\n", "print(\"Data generated. Calculating importance...\")\n" ] }, { "cell_type": "markdown", "id": "1c10a85a", "metadata": {}, "source": [ "\n", "## Calculate Importance Scores\n", "\n", "We use `compute_guide_importance` with `bias_correction=True`. This runs permutation tests to establish a null distribution for every feature.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a6bc56e0", "metadata": {}, "outputs": [], "source": [ "\n", "clf = GuideTreeClassifier(interaction_depth=1, random_state=42)\n", "\n", "# Compute strict importance\n", "# This might take a few seconds due to permutations (default n_permutations=300)\n", "scores = clf.compute_guide_importance(\n", " df, y, \n", " bias_correction=True, \n", " n_permutations=100 # Reduced for demo speed\n", ")\n", "\n", "# Create a DataFrame for visualization\n", "importance_df = pd.DataFrame({\n", " 'Feature': df.columns,\n", " 'Strict Importance (VI)': scores\n", "}).sort_values('Strict Importance (VI)', ascending=False)\n", "\n", "print(importance_df)\n" ] }, { "cell_type": "markdown", "id": "9f9a3226", "metadata": {}, "source": [ "\n", "## Interpretation\n", "\n", "- **Scores > 1.0:** Significant association.\n", "- **Scores ≈ 1.0:** Noise.\n", "\n", "Notice how `noise_high_card` (x3) has a score near or below 1.0, despite having many unique values. A standard Random Forest impurity importance would likely rank this noise variable very high due to cardinality bias.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b9420d15", "metadata": {}, "outputs": [], "source": [ "\n", "# Visualization\n", "plt.figure(figsize=(10, 6))\n", "plt.barh(importance_df['Feature'], importance_df['Strict Importance (VI)'], color='skyblue')\n", "plt.axvline(x=1.0, color='red', linestyle='--', label='Noise Threshold (1.0)')\n", "plt.xlabel('Strict Importance Score (normalized)')\n", "plt.title('Unbiased Variable Importance (Loh & Zhou, 2021)')\n", "plt.legend()\n", "plt.gca().invert_yaxis()\n", "plt.show()\n" ] } ], "metadata": {}, "nbformat": 4, "nbformat_minor": 5 }