1. 손글씨 숫자 (mnist) 데이터 1

MNIST in CSV 웹사이트에 데이터 분석의 표준 파일형식(.csv)으로 기계학습과 딥러닝 교육용으로 많이 활용되는 손글씨 숫자 mnist 데이터를 제공하고 있다.

1.1. 데이터 가져오기

웹사이트에서 데이터를 바로 불러오면 되기 때문에 tidyverse 팩키지를 구성하는 readr 팩키지 read_csv() 함수를 사용해서 데이터를 불러온다. 훈련데이터 60,000, 검증데이터 10,000 손글씨로 쓴 숫자 데이터가 제공되는데 이미지라서 \(28 \times 28 = 784\) 그리고 정답 숫자라벨이 한 칼럼을 차지해서 785 칼럼으로 구성되어 있다.

# 0. 환경설정 -----
library(tidyverse)
library(extrafont)
loadfonts()
library(gridExtra)

   
# 1. 데이터 가져오기 -----

mnist_train_dat <- read_csv("https://pjreddie.com/media/files/mnist_train.csv", col_names = FALSE)
# mnist_test_dat  <- read_csv("https://pjreddie.com/media/files/mnist_test.csv", col_names = FALSE)

1.2. 데이터 전처리

데이터를 tidy 형태로 가공하는 작업을 수행하고 mnist_train_df로 저장하여 탐색적 데이터 분석을 수행할 수 있도록 준비한다.

# 2. 데이터 전처리 -----

mnist_train_df <- mnist_train_dat %>%
#    head(10000) %>%
    rename(정답 = X1) %>%
    mutate(사례 = row_number()) %>%
    gather(픽셀, RGB값, -정답, -사례) %>%
    tidyr::extract(픽셀, "픽셀", "(\\d+)", convert = TRUE) %>%
    mutate(픽셀 = 픽셀 - 2,
           x = 픽셀 %% 28,
           y = 28 - 픽셀 %/% 28)

2. 탐색적 데이터 분석

2.1. 데이터 살펴보기

탐색적 데이터 분석을 위해서 ggplot을 통해 준비한 데이터 형태를 살펴본다.

# 3. 탐색적 데이터 분석 -----
## 3.1. 시각화 통해 살펴보기
mnist_train_df %>%
    filter(사례 <= 12) %>%
    ggplot(aes(x, y, fill = RGB값)) +
    geom_tile() +
    facet_wrap(~ 사례 + 정답) +
    theme_void() +
    scale_fill_gradient2(low = "white", high = "black", mid = "gray", midpoint = 255/2)

2.2. 데이터 중심

\(28 \times 28 = 784\) 크기를 갖는 60,000개 관측점을 하나의 숫자 평균을 낸다면 어떤 모양일지 살펴보자. 손글씨 숫자 중심을 중위수와 평균으로 잡아 이를 시각적으로 파악해 보자. 신규 손글씨 숫자가 나왔을 때 k-최근접 이웃 알고리즘을 돌려 가장 오차가 적은 것으로 분류하는 것도 좋은 접근법이 될 듯 싶다.

## 3.2. 중앙 평균값과 평균
mnist_median_g <- mnist_train_df %>% 
    group_by(x, y, 정답) %>%
    summarize(RGB중위수 = median(RGB값)) %>%
    ungroup() %>% 
    ggplot(aes(x, y, fill = RGB중위수)) +
    geom_tile() +
    scale_fill_gradient2(low = "white", high = "black", mid = "gray", midpoint = 255/2) +
    facet_wrap( ~ 정답, nrow = 2) +
    labs(title = "손글씨 MNIST 숫자 10개에 대한 RGB 중위수",
         fill = "RGB 중위수") +
    theme_void(base_family="NanumGothic")
    
mnist_mean_g <- mnist_train_df %>% 
    group_by(x, y, 정답) %>%
    summarize(RGB중위수 = mean(RGB값)) %>%
    ungroup() %>% 
    ggplot(aes(x, y, fill = RGB중위수)) +
    geom_tile() +
    scale_fill_gradient2(low = "white", high = "black", mid = "gray", midpoint = 255/2) +
    facet_wrap( ~ 정답, nrow = 2) +
    labs(title = "손글씨 MNIST 숫자 10개에 대한 RGB 평균",
         fill = "RGB 평균") +
    theme_void(base_family="NanumGothic")

grid.arrange(mnist_median_g, mnist_mean_g, nrow=2)

2.3. 손글씨 이상점

\(28 \times 28 = 784\)의 중심을 찾았다면, 평균거리가 가장 먼 손글씨 숫자 사례를 찾아 각 숫자 0-9에 대해 6가지 사례만 시각적으로 파악해 본다. 오분류 가능성이 많은 사례로 … 기계학습 알고리즘이 찾아내지 못하는 것이 어떤 경우가 되는지 미리 파악해 본다.

## 3.3. 이상 사례 찾아내기 -----
mnist_train_mean_df <- mnist_train_df %>% 
    group_by(x, y, 정답) %>%
    summarize(RGB평균 = mean(RGB값))

mnist_train_eda_df <- inner_join(mnist_train_df, mnist_train_mean_df, by = c("x", "y", "정답"))

mnist_train_dist_df <- mnist_train_eda_df %>%
    group_by(정답, 사례) %>%
    summarize(유클리드_거리 = sqrt(mean((RGB값 - RGB평균) ^ 2))) %>% 
    ungroup()

worst_instances <- mnist_train_dist_df %>%
    group_by(정답) %>% 
    top_n(6, 유클리드_거리) %>%
    mutate(number = rank(-유클리드_거리))

mnist_train_df %>%
    inner_join(worst_instances, by = c("정답", "사례")) %>%
    ggplot(aes(x, y, fill = RGB값)) +
    geom_tile(show.legend = FALSE) +
    scale_fill_gradient2(low = "white", high = "black", mid = "gray", midpoint = 255/2) +
    facet_grid(정답 ~ number) +
    labs(title = "식별이 쉽지 않은 숫자",
         subtitle = "평균과 차이가 많이 나는 손글씨 숫자 6개") +
    theme_void(base_family="NanumGothic") +
    theme(strip.text = element_blank())

2.3. 교차분석

0-9 사이 숫자가 잘 구분되거나 구분이 잘 되지 않는 경우가 있을 때, 어떤 아라비아 숫자 두개가 이슈가 되는지 교차분석을 통해 파악한다.

## 3.4. 교차분석 -----

digit_differences <- crossing(compare1 = 0:9, compare2 = 0:9) %>%
    filter(compare1 != compare2) %>%
    mutate(negative = compare1, positive = compare2) %>%
    gather(class, 정답, positive, negative) %>%
    inner_join(mnist_train_mean_df, by = "정답") %>%
    select(-정답) %>%
    spread(class, RGB평균)

ggplot(digit_differences, aes(x, y, fill = positive - negative)) +
    geom_tile() +
    scale_fill_gradient2(low = "blue", high = "red", mid = "white", midpoint = .5) +
    facet_grid(compare2 ~ compare1) +
    theme_void()


  1. http://varianceexplained.org - Exploring handwritten digit classification: a tidy analysis of the MNIST dataset