모형과 데이터 불러오기

데이터와 기계학습 예측모형을 불러오자

library(tidyverse)

titanic_list  <-  
  read_rds("data/titanic_list.rds")

str(titanic_list, max.level = 2) 
List of 2
 $ data :List of 3
  ..$ training: tibble[,9] [2,099 × 9] (S3: tbl_df/tbl/data.frame)
  .. ..- attr(*, "na.action")= 'omit' Named int [1:108] 46 90 118 119 122 132 139 145 151 152 ...
  .. .. ..- attr(*, "names")= chr [1:108] "46" "90" "118" "119" ...
  ..$ henry   : tibble[,7] [1 × 7] (S3: tbl_df/tbl/data.frame)
  ..$ johnny_d: tibble[,7] [1 × 7] (S3: tbl_df/tbl/data.frame)
 $ model:List of 4
  ..$ titanic_lmr:List of 24
  .. ..- attr(*, "class")= chr [1:3] "lrm" "rms" "glm"
  ..$ titanic_rf :List of 19
  .. ..- attr(*, "class")= chr [1:2] "randomForest.formula" "randomForest"
  ..$ titanic_gbm:List of 27
  .. ..- attr(*, "class")= chr "gbm"
  ..$ titanic_svm:List of 30
  .. ..- attr(*, "class")= chr [1:2] "svm.formula" "svm"

1 관측점 설명

관측점(instance) 별로 기계가 학습한 모형을 설명을 하는 방식은 다음과 같다.

  • 분해(Break-down) 그래프: 예측에 대한 주요 변수별 기여분을 시각화.
library(tidyverse)
library(DALEX)
library(DALEXtra)
library(randomForest)

explainer_rf  <- DALEX::explain(titanic_list$model$titanic_rf, 
                                data = titanic_list$data$training %>% select(-survived),
                                 y = titanic_list$data$training %>% select(survived))
Preparation of a new explainer is initiated
  -> model label       :  randomForest  (  default  )
  -> data              :  2099  rows  8  cols 
  -> data              :  tibble converted into a data.frame 
  -> target variable   :  Argument 'y' was a data frame. Converted to a vector. (  WARNING  )
  -> target variable   :  2099  values 
  -> predict function  :  yhat.randomForest  will be used (  default  )
  -> predicted values  :  No value for predict function target column. (  default  )
  -> model_info        :  package randomForest , ver. 4.6.14 , task classification (  default  ) 
  -> model_info        :  Model info detected classification task but 'y' is a factor .  (  WARNING  )
  -> model_info        :  By deafult classification tasks supports only numercical 'y' parameter. 
  -> model_info        :  Consider changing to numerical vector with 0 and 1 values.
  -> model_info        :  Otherwise I will not be able to calculate residuals or loss function.
  -> predicted values  :  numerical, min =  0 , mean =  0.2384393 , max =  1  
  -> residual function :  difference between y and yhat (  default  )
  -> residuals         :  numerical, min =  NA , mean =  NA , max =  NA  
  A new explainer has been created!  

2 변수별 기여 분해 그래프

특정 관측점에 대한 변수별 기여를 분해하여 시각적으로 이해하기 쉽게 표현함.

2.1 헨리(henry)

2.1.1

library(reactable)

bd_rf <- predict_parts(explainer = explainer_rf, 
                       new_observation = titanic_list$data$henry,
                       type = "break_down")
bd_rf %>% 
  select(-label) %>%
  reactable::reactable(columns = list(
    contribution  = colDef(format = colFormat(digits = 2)),
    cumulative  = colDef(format = colFormat(digits = 2))
  ))

2.1.2 분해 그래프

bd_rf %>% 
  plot()

2.1.3 다른 탑승객과 비교

바이올린 그래프가 그려져야하는데… 이론상… 하지만 그렇게 구현되지 않음!!! DALEX 버전 1.x 버전에서 생겼던 문제로 최신 버전 2.2.0으로 올리게 되면 문제 없음.

bd_rf_distr <- predict_parts(explainer = explainer_rf, 
                             new_observation = titanic_list$data$henry,
                             type = "break_down", 
                             order = c("age", "class", "fare", "gender", "embarked", "sibsp", "parch"), 
                             keep_distributions = TRUE)

plot(bd_rf_distr, plot_distributions = TRUE) 

2.2 쟈니(johnny)

2.2.1

library(reactable)

bd_johnny_rf <- predict_parts(explainer = explainer_rf, 
                       new_observation = titanic_list$data$johnny_d,
                       type = "break_down")
bd_johnny_rf %>% 
  select(-label) %>%
  reactable::reactable(columns = list(
    contribution  = colDef(format = colFormat(digits = 2)),
    cumulative  = colDef(format = colFormat(digits = 2))
  ))

2.2.2 분해 그래프

bd_johnny_rf %>% 
  plot()

2.2.3 다른 탑승객과 비교

바이올린 그래프가 그려져야하는데… 이론상… 하지만 그렇게 구현되지 않음!!!

bd_rf_johnny_distr <- predict_parts(explainer = explainer_rf, 
                             new_observation = titanic_list$data$johnny_d,
                             type = "break_down", 
                             order = c("age", "class", "fare", "gender", "embarked", "sibsp", "parch"), 
                             keep_distributions = TRUE)

plot(bd_rf_johnny_distr, plot_distributions = TRUE) 

3 섀플리 값(Shapley Value)

게임 이론에서 가져온 개념을 기계학습에 적용시킨 것으로 다음과 같이 변수 기여도를 해석할 수 있다. 최적의 변수 조합을 찾는 것이 문제이며 각 변수는 player로 보고 다양한 상호협력 조합을 통해 예측값을 만들어 내느냐는 것이다. 계산량이 많아 다소 불리한 점이 있지만 분해(Break-down) 방법이 갖는 순서 문제(어떤 변수가 먼저 들어가느냐에 따라 해석이 달라지는 문제)와 교호작용(interaction)이 있는 문제점을 해결할 수 있다는 점에서 장점을 갖는다. 또한 새플리 값을 사용하는 경우 가법 모형을 상정하기 때문에 비선형 관계를 갖는 경우 설명에 한계가 존재한다.

3.1 헨리(henry)

3.1.1

shap_henry <- predict_parts(explainer       = explainer_rf, 
                            new_observation = titanic_list$data$henry,
                            type = "shap",
                            B = 5)

shap_henry
                                             min           q1       median
randomForest: age = 47             -0.2074159123 -0.184018247 -0.101094140
randomForest: class = 1st           0.0916169605  0.180594616  0.195558361
randomForest: embarked = Cherbourg  0.0111557885  0.025848452  0.056723106
randomForest: fare = 25            -0.0303592187 -0.015308957  0.004845260
randomForest: gender = male        -0.1591329204 -0.145780419 -0.131365984
randomForest: parch = 0            -0.0206260124 -0.008240877 -0.007172939
randomForest: sibsp = 0            -0.0008870891  0.003097856  0.005363888
                                           mean           q3          max
randomForest: age = 47             -0.113825250 -0.048969271 -0.030093378
randomForest: class = 1st           0.176588280  0.199024059  0.201010005
randomForest: embarked = Cherbourg  0.056297094  0.058251787  0.138862315
randomForest: fare = 25             0.015328442  0.046053645  0.074876608
randomForest: gender = male        -0.134816770 -0.123636017 -0.115391139
randomForest: parch = 0            -0.008438685 -0.005249166 -0.002455455
randomForest: sibsp = 0             0.004427632  0.006300143  0.007770367

3.1.2 그래프

library(patchwork)

shap_boxplot_gg <- plot(shap_henry) +
  scale_y_continuous(limits =c(-0.3, 0.3))

shap_average_gg <- plot(shap_henry, show_boxplots = FALSE) +
  scale_y_continuous(limits =c(-0.3, 0.3))

shap_boxplot_gg / shap_average_gg

4 라임(LIME)

분해(Break-down), 새플리 값(Shapley Value)는 설명변수가 크지 않는 경우 사용할 수 있지만, 설명변수가 많은 경우 Local Interpretable Model-agnostic Explanations (LIME)이 제시되고 있다.

library("DALEXtra") 
library("lime")

lime_johnny <- DALEXtra::predict_surrogate(explainer = explainer_rf, 
                                           new_observation = titanic_list$data$johnny_d,
                                           n_features = 3, 
                                           n_permutations = 1000, 
                                           type = "lime")
Error: The class of model must have a model_type method. See ?model_type to get an overview of models supported out of the box

5 Ceteris Paribus

라틴어 Ceteris Paribus는 ‘세테리스 패러버스’ 로 발음하고 영어로 “all other things being equal” 로 표현되며 “다른 모든 조건이 동일하다면” 을 의미한다. 즉, What-If 처럼 다른 조건을 동일하게 둔 상태에서 관심있는 변수를 변화시켰을 때 예측값의 변화를 살펴보는 방법이다.

5.1 쟈니(johnny)

5.1.1

cp_johnny_rf <- predict_profile(explainer = explainer_rf, 
                                 new_observation = titanic_list$data$johnny_d)
cp_johnny_rf
Top profiles    : 
             class gender age sibsp parch fare    embarked _yhat_ _vname_ _ids_
1              1st   male   8     0     0   72 Southampton  0.402   class     1
2              2nd   male   8     0     0   72 Southampton  0.412   class     1
3              3rd   male   8     0     0   72 Southampton  0.384   class     1
4        deck crew   male   8     0     0   72 Southampton  0.496   class     1
5 engineering crew   male   8     0     0   72 Southampton  0.326   class     1
6 restaurant staff   male   8     0     0   72 Southampton  0.328   class     1
       _label_
1 randomForest
2 randomForest
3 randomForest
4 randomForest
5 randomForest
6 randomForest


Top observations:
  class gender age sibsp parch fare    embarked _yhat_      _label_ _ids_
1   1st   male   8     0     0   72 Southampton  0.402 randomForest     1

5.1.2 연속형 변수

library(patchwork)

plot(cp_johnny_rf, variables = "age") + plot(cp_johnny_rf, variables = "fare") 

5.1.3 범주형 변수

plot(cp_johnny_rf, 
     variables = "embarked", 
     variable_type = "categorical", 
     categorical_type = "bars") 
Error in tmp[, sv] == key[as.character(tmp$`_ids_`), sv]: comparison of these types is not implemented

5.2 헨리와 쟈니

5.2.1 헨리

variable_splits <-list(age = seq(0, 70, 0.1), fare = seq(0, 100, 0.1))
cp_henry_rf <- predict_profile(explainer = explainer_rf, 
                                 new_observation = titanic_list$data$henry,
                                 variable_splits = variable_splits)

plot(cp_henry_rf, variables = "age") + plot(cp_henry_rf, variables = "fare") 

5.2.2 쟈니

cp_johnny_rf <- predict_profile(explainer = explainer_rf, 
                                 new_observation = titanic_list$data$johnny_d,
                                 variable_splits = variable_splits)

plot(cp_johnny_rf, variables = "age") + plot(cp_johnny_rf, variables = "fare") 

5.2.3 헨리와 쟈니

cp_henry_johnny_rf <- predict_profile(explainer = explainer_rf, 
                                 new_observation = rbind(titanic_list$data$henry, titanic_list$data$johnny_d),
                                 variable_splits = variable_splits)

plot(cp_henry_johnny_rf, variables = "age", color = "_ids_") + 
  plot(cp_henry_johnny_rf, variables = "fare", color = "_ids_") 

6 지역-검진 그래프

관측점에 대해 안정성(stability)을 확인하는 과정으로 잔차를 비교한다.

ldiag_rf <- predict_diagnostics(explainer = explainer_rf, 
                                new_observation = titanic_list$data$johnny_d,
                                neighbors = 100)
Error in ks.test(residuals_other, residuals_sel): not enough 'x' data
ldiag_rf %>% plot
Error in plot(.): object 'ldiag_rf' not found
ldiag_rf <- predict_diagnostics(explainer = explainer_rf, 
                                new_observation = titanic_list$data$henry,
                                neighbors = 100,
                                variable = "age")

ldiag_rf %>% plot

 

데이터 과학자 이광춘 저작

kwangchun.lee.7@gmail.com