Econometrics and Free Software by Bruno Rodrigues.
RSS feed for blog post updates.
Follow me on Mastodon, twitter, or check out my Github.
Check out my package that adds logging to R functions, {chronicler}.
Or read my free ebooks, to learn some R and build reproducible analytical pipelines..
You can also watch my youtube channel or find the slides to the talks I've given here.
Buy me a coffee, my kids don't let me sleep.

Multiple data imputation and explainability

R

Introduction

Imputing missing values is quite an important task, but in my experience, very often, it is performed using very simplistic approaches. The basic approach is to impute missing values for numerical features using the average of each feature, or using the mode for categorical features. There are better ways of imputing missing values, for instance by predicting the values using a regression model, or KNN. However, imputing only once is not enough, because each imputed value carries with it a certain level of uncertainty. To account for this, it is better to perform multiple imputation. This means that if you impute your dataset 10 times, you’ll end up with 10 different datasets. Then, you should perform your analysis 10 times, for instance, if training a machine learning model, you should train it on the 10 datasets (and do a train/test split for each, even potentially tune a model for each). Finally, you should pool the results of these 10 analyses.

I have met this approach in the social sciences and statistical literature in general, but very rarely in machine learning. Usually, in the social sciences, explainability is the goal of fitting statistical models to data, and the approach I described above is very well suited for this. Fit 10 (linear) regressions to each imputed dataset, and then pool the estimated coefficients/weights together. Rubin’s rule is used to pool these estimates. You can read more about this rule here. In machine learning, the task is very often prediction; in this case, you should pool the predictions. Computing the average and other statistics of the predictions seem to work just fine in practice.

However, if you are mainly interested in explainability, how should you proceed? I’ve thought a bit about it, and the answer, is “exactly the same way”… I think. What I’m sure about, is you should impute m times, run the analysis m times (which in this case will include getting explanations) and then pool. So the idea is to be able to pool explanations.

Explainability in the “standard” case (no missing values)

To illustrate this idea, I’ll be using the {mice} package for multiple imputation, {h2o} for the machine learning bit and{iml} for explainability. Note that I could have used any other machine learning package instead of {h2o} as {iml} is totally package-agnostic. However, I have been experimenting with {h2o}’s automl implementation lately, so I happened to have code on hand. Let’s start with the “standard” case where the data does not have any missing values.

First let’s load the needed packages and initialize h2o functions with h2o.init():

library(tidyverse)
library(Ecdat)
library(mice)
library(h2o)
library(iml)
h2o.init()

I’ll be using the DoctorContacts data. Here’s a description:

Click to view the description of the data

DoctorContacts              package:Ecdat              R Documentation

Contacts With Medical Doctor

Description:

     a cross-section from 1977-1978

     _number of observations_ : 20186

Usage:

     data(DoctorContacts)
     
Format:

     A time serie containing :

     mdu number of outpatient visits to a medical doctor

     lc log(coinsrate+1) where coinsurance rate is 0 to 100

     idp individual deductible plan ?

     lpi log(annual participation incentive payment) or 0 if no payment

     fmde log(max(medical deductible expenditure)) if IDP=1 and MDE>1
          or 0 otherw

     physlim physical limitation ?

     ndisease number of chronic diseases

     health self-rate health (excellent,good,fair,poor)

     linc log of annual family income (in \$)

     lfam log of family size

     educdec years of schooling of household head

     age exact age

     sex sex (male,female)

     child age less than 18 ?

     black is household head black ?

Source:

     Deb, P.  and P.K.  Trivedi (2002) “The Structure of Demand for
     Medical Care: Latent Class versus Two-Part Models”, _Journal of
     Health Economics_, *21*, 601-625.

References:

     Cameron, A.C.  and P.K.  Trivedi (2005) _Microeconometrics :
     methods and applications_, Cambridge, pp. 553-556 and 565.

The task is to predict "mdu", the number of outpatient visits to an MD. Let’s prepare the data and split it into 3; a training, validation and holdout set.

data("DoctorContacts")

contacts <- as.h2o(DoctorContacts)
splits <- h2o.splitFrame(data=contacts, ratios = c(0.7, 0.2))

original_train <- splits[[1]]

validation <- splits[[2]]

holdout <- splits[[3]]

features_names <- setdiff(colnames(original_train), "mdu")

As you see, the ratios argument c(0.7, 0.2) does not add up to 1. This means that the first of the splits will have 70% of the data, the second split 20% and the final 10% will be the holdout set.

Let’s first go with a poisson regression. To obtain the same results as with R’s built-in glm() function, I use the options below, as per H2o’s glm faq.

If you read Cameron and Trivedi’s Microeconometrics, where this data is presented in the context of count models, you’ll see that they also fit a negative binomial model 2 to this data, as it allows for overdispersion. Here, I’ll stick to a simple poisson regression, simply because the goal of this blog post is not to get the best model; as explained in the beginning, this is an attempt at pooling explanations when doing multiple imputation (and it’s also because GBMs, which I use below, do not handle the negative binomial model).

glm_model <- h2o.glm(y = "mdu", x = features_names,
                     training_frame = original_train,
                     validation_frame = validation,
                     compute_p_values = TRUE,
                     solver = "IRLSM",
                     lambda = 0,
                     remove_collinear_columns = TRUE,
                     score_each_iteration = TRUE,
                     family = "poisson", 
                     link = "log")

Now that I have this simple model, which returns the (almost) same results as R’s glm() function, I can take a look at coefficients and see which are important, because GLMs are easily interpretable:

Click to view h2o.glm()’s output

summary(glm_model)
## Model Details:
## ==============
## 
## H2ORegressionModel: glm
## Model Key:  GLM_model_R_1572735931328_5 
## GLM Model: summary
##    family link regularization number_of_predictors_total
## 1 poisson  log           None                         16
##   number_of_active_predictors number_of_iterations  training_frame
## 1                          16                    5 RTMP_sid_8588_3
## 
## H2ORegressionMetrics: glm
## ** Reported on training data. **
## 
## MSE:  17.6446
## RMSE:  4.200547
## MAE:  2.504063
## RMSLE:  0.8359751
## Mean Residual Deviance :  3.88367
## R^2 :  0.1006768
## Null Deviance :64161.44
## Null D.o.F. :14131
## Residual Deviance :54884.02
## Residual D.o.F. :14115
## AIC :83474.52
## 
## 
## H2ORegressionMetrics: glm
## ** Reported on validation data. **
## 
## MSE:  20.85941
## RMSE:  4.56721
## MAE:  2.574582
## RMSLE:  0.8403465
## Mean Residual Deviance :  4.153042
## R^2 :  0.09933874
## Null Deviance :19667.55
## Null D.o.F. :4078
## Residual Deviance :16940.26
## Residual D.o.F. :4062
## AIC :25273.25
## 
## 
## 
## 
## Scoring History: 
##             timestamp   duration iterations negative_log_likelihood
## 1 2019-11-03 00:33:46  0.000 sec          0             64161.43611
## 2 2019-11-03 00:33:46  0.004 sec          1             56464.99004
## 3 2019-11-03 00:33:46  0.020 sec          2             54935.05581
## 4 2019-11-03 00:33:47  0.032 sec          3             54884.19756
## 5 2019-11-03 00:33:47  0.047 sec          4             54884.02255
## 6 2019-11-03 00:33:47  0.063 sec          5             54884.02255
##   objective
## 1   4.54015
## 2   3.99554
## 3   3.88728
## 4   3.88368
## 5   3.88367
## 6   3.88367
## 
## Variable Importances: (Extract with `h2o.varimp`) 
## =================================================
## 
##        variable relative_importance scaled_importance  percentage
## 1    black.TRUE          0.67756097        1.00000000 0.236627982
## 2   health.poor          0.48287163        0.71266152 0.168635657
## 3  physlim.TRUE          0.33962316        0.50124369 0.118608283
## 4   health.fair          0.25602066        0.37785627 0.089411366
## 5      sex.male          0.19542639        0.28842628 0.068249730
## 6      ndisease          0.16661902        0.24591001 0.058189190
## 7      idp.TRUE          0.15703578        0.23176627 0.054842384
## 8    child.TRUE          0.09988003        0.14741114 0.034881600
## 9          linc          0.09830075        0.14508030 0.034330059
## 10           lc          0.08126160        0.11993253 0.028379394
## 11         lfam          0.07234463        0.10677213 0.025265273
## 12         fmde          0.06622332        0.09773781 0.023127501
## 13      educdec          0.06416087        0.09469387 0.022407220
## 14  health.good          0.05501613        0.08119732 0.019213558
## 15          age          0.03167598        0.04675000 0.011062359
## 16          lpi          0.01938077        0.02860373 0.006768444

As a bonus, let’s see the output of the glm() function:

Click to view glm()’s output

train_tibble <- as_tibble(original_train)

r_glm <- glm(mdu ~ ., data = train_tibble,
            family = poisson(link = "log"))

summary(r_glm)
## 
## Call:
## glm(formula = mdu ~ ., family = poisson(link = "log"), data = train_tibble)
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -5.7039  -1.7890  -0.8433   0.4816  18.4703  
## 
## Coefficients:
##               Estimate Std. Error z value Pr(>|z|)    
## (Intercept)  0.0005100  0.0585681   0.009   0.9931    
## lc          -0.0475077  0.0072280  -6.573 4.94e-11 ***
## idpTRUE     -0.1794563  0.0139749 -12.841  < 2e-16 ***
## lpi          0.0129742  0.0022141   5.860 4.63e-09 ***
## fmde        -0.0166968  0.0042265  -3.951 7.80e-05 ***
## physlimTRUE  0.3182780  0.0126868  25.087  < 2e-16 ***
## ndisease     0.0222300  0.0007215  30.811  < 2e-16 ***
## healthfair   0.2434235  0.0192873  12.621  < 2e-16 ***
## healthgood   0.0231824  0.0115398   2.009   0.0445 *  
## healthpoor   0.4608598  0.0329124  14.003  < 2e-16 ***
## linc         0.0826053  0.0062208  13.279  < 2e-16 ***
## lfam        -0.1194981  0.0106904 -11.178  < 2e-16 ***
## educdec      0.0205582  0.0019404  10.595  < 2e-16 ***
## age          0.0041397  0.0005152   8.035 9.39e-16 ***
## sexmale     -0.2096761  0.0104668 -20.032  < 2e-16 ***
## childTRUE    0.1529588  0.0179179   8.537  < 2e-16 ***
## blackTRUE   -0.6231230  0.0176758 -35.253  < 2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for poisson family taken to be 1)
## 
##     Null deviance: 64043  on 14096  degrees of freedom
## Residual deviance: 55529  on 14080  degrees of freedom
## AIC: 84052
## 
## Number of Fisher Scoring iterations: 6

I could also use the excellent {ggeffects} package to see the marginal effects of different variables, for instance "linc":

library(ggeffects)

ggeffect(r_glm, "linc") %>% 
    ggplot(aes(x, predicted)) +
    geom_ribbon(aes(ymin = conf.low, ymax = conf.high), fill = "#0f4150") +
    geom_line(colour = "#82518c") +
    brotools::theme_blog()

We can see that as “linc” (and other covariates are held constant), the target variable increases.

Let’s also take a look at the marginal effect of a categorical variable, namely "sex":

Click to view another example of marginal effects

library(ggeffects)

ggeffect(r_glm, "sex") %>% 
    ggplot(aes(x, predicted)) +
    geom_point(colour = "#82518c") +
    geom_errorbar(aes(x, ymin = conf.low, ymax = conf.high), colour = "#82518c") +
    brotools::theme_blog()

In the case of the "sex" variable, men have significantly less doctor contacts than women.

Now, let’s suppose that I want to train a model with a more complicated name, in order to justify my salary. Suppose I go with one of those nifty black-box models, for instance a GBM, which very likely will perform better than the GLM from before. GBMs are available in {h2o} via the h2o.gbm() function:

gbm_model <- h2o.gbm(y = "mdu", x = features_names,
            training_frame = original_train,
            validation_frame = validation,
            distribution = "poisson",
            score_each_iteration = TRUE,
            ntrees = 110,
            max_depth = 20,
            sample_rate = 0.6,
            col_sample_rate = 0.8,
            col_sample_rate_per_tree = 0.9,
            learn_rate = 0.05)

To find a set of good hyper-parameter values, I actually used h2o.automl() and then used the returned parameter values from the leader model. Maybe I’ll write another blog post about h2o.automl() one day, it’s quite cool. Anyways, now, how do I get me some explainability out of this? The model does perform better than the GLM as indicated by all the different metrics, but now I cannot compute any marginal effects, or anything like that. I do get feature importance by default with:

h2o.varimp(gbm_model)
## Variable Importances: 
##    variable relative_importance scaled_importance percentage
## 1       age       380350.093750          1.000000   0.214908
## 2      linc       282274.343750          0.742143   0.159492
## 3  ndisease       245862.718750          0.646412   0.138919
## 4       lpi       173552.734375          0.456297   0.098062
## 5   educdec       148186.265625          0.389605   0.083729
## 6      lfam       139174.312500          0.365911   0.078637
## 7      fmde        94193.585938          0.247650   0.053222
## 8    health        86160.679688          0.226530   0.048683
## 9       sex        63502.667969          0.166958   0.035881
## 10       lc        50674.968750          0.133232   0.028633
## 11  physlim        45328.382812          0.119175   0.025612
## 12    black        26376.841797          0.069349   0.014904
## 13      idp        24809.185547          0.065227   0.014018
## 14    child         9382.916992          0.024669   0.005302

but that’s it. And had I chosen a different “black-box” model, not based on trees, then I would not even have that. Thankfully, there’s the amazing {iml} package that contains a lot of functions for model-agnostic explanations. If you are not familiar with this package and the methods it implements, I highly encourage you to read the free online ebook written by the packages author, Christoph Molnar (who you can follow on Twitter).

Out of the box, {iml} works with several machine learning frameworks, such as {caret} or {mlr} but not with {h2o}. However, this is not an issue; you only need to create a predict function which returns a data frame (h2o.predict() used for prediction with h2o models returns an h2o frame). I have found this interesting blog post from business-science.io which explains how to do this. I highly recommend you read this blog post, as it goes much deeper into the capabilities of {iml}.

So let’s write a predict function that {iml} can use:

#source: https://www.business-science.io/business/2018/08/13/iml-model-interpretability.html
predict_for_iml <- function(model, newdata){
  as_tibble(h2o.predict(model, as.h2o(newdata)))
}

And let’s now create a Predictor object. These objects are used by {iml} to create explanations:

just_features <- as_tibble(holdout[, 2:15])
actual_target <- as_tibble(holdout[, 1])

predictor_original <- Predictor$new(
  model = gbm_model, 
  data = just_features, 
  y = actual_target, 
  predict.fun = predict_for_iml
  )

predictor_original can now be used to compute all kinds of explanations. I won’t go into much detail here, as this blog post is already quite long (and I haven’t even reached what I actually want to write about yet) and you can read more on the before-mentioned blog post or directly from Christoph Molnar’s ebook linked above.

First, let’s compute a partial dependence plot, which shows the marginal effect of a variable on the outcome. This is to compare it to the one from the GLM model:

feature_effect_original <- FeatureEffect$new(predictor_original, "linc", method = "pdp")

plot(feature_effect_original) +
    brotools::theme_blog()

feature_effect_original <- FeatureEffect$new(predictor_original, "linc", method = "pdp")

plot(feature_effect_original) +
    brotools::theme_blog()

Quite similar to the marginal effects from the GLM! Let’s now compute model-agnostic feature importances:

feature_importance_original <- FeatureImp$new(predictor_original, loss = "mse")

plot(feature_importance_original)

And finally, the interaction effect of the sex variable interacted with all the others:

interaction_sex_original <- Interaction$new(predictor_original, feature = "sex")

plot(interaction_sex_original)

Ok so let’s assume that I’m happy with these explanations, and do need or want to go further. This would be the end of it in an ideal world, but this is not an ideal world unfortunately, but it’s the best we’ve got. In the real world, it often happens that data comes with missing values.

Missing data and explainability

As explained in the beginning, I’ve been wondering how to deal with missing values when the goal of the analysis is explainability. How can the explanations be pooled? Let’s start with creating a data set with missing values, then perform multiple imputation, then perform the analysis.

First, let me create a patterns matrix, that I will pass to the ampute() function from the {mice} package. This function creates a dataset with missing values, and by using its patterns argument, I can decide which columns should have missing values:

patterns <- -1*(diag(1, nrow = 15, ncol = 15) - 1)

patterns[ ,c(seq(1, 6), c(9, 13))] <- 0

amputed_train <- ampute(as_tibble(original_train), prop = 0.1, patterns = patterns, mech = "MNAR")
## Warning: Data is made numeric because the calculation of weights requires
## numeric data

Let’s take a look at the missingness pattern:

naniar::vis_miss(amputed_train$amp) + 
    brotools::theme_blog() + 
      theme(axis.text.x=element_text(angle=90, hjust=1))

Ok, so now let’s suppose that this was the dataset I was given. As a serious data scientist, I decide to perform multiple imputation first:

imputed_train_data <- mice(data = amputed_train$amp, m = 10)

long_train_data <- complete(imputed_train_data, "long")

So because I performed multiple imputation 10 times, I now have 10 different datasets. I should now perform my analysis on these 10 datasets, which means I should run my GBM on each of them, and then get out the explanations for each of them. So let’s do just that. But first, let’s change the columns back to how they were; to perform amputation, the factor columns were converted to numbers:

long_train_data <- long_train_data %>% 
    mutate(idp = ifelse(idp == 1, FALSE, TRUE),
           physlim = ifelse(physlim == 1, FALSE, TRUE),
           health = as.factor(case_when(health == 1 ~ "excellent",
                              health == 2 ~ "fair",
                              health == 3 ~ "good", 
                              health == 4 ~  "poor")),
           sex = as.factor(ifelse(sex == 1, "female", "male")),
           child = ifelse(child == 1, FALSE, TRUE),
           black = ifelse(black == 1, FALSE, TRUE))

Ok, so now we’re ready to go. I will use the h2o.gbm() function on each imputed data set. For this, I’ll use the group_by()-nest() trick which consists in grouping the dataset by the .imp column, then nesting it, then mapping the h2o.gbm() function to each imputed dataset. If you are not familiar with this, you can read this other blog post, which explains the approach. I define a custom function, train_on_imputed_data() to run h2o.gbm() on each imputed data set:

train_on_imputed_data <- function(long_data){
    long_data %>% 
        group_by(.imp) %>% 
        nest() %>% 
        mutate(model = map(data, ~h2o.gbm(y = "mdu", x = features_names,
            training_frame = as.h2o(.),
            validation_frame = validation,
            distribution = "poisson",
            score_each_iteration = TRUE,
            ntrees = 110,
            max_depth = 20,
            sample_rate = 0.6,
            col_sample_rate = 0.8,
            col_sample_rate_per_tree = 0.9,
            learn_rate = 0.05)))
}

Now the training takes place:

imp_trained <- train_on_imputed_data(long_train_data)

Let’s take a look at imp_trained:

imp_trained
## # A tibble: 10 x 3
## # Groups:   .imp [10]
##     .imp            data model     
##    <int> <list<df[,16]>> <list>    
##  1     1   [14,042 × 16] <H2ORgrsM>
##  2     2   [14,042 × 16] <H2ORgrsM>
##  3     3   [14,042 × 16] <H2ORgrsM>
##  4     4   [14,042 × 16] <H2ORgrsM>
##  5     5   [14,042 × 16] <H2ORgrsM>
##  6     6   [14,042 × 16] <H2ORgrsM>
##  7     7   [14,042 × 16] <H2ORgrsM>
##  8     8   [14,042 × 16] <H2ORgrsM>
##  9     9   [14,042 × 16] <H2ORgrsM>
## 10    10   [14,042 × 16] <H2ORgrsM>

We see that the column model contains one model for each imputed dataset. Now comes the part I wanted to write about (finally): getting explanations out of this. Getting the explanations from each model is not the hard part, that’s easily done using some {tidyverse} magic (if you’re following along, run this bit of code below, and go make dinner, have dinner, and wash the dishes, because it takes time to run):

make_predictors <- function(model){
    Predictor$new(
        model = model, 
        data = just_features, 
        y = actual_target, 
        predict.fun = predict_for_iml
        )
}

make_effect <- function(predictor_object, feature = "linc", method = "pdp"){
    FeatureEffect$new(predictor_object, feature, method)
}

make_feat_imp <- function(predictor_object, loss = "mse"){
    FeatureImp$new(predictor_object, loss)
}

make_interactions <- function(predictor_object, feature = "sex"){
    Interaction$new(predictor_object, feature = feature)
}

imp_trained <- imp_trained %>%
    mutate(predictors = map(model, make_predictors)) %>% 
    mutate(effect_linc = map(predictors, make_effect)) %>% 
    mutate(feat_imp = map(predictors, make_feat_imp)) %>% 
    mutate(interactions_sex = map(predictors, make_interactions))

Ok so now that I’ve got these explanations, I am done with my analysis. This is the time to pool the results together. Remember, in the case of regression models as used in the social sciences, this means averaging the estimated model parameters and using Rubin’s rule to compute their standard errors. But in this case, this is not so obvious. Should the explanations be averaged? Should I instead analyse them one by one, and see if they differ? My gut feeling is that they shouldn’t differ much, but who knows? Perhaps the answer is doing a bit of both. I have checked online for a paper that would shed some light into this, but have not found any. So let’s take a closer look to the explanations. Let’s look at feature importance:

Click to view the 10 feature importances

imp_trained %>% 
    pull(feat_imp)
## [[1]]
## Interpretation method:  FeatureImp 
## error function: mse
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##    feature importance.05 importance importance.95 permutation.error
## 1 ndisease     1.0421605   1.362672      1.467244          22.03037
## 2     fmde     0.8611917   1.142809      1.258692          18.47583
## 3      lpi     0.8706659   1.103367      1.196081          17.83817
## 4   health     0.8941010   1.098014      1.480508          17.75164
## 5       lc     0.8745229   1.024288      1.296668          16.55970
## 6    black     0.7537278   1.006294      1.095054          16.26879
## 
## [[2]]
## Interpretation method:  FeatureImp 
## error function: mse
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##    feature importance.05 importance importance.95 permutation.error
## 1      age      0.984304   1.365702      1.473146          22.52529
## 2     linc      1.102023   1.179169      1.457907          19.44869
## 3 ndisease      1.075821   1.173938      1.642938          19.36241
## 4     fmde      1.059303   1.150112      1.281291          18.96944
## 5       lc      0.837573   1.132719      1.200556          18.68257
## 6  physlim      0.763757   1.117635      1.644434          18.43379
## 
## [[3]]
## Interpretation method:  FeatureImp 
## error function: mse
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##    feature importance.05 importance importance.95 permutation.error
## 1      age     0.8641304   1.334382      1.821797          21.62554
## 2    black     1.0553001   1.301338      1.429119          21.09001
## 3     fmde     0.8965085   1.208761      1.360217          19.58967
## 4 ndisease     1.0577766   1.203418      1.651611          19.50309
## 5     linc     0.9299725   1.114041      1.298379          18.05460
## 6      sex     0.9854144   1.091391      1.361406          17.68754
## 
## [[4]]
## Interpretation method:  FeatureImp 
## error function: mse
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##   feature importance.05 importance importance.95 permutation.error
## 1 educdec     0.9469049   1.263961      1.358115          20.52909
## 2     age     1.0980269   1.197441      1.763202          19.44868
## 3  health     0.8539843   1.133338      1.343389          18.40753
## 4    linc     0.7608811   1.123423      1.328756          18.24649
## 5     lpi     0.8203850   1.103394      1.251688          17.92118
## 6   black     0.9476909   1.089861      1.328960          17.70139
## 
## [[5]]
## Interpretation method:  FeatureImp 
## error function: mse
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##   feature importance.05 importance importance.95 permutation.error
## 1     lpi     0.9897789   1.336405      1.601778          22.03791
## 2 educdec     0.8701162   1.236741      1.424602          20.39440
## 3     age     0.8537786   1.181242      1.261411          19.47920
## 4    lfam     1.0185313   1.133158      1.400151          18.68627
## 5     idp     0.9502284   1.069772      1.203147          17.64101
## 6    linc     0.8600586   1.042453      1.395231          17.19052
## 
## [[6]]
## Interpretation method:  FeatureImp 
## error function: mse
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##   feature importance.05 importance importance.95 permutation.error
## 1      lc     0.7707383   1.208190      1.379422          19.65436
## 2     sex     0.9309901   1.202629      1.479511          19.56391
## 3    linc     1.0549563   1.138404      1.624217          18.51912
## 4     lpi     0.9360817   1.135198      1.302084          18.46696
## 5 physlim     0.7357272   1.132525      1.312584          18.42349
## 6   child     1.0199964   1.109120      1.316306          18.04274
## 
## [[7]]
## Interpretation method:  FeatureImp 
## error function: mse
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##    feature importance.05 importance importance.95 permutation.error
## 1     linc     0.9403425   1.262994      1.511122          20.65942
## 2       lc     1.0481333   1.233136      1.602796          20.17103
## 3 ndisease     1.1612194   1.212454      1.320208          19.83272
## 4  educdec     0.7924637   1.197343      1.388218          19.58554
## 5     lfam     0.8423790   1.178545      1.349884          19.27805
## 6      age     0.9125829   1.168297      1.409525          19.11043
## 
## [[8]]
## Interpretation method:  FeatureImp 
## error function: mse
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##    feature importance.05 importance importance.95 permutation.error
## 1      age     1.1281736   1.261273      1.609524          20.55410
## 2   health     0.9134557   1.240597      1.432366          20.21716
## 3     lfam     0.7469043   1.182294      1.345910          19.26704
## 4      lpi     0.8088552   1.160863      1.491139          18.91779
## 5 ndisease     1.0756671   1.104357      1.517278          17.99695
## 6     fmde     0.6929092   1.093465      1.333544          17.81946
## 
## [[9]]
## Interpretation method:  FeatureImp 
## error function: mse
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##    feature importance.05 importance importance.95 permutation.error
## 1  educdec     1.0188109   1.287697      1.381982          20.92713
## 2      lpi     0.9853336   1.213095      1.479002          19.71473
## 3     linc     0.8354715   1.195344      1.254350          19.42625
## 4      age     0.9980451   1.179371      1.383545          19.16666
## 5 ndisease     1.0492685   1.176804      1.397398          19.12495
## 6     lfam     1.0814043   1.166626      1.264592          18.95953
## 
## [[10]]
## Interpretation method:  FeatureImp 
## error function: mse
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##    feature importance.05 importance importance.95 permutation.error
## 1      age     0.9538824   1.211869      1.621151          19.53671
## 2      sex     0.9148921   1.211253      1.298311          19.52678
## 3     lfam     0.8227355   1.093094      1.393815          17.62192
## 4 ndisease     0.8282127   1.090779      1.205994          17.58459
## 5       lc     0.7004401   1.060870      1.541697          17.10244
## 6   health     0.8137149   1.058324      1.183639          17.06138

As you can see, the feature importances are quite different from each other, but I don’t think this comes from the imputations, but rather from the fact that feature importance depends on shuffling the feature, which adds randomness to the measurement (source: https://christophm.github.io/interpretable-ml-book/feature-importance.html#disadvantages-9). To mitigate this, Christoph Molnar suggests repeating the the permutation and averaging the importance measures; I think that this would be my approach for pooling as well.

Let’s now take a look at interactions:

Click to view the 10 interactions

imp_trained %>% 
    pull(interactions_sex)
## [[1]]
## Interpretation method:  Interaction 
## 
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##       .feature .interaction
## 1       lc:sex   0.07635197
## 2      idp:sex   0.08172754
## 3      lpi:sex   0.10704357
## 4     fmde:sex   0.11267146
## 5  physlim:sex   0.04099073
## 6 ndisease:sex   0.16314524
## 
## [[2]]
## Interpretation method:  Interaction 
## 
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##       .feature .interaction
## 1       lc:sex   0.10349820
## 2      idp:sex   0.07432519
## 3      lpi:sex   0.11651413
## 4     fmde:sex   0.18123926
## 5  physlim:sex   0.12952808
## 6 ndisease:sex   0.14528876
## 
## [[3]]
## Interpretation method:  Interaction 
## 
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##       .feature .interaction
## 1       lc:sex   0.05919320
## 2      idp:sex   0.05586197
## 3      lpi:sex   0.24253335
## 4     fmde:sex   0.05240474
## 5  physlim:sex   0.06404969
## 6 ndisease:sex   0.14508072
## 
## [[4]]
## Interpretation method:  Interaction 
## 
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##       .feature .interaction
## 1       lc:sex   0.02775529
## 2      idp:sex   0.02050390
## 3      lpi:sex   0.11781130
## 4     fmde:sex   0.11084240
## 5  physlim:sex   0.17932694
## 6 ndisease:sex   0.07181589
## 
## [[5]]
## Interpretation method:  Interaction 
## 
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##       .feature .interaction
## 1       lc:sex   0.12873151
## 2      idp:sex   0.03681428
## 3      lpi:sex   0.15879389
## 4     fmde:sex   0.16952900
## 5  physlim:sex   0.07031520
## 6 ndisease:sex   0.10567463
## 
## [[6]]
## Interpretation method:  Interaction 
## 
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##       .feature .interaction
## 1       lc:sex   0.15320481
## 2      idp:sex   0.08645037
## 3      lpi:sex   0.16674641
## 4     fmde:sex   0.14671054
## 5  physlim:sex   0.09236257
## 6 ndisease:sex   0.14605618
## 
## [[7]]
## Interpretation method:  Interaction 
## 
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##       .feature .interaction
## 1       lc:sex   0.04072960
## 2      idp:sex   0.05641868
## 3      lpi:sex   0.19491959
## 4     fmde:sex   0.07119644
## 5  physlim:sex   0.05777469
## 6 ndisease:sex   0.16555363
## 
## [[8]]
## Interpretation method:  Interaction 
## 
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##       .feature .interaction
## 1       lc:sex   0.04979709
## 2      idp:sex   0.06036898
## 3      lpi:sex   0.14009307
## 4     fmde:sex   0.10927688
## 5  physlim:sex   0.08761533
## 6 ndisease:sex   0.20544585
## 
## [[9]]
## Interpretation method:  Interaction 
## 
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##       .feature .interaction
## 1       lc:sex   0.08572075
## 2      idp:sex   0.12254979
## 3      lpi:sex   0.17532347
## 4     fmde:sex   0.12557420
## 5  physlim:sex   0.05084209
## 6 ndisease:sex   0.13977328
## 
## [[10]]
## Interpretation method:  Interaction 
## 
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##       .feature .interaction
## 1       lc:sex   0.08636490
## 2      idp:sex   0.04807331
## 3      lpi:sex   0.17922280
## 4     fmde:sex   0.05728403
## 5  physlim:sex   0.09392774
## 6 ndisease:sex   0.13408956

It would seem that interactions are a bit more stable. Let’s average the values; for this I need to access the results element of the interactions object and the result out:

interactions_sex_result <- imp_trained %>% 
    mutate(interactions_results = map(interactions_sex, function(x)(x$results))) %>% 
    pull()

interactions_sex_result is a list of dataframes, which means I can bind the rows together and compute whatever I need:

interactions_sex_result %>% 
    bind_rows() %>% 
    group_by(.feature) %>% 
    summarise_at(.vars = vars(.interaction), 
                 .funs = funs(mean, sd, low_ci = quantile(., 0.05), high_ci = quantile(., 0.95)))
## # A tibble: 13 x 5
##    .feature       mean     sd low_ci high_ci
##    <chr>         <dbl>  <dbl>  <dbl>   <dbl>
##  1 age:sex      0.294  0.0668 0.181    0.369
##  2 black:sex    0.117  0.0286 0.0763   0.148
##  3 child:sex    0.0817 0.0308 0.0408   0.125
##  4 educdec:sex  0.148  0.0411 0.104    0.220
##  5 fmde:sex     0.114  0.0443 0.0546   0.176
##  6 health:sex   0.130  0.0190 0.104    0.151
##  7 idp:sex      0.0643 0.0286 0.0278   0.106
##  8 lc:sex       0.0811 0.0394 0.0336   0.142
##  9 lfam:sex     0.149  0.0278 0.125    0.198
## 10 linc:sex     0.142  0.0277 0.104    0.179
## 11 lpi:sex      0.160  0.0416 0.111    0.221
## 12 ndisease:sex 0.142  0.0356 0.0871   0.187
## 13 physlim:sex  0.0867 0.0415 0.0454   0.157

That seems pretty good. Now, what about the partial dependence? Let’s take a closer look:

Click to view the 10 pdps

imp_trained %>% 
    pull(effect_linc)
## [[1]]
## Interpretation method:  FeatureEffect 
## features: linc[numerical]
## grid size: 20
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##        linc   .y.hat .type
## 1 0.0000000 1.652445   pdp
## 2 0.5312226 1.687522   pdp
## 3 1.0624453 1.687522   pdp
## 4 1.5936679 1.687522   pdp
## 5 2.1248905 1.685088   pdp
## 6 2.6561132 1.694112   pdp
## 
## [[2]]
## Interpretation method:  FeatureEffect 
## features: linc[numerical]
## grid size: 20
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##        linc   .y.hat .type
## 1 0.0000000 1.813449   pdp
## 2 0.5312226 1.816195   pdp
## 3 1.0624453 1.816195   pdp
## 4 1.5936679 1.816195   pdp
## 5 2.1248905 1.804457   pdp
## 6 2.6561132 1.797238   pdp
## 
## [[3]]
## Interpretation method:  FeatureEffect 
## features: linc[numerical]
## grid size: 20
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##        linc   .y.hat .type
## 1 0.0000000 1.906515   pdp
## 2 0.5312226 2.039318   pdp
## 3 1.0624453 2.039318   pdp
## 4 1.5936679 2.039318   pdp
## 5 2.1248905 2.002970   pdp
## 6 2.6561132 2.000922   pdp
## 
## [[4]]
## Interpretation method:  FeatureEffect 
## features: linc[numerical]
## grid size: 20
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##        linc   .y.hat .type
## 1 0.0000000 1.799552   pdp
## 2 0.5312226 2.012634   pdp
## 3 1.0624453 2.012634   pdp
## 4 1.5936679 2.012634   pdp
## 5 2.1248905 1.982425   pdp
## 6 2.6561132 1.966392   pdp
## 
## [[5]]
## Interpretation method:  FeatureEffect 
## features: linc[numerical]
## grid size: 20
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##        linc   .y.hat .type
## 1 0.0000000 1.929158   pdp
## 2 0.5312226 1.905171   pdp
## 3 1.0624453 1.905171   pdp
## 4 1.5936679 1.905171   pdp
## 5 2.1248905 1.879721   pdp
## 6 2.6561132 1.869113   pdp
## 
## [[6]]
## Interpretation method:  FeatureEffect 
## features: linc[numerical]
## grid size: 20
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##        linc   .y.hat .type
## 1 0.0000000 2.147697   pdp
## 2 0.5312226 2.162393   pdp
## 3 1.0624453 2.162393   pdp
## 4 1.5936679 2.162393   pdp
## 5 2.1248905 2.119923   pdp
## 6 2.6561132 2.115131   pdp
## 
## [[7]]
## Interpretation method:  FeatureEffect 
## features: linc[numerical]
## grid size: 20
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##        linc   .y.hat .type
## 1 0.0000000 1.776742   pdp
## 2 0.5312226 1.957938   pdp
## 3 1.0624453 1.957938   pdp
## 4 1.5936679 1.957938   pdp
## 5 2.1248905 1.933847   pdp
## 6 2.6561132 1.885287   pdp
## 
## [[8]]
## Interpretation method:  FeatureEffect 
## features: linc[numerical]
## grid size: 20
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##        linc   .y.hat .type
## 1 0.0000000 2.020647   pdp
## 2 0.5312226 2.017981   pdp
## 3 1.0624453 2.017981   pdp
## 4 1.5936679 2.017981   pdp
## 5 2.1248905 1.981122   pdp
## 6 2.6561132 2.017604   pdp
## 
## [[9]]
## Interpretation method:  FeatureEffect 
## features: linc[numerical]
## grid size: 20
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##        linc   .y.hat .type
## 1 0.0000000 1.811189   pdp
## 2 0.5312226 2.003053   pdp
## 3 1.0624453 2.003053   pdp
## 4 1.5936679 2.003053   pdp
## 5 2.1248905 1.938150   pdp
## 6 2.6561132 1.918518   pdp
## 
## [[10]]
## Interpretation method:  FeatureEffect 
## features: linc[numerical]
## grid size: 20
## 
## Analysed predictor: 
## Prediction task: unknown 
## 
## 
## Analysed data:
## Sampling from data.frame with 2013 rows and 14 columns.
## 
## Head of results:
##        linc   .y.hat .type
## 1 0.0000000 1.780325   pdp
## 2 0.5312226 1.850203   pdp
## 3 1.0624453 1.850203   pdp
## 4 1.5936679 1.850203   pdp
## 5 2.1248905 1.880805   pdp
## 6 2.6561132 1.881305   pdp

As you can see, the values are quite similar. I think that in the case of plots, the best way to visualize the impact of the imputation is to simply plot all the lines in a single plot:

effect_linc_results <- imp_trained %>% 
    mutate(effect_linc_results = map(effect_linc, function(x)(x$results))) %>% 
    select(.imp, effect_linc_results) %>% 
    unnest(effect_linc_results)

effect_linc_results %>% 
    bind_rows() %>% 
    ggplot() + 
    geom_line(aes(y = .y.hat, x = linc, group = .imp), colour = "#82518c") + 
    brotools::theme_blog()

Overall, the partial dependence plot seems to behave in a very similar way across the different imputed datasets!

To conclude, I think that the approach I suggest here is nothing revolutionary; it is consistent with the way one should conduct an analysis with multiple imputed datasets. However, the pooling step is non-trivial and there is no magic recipe; it really depends on the goal of the analysis and what you want or need to show.

Hope you enjoyed! If you found this blog post useful, you might want to follow me on twitter for blog post updates and buy me an espresso or paypal.me, or buy my ebook on Leanpub.

Buy me an EspressoBuy me an Espresso