15
(Talk by Emtiyaz Khan) Bayesian Deep Learning using Weight-Perturbation in Adam (Poster #190) Fast and Scalable Bayesian Deep Learning by Weight-Perturbation in Adam Mohammad Emtiyaz Khan* RIKEN Center for AI Project (AIP), Tokyo, Japan Didrik Nielsen* (AIP RIKEN), Voot Tangkaratt* (AIP RIKEN), Wu Lin (UBC), Yarin Gal (University of Oxford), Akash Srivastava (University of Edinburgh) *Equal Contribution

Fast and Scalable Bayesian Deep Learning by Weight ... · Akash Srivastava (University of Edinburgh) *Equal Contribution (Talk by Emtiyaz Khan) Bayesian Deep Learning using Weight-Perturbation

  • Upload
    others

  • View
    6

  • Download
    0

Embed Size (px)

Citation preview

Page 1: Fast and Scalable Bayesian Deep Learning by Weight ... · Akash Srivastava (University of Edinburgh) *Equal Contribution (Talk by Emtiyaz Khan) Bayesian Deep Learning using Weight-Perturbation

(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)

FastandScalableBayesianDeepLearningbyWeight-PerturbationinAdam

Mohammad EmtiyazKhan*RIKENCenterforAIProject(AIP),Tokyo,Japan

Didrik Nielsen*(AIPRIKEN),Voot Tangkaratt*(AIPRIKEN),WuLin(UBC),Yarin Gal(UniversityofOxford),AkashSrivastava(UniversityofEdinburgh)

*EqualContribution

Page 2: Fast and Scalable Bayesian Deep Learning by Weight ... · Akash Srivastava (University of Edinburgh) *Equal Contribution (Talk by Emtiyaz Khan) Bayesian Deep Learning using Weight-Perturbation

(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)

BayesianDeepLearning

Computeaveragesoverthesamplesfromtheposteriordistribution

2/13

Page 3: Fast and Scalable Bayesian Deep Learning by Weight ... · Akash Srivastava (University of Edinburgh) *Equal Contribution (Talk by Emtiyaz Khan) Bayesian Deep Learning using Weight-Perturbation

(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)

ApproximateBayesianInference

ConvertBayesianinferencetoanoptimizationproblemusingVariational Inference(VI),andthenusegradient-basedmethods foroptimization

3/13

BayesbyBackprop (Blundelletal.2015),PracticalVI(Gravesetal.2011),Black-boxVI(Rangnathan etal.2014)andmanymore….

Page 4: Fast and Scalable Bayesian Deep Learning by Weight ... · Akash Srivastava (University of Edinburgh) *Equal Contribution (Talk by Emtiyaz Khan) Bayesian Deep Learning using Weight-Perturbation

(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)

ApproximateBayesianInferencerequiresmorecomputation,memory,andimplementation

effortthanMLEIsitpossibletoreducethesecosts?

Byreplacinggradientswithnatural-gradients

4/13

Page 5: Fast and Scalable Bayesian Deep Learning by Weight ... · Akash Srivastava (University of Edinburgh) *Equal Contribution (Talk by Emtiyaz Khan) Bayesian Deep Learning using Weight-Perturbation

(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)

MaximumLikelihoodEstimation(MLE)

5/13

RMSprop forMLE

max

NX

i=1

log p(Di|✓) Log-likelihood

Backprop onminibatches

Scalevector(gradient-magnitude)

Adaptivegradientupdate

✓ µ

g 1M

X

i

r✓ log p(Di|✓)

s (1� �)s+ �g2

µ µ+ ↵gps+ �

Page 6: Fast and Scalable Bayesian Deep Learning by Weight ... · Akash Srivastava (University of Edinburgh) *Equal Contribution (Talk by Emtiyaz Khan) Bayesian Deep Learning using Weight-Perturbation

(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)

GaussianMean-FieldVariational Inference

6/13

p(✓) = N (0, I/�) Knownpriorprecision

p(✓|D) =p(D|✓)p(✓)Rp(D|✓)p(✓)d✓⇡ q(✓) = N (µ,�2) Covariancematrix=diag(𝜎")

max

µ,�2L(µ,�2

) :=

NX

i=1

Eq[log p(Di|✓)]�KLhq(✓)kp(✓)

i

Data-fitterm Regularizer

Page 7: Fast and Scalable Bayesian Deep Learning by Weight ... · Akash Srivastava (University of Edinburgh) *Equal Contribution (Talk by Emtiyaz Khan) Bayesian Deep Learning using Weight-Perturbation

(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)

MLEvsGradient-basedVI

7/13

RMSprop forMax-likelihood Gradient-basedVariational Inference

max

NX

i=1

log p(Di|✓)

µ µ+ ↵rµLpsµ + �

� � + ↵r�Lps� + �

(Gravesetal.2011,Blundelletal.2015)

max

µ,�2L(µ,�2

) :=

NX

i=1

Eq[log p(Di|✓)]�KLhq(✓)kp(✓)

i

✓ µ

g 1M

X

i

r✓ log p(Di|✓)

s (1� �)s+ �g2

µ µ+ ↵gps+ �

Page 8: Fast and Scalable Bayesian Deep Learning by Weight ... · Akash Srivastava (University of Edinburgh) *Equal Contribution (Talk by Emtiyaz Khan) Bayesian Deep Learning using Weight-Perturbation

(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)

MLEvsNatural-GradientVI

8/13

RMSprop forMax-likelihood Natural-GradientVI(Khan,Lin2017,Khan,Nielsen2018)

Variational Online-Newton(VON)Khanetal.2017

✓ µ

g 1M

X

i

r✓ log p(Di|✓)

s (1� �)s+ �g2

µ µ+ ↵gps+ �

✓ µ+✏, where ✏ ⇠ N (0, Ns+ �)

g 1M

X

i

r✓ log p(Di|✓)

s (1� �)s+ � 1M

X

i

r2✓✓ log p(Di|✓)

µ µ+ ↵g+�µ/N

s+ �/N

Page 9: Fast and Scalable Bayesian Deep Learning by Weight ... · Akash Srivastava (University of Edinburgh) *Equal Contribution (Talk by Emtiyaz Khan) Bayesian Deep Learning using Weight-Perturbation

(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)

MLEvsNatural-GradientVI

9/13

RMSprop forMax-likelihood Natural-GradientVI(Khan,Lin2017,Khan,Nielsen2018)

✓ µ

g 1M

X

i

r✓ log p(Di|✓)

s (1� �)s+ �g2

µ µ+ ↵gps+ �

✓ µ+✏, where ✏ ⇠ N (0, Ns+ �)

g 1M

X

i

r✓ log p(Di|✓)

µ µ+ ↵g+�µ/N

s+ �/N

Variational OnlineGauss-Newton(VOGN)

s (1� �)s+ � 1M

X

i

hr✓ log p(Di|✓)

i2

Page 10: Fast and Scalable Bayesian Deep Learning by Weight ... · Akash Srivastava (University of Edinburgh) *Equal Contribution (Talk by Emtiyaz Khan) Bayesian Deep Learning using Weight-Perturbation

(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)

MLEvsNatural-GradientVI

10/13

RMSprop forMax-likelihood Natural-GradientVI(Khan,Lin2017,Khan,Nielsen2018)

✓ µ

g 1M

X

i

r✓ log p(Di|✓)

s (1� �)s+ �g2

µ µ+ ↵gps+ �

✓ µ+✏, where ✏ ⇠ N (0, Ns+ �)

g 1M

X

i

r✓ log p(Di|✓)

s (1� �)s+ �g2

Variational RMSprop (Vprop)

µ µ+ ↵g+�µ/Nps+ �/N

Page 11: Fast and Scalable Bayesian Deep Learning by Weight ... · Akash Srivastava (University of Edinburgh) *Equal Contribution (Talk by Emtiyaz Khan) Bayesian Deep Learning using Weight-Perturbation

(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)

Variational Adam(Vadam)

11/13

AdamforMax-likelihood Vadam forVI

✓ µ

g 1M

X

i

r✓ log p(Di|✓)

s (1� �)s+ �g2

m (1� �)m+ �g

m m/(1� (1� �)t)

s s/(1� (1� �)t)

µ µ+ ↵mps+ �

✓ µ+✏, where N (0, Ns+ �)

g 1M

X

i

r✓ log p(Di|✓)

s (1� �)s+ �g2

m (1� �)m+ �(g+�µ/N)

m m/(1� (1� �)t)

s s/(1� (1� �)t)

µ µ+ ↵mp

s+ �/N

Page 12: Fast and Scalable Bayesian Deep Learning by Weight ... · Akash Srivastava (University of Edinburgh) *Equal Contribution (Talk by Emtiyaz Khan) Bayesian Deep Learning using Weight-Perturbation

(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)

Summary:UncertaintyusingAdam

Perturbtheweightsbeforebackprop.Chooseasmallminibatch size.

12/13

Page 13: Fast and Scalable Bayesian Deep Learning by Weight ... · Akash Srivastava (University of Edinburgh) *Equal Contribution (Talk by Emtiyaz Khan) Bayesian Deep Learning using Weight-Perturbation

(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190) 13/13

Bayesianlogisticregressionon“Breast-Cancer”(N=683,D=8)

Aswereducetheminibatch size,

Vadam givessimilarperformanceas

VOGN.

ErrorinPo

steriorA

pproximation

VOGN

Page 14: Fast and Scalable Bayesian Deep Learning by Weight ... · Akash Srivastava (University of Edinburgh) *Equal Contribution (Talk by Emtiyaz Khan) Bayesian Deep Learning using Weight-Perturbation

(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)

0.0 0.5 1.0 1.5 2.0Epoch

0.0

0.2

0.4

0.6

0.8

1.0

log 2

loss

BBVI

Vadam

CVI-GGN

14/13

1layer64hiddenUnitswithReLu on

BreastCancer[N=683,D=10]

BBVIVadamVOGN

VOGNshowsfastconvergence

Page 15: Fast and Scalable Bayesian Deep Learning by Weight ... · Akash Srivastava (University of Edinburgh) *Equal Contribution (Talk by Emtiyaz Khan) Bayesian Deep Learning using Weight-Perturbation

(TalkbyEmtiyazKhan)BayesianDeepLearningusingWeight-PerturbationinAdam(Poster#190)

FastandScalableBayesianDeepLearningbyWeight-PerturbationinAdam

Postertonight(HallB#190)Codeavailableathttps://github.com/emtiyaz/vadam/

Alsocheckout“NoisyNatural-gradientasVI”byZhangetal.atthisconferene

13/15