17
Dropout Distillation Samuel Rota Bulò, Lorenzo Porzi , Peter Kontschieder ICML2016読み会 紹介者:佐野正太郎 株式会社リクルートコミュニケーションズ

Dropout Distillation

Embed Size (px)

Citation preview

Page 1: Dropout Distillation

0

  

Dropout Distillation Samuel Rota Bulò, Lorenzo Porzi , Peter Kontschieder

ICML2016読み会

紹介者:佐野正太郎

株式会社リクルートコミュニケーションズ

Page 2: Dropout Distillation

(C)Recruit Communications Co., Ltd.

背景:Dropout学習

•  ニューラルネットワークの過学習を抑制する手法

•  学習ステップ毎にランダムに一部のユニットを落とす

•  暗に多数のネットワークのアンサンブルモデルを学習している

•  [Srivastava et al., 2014]

1

学習対象のネットワーク 学習ステップ1 学習ステップ2

・・・ 学習時

Page 3: Dropout Distillation

(C)Recruit Communications Co., Ltd.

背景:Dropoutにおける予測計算

2

Doropout学習時にはネットワーク構造がランダム => 予測時にどの構造を採用するか?

理想:全てのDropoutパターンでの予測計算の期待値をとる

Standard Dropout [Srivastava et al., 2014]

•  予測時にはユニットを落とさない

•  各ユニットの出力を (1 – dropout率) でスケールすることで実用的な精度が得られる

Monte-Carlo Dropout [Gal & Ghahramani, 2015]

•  予測時に複数のDropoutパターンを試して平均をとる

•  予測の計算コストが高い代わりにStandard Dropoutよりも良い精度が得られる

Page 4: Dropout Distillation

(C)Recruit Communications Co., Ltd.

背景:Distillation

3

Distilling the knowledge in Neural Network [Hinton et al., 2014]

•  distill = 蒸留する

•  複数のネットワークや複雑なネットワークを単一の小さなモデルに圧縮する手法

蒸留モデル

アンサンブルモデル

Page 5: Dropout Distillation

(C)Recruit Communications Co., Ltd.

提案手法:Dropout Distillation

概要

•  Dropout学習が暗に獲得しているアンサンブルモデルを圧縮/蒸留(Distillation)する

•  Dropout学習後モデルのMonte-Carlo予測を模倣する新しいモデルを学習する

利点

•  Standard Dropoutと同じ予測計算コストでStandard Dropoutよりも高い予測精度

•  半教師あり学習への応用可能性:教師信号が欠損したデータをDistillationフェーズで活用できる

•  モデル圧縮への応用可能性:Dropuoutで複雑なモデルを学習してDistillationフェーズで圧縮できる

欠点

•  Distillationフェーズに余計な時間がかかる

4

Page 6: Dropout Distillation

(C)Recruit Communications Co., Ltd.

提案手法:Dropout Distillation

5

Dropout 学習済み モデル

生徒モデル

損失関数

Dropout パターン

Page 7: Dropout Distillation

(C)Recruit Communications Co., Ltd.

提案手法:Dropout Distillation

6

教師モデル (Dropout学習済み)

生徒モデル

Distillationフェーズでは 教師モデルの振る舞いを真似るよう

生徒モデルを学習する

通常のDropout学習で 教師となるモデルを獲得

Page 8: Dropout Distillation

(C)Recruit Communications Co., Ltd.

提案手法:Dropout Distillation

7

Distillation用 学習データ

(教師信号無し)

教師モデル (Dropout学習済み)

生徒モデル

生徒モデルの出力

出力間の損失を 埋めるように 生徒モデルの パラメタを更新

教師モデルの出力

生徒モデルには ドロップアウトをかけない

教師モデルにドロップアウトを かけながら出力データを生成

Page 9: Dropout Distillation

(C)Recruit Communications Co., Ltd.

提案手法:Dropout Distillation

8

Distillation用 学習データ

(教師信号無し)

教師モデル (Dropout学習済み)

生徒モデル

生徒モデルの出力

教師モデルの出力

教師モデルと生徒モデルの ネットワーク構造は違っていてもよい

データはDropoutフェーズから流用可 新しいデータを用意するのも可

Page 10: Dropout Distillation

(C)Recruit Communications Co., Ltd.

理想の予測関数

•  全てのDropoutパターンでの出力期待値

•  Dropoutパターンはユニット数に対し指数関数的に増加するので事実上計算できない

問題設定

•  『理想の予測関数』を教師モデルとした生徒モデルを学習したい

どうやって『理想の予測関数』を計算に取り入れるか?

Dropout学習済みモデル

導出

9

理想の予測関数

損失関数

生徒モデル 評価できない

Dropoutパターン

Page 11: Dropout Distillation

(C)Recruit Communications Co., Ltd.

アプローチ

•  『理想の予測関数』をDropout学習済みモデルで置き換える

•  損失関数がBregmanダイバージェンスのとき以下の最小化問題が等価

Bregmanダイバージェンス

•  二乗損失・Logistic損失・KLダイバージェンスなどを一般化したもの

Dropout 学習済み モデル

導出

10

生徒モデル

微分可能な凸関数

Dropoutパターン この表現を形にしたのが スライド5〜8のアルゴリズム

Page 12: Dropout Distillation

(C)Recruit Communications Co., Ltd.

証明

qに関係ないので定数とみなせる

導出

11

本来の最小化対象

Dropout Distillationでの最小化対象

Page 13: Dropout Distillation

(C)Recruit Communications Co., Ltd.

実験1:予測計算手法による性能比較

12

MNIST/CIFAR10/CIFAR100データセットで3予測手法のエラー率比較

•  Standard Dropout

•  Monte-Carlo Dropout(100サンプリング)

•  Dropout Distillation

実験手順

1.  Dropout学習でベースラインモデルを獲得(300エポック)

2.  ベースラインモデルでStandard DropoutとMonte-Carlo Dropoutの性能評価

3.  ベースラインモデルを教師としてDropout Distillation(30エポック)

–  生徒モデルのネットワーク構造はベースラインモデルと同様

–  ベースラインモデルの学習後パラメタで生徒モデルを初期化

–  ベースラインモデルの入力データを流用(pixel毎に確率0.2で値をゼロ化)

4.  生徒モデルでDropout Distillationの性能評価

Page 14: Dropout Distillation

(C)Recruit Communications Co., Ltd.

実験1:予測計算手法による性能比較

13

•  平均エラー率は Standard > Distillation > Monte-Carlo の順

•  Monte-CarloよりDistillationの方がパフォーマンスの分散は低い

Page 15: Dropout Distillation

(C)Recruit Communications Co., Ltd.

実験2:Distillationに使うデータセットによる性能比較

14

Distillationフェーズの入力データについて3シナリオで性能比較

•  [Train] 教師モデルのトレーニングセットをそのまま利用

•  [Pert. Train] 教師モデルのトレーニングセットをピクセル毎に確率0.2で値をゼロ化

•  [Test] テストデータを利用

どのシナリオが 優れているかは

場合による

Page 16: Dropout Distillation

(C)Recruit Communications Co., Ltd.

実験3:モデル圧縮への応用可能性

15

CIFAR10/Quickでユニット数を削減した場合のパフォーマンス変化

•  [Baseline] Dropout学習のみで削減後モデルを学習

•  [Distillation] Dropoutフェーズで削減前モデルを学習してDistillationフェーズで削減後モデルに圧縮

青枠内では『Dropoutフェーズで複雑なモデルを学習 => Distillationフェーズで圧縮』が有効に働いている

FC層からのみユニットを削った場合 全層からフィルタ/ユニットを削った場合

Page 17: Dropout Distillation

(C)Recruit Communications Co., Ltd.

従来手法

•  Standard Dropout:予測時間が短いけど精度が低め

•  Monte-Carlo Dropout:予測時間が長いけど精度が高め

提案手法の主な貢献

•  Standard Dropoutと同じオーダーの予測時間

•  安定してStandard Dropoutよりも良い精度が出る

場合によって効いてくるメリット

•  教師信号が欠損したデータをDistillationフェーズで活用

•  Dropuoutで複雑なモデルを学習してDistillationフェーズで圧縮

まとめ

16