class: center, middle, inverse, title-slide .title[ # PPOL 670-03: Introduction to Data Science ] .subtitle[ ## Week 10 Interpretable Machine Learning ] .author[ ### Alexander Podkul, PhD ] .date[ ### Spring 2023 ] --- ## Tonight's Outline .pull-left[ - The Next Few Weeks - A few quick problem set #3 notes(!) - Unsupervised Learning - Defining Interpretable Machine Learning - Global v. Local Methods - More on Variable Importance - Partial Dependence Plots - Permutation Feature Importance - Global Surrogate Model - Individual Conditional Expectation - 🚨`iml`🚨 - __Break__ - Code Session on Interpretability and More ] .pull-right[ <img src="sticker.png" width="80%" style="display: block; margin: auto;" /> ] --- ## The Next Few Weeks The weeks ahead: - Week 10: Interpretable Machine Learning + Will discuss week 13 options - Week 11: Text-as-Data - Week 12: Missing Data/Relational Databases + _Problem Set #5 Assigned_ - Week 13: TBD + __Problem Set #5 Due__ - Week 14: Presentations(!) --- ## The Next Few Weeks TBD Week (Week 13) is just a few weeks away... some options on the table: - going in any topic we've covered this semester (interactive visualizations, a deeper dive into SVM, etc.) - using R for creating a personal website - useful assortment of packages - using GitHub with R Studio for version control - Shiny Apps ([Example](https://gpilgrim.shinyapps.io/SwimmingProject-Click/?_ga=2.171880935.671469752.1666748383-726201935.1666748383)) - Time series modeling - Replicating Stata (and other quant tools) Other ideas? I'll send out a poll :) --- ## Some Notes on Problem Set #3 + Be sure to keep the training set and the hold-out set separate. Overlapping these datasets means that we might overfit, which can lead to the overly confident predictions and insights! -- <img src="holdout.png" width="517" height="90%" /> --- ## Some Notes on Problem Set #3 + When running a kitchen sink regression, we need to be careful not to include variables from earlier steps. -- Example #1: ```r mod1 <- train(alcohol ~ .,method='rpart',data=wine_train,na.action = na.pass) wine_train$pred_rpart <- predict(mod1, wine_train) ``` -- Example #2: ```r mod2 <- train(alcohol ~ .,method='lm',data=wine_train,na.action = na.pass) ``` ... `pred_rpart` is now a feature in mod2! --- ## Some Notes on Problem Set #3 + When assessing model performance, we typically will use _the best performing_ model. -- <img src="unnamed.png" width="1672" height="90%" /> --- class: inverse, center, middle ## Unsupervised Learning --- ## Unsupervised Learning A few weeks ago when we introduced this current unit, we spoke of the difference between __supervised__ and __unsupervised__ learning, which often depends on whether our data is labelled or unlabelled, respectively. In the latter case, we would have a dataset that exclusively has features without any outcomes or classes. __Unsupervised__ learning refers to a suite of methods where our algorithm can help find organization and order that exists within the data. -- Types of unsupervised learning might include: .pull-left[ 1. __clustering__ - for identifying groups or hierarchies that exist within the data 2. dimensionality reduction 3. density estimation ] .pull-right[ <img src="k-means-copy.jpg" width="1757" height="80%" /> ] --- ## Unsupervised Learning A few weeks ago when we introduced this current unit, we spoke of the difference between __supervised__ and __unsupervised__ learning, which often depends on whether our data is labelled or unlabelled, respectively. In the latter case, we would have a dataset that exclusively has features without any outcomes or classes. __Unsupervised__ learning refers to a suite of methods where our algorithm can help find organization and order that exists within the data. Types of unsupervised learning might include: .pull-left[ 1. clustering - for identifying groups or hierarchies that exist within the data 2. __dimensionality reduction__ - to simplify highly dimensional data sets without losing significant amounts of information 3. density estimation ] .pull-right[ <img src="data_set_sample_3d_pca_.jpg" width="1251" height="90%" /> ] --- ## Unsupervised Learning A few weeks ago when we introduced this current unit, we spoke of the difference between __supervised__ and __unsupervised__ learning, which often depends on whether our data is labelled or unlabelled, respectively. In the latter case, we would have a dataset that exclusively has features without any outcomes or classes. __Unsupervised__ learning refers to a suite of methods where our algorithm can help find organization and order that exists within the data. Types of unsupervised learning might include: .pull-left[ 1. clustering - for identifying groups or hierarchies that exist within the data 2. dimensionality reduction - to simplify highly dimensional data sets without losing significant amounts of information 3. __density estimation__ - to learn how features are distributed and subsequently related ] .pull-right[ <img src="plot_bayesian_blocks_11.png" width="1333" height="90%" /> ] --- ## Clustering Clustering can generally refer to a suite of tools used to group data points in a variety of ways. One example of clustering might refer to centroid models developed using the k-means algorithm (cf. KNN). -- .pull-left[ An outline of naive k-means: for each k (number of clusters): 1. select random centroids (as beginning points) 2. assign each data point to the nearest centroid (forming a "cluster") 3. calculate the variance of the assigned cluster 4. repeat steps 1-3 by moving centroids to optimize the variance until we can no longer improve the output ] .pull-right[ <img src="Week10_files/figure-html/unnamed-chunk-9-1.png" height="70%" /> ] --- ## Clustering Clustering can generally refer to a suite of tools used to group data points in a variety of ways. One example of clustering might refer to centroid models developed using the k-means algorithm (cf. KNN). .pull-left[ An outline of naive k-means: for each k (number of clusters): 1. select random centroids (as beginning points) 2. assign each data point to the nearest centroid (forming a "cluster") 3. calculate the variance of the assigned cluster 4. repeat steps 1-3 by moving centroids to optimize the variance until we can no longer improve the output ] .pull-right[ <img src="Week10_files/figure-html/unnamed-chunk-10-1.png" height="70%" /> ] --- ## Clustering Clustering can generally refer to a suite of tools used to group data points in a variety of ways. One example of clustering might refer to centroid models developed using the k-means algorithm (cf. KNN). .pull-left[ An outline of naive k-means: for each k (number of clusters): 1. select random centroids (as beginning points) 2. assign each data point to the nearest centroid (forming a "cluster") 3. calculate the variance of the assigned cluster 4. repeat steps 1-3 by moving centroids to optimize the variance until we can no longer improve the output ] .pull-right[ <img src="Week10_files/figure-html/unnamed-chunk-11-1.png" height="70%" /> ] --- ## Clustering Clustering can generally refer to a suite of tools used to group data points in a variety of ways. One example of clustering might refer to centroid models developed using the k-means algorithm (cf. KNN). .pull-left[ An outline of naive k-means: for each k (number of clusters): 1. select random centroids (as beginning points) 2. assign each data point to the nearest centroid (forming a "cluster") 3. calculate the variance of the assigned cluster 4. repeat steps 1-3 by moving centroids to optimize the variance until we can no longer improve the output ] .pull-right[ <img src="Week10_files/figure-html/unnamed-chunk-12-1.png" height="70%" /> ] --- ## Clustering Clustering can generally refer to a suite of tools used to group data points in a variety of ways. One example of clustering might refer to centroid models developed using the k-means algorithm (cf. KNN). .pull-left[ An outline of naive k-means: for each k (number of clusters): 1. select random centroids (as beginning points) 2. assign each data point to the nearest centroid (forming a "cluster") 3. calculate the variance of the assigned cluster 4. repeat steps 1-3 by moving centroids to optimize the variance until we can no longer improve the output ] .pull-right[ <img src="Week10_files/figure-html/unnamed-chunk-13-1.png" height="70%" /> ] --- ## Clustering Clustering can generally refer to a suite of tools used to group data points in a variety of ways. One example of clustering might refer to centroid models developed using the k-means algorithm (cf. KNN). .pull-left[ An outline of naive k-means: for each k (number of clusters): 1. select random centroids (as beginning points) 2. assign each data point to the nearest centroid (forming a "cluster") 3. calculate the variance of the assigned cluster 4. repeat steps 1-3 by moving centroids to optimize the variance until we can no longer improve the output ] .pull-right[ <img src="Week10_files/figure-html/unnamed-chunk-14-1.png" height="70%" /> ] --- ## Clustering Other considerations: __how to measure distance between points (cf. KNN)__ it depends on the type of data (in the previous example we use euclidean distance) -- __how to estimate error__ Often we consider the variability of the observations within a cluster. To measure this, we offten use Within Cluster Sum of Squares (or WCSS) which simply calculates the average distance of each point and the centroid. Smaller distances help us identify smaller variability and more compactness. -- __how do we select K__ Similar to KNN, we can estimate multiple K's and compare performance across them. To compare performance we can simply compare the explained variation and then consider trade-offs between overfitting and underfitting. --- class: inverse, center, middle ## Interpretable ML --- ## Interpretable ML .pull-left[ At worst, the ML models we use in production are entirely "black box," i.e. we might not entirely be able to break apart due to its complexity. Nevertheless, we might be interested in interpretability for a number of reasons: - making substantive inferences about how `\(X\)` relates to `\(y\)` - detecting predictors that might influence data collection practices - debugging a model that is producing curious predictions - identifying bias Generally, interpretation methods can broadly be divided into being __model-dependent__ or __model-agnostic__. Additionally, they can be __global__ or __local__. ] .pull-right[ <div class="figure"> <img src="big-picture.png" alt="from Molnar (2021)" width="1789" /> <p class="caption">from Molnar (2021)</p> </div> ] --- ### Algorithmic Bias .pull-left[ Methods of interpretation are quite useful for mitigating _algorithmic bias_, which refers to systematic biases within a model that are then perpetuated by the model. (e.g. [Amazon gender biases in resume classification](https://www.liberties.eu/en/stories/algorithmic-bias-17052021/43528)). Methods of interpretability can help make sense of the decisions embedded within a model by allowing researchers to better understand the influence (or lack of influence) that a predictor may have on a particular outcome. These methods can help open up the black box and better understand some of the complicated relationships that might otherwise be obscured. For more on the effects of algorithmic biases, check out _Weapons of Math Destruction_ (2016) by O'Neil. ] .pull-right[ <div class="figure" style="text-align: center"> <img src="x.png" alt="Source: CMU" width="661" /> <p class="caption">Source: CMU</p> </div> ] --- class: inverse, center, middle ## Global v. Local Methods --- ## Global v. Local Methods __Global__ methods reveal how features influence a prediction _on average_ whereas __local__ methods look to explain _individual predictions_. -- Global methods: - Seek to understand how a model behaves _generally_ - Analyze how features can be related to making predictions and how they might be meaningless - Examples: Partial dependence plots, permutation feature importance -- Local methods: - Seek to understand why a specific, individual prediction is made - Analyzes the effect of a particular _value_ - Examples: Individual conditional expectation, local surrogate models --- class: inverse, center, middle ## Variable Importance --- ## Variable Importance __Variable importance__ often refers to the concept of how much a model or algorithm leverages a particular feature or set of features for making predictions. By determining how _important_ a feature is to the model we can make assessments about model performance (and collect the right data) and begin to make inferences about the relationship(s) between `\(X\)` and `\(y\)`. These metrics related to variable importance depend on the model. For example, we can quickly look at: - linear models - decision trees and random forests --- ### Linear Model Using the Cleveland dataset (from last week): <table style="NAborder-bottom: 0; width: auto !important; margin-left: auto; margin-right: auto;" class="table"> <caption>DV: Cholesterol</caption> <thead> <tr> <th style="text-align:left;"> </th> <th style="text-align:center;"> Model 1 </th> </tr> </thead> <tbody> <tr> <td style="text-align:left;"> (Intercept) </td> <td style="text-align:center;"> 160.308*** </td> </tr> <tr> <td style="text-align:left;"> </td> <td style="text-align:center;"> (30.636) </td> </tr> <tr> <td style="text-align:left;"> age </td> <td style="text-align:center;"> 1.821*** </td> </tr> <tr> <td style="text-align:left;"> </td> <td style="text-align:center;"> (0.542) </td> </tr> <tr> <td style="text-align:left;"> sex </td> <td style="text-align:center;"> 40.857 </td> </tr> <tr> <td style="text-align:left;"> </td> <td style="text-align:center;"> (37.408) </td> </tr> <tr> <td style="text-align:left;"> age × sex </td> <td style="text-align:center;"> −1.107+ </td> </tr> <tr> <td style="text-align:left;box-shadow: 0px 1px"> </td> <td style="text-align:center;box-shadow: 0px 1px"> (0.670) </td> </tr> <tr> <td style="text-align:left;"> Num.Obs. </td> <td style="text-align:center;"> 303 </td> </tr> <tr> <td style="text-align:left;"> R2 </td> <td style="text-align:center;"> 0.085 </td> </tr> <tr> <td style="text-align:left;"> R2 Adj. </td> <td style="text-align:center;"> 0.075 </td> </tr> </tbody> <tfoot><tr><td style="padding: 0; " colspan="100%"> <sup></sup> + p < 0.1, * p < 0.05, ** p < 0.01, *** p < 0.001</td></tr></tfoot> </table> -- How can we think about _variable importance_ from our regression output? --- ### Decision Trees and Random Forests Using the Cleveland dataset (from last week): .pull-left[  ] .pull-right[ <table class=" lightable-paper lightable-striped" style='font-family: "Arial Narrow", arial, helvetica, sans-serif; margin-left: auto; margin-right: auto;'> <thead> <tr> <th style="text-align:left;"> Var </th> <th style="text-align:right;"> Overall </th> </tr> </thead> <tbody> <tr> <td style="text-align:left;"> chest_pain_type </td> <td style="text-align:right;"> 58 </td> </tr> <tr> <td style="text-align:left;"> age:sex </td> <td style="text-align:right;"> 23 </td> </tr> <tr> <td style="text-align:left;"> age </td> <td style="text-align:right;"> 13 </td> </tr> <tr> <td style="text-align:left;"> chol </td> <td style="text-align:right;"> 5 </td> </tr> <tr> <td style="text-align:left;"> sex </td> <td style="text-align:right;"> 1 </td> </tr> </tbody> </table> ] Reports the reduction in mean squared error attributed to each variable at each split. --- ### Decision Trees and Random Forests Using the Cleveland dataset (from last week): .pull-left[ <table class=" lightable-paper lightable-striped" style='font-family: "Arial Narrow", arial, helvetica, sans-serif; margin-left: auto; margin-right: auto;'> <thead> <tr> <th style="text-align:right;"> mtry </th> <th style="text-align:right;"> RMSE </th> <th style="text-align:right;"> Rsquared </th> </tr> </thead> <tbody> <tr> <td style="text-align:right;"> 2 </td> <td style="text-align:right;"> 0.4184 </td> <td style="text-align:right;"> 0.300 </td> </tr> <tr> <td style="text-align:right;"> 3 </td> <td style="text-align:right;"> 0.4280 </td> <td style="text-align:right;"> 0.292 </td> </tr> <tr> <td style="text-align:right;"> 5 </td> <td style="text-align:right;"> 0.4399 </td> <td style="text-align:right;"> 0.269 </td> </tr> </tbody> </table> ] .pull-right[ <table class=" lightable-paper lightable-striped" style='font-family: "Arial Narrow", arial, helvetica, sans-serif; margin-left: auto; margin-right: auto;'> <thead> <tr> <th style="text-align:left;"> Var </th> <th style="text-align:right;"> Overall </th> </tr> </thead> <tbody> <tr> <td style="text-align:left;"> chest_pain_type </td> <td style="text-align:right;"> 100.00 </td> </tr> <tr> <td style="text-align:left;"> chol </td> <td style="text-align:right;"> 70.22 </td> </tr> <tr> <td style="text-align:left;"> age:sex </td> <td style="text-align:right;"> 59.15 </td> </tr> <tr> <td style="text-align:left;"> age </td> <td style="text-align:right;"> 52.26 </td> </tr> <tr> <td style="text-align:left;"> sex </td> <td style="text-align:right;"> 0.00 </td> </tr> </tbody> </table> ] Measures the prediction error on the out-of-bag portion of the data, measures the prediction error for the OOB portion of the data after permuting each predictor, and the difference is averaged over all trees and normalized. --- class: inverse, center, middle ## Partial Dependence Plots --- ## Partial Dependence Plots __Partial dependence plots__ demonstrate the _marginal effect_ that features have on the predicted outcome of the model. Importantly, PDP's can help show the _type_ of relationship that the model expects between predictors and outcomes. PDP is _global_ by looking at all instances and produces "a statement about the global relationship of a feature with the predicted outcome." -- General method: 1. Fit a ML model 2. Examine features to explore 3. While holding all other predictors constant, vary values of features and compute the average prediction 4. Plot the average values (often in 1 or 2 dimensions) --- ## Partial Dependence Plots For example, imagine a classification model (heart disease) using 13 predictors. <img src="Week10_files/figure-html/unnamed-chunk-24-1.png" style="display: block; margin: auto;" /> --- ## Partial Dependence Plots For example, imagine a classification model (heart disease) using 13 predictors. <img src="Week10_files/figure-html/unnamed-chunk-25-1.png" style="display: block; margin: auto;" /> --- ## Partial Dependence Plots __Benefit of this method of interpretation:__ - easy to understand - relatively simple to implement __Drawbacks:__ - it does not consider feature distribution - assumes independence of predictors --- class: inverse, center, middle ## Permutation Feature Importance --- ## Permutation Feature Importance __Permutation feature importance__ is a _global_ method that examines a model's prediction error after a particular predictor's values are _permuted_ (thereby disrupting the relationship between predictor and outcome). -- There are three primary steps; 1. Estimate an ML model and identify some performance metric (e.g. RMSE) 2. For each predictor `\(j \in \{1, ..., p\}\)`: - Generate a feature matrix `\(X_{perm}\)` by permuting feature `\(j\)` - Estimate the error of the predictions created in the permuted data - Calculate feature importance by looking at `\(e_{perm}/e_{orig}\)` or `\(e_{perm} - e_{orig}\)` 3. Arrange predictors by descending feature importance --- ## Permutation Feature Importance: Logic We can quickly review the __logic of permutation__. For example, using the `life_expectancy` dataset we can estimate a model. ```r life_expect <- read.csv('https://github.com/apodkul/ppol670_01/raw/main/Data/life_expect.csv') model_1 <- lm(life_expectancy~log(GDP_per_capita) + factor(Continent), data = life_expect) model_1 ``` -- <table class=" lightable-paper lightable-striped" style='font-family: "Arial Narrow", arial, helvetica, sans-serif; margin-left: auto; margin-right: auto;'> <thead> <tr> <th style="text-align:left;"> term </th> <th style="text-align:right;"> estimate </th> <th style="text-align:right;"> std.error </th> <th style="text-align:right;"> statistic </th> <th style="text-align:right;"> p.value </th> </tr> </thead> <tbody> <tr> <td style="text-align:left;"> (Intercept) </td> <td style="text-align:right;"> 34.20 </td> <td style="text-align:right;"> 2.85 </td> <td style="text-align:right;"> 12.00 </td> <td style="text-align:right;"> 0 </td> </tr> <tr> <td style="text-align:left;"> log(GDP_per_capita) </td> <td style="text-align:right;"> 3.52 </td> <td style="text-align:right;"> 0.35 </td> <td style="text-align:right;"> 10.19 </td> <td style="text-align:right;"> 0 </td> </tr> <tr> <td style="text-align:left;"> factor(Continent)Asia </td> <td style="text-align:right;"> 6.24 </td> <td style="text-align:right;"> 0.92 </td> <td style="text-align:right;"> 6.76 </td> <td style="text-align:right;"> 0 </td> </tr> <tr> <td style="text-align:left;"> factor(Continent)Europe </td> <td style="text-align:right;"> 8.59 </td> <td style="text-align:right;"> 1.10 </td> <td style="text-align:right;"> 7.78 </td> <td style="text-align:right;"> 0 </td> </tr> <tr> <td style="text-align:left;"> factor(Continent)North America </td> <td style="text-align:right;"> 8.18 </td> <td style="text-align:right;"> 1.16 </td> <td style="text-align:right;"> 7.05 </td> <td style="text-align:right;"> 0 </td> </tr> <tr> <td style="text-align:left;"> factor(Continent)Oceania </td> <td style="text-align:right;"> 10.88 </td> <td style="text-align:right;"> 2.95 </td> <td style="text-align:right;"> 3.68 </td> <td style="text-align:right;"> 0 </td> </tr> <tr> <td style="text-align:left;"> factor(Continent)South America </td> <td style="text-align:right;"> 7.70 </td> <td style="text-align:right;"> 1.44 </td> <td style="text-align:right;"> 5.35 </td> <td style="text-align:right;"> 0 </td> </tr> </tbody> </table> --- ## Permutation Feature Importance: Logic But we can then _permute_ a predictor (e.g. we can permute `GDP_per_capita`) ```r model_2 <- life_expect %>% mutate(GDP_per_capita = sample(GDP_per_capita)) %>% lm(life_expectancy~log(GDP_per_capita) + factor(Continent), data = .) model_2 ``` -- <table class=" lightable-paper lightable-striped" style='font-family: "Arial Narrow", arial, helvetica, sans-serif; margin-left: auto; margin-right: auto;'> <thead> <tr> <th style="text-align:left;"> term </th> <th style="text-align:right;"> estimate </th> <th style="text-align:right;"> std.error </th> <th style="text-align:right;"> statistic </th> <th style="text-align:right;"> p.value </th> </tr> </thead> <tbody> <tr> <td style="text-align:left;"> (Intercept) </td> <td style="text-align:right;"> 58.69 </td> <td style="text-align:right;"> 3.18 </td> <td style="text-align:right;"> 18.45 </td> <td style="text-align:right;"> 0.0 </td> </tr> <tr> <td style="text-align:left;"> log(GDP_per_capita) </td> <td style="text-align:right;"> 0.43 </td> <td style="text-align:right;"> 0.34 </td> <td style="text-align:right;"> 1.29 </td> <td style="text-align:right;"> 0.2 </td> </tr> <tr> <td style="text-align:left;"> factor(Continent)Asia </td> <td style="text-align:right;"> 11.04 </td> <td style="text-align:right;"> 1.03 </td> <td style="text-align:right;"> 10.69 </td> <td style="text-align:right;"> 0.0 </td> </tr> <tr> <td style="text-align:left;"> factor(Continent)Europe </td> <td style="text-align:right;"> 15.83 </td> <td style="text-align:right;"> 1.07 </td> <td style="text-align:right;"> 14.81 </td> <td style="text-align:right;"> 0.0 </td> </tr> <tr> <td style="text-align:left;"> factor(Continent)North America </td> <td style="text-align:right;"> 12.75 </td> <td style="text-align:right;"> 1.38 </td> <td style="text-align:right;"> 9.25 </td> <td style="text-align:right;"> 0.0 </td> </tr> <tr> <td style="text-align:left;"> factor(Continent)Oceania </td> <td style="text-align:right;"> 19.44 </td> <td style="text-align:right;"> 3.62 </td> <td style="text-align:right;"> 5.37 </td> <td style="text-align:right;"> 0.0 </td> </tr> <tr> <td style="text-align:left;"> factor(Continent)South America </td> <td style="text-align:right;"> 12.77 </td> <td style="text-align:right;"> 1.74 </td> <td style="text-align:right;"> 7.33 </td> <td style="text-align:right;"> 0.0 </td> </tr> </tbody> </table> --- ## Permutation Feature Importance: Example 1) Estimate an ML model and identify some performance metric (e.g. RMSE) Using the Cleveland data, let's estimate a random forest regression with 13 predictors. Since it is a regression, let's use RMSE as our performance metric. ``` ## Random Forest ## ## 303 samples ## 13 predictor ## ## No pre-processing ## Resampling: Bootstrapped (25 reps) ## Summary of sample sizes: 303, 303, 303, 303, 303, 303, ... ## Resampling results across tuning parameters: ## ## mtry RMSE Rsquared MAE ## 2 0.3527734 0.5288721 0.2868672 ## 4 0.3553294 0.5076610 0.2728981 ## 6 0.3592906 0.4937615 0.2694403 ## 8 0.3635948 0.4810182 0.2687594 ## 10 0.3658224 0.4743724 0.2674346 ## ## RMSE was used to select the optimal model using the smallest value. ## The final value used for the model was mtry = 2. ``` --- ## Permutation Feature Importance: Example 2) For each predictor `\(j \in \{1, ..., p\}\)`: a) Generate a feature matrix `\(X_{perm}\)` by permuting feature `\(j\)` b) Estimate the error of the predictions created in the permuted data c) Calculate feature importance by looking at `\(e_{perm}/e_{orig}\)` or `\(e_{perm} - e_{orig}\)` -- We can take take our data (Molnar recommends using the test set here), develop predictions from the model estimated in step 1 (which allow us to calculate `\(e_{orig}\)`). Then, for _each_ predictor we will take that same data, permute the predictor, find new predictions, calculate `\(e_{perm}\)` and compare. <table class=" lightable-paper" style='font-family: "Arial Narrow", arial, helvetica, sans-serif; margin-left: auto; margin-right: auto;'> <thead> <tr> <th style="text-align:right;"> age </th> <th style="text-align:right;"> sex </th> <th style="text-align:right;"> chest_pain_type </th> <th style="text-align:right;"> chol </th> <th style="text-align:right;"> heart </th> <th style="text-align:right;"> orig_prediction </th> </tr> </thead> <tbody> <tr> <td style="text-align:right;"> 63 </td> <td style="text-align:right;"> 1 </td> <td style="text-align:right;"> 1 </td> <td style="text-align:right;"> 233 </td> <td style="text-align:right;"> 0 </td> <td style="text-align:right;"> 0.2461070 </td> </tr> <tr> <td style="text-align:right;"> 67 </td> <td style="text-align:right;"> 1 </td> <td style="text-align:right;"> 4 </td> <td style="text-align:right;"> 286 </td> <td style="text-align:right;"> 1 </td> <td style="text-align:right;"> 0.9272854 </td> </tr> <tr> <td style="text-align:right;"> 67 </td> <td style="text-align:right;"> 1 </td> <td style="text-align:right;"> 4 </td> <td style="text-align:right;"> 229 </td> <td style="text-align:right;"> 1 </td> <td style="text-align:right;"> 0.9660902 </td> </tr> <tr> <td style="text-align:right;"> 37 </td> <td style="text-align:right;"> 1 </td> <td style="text-align:right;"> 3 </td> <td style="text-align:right;"> 250 </td> <td style="text-align:right;"> 0 </td> <td style="text-align:right;"> 0.1930450 </td> </tr> <tr> <td style="text-align:right;"> 41 </td> <td style="text-align:right;"> 0 </td> <td style="text-align:right;"> 2 </td> <td style="text-align:right;"> 204 </td> <td style="text-align:right;"> 0 </td> <td style="text-align:right;"> 0.0258744 </td> </tr> <tr> <td style="text-align:right;"> 56 </td> <td style="text-align:right;"> 1 </td> <td style="text-align:right;"> 2 </td> <td style="text-align:right;"> 236 </td> <td style="text-align:right;"> 0 </td> <td style="text-align:right;"> 0.0596381 </td> </tr> </tbody> </table> --- ## Permutation Feature Importance: Example 2) For each predictor `\(j \in \{1, ..., p\}\)`: a) Generate a feature matrix `\(X_{perm}\)` by permuting feature `\(j\)` b) Estimate the error of the predictions created in the permuted data c) Calculate feature importance by looking at `\(e_{perm}/e_{orig}\)` or `\(e_{perm} - e_{orig}\)` We can take take our data (Molnar recommends using the test set here), develop predictions from the model estimated in step 1 (which allow us to calculate `\(e_{orig}\)`). Then, for _each_ predictor we will take that same data, permute the predictor, find new predictions, calculate `\(e_{perm}\)` and compare. <table class=" lightable-paper" style='font-family: "Arial Narrow", arial, helvetica, sans-serif; margin-left: auto; margin-right: auto;'> <thead> <tr> <th style="text-align:right;"> age </th> <th style="text-align:right;"> sex </th> <th style="text-align:right;"> chest_pain_type </th> <th style="text-align:right;"> chol </th> <th style="text-align:right;"> heart </th> <th style="text-align:right;"> new_prediction </th> </tr> </thead> <tbody> <tr> <td style="text-align:right;"> 70 </td> <td style="text-align:right;"> 1 </td> <td style="text-align:right;"> 1 </td> <td style="text-align:right;"> 233 </td> <td style="text-align:right;"> 0 </td> <td style="text-align:right;"> 0.2850819 </td> </tr> <tr> <td style="text-align:right;"> 43 </td> <td style="text-align:right;"> 1 </td> <td style="text-align:right;"> 4 </td> <td style="text-align:right;"> 286 </td> <td style="text-align:right;"> 1 </td> <td style="text-align:right;"> 0.8164277 </td> </tr> <tr> <td style="text-align:right;"> 63 </td> <td style="text-align:right;"> 1 </td> <td style="text-align:right;"> 4 </td> <td style="text-align:right;"> 229 </td> <td style="text-align:right;"> 1 </td> <td style="text-align:right;"> 0.9789997 </td> </tr> <tr> <td style="text-align:right;"> 56 </td> <td style="text-align:right;"> 1 </td> <td style="text-align:right;"> 3 </td> <td style="text-align:right;"> 250 </td> <td style="text-align:right;"> 0 </td> <td style="text-align:right;"> 0.2487433 </td> </tr> <tr> <td style="text-align:right;"> 43 </td> <td style="text-align:right;"> 0 </td> <td style="text-align:right;"> 2 </td> <td style="text-align:right;"> 204 </td> <td style="text-align:right;"> 0 </td> <td style="text-align:right;"> 0.0258744 </td> </tr> <tr> <td style="text-align:right;"> 60 </td> <td style="text-align:right;"> 1 </td> <td style="text-align:right;"> 2 </td> <td style="text-align:right;"> 236 </td> <td style="text-align:right;"> 0 </td> <td style="text-align:right;"> 0.1426383 </td> </tr> </tbody> </table> --- ## Permutation Feature Importance 3) Arrange predictors by descending feature importance <img src="Week10_files/figure-html/unnamed-chunk-33-1.png" style="display: block; margin: auto;" /> --- ## Permutation Feature Importance __Benefit of this method of interpretation:__ - Relatively simple interpretation of global trends - Considers interactions with other features __Drawbacks:__ - Permutation feature importance is linked to the error of the model - Need observable outcome data - Same unlikely data issue as PDP --- class: inverse, center, middle ## Global Surrogate Model --- ## Global Surrogate Model The use of __Global Surrogate Models__ is a model-agnostic framework for where black-box models can be understood by training a _surrogate_ interpretable model. Essentially, the interpretable model simply needs a dataset and model outputs from the black-boxed model and model metrics from the surrogate can help identify how well the surrogate is approximating the primary model. 1. Create or select a dataset `\(X\)` (this can include the training set, a subset, a grid, etc.) 2. Render predictions for `\(X\)` using the black-box model 3. Select interpretable model type (such as OLS or CART) 4. Train the model using `\(X\)` and the black-box model predictions (= surrogate model) 5. Quantify how closely the surrogate model approximates the black-box model --- ## Global Surrogate Model ```r predictions <- predict(mystery_mod, newdata = cleveland, type = 'prob')['1'] new_data <- cleveland %>% dplyr::select(-heart) %>% dplyr::mutate(preds = predictions$`1`) surrogate <- caret::train(preds~., data = new_data, method = 'rpart', maxdepth = 5, tuneGrid = expand.grid(cp = seq(0.05, 0.3, by = .05))) rpart.plot::rpart.plot(surrogate$finalModel) ``` --- ## Global Surrogate Model ``` ## [21:37:14] WARNING: amalgamation/../src/learner.cc:1040: ## If you are loading a serialized model (like pickle in Python, RDS in R) generated by ## older XGBoost, please export the model by calling `Booster.save_model` from that version ## first, then load it back in current version. See: ## ## https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html ## ## for more details about differences between saving model and serializing. ## ## [21:37:14] WARNING: amalgamation/../src/learner.cc:749: Found JSON model saved before XGBoost 1.6, please save the model using current version again. The support for old JSON model will be discontinued in XGBoost 2.3. ``` <img src="Week10_files/figure-html/unnamed-chunk-36-1.png" style="display: block; margin: auto;" /> --- ## Global Surrogate Model ```r y_hat <- predict(surrogate, cleveland) caret::R2(pred = y_hat, obs = predictions) ``` ``` ## [,1] ## 1 0.7705317 ``` --- ## Global Surrogate Model __Benefit of this method of interpretation:__ - Flexibility - Relatively straightforward __Drawbacks:__ - Not always simple to identify whether the surrogate model is a "good" approximation - Findings might not be generalizable across datasets --- class: inverse, center, middle ## Individual Conditional Expectation --- ## Individual Conditional Expectation Building on the intuition of partial dependency plots, __individual conditional expectation (ICE)__ plots can explore heterogeneity that exists across predictors by plotting the marginal effect for _each_ instance. Whereas a PDP explores the _average_ relationship, ICE plots explore the full scope of individual observations. -- <img src="Week10_files/figure-html/unnamed-chunk-38-1.png" style="display: block; margin: auto;" /> --- ## Individual Conditional Expectation Building on the intuition of partial dependency plots, __individual conditional expectation (ICE)__ plots can explore heterogeneity that exists across predictors by plotting the marginal effect for _each_ instance. Whereas a PDP explores the _average_ relationship, ICE plots explore the full scope of individual observations. <img src="Week10_files/figure-html/unnamed-chunk-39-1.png" style="display: block; margin: auto;" /> --- ## Individual Conditional Expectation Building on the intuition of partial dependency plots, __individual conditional expectation (ICE)__ plots can explore heterogeneity that exists across predictors by plotting the marginal effect for _each_ instance. Whereas a PDP explores the _average_ relationship, ICE plots explore the full scope of individual observations. <img src="Week10_files/figure-html/unnamed-chunk-40-1.png" style="display: block; margin: auto;" /> --- ## Individual Conditional Expectation __Benefit of this method of interpretation:__ - _very_ easy to understand - relatively simple to implement - help explore heterogeneity that exists within our data! __Drawbacks:__ - it does not consider feature distribution (i.e. invalid data points) - can only display one predictor - visualization-related drawbacks --- ## 🚨`iml`🚨 The `iml` package is a companion package to the Molnar readings for this week and includes relatively simple methods for exploring and visualizing both global and local model agnostic (and model dependent) intepretability metrics. .pull-left[ The package plays well with `ggplot2` and `caret` so it works quite well with our workflow.] .pull-right[ - Can work with `caret` or other ML package native data types (such as the `randomForest` package objects) - Can work with different predictors - Produces useful visualizations that can be easily adjusted (since they are built using `ggplot2`) ] We'll cover some of the nuances of `iml` in more detail during our coding session. --- ## Next Week's Readings __April 5:__ Text as Data - [Review] W&G. [Chapter 14: Strings](https://r4ds.had.co.nz/strings.html) - Baumer et al. (2021). [Chapter 19: Text as data](https://mdsr-book.github.io/mdsr2e/ch-text.html) - Silge and Robinson (2021). [Chapter 6: Topic Modeling](https://www.tidytextmining.com/topicmodeling.html) - [Skim] Garson. Chapter 9: Text analytics - [OPTIONAL] [Article: Grimmer and Stewart (2013)](https://web.stanford.edu/~jgrimmer/tad2.pdf)