【Deep Learning】過学習とDropoutについて
前回、Deep Learningを用いてCIFAR-10の画像を識別しました。今回は機械学習において重要な問題である過学習と、その対策について取り上げます。
sonickun.hatenablog.com
過学習について
過学習(Overfitting)とは、機械学習において、訓練データに対して学習されているが、未知のデータに対して適合できていない(汎化できていない)状態を指します。たとえ訓練データに対する精度が100%近くに達したとしても、テストデータに対する精度が高くならなければ、それは良い学習とはいえません。特にニューラルネットは複雑なモデルのため過学習に陥りやすいと言われています。
過学習の例
過学習の例として、最小二乗法による多項式近似を用いてサインカーブ(+標準偏差0.3の乱数)を推測してみます。
下の図は、多項式の次数Mを変化させた時の学習の結果を表しています。E(RMS)は近似したい関数との平均二乗誤差を表しています。M=9のときE(RMS)の値は0.0(回帰曲線が全てのプロットを通っている)となっていますが、近似したいサインカーブからは大きくはずれてしまっています。このような状態を過学習と呼びます。一方、M=3の時は予測したいサインカーブに近い回帰曲線が描けており、4つのグラフの中ではこれが最も良いモデルだといえます。E(RMS)の値が0.3に近いのは、データが本質的に0.3程度の誤差を含んでいることを示唆しています。
最小二乗法による多項式近似 スクリプト
・https://gist.github.com/sonickun/c7837d0cf732cda7d69d373abc82f99c
Dropoutについて
Dropoutは、階層の深いニューラルネットを精度よく最適化するためにHintonらによって提案された手法です。
Nitish Srivastava, Geoffrey Hinton, Alex Krizhevsky, Ilya Sutskever, Ruslan Salakhutdinov. Dropout: A Simple Way to Prevent Neural Networks from Overfitting. The Journal of Machine Learning Research, Volume 15 Issue 1, January 2014 Pages 1929-1958
・PDF: https://www.cs.toronto.edu/~hinton/absps/JMLRdropout.pdf
Dropoutでは、ニューラルネットワークを学習する際に、ある更新で層の中のノードのうちのいくつかを無効にして(そもそも存在しないかのように扱って)学習を行い、次の更新では別のノードを無効にして学習を行うことを繰り返します。これにより学習時にネットワークの自由度を強制的に小さくして汎化性能を上げ、過学習を避けることができます。隠れ層においては、一般的に50%程度を無効すると良いと言われています。当初Dropoutは全結合のみに適用されていましたが、先ほど挙げた論文によれば、畳み込み層等に適用しても同様に性能を向上させることが確かめられています。
Dropoutが高性能である理由は、「アンサンブル学習」という方法の近似になるからとも言われています。アンサンブル学習とは、複数の機械学習結果を利用して判定を行うことで、学習の性能を上げることです。これを応用した学習器として、ランダムサンプリングしたデータによって学習した多数の決定木を平均化するランダムフォレストなどがあります。
実験
前回記事とおなじCIFAR-10の画像識別を行い、Dropoutを実装した時としない時の比較を行ってみたいと思います。使用したDeep Learningフレームワークは前回と同じCaffeです。
学習の設定は以下のとおりです(細かいパラメータは割愛)。前回の知見を踏まえて、「学習率は徐々に小さく」、「活性化関数はReLU関数」、「ニューラルネットワークはCNN」を意識しました。
- 学習回数(Iterations): 70000
- バッチサイズ(Batch Size): 100
- 勾配降下法の学習率(Learning Rate): 0.001(~60000iters) -> 0.0001(~65000iters) -> 0.0001(~70000iters)
- 活性化関数(Activation Function): ReLU関数
- ニューラルネットワーク: CNN (畳み込み層×2+プーリング層×2+LRN層×2+全結合層×2)
今回は3つのパターンを試します。
1. CNNのみ
2. CNN + Dropout (全結合層のみ)
3. CNN + Dropout (全層)
なお、Dropoutのユニットの選出確率pは全結合層ではp=0.5、その他はp=0.2としています。
結果と考察
以下の図は、上記の3つのパターンそれぞれについて、Train Accuracy(訓練データに対する精度)とTest Accuracy(テストデータに対する精度)の変化の様子を表したグラフです。CNNのみのグラフでは、学習が進むにつれてTest AccuracyとTrain Accuracyの差が開いていっていますが、これは過学習が起きていることを表しています。一方、Dropoutを実装するとこの差が小さくなり、特に全層に適用した場合はTest AccuracyとTrain Accuracyの差はほとんどなくなっています。このことより、Dropoutが過学習の回避策として機能していることが分かります。
CNNのみ
CNN + Dropout (全結合層のみ)
CNN + Dropout (全層)
最後まで学習したときの(70000 iters)3つのパターンの識別精度を以下の表にまとめました。やはり全層に対してDropoutを適用した場合は過学習をほぼ回避できているようです。ただし、Dropoutで過学習を回避することがそのまま識別精度の向上につながるわけではなさそうです。予備実験では、全結合層以外のユニット選出確率もすべてp=0.5にしたところ、確かに過学習を回避できましたが、識別精度は60%代にまで落ち込んでしまいました。どの層にDropoutを施すかによって、ユニット選出確率の値を慎重に選ぶ必要がありそうです。
Dropoutの詳細 | Train Accuracy (%) | Test Accuracy (%) | Train - Test |
---|---|---|---|
CNNのみ | 94.0 | 80.3 | 13.7 |
CNN + Dropout (全結合層のみ) | 91.0 | 81.8 | 9.2 |
CNN + Dropout (全層) | 82.0 | 81.1 | 0.9 |
おまけTips: CaffeのログにTrain Accuracyを出力する
Caffeのチュートリアル通りに学習を行うと、Iterationsの途中でTest Accuracy(テストデータに対する正解率)とTrain Loss(誤差関数の値)をログに出力してくれますが、Train Accuracy(学習データに対する正解率)は出力してくれません。
I0712 12:22:57.474715 4721 solver.cpp:337] Iteration 62000, Testing net (#0) I0712 12:23:44.508533 4721 solver.cpp:404] Test net output #0: accuracy = 0.8094 I0712 12:23:44.508831 4721 solver.cpp:404] Test net output #1: loss = 0.555002 (* 1 = 0.555002 loss) I0712 12:23:45.563415 4721 solver.cpp:228] Iteration 62000, loss = 0.294429 I0712 12:23:45.563508 4721 solver.cpp:244] Train net output #0: loss = 0.294429 (* 1 = 0.294429 loss)
Train Accuracy(および Test Loss)を出力するにはニューラルネットワークの定義ファイル(.prototxt)を編集し、"Accuracy"の層のincludeの部分を削除すればよいです。
layer { name: "accuracy" type: "Accuracy" bottom: "ip1" bottom: "label" top: "accuracy" include { # <- phase: TEST # <- この3行を削除 } # <- }
修正後のログ
I0714 17:34:24.503883 27309 solver.cpp:337] Iteration 60000, Testing net (#0) I0714 17:35:07.727932 27309 solver.cpp:404] Test net output #0: accuracy = 0.7468 I0714 17:35:07.728186 27309 solver.cpp:404] Test net output #1: loss = 0.842801 (* 1 = 0.842801 loss) I0714 17:35:08.998235 27309 solver.cpp:228] Iteration 60000, loss = 0.287727 I0714 17:35:08.998337 27309 solver.cpp:244] Train net output #0: accuracy = 0.93 I0714 17:35:08.998360 27309 solver.cpp:244] Train net output #1: loss = 0.287727 (* 1 = 0.287727 loss)
次回 -> http://そのうち