모형과 데이터 불러오기

데이터와 기계학습 예측모형을 불러오자

library(tidyverse)

titanic_list  <-  
  read_rds("data/titinic_list.rds")

str(titanic_list, max.level = 2) 
List of 2
 $ data :List of 3
  ..$ training:'data.frame':    2099 obs. of  9 variables:
  .. ..- attr(*, "na.action")= 'omit' Named int [1:108] 46 90 118 119 122 132 139 145 151 152 ...
  .. .. ..- attr(*, "names")= chr [1:108] "46" "90" "118" "119" ...
  ..$ henry   :'data.frame':    1 obs. of  7 variables:
  ..$ johnny_d:'data.frame':    1 obs. of  7 variables:
 $ model:List of 4
  ..$ titanic_lmr:List of 24
  .. ..- attr(*, "class")= chr [1:3] "lrm" "rms" "glm"
  ..$ titanic_rf :List of 19
  .. ..- attr(*, "class")= chr [1:2] "randomForest.formula" "randomForest"
  ..$ titanic_gbm:List of 27
  .. ..- attr(*, "class")= chr "gbm"
  ..$ titanic_svm:List of 30
  .. ..- attr(*, "class")= chr [1:2] "svm.formula" "svm"

1 XAI 모형 아키텍처

일단 기계가 학습한 알고리즘(함수)이 있다면 이를 DALEX 팩키지 explain 객체로 변환시킨 후에 다양한 관점에서 모형을 설명하도록 한다.

library("rms")
library("DALEX")

## 로지스틱 분류모형 설명자(explainer)
explain_lmr <- explain(model = titanic_list$model$titanic_lmr,
                       data  = titanic %>% select(-survived),
                       y     = titanic$survived == "yes", 
                       type = "classification",
                       label = "Logistic Regression")
Preparation of a new explainer is initiated
  -> model label       :  Logistic Regression 
  -> data              :  2207  rows  8  cols 
  -> target variable   :  2207  values 
  -> predict function  :  yhat.lrm  will be used (  default  )
  -> predicted values  :  numerical, min =  NA , mean =  NA , max =  NA  
  -> model_info        :  package rms , ver. 6.2.0 , task classification (  default  ) 
  -> model_info        :  type set to  classification 
  -> residual function :  difference between y and yhat (  default  )
  -> residuals         :  numerical, min =  NA , mean =  NA , max =  NA  
  A new explainer has been created!  
## Random Forest 설명자(Explainer)
library("randomForest")
explain_rf <- explain(model = titanic_list$model$titanic_rf,
                      data = titanic %>% select(-survived),
                      y = titanic$survived == "yes", 
                      label = "Random Forest")
Preparation of a new explainer is initiated
  -> model label       :  Random Forest 
  -> data              :  2207  rows  8  cols 
  -> target variable   :  2207  values 
  -> predict function  :  yhat.randomForest  will be used (  default  )
  -> predicted values  :  numerical, min =  NA , mean =  NA , max =  NA  
  -> model_info        :  package randomForest , ver. 4.6.14 , task classification (  default  ) 
  -> residual function :  difference between y and yhat (  default  )
  -> residuals         :  numerical, min =  NA , mean =  NA , max =  NA  
  A new explainer has been created!  
## SVM 설명자(Explainer)
library(e1071)
explain_svm <- explain(model = titanic_list$model$titanic_svm,
                      data = titanic %>% select(-survived),
                      y = titanic$survived == "yes", 
                      label = "SVM")
Preparation of a new explainer is initiated
  -> model label       :  SVM 
  -> data              :  2207  rows  8  cols 
  -> target variable   :  2207  values 
  -> predict function  :  yhat.svm  will be used (  default  )
  -> predicted values  :  numerical, min =  0.08561812 , mean =  0.3267464 , max =  0.9582763  
  -> model_info        :  package e1071 , ver. 1.7.6 , task classification (  default  ) 
  -> residual function :  difference between y and yhat (  default  )
  -> residuals         :  numerical, min =  -0.9556032 , mean =  -0.007418432 , max =  0.9143819  
  A new explainer has been created!  
## GBM 설명자(Explainer)
library(gbm)
explain_gbm <- explain(model = titanic_list$model$titanic_gbm,
                      data = titanic %>% select(-survived),
                      y = titanic$survived == "yes", 
                      label = "GBM")
Preparation of a new explainer is initiated
  -> model label       :  GBM 
  -> data              :  2207  rows  8  cols 
  -> target variable   :  2207  values 
  -> predict function  :  yhat.gbm  will be used (  default  )
  -> predicted values  :  numerical, min =  0.0003582783 , mean =  0.3238207 , max =  0.9983936  
  -> model_info        :  package gbm , ver. 2.1.8 , task classification (  default  ) 
  -> residual function :  difference between y and yhat (  default  )
  -> residuals         :  numerical, min =  -0.9409688 , mean =  -0.001663897 , max =  0.966334  
  A new explainer has been created!  

2 모형 성능

모형 성능평가는 다음 세가지 척도를 가지고 고민할 수 있다.

  • 모형평가(model evaluation)
  • 모형비교(model comparison)
  • 훈련/시험 데이터 비교(out-of-sample, out-of-time comparisons)

2.1 AUC

# 로지스틱 회귀모형 ---------------------------
eva_lr <- DALEX::model_performance(explain_lmr)

eva_lr_tbl <- eva_lr$measures %>% 
  enframe() %>% 
  mutate(performance = map_dbl(value, 1)) %>% 
  mutate(model = "LR")

# Random Forest ------------------------------
eva_rf <- DALEX::model_performance(explain_rf)

eva_rf_tbl <- eva_rf$measures %>% 
  enframe() %>% 
  mutate(performance = map_dbl(value, 1)) %>% 
  mutate(model = "RF")

# # SVM ------------------------------
# eva_svm <- DALEX::model_performance(explain_svm)
# 
# eva_svm_tbl <- eva_svm$measures %>% 
#   enframe() %>% 
#   mutate(performance = map_dbl(value, 1)) %>% 
#   mutate(model = "SVM")

# GBM ------------------------------
eva_gbm <- DALEX::model_performance(explain_gbm)

eva_gbm_tbl <- eva_gbm$measures %>% 
  enframe() %>% 
  mutate(performance = map_dbl(value, 1)) %>% 
  mutate(model = "GBM")

# 모형 성능 종합 ======================
plot(eva_lr, 
     eva_rf, 
     eva_gbm,
     # eva_svm,
     geom = "roc")

2.2 성능표

library(reactable)

performance_tbl <- 
  bind_rows(eva_lr_tbl, eva_rf_tbl) %>% 
  # bind_rows(eva_svm_tbl) %>% 
  bind_rows(eva_gbm_tbl) %>% 
  select(-value)

performance_tbl %>% 
  pivot_wider(names_from = "model", values_from = "performance") %>% 
  reactable::reactable(columns = list(
    LR  = colDef(format = colFormat(digits = 2)),
    RF  = colDef(format = colFormat(digits = 2)),
    # SVM = colDef(format = colFormat(digits = 2)),
    GBM = colDef(format = colFormat(digits = 2))
  ))
 

데이터 과학자 이광춘 저작

kwangchun.lee.7@gmail.com