交差検証を読み解く
こんにちは、Vet Analyticsでデータサイエンティストをしている浜野です。
Vet Analyticsは獣医学に機械学習の知見を導入しようと日々奮闘しているデータサイエンス・エンジニアチームです。 東京大学本郷キャンパス付近でいつも活動しているので興味がありましたらtwitterアカウントから連絡してください。
前回はロジスティック回帰の理論とPythonを使った実装方法について解説しました。
今回は交差検証について解説していきたいと思います。
はじめに
まず汎化誤差(test error)と訓練誤差(train error)について理解する必要があります。汎化誤差とは訓練データを使い統計手法で学習をした後、未知のデータに対して予測を行った場合の誤差です。また、訓練誤差は訓練に使ったデータに対して予測を行った場合の誤差です。
このときどのような基準でこの訓練誤差と汎化誤差を求めるのでしょうか。
バリデーションセット(交差)のアプローチ
このとき、バリデーションセットのアプローチでは、最も簡単なものとして、ランダムにデータを分割して、片方を訓練データ、もう片方をテストデータとします。 そしてそれぞれのデータに対して誤差を計算し、それを訓練誤差、汎化誤差とするパターンです。かなり直感的ですよね。 そこで実際にあるデータを使って実験してみました。
下の図は多項式回帰を使い、先ほど紹介したランダムに分割して求める手法を用いたものです。横軸は多項式回帰の次数、縦軸はMSEを用いた誤差です。
続いて、この試行を10回繰り返してみた結果がこちらです。一色が一回の試行を示しています。
こう見ると、一回の検証ごとの誤差にかなりばらつきがあることがわかりますね。どのデータが訓練に使われて、どのデータがテストに使われたか、というのがかなり誤差に響くことがわかりました。ここでこれらの解決策を2つ紹介します。
Leave-One-Out Cross-Validation
Leave-One-Out Cross-Validationは、「データを一つ取り出す(Leave-One-Out)」ことをし、それをテストデータ、それ以外を訓練データとする手法です。 この試行を全部がテストデータになる(つまりデータがn行ならn回)繰り返し、その平均をとり、それぞれを訓練誤差、汎化誤差としました。 しかしこの手法には欠点があります。計算量が莫大になってしまうことです。例えば1万行のデータセットがあったとしたら、1万回この試行を繰り返さなければならないわけです。これを改善したのが次の手法です。
K-fold Cross-Validation
K-fold Cross-Validationはデータをk個に均等に分割して、1つをテストデータ、それ以外を訓練データとする試行をk回繰り返し、平均をとる手法です。(注意:もしデータn行あったとしてn-fold Cross-Validationをしたとしたら、それはLeave-One-Out Cross-Validationと同じです)
この手法をLeave-One-Out Cross-Validationと比べてみましょう。 左がLeave-One-Out Cross-Validation、右が10-fold Cross-Validationを9回試した結果です。こうしてみてみると、10-fold Cross-Validationの方が計算量がずっと小さいのにもかかわらず、あまり結果が変わらないことがわかります。
実装
今回紹介したなかで最もよく使われるのは、やはりK-fold Cross-Validationです。簡単でライブラリも充実しています。
import numpy as np #こちらがK-fold Cross-Validationです from sklearn.model_selection import KFold X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]]) y = np.array([1, 2, 3, 4]) #ここでfoldの数を指定します。今回は2とします kf = KFold(n_splits=2) KFold(n_splits=2, random_state=None, shuffle=False) for train_index, test_index in kf.split(X): print("TRAIN:", train_index, "TEST:", test_index) X_train, X_test = X[train_index], X[test_index] y_train, y_test = y[train_index], y[test_index]
出力結果
TRAIN: [2 3] TEST: [0 1] TRAIN: [0 1] TEST: [2 3]
しっかりK-fold Cross-Validationが実装できていますね! 以上で交差検証の説明は終わりです。 Vet Analyticsは様々な方のコンタクトを歓迎しています!DMへのご連絡をお待ちしております!
次回は趣向を変えて獣医学への機械学習の応用事例について書きたいと思います。
Vet Analyticsは様々な方のコンタクトを歓迎しています!DMへのご連絡をお待ちしております!
出典
とてもわかりやすい本です faculty.marshall.usc.edu