Prerequisites

# Load required packages
library(dplyr)    # for data wrangling 
library(ggplot2)  # for general visualization
library(kernlab)  # for fitting SVMs
library(pdp)      # for partial dependence plots
library(ranger)   # for fitting random forests

User-defined prediction functions

Partial dependence plots (PDPs) are essentially just averaged predictions; this is apparent from step 1. (c) in Algorithm 1 in Greenwell (2017). Consequently, as pointed out by Goldstein et al. (2015), strong heterogeneity can conceal the complexity of the modeled relationship between the response and predictors of interest. This was part of the motivation behind Goldstein et al. (2015)’s ICE plot procedure.

With partial() it is possible to replace the mean in step 1. (c) of Algorithm 1 with any other function (e.g., the median or trimmed mean), or obtain PDPs for classification problems on the probability scale1. It is even possible to obtain ICE curves2. This flexibility is due to the new pred.fun argument in partial() (starting with pdp version 0.4.0). This argument accepts an optional prediction function that requires two arguments: object and newdata. The supplied prediction function must return either a single prediction or a vector of predictions. Returning the mean of all the predictions will result in the traditional PDP. Returning a vector of predictions (i.e., one for each observation) will result in a set of ICE curves. The examples below illustrate. The point to be made is that the pred.fun argument gives partial() the flexibility to handle all kinds of situations!

Using the pred.fun argument, it is possible to obtain PDPs for classification problems on the probability scale. We just need to write a function that computes the predicted class probability of interest averaged across all observations.

To illustrate, we consider Edgar Anderson’s iris data from the datasets package. The iris data frame contains the sepal length, sepal width, petal length, and petal width (in centimeters) for 50 flowers from each of three species of iris: setosa, versicolor, and virginica (i.e., \(K = 3\)). In the code chunk below, we fit a support vector machine (SVM) with a Gaussian radial basis function kernel to the iris data using the svm() function in the kernlab package (Karatzoglou, Smola, and Hornik 2018) (the tuning parameters were determined using 5-fold cross-validation). Note that the partial() function has to be able to extract the predicted probabilities for each class, so it is necessary to set probability = TRUE in the call to svm(). See the vignette titled Interpretting classification models for how to obtain variable importance plots for arbitrary models (like the SVM below) using the vip package (Greenwell and Boehmke, n.d.).

# Fit an SVM to the Edgar Anderson's iris data
iris_svm <- ksvm(Species ~ ., data = iris, kernel = "rbfdot", 
                 kpar = list(sigma = 0.709), C = 0.5, prob.model = TRUE)

The function below can be used to extract the average predicted probability of belonging to the Setosa class.

Next, we simply pass this function via the pred.fun argument in the call to partial(). The following chunk of code uses pred_prob to obtain PDPs for Petal.Width and Petal.Length on the probability scale. The results are displayed in Figure 1.

# PDPs for Petal.Width and Petal.Length on the probability scale
pdp1 <- partial(iris_svm, pred.var = "Petal.Width", pred.fun = pred_prob,
                  plot = TRUE, train = iris)
pdp2 <- partial(iris_svm, pred.var = "Petal.Length", pred.fun = pred_prob,
                  plot = TRUE, train = iris)
pdp3 <- partial(iris_svm, pred.var = c("Petal.Width", "Petal.Length"),
                     pred.fun = pred_prob, plot = TRUE, train = iris)

# Figure 1
grid.arrange(pdp1, pdp2, pdp3, ncol = 3)
**Figure 1** Partial dependence of `setosa` on `Petal.Width` and `Petal.Length` plotted on the probability scale; in this case, the probability of belonging to the setosa species.

Figure 1 Partial dependence of setosa on Petal.Width and Petal.Length plotted on the probability scale; in this case, the probability of belonging to the setosa species.

We could also plot the PDP for a single feature and include pointwise standard deviation bands! To do this, we simply augment the user-defined prediction function to return the mean, as well as the mean +/- one standard deviation (see Figure 2):

**Figure 2** Partial dependence of `setosa` on `Petal.Width` +/- one (pointwise) standard deviation.

Figure 2 Partial dependence of setosa on Petal.Width +/- one (pointwise) standard deviation.

For regression problems, the default prediction function is essentially

This corresponds to step step 1. (c) in Algorithm 1. Suppose we would like to manually construct ICE curves instead. To accomplish this we need to pass a prediction function that returns a vector of predictions, one for each observation in newdata (i.e., just remove the call to mean in pred.fun).

For illustration, we’ll use a corrected version of the Boston housing data analyzed in Harrison and Rubinfeld (1978); the data are available in the pdp package (see ?pdp::boston for details). We begin by loading the data and fitting a random forest with default tuning parameters and 500 trees using the ranger package (Wright, Wager, and Probst 2018).

# Fit a random forest to the Boston housing data
set.seed(101)  # for reproducibility
boston_rfo <- ranger(cmedv ~ ., data = boston)

The model fit is reasonable, with an out-of-bag (pseudo) \(R^2\) of 0.879.

The code snippet below manually constructs ICE curves for the Boston housing example using the predictor rm. The result is displayed in Figure 3. Note that when the function supplied to pred.fun returns multiple predictions, the data frame returned by partial() includes an additional column, yhat.id, that indicates which curve a point belongs to; in the following code chunk, there will be one curve for each observation in boston.

# Use partial to obtain ICE curves
ranger_ice <- function(object, newdata) {
  predict(object, newdata)$predictions
}
rm_ice <- partial(boston_rfo, pred.var = "rm", pred.fun = ranger_ice)

# Figure 3
autoplot(rm_ice, rug = TRUE, train = boston, alpha = 0.3)
#> Warning: Ignoring unknown parameters: csides
**Figure 3** ICE curves depicting the relationship between `cmedv` and `rm` for the Boston housing example. Each curve corresponds to a different observation.

Figure 3 ICE curves depicting the relationship between cmedv and rm for the Boston housing example. Each curve corresponds to a different observation.

The curves in Figure 3 indicate some heterogeneity in the fitted model (i.e., some of the curves depict the opposite relationship). Such heterogeneity can be easier to spot using c-ICE curves; see Equation (4) on page 49 of Goldstein et al. (2015). Using dplyr [Wickham et al. (2018)}, it is rather straightforward to post-process the output from partial() to obtain c-ICE curves (similar to the construction of raw change scores (Fitzmaurice, Laird, and Ware 2011, pg. 130) for longitudinal data)3. This is shown below.

# Post-process rm.ice to obtain c-ICE curves
rm_ice <- rm_ice %>%
  group_by(yhat.id) %>%  # perform next operation within each yhat.id
  mutate(yhat.centered = yhat - first(yhat))  # so each curve starts at yhat = 0

Since the PDP is just the average of the corresponding ICE curves, it is quite simple to display both on the same plot. This is easily accomplished using the stat_summary() function from the ggplot2 package to average the ICE curves together. The code snippet below plots the ICE curves and c-ICE curves, along with their averages, for the predictor rm in the Boston housing example. The results are displayed in Figure 4.

# ICE curves with their average
p1 <- ggplot(rm_ice, aes(rm, yhat)) +
  geom_line(aes(group = yhat.id), alpha = 0.2) +
  stat_summary(fun.y = mean, geom = "line", col = "red", size = 1)

# c-ICE curves with their average
p2 <- ggplot(rm_ice, aes(rm, yhat.centered)) +
  geom_line(aes(group = yhat.id), alpha = 0.2) +
  stat_summary(fun.y = mean, geom = "line", col = "red", size = 1)

# Figure 4
grid.arrange(p1, p2, ncol = 2)
**Figure 4** ICE curves (black curves) and their average (red curve) depicting the relationship between `cmedv` and `rm` for the Boston housing example. *Left*: Uncentered (here the red curve is just the traditional PDP). *Right*: Centered.

Figure 4 ICE curves (black curves) and their average (red curve) depicting the relationship between cmedv and rm for the Boston housing example. Left: Uncentered (here the red curve is just the traditional PDP). Right: Centered.

Fitzmaurice, G. M., N. M. Laird, and J. H. Ware. 2011. Applied Longitudinal Analysis. Wiley Series in Probability and Statistics. John Wiley & Sons.

Goldstein, Alex, Adam Kapelner, Justin Bleich, and Emil Pitkin. 2015. “Peeking Inside the Black Box: Visualizing Statistical Learning with Plots of Individual Conditional Expectation.” Journal of Computational and Graphical Statistics 24 (1): 44–65. https://doi.org/10.1080/10618600.2014.907095.

Greenwell, Brandon, and Brad Boehmke. n.d. Vip: Variable Importance Plots. https://koalaverse.github.io/vip/index.html.

Greenwell, Brandon M. 2017. “Pdp: An R Package for Constructing Partial Dependence Plots.” The R Journal 9 (1): 421–36. https://journal.r-project.org/archive/2017/RJ-2017-016/index.html.

Harrison, David, and Daniel L. Rubinfeld. 1978. “Hedonic Housing Prices and the Demand for Clean Air.” Journal of Environmental Economics and Management 5 (1): 81–102. https://doi.org/10.1016/0095-0696(78)90006-2.

Karatzoglou, Alexandros, Alex Smola, and Kurt Hornik. 2018. Kernlab: Kernel-Based Machine Learning Lab. https://CRAN.R-project.org/package=kernlab.

Wickham, Hadley, Romain François, Lionel Henry, and Kirill Müller. 2018. Dplyr: A Grammar of Data Manipulation. https://CRAN.R-project.org/package=dplyr.

Wright, Marvin N., Stefan Wager, and Philipp Probst. 2018. Ranger: A Fast Implementation of Random Forests. https://CRAN.R-project.org/package=ranger.


  1. This is more conveniently available via the prob argument starting with pdp version 0.5.0

  2. ICE curves are more conveniently available via the ice argument starting with pdp version 0.6.0

  3. c-ICE curves are more conveniently available via the ice and center arguments starting with pdp version 0.6.0