Regression with GUIDE¶
This notebook demonstrates how to use GuideTreeRegressor and GuideGradientBoostingRegressor on the Diabetes dataset.
We will compare:
Single GUIDE Tree: Interpretable, unbiased variable selection.
GUIDE Random Forest: Robust ensemble.
GUIDE Gradient Boosting: High-performance ensemble.
[1]:
import matplotlib.pyplot as plt
from sklearn.datasets import load_diabetes
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from pyguide import (
GuideGradientBoostingRegressor,
GuideRandomForestRegressor,
GuideTreeRegressor,
plot_tree,
)
# Load Data
X, y = load_diabetes(return_X_y=True, as_frame=True)
feature_names = X.columns
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print(f"Training samples: {len(X_train)}")
print(f"Features: {len(feature_names)}")
Training samples: 353
Features: 10
1. Single GUIDE Tree¶
[2]:
# Initialize GUIDE Regressor
# interaction_depth=1 enables pairwise interaction detection
reg = GuideTreeRegressor(max_depth=3, interaction_depth=1)
reg.fit(X_train, y_train)
# Evaluate
y_pred = reg.predict(X_test)
r2 = r2_score(y_test, y_pred)
print(f"Single Tree R2: {r2:.4f}")
# Visualize
plt.figure(figsize=(12, 8))
plot_tree(reg, feature_names=feature_names, fontsize=10)
plt.title("GUIDE Regression Tree (Diabetes)")
plt.show()
Single Tree R2: 0.3414
2. GUIDE Gradient Boosting¶
Gradient Boosting builds trees sequentially to correct the errors of previous trees.
[3]:
gbm = GuideGradientBoostingRegressor(
n_estimators=100,
learning_rate=0.1,
max_depth=3,
subsample=0.8,
random_state=42
)
gbm.fit(X_train, y_train)
y_pred_gbm = gbm.predict(X_test)
r2_gbm = r2_score(y_test, y_pred_gbm)
print(f"Gradient Boosting R2: {r2_gbm:.4f}")
Gradient Boosting R2: 0.5186
3. Comparison with Random Forest¶
[4]:
rf = GuideRandomForestRegressor(n_estimators=100, max_depth=5, random_state=42)
rf.fit(X_train, y_train)
y_pred_rf = rf.predict(X_test)
r2_rf = r2_score(y_test, y_pred_rf)
print(f"Random Forest R2: {r2_rf:.4f}")
Random Forest R2: 0.4896