교차 검증 데이터 구성하기와 “caret::createFolds”
k-겹 교차 검증(k-fold cross validation)
교차 검증은 모형의 성능을 판단하기 위해 사용한다. 선형 모형에서 \(R^2\) (또는 \(adj-R^2\) 나 AIC, BIC 등은 모두 특정한 가정을 성립할 때 일반화 성능을 판단하기 위해 사용할 수 있지만, 모형이 복잡하거나, 오차항이 정규분포를 띄지 않거나 하는 좀 더 일반적인 상황에서는 모형의 일반화 성능을 정확하게 반영한다고 말하기 힘들기 때문이다.
먼저 다음과 같은 데이터가 있다고 하자.
N = 100
x1 <- rexp(N)
x2 <- rexp(N, 0.5)
e <- rpois(N, 1)
y <- x1 + 2*log(x2+1) + e
dat <- data.frame(x1, x2, y)
head(dat)
## x1 x2 y ## 1 1.35960161 1.09324970 3.837037 ## 2 2.13923671 0.03644536 2.210831 ## 3 1.81704509 1.42709966 3.590439 ## 4 1.27645172 3.51469184 5.291126 ## 5 0.08845231 0.94381118 2.417753 ## 6 2.57794238 0.71397535 3.655573
5-겹 교차 검증 데이터를 구성하는 방법을 생각해보자. 다음과 같다.
isamp <- sample(100)
folds <- list(fold1 = isamp[1:20],
fold2 = isamp[21:40],
fold3 = isamp[41:60],
fold4 = isamp[61:80],
fold5 = isamp[81:100])
folds
## $fold1 ## [1] 6 92 28 80 54 66 55 11 69 32 17 75 87 84 45 25 68 21 86 85 ## ## $fold2 ## [1] 31 33 94 99 29 27 88 89 41 15 64 42 18 72 12 50 26 3 37 44 ## ## $fold3 ## [1] 39 76 81 14 48 38 97 100 46 63 90 43 5 49 52 51 56 59 78 23 ## ## $fold4 ## [1] 24 93 22 20 67 8 53 61 74 10 1 96 62 9 79 70 65 34 2 40 ## ## $fold5 ## [1] 73 77 60 47 91 58 71 4 36 57 19 35 95 82 83 98 13 7 16 30
함수 sample()
은 1
부터 100
의 자연수를 적절히 섞어서 길이 100의 자연수 벡터를 뱉어낸다. 그리고 이 결과는 isamp
에 저장되므로, 100개를 20개씩 쪼개서 validation set으로 만들면 된다.
(우리는 가장 구체적인 100개의 자료, 5개의 교차 검증 횟수 5에서 시작해서 임의의 자료 갯수 N, 임의의 교차 검증 k에서 작동하는 방법을 구성해나가고자 한다.)
따라서 첫 번째 데이터는 dat[-folds[[1]],]
을 train set으로 dat[folds[[1]],]
을 test set으로 구성하면 된다.
하지만 위의 코드를 보면 뭔가 매우 중복되는 요소가 많이 들어가 있음을 확인할 수 있다. list()
안에 isamp
가 5번 반복되고 있다. 숫자 1,20,21,40,41,60,61,80,81,100
에서도 연속되는 자연수는 다소 중복적인 요소이다.
isamp
의 원소를 적당하게 나눠주는 R의 함수를 사용하면 다음과 같이 쓸 수 있다.
folds <- split(isamp, cut(1:100,
breaks=100*c(0,0.2, 0.4, 0.6, 0.8, 1.0),
include.lowest=TRUE))
folds
## $`[0,20]` ## [1] 6 92 28 80 54 66 55 11 69 32 17 75 87 84 45 25 68 21 86 85 ## ## $`(20,40]` ## [1] 31 33 94 99 29 27 88 89 41 15 64 42 18 72 12 50 26 3 37 44 ## ## $`(40,60]` ## [1] 39 76 81 14 48 38 97 100 46 63 90 43 5 49 52 51 56 59 78 23 ## ## $`(60,80]` ## [1] 24 93 22 20 67 8 53 61 74 10 1 96 62 9 79 70 65 34 2 40 ## ## $`(80,100]` ## [1] 73 77 60 47 91 58 71 4 36 57 19 35 95 82 83 98 13 7 16 30
split(x, f)
는 벡터 x
의 내용을 팩터 f
에 따라 나눠준다. ?split
을 해보면 Divide into Groups and Reassemble
이라고 설명을 해놓았다.
따라서 우선 다음과 같이 해볼 수 있다.
folds <- split(isamp, c(rep(1,20), rep(2,20), rep(3,20), rep(4,20), rep(5,20)))
folds
## $`1` ## [1] 6 92 28 80 54 66 55 11 69 32 17 75 87 84 45 25 68 21 86 85 ## ## $`2` ## [1] 31 33 94 99 29 27 88 89 41 15 64 42 18 72 12 50 26 3 37 44 ## ## $`3` ## [1] 39 76 81 14 48 38 97 100 46 63 90 43 5 49 52 51 56 59 78 23 ## ## $`4` ## [1] 24 93 22 20 67 8 53 61 74 10 1 96 62 9 79 70 65 34 2 40 ## ## $`5` ## [1] 73 77 60 47 91 58 71 4 36 57 19 35 95 82 83 98 13 7 16 30
c(rep(1,20), rep(2,20), rep(3,20), rep(4,20), rep(5,20))
은 총 원소 100개를 20개씩 집단으로 나누는 표식이라고 생각할 수 있다. 위의 cut()
은 비슷한 방법이지만 누적비율로 나타낼 수 있다는 장점이 있다.
위의 방법은 자료의 크기를 일반화할 수 있다는 장점이 있다. 다음의 예를 보자. N
에 어떤 수를 집어 넣어도 잘 작동한다!
N=100
folds1 <- split(isamp, cut(1:N,
breaks=N*c(0,0.2, 0.4, 0.6, 0.8, 1.0),
include.lowest=TRUE))
folds2 <- split(isamp, c(rep(1,N*0.2), rep(2,N*0.2), rep(3,N*0.2),
rep(4,N*0.2), rep(5,N*0.2)))
folds1
## $`[0,20]` ## [1] 6 92 28 80 54 66 55 11 69 32 17 75 87 84 45 25 68 21 86 85 ## ## $`(20,40]` ## [1] 31 33 94 99 29 27 88 89 41 15 64 42 18 72 12 50 26 3 37 44 ## ## $`(40,60]` ## [1] 39 76 81 14 48 38 97 100 46 63 90 43 5 49 52 51 56 59 78 23 ## ## $`(60,80]` ## [1] 24 93 22 20 67 8 53 61 74 10 1 96 62 9 79 70 65 34 2 40 ## ## $`(80,100]` ## [1] 73 77 60 47 91 58 71 4 36 57 19 35 95 82 83 98 13 7 16 30
names(folds1) = NULL
names(folds2) = NULL
all.equal(folds1, folds2)
## [1] TRUE
두 번째 관문은 k
를 일반화하는 것이다.
k
를 일반화하기 위해서는 c(0, 0.2, 0.4, 0.6, 0.8, 1.0)
부분을 적당히 고쳐야 할 것이다. 예를 들어 k=2
라면 c(0, 0.5, 1)
이 되어야 하며, k=4
라면 c(0, 0.25, 0.5, 0.75, 1)
이 되어야 한다. 공통점이 보이는가? 시작은 언제나 0
이고, 마지막은 언제나 1
이다. 그리고 k
에 따라 길이가 달라져야 한다.
이런 경우 R의 seq
를 사용하여 seq(from=0, to=1, length.out=k+1)
로 쓸 수 있다. 따라서 다음과 같은 결과를 얻는다.
N = 100; k = 10
isamp <- sample(N)
folds <- split(isamp, cut(1:N,
breaks=N*seq(0, 1, length.out=k+1),
include.lowest=TRUE))
#folds
N
에 101
를 넣어도, k
에 17
를 넣어도 잘 작동함을 확인할 수 있을 것이다! 결국 일반적인 N
과 k
에 대해 k
-겹 교차 검증 데이터 셋을 구성할 수 있는 방법을 개발하였다!
N = 101; k = 17
isamp <- sample(N)
folds <- split(isamp, cut(1:N,
breaks=N*seq(0, 1, length.out=k+1),
include.lowest=TRUE))
folds
## $`[0,5.94]` ## [1] 26 98 86 4 17 ## ## $`(5.94,11.9]` ## [1] 83 13 56 10 81 91 ## ## $`(11.9,17.8]` ## [1] 18 1 88 62 16 78 ## ## $`(17.8,23.8]` ## [1] 5 49 72 47 59 82 ## ## $`(23.8,29.7]` ## [1] 38 24 21 63 68 67 ## ## $`(29.7,35.6]` ## [1] 41 71 53 61 23 64 ## ## $`(35.6,41.6]` ## [1] 27 69 85 76 30 54 ## ## $`(41.6,47.5]` ## [1] 11 33 75 52 84 92 ## ## $`(47.5,53.5]` ## [1] 89 73 36 93 25 15 ## ## $`(53.5,59.4]` ## [1] 94 55 42 80 32 44 ## ## $`(59.4,65.4]` ## [1] 3 45 66 50 22 90 ## ## $`(65.4,71.3]` ## [1] 96 40 20 43 28 74 ## ## $`(71.3,77.2]` ## [1] 79 77 97 7 35 34 ## ## $`(77.2,83.2]` ## [1] 12 19 99 29 2 100 ## ## $`(83.2,89.1]` ## [1] 70 48 60 51 101 65 ## ## $`(89.1,95.1]` ## [1] 37 14 87 46 39 57 ## ## $`(95.1,101]` ## [1] 8 58 9 31 95 6
sapply(folds, length)
## [0,5.94] (5.94,11.9] (11.9,17.8] (17.8,23.8] (23.8,29.7] (29.7,35.6] (35.6,41.6] (41.6,47.5] ## 5 6 6 6 6 6 6 6 ## (47.5,53.5] (53.5,59.4] (59.4,65.4] (65.4,71.3] (71.3,77.2] (77.2,83.2] (83.2,89.1] (89.1,95.1] ## 6 6 6 6 6 6 6 6 ## (95.1,101] ## 6
lelem <- sapply(folds, length)
sum(lelem)
## [1] 101
caret
의 createFolds
caret
의 createFolds
는 우리가 했던 바로 그 작업을 해주는 함수이다.
folds <- caret::createFolds(y=1:101, k=7)
folds
## $Fold1 ## [1] 1 4 12 25 33 36 42 53 54 57 76 79 81 86 ## ## $Fold2 ## [1] 14 15 17 27 29 38 52 55 62 63 78 93 97 ## ## $Fold3 ## [1] 2 11 13 20 40 44 46 49 59 61 65 75 84 90 92 100 ## ## $Fold4 ## [1] 3 8 9 23 32 34 35 51 58 71 73 91 94 98 101 ## ## $Fold5 ## [1] 5 19 21 24 31 41 45 48 64 66 67 85 87 89 95 ## ## $Fold6 ## [1] 6 7 22 26 30 37 47 50 56 60 70 72 77 82 83 88 ## ## $Fold7 ## [1] 10 16 18 28 39 43 68 69 74 80 96 99
sapply(folds, length)
## Fold1 Fold2 Fold3 Fold4 Fold5 Fold6 Fold7 ## 14 13 16 15 15 16 12
푸하하! 우리가 만들었던 함수보다 약간 못한 것 같다.
lelem <- sapply(folds, length)
sum(lelem)
## [1] 101
caret::createFolds
의 장점
하지만 createFolds
는 target label을 균등하게 배분하는 기능이 있다. 이게 무슨 말인가?
library(caret)
y <- sample(3, 100, replace=TRUE, prob=c(0.2, 0.3, 0.5))
folds <- createFolds(y, k=3)
folds
## $Fold1 ## [1] 6 7 10 15 16 21 22 23 24 27 29 30 32 46 50 51 52 53 59 63 64 65 66 68 69 75 76 78 81 86 89 92 ## [33] 94 96 ## ## $Fold2 ## [1] 2 5 9 11 14 17 18 20 34 35 36 37 38 40 41 42 43 44 47 48 49 55 60 61 ## [25] 62 72 83 84 85 87 93 95 98 100 ## ## $Fold3 ## [1] 1 3 4 8 12 13 19 25 26 28 31 33 39 45 54 56 57 58 67 70 71 73 74 77 79 80 82 88 90 91 97 99
위에서 y
에는 1
또는 2
또는 3
이 저장되어 있으며, 그 비율은 0.2, 0.3, 0.5과 비슷하다.
table(y)
## y ## 1 2 3 ## 22 34 44
우리가 k-겹 교차 검증 데이터를 구성할 때, target label의 비율이 각 fold마다 지나치게 다르게 되며 학습에 이런 imbalance가 반영될 수 있다. 예를 들어 y를 예측해야 하는데, train set에 모두 y=1인 사례만 들어가 있는 극단적인 경우를 생각해보면 imbalance의 문제를 이해할 수 있을 것이다.
caret::createFolds
는 이렇게 target label의 비율을 의도적으로 맞춰준다. 다음에서 확인할 수 있다.
library(magrittr)
table(y[-folds[[1]]]) %>% prop.table %>% round(1)
## ## 1 2 3 ## 0.2 0.3 0.4
table(y[-folds[[2]]]) %>% prop.table %>% round(1)
## ## 1 2 3 ## 0.2 0.3 0.4
table(y[-folds[[3]]]) %>% prop.table %>% round(1)
## ## 1 2 3 ## 0.2 0.4 0.4
반면 우리가 개발한 함수에는 그런 기능이 없다.
N = 100; k = 3
isamp <- sample(N)
folds <- split(isamp, cut(1:N,
breaks=N*seq(0, 1, length.out=k+1),
include.lowest=TRUE))
sapply(folds, length)
## [0,33.3] (33.3,66.7] (66.7,100] ## 33 33 34
table(y[-folds[[1]]]) %>% prop.table %>% round(1)
## ## 1 2 3 ## 0.2 0.3 0.5
table(y[-folds[[2]]]) %>% prop.table %>% round(1)
## ## 1 2 3 ## 0.2 0.3 0.4
table(y[-folds[[3]]]) %>% prop.table %>% round(1)
## ## 1 2 3 ## 0.2 0.4 0.4
결론
k-겹 교차 검증 데이터를 구성하기 위해 필요한 함수를 개발하였다.
createfolds = function(N, k) {
isamp <- sample(N)
split(isamp, cut(1:N,
breaks=N*seq(0, 1, length.out=k+1),
include.lowest=TRUE))
}
createfolds(77,4)
## $`[0,19.2]` ## [1] 26 43 60 34 70 24 2 7 64 9 52 16 14 74 57 59 10 69 42 ## ## $`(19.2,38.5]` ## [1] 62 19 13 73 77 65 1 22 67 3 37 56 58 29 8 63 17 53 6 ## ## $`(38.5,57.8]` ## [1] 41 66 36 76 44 11 21 15 49 4 61 27 48 38 32 54 51 39 5 ## ## $`(57.8,77]` ## [1] 12 28 46 20 50 45 40 31 75 23 18 55 68 33 30 35 25 71 72 47
Leave a comment