10417/10617 Intermediate Deep Learning: Fall2019rsalakhu/10417/Lectures/Lecture_VAE.pdf• The VAE...

Preview:

Citation preview

10417/10617IntermediateDeepLearning:

Fall2019RussSalakhutdinov

Machine Learning Departmentrsalakhu@cs.cmu.edu

https://deeplearning-cmu-10417.github.io/

Variational Autoencoders

MotivatingExample• Canwegenerateimagesfromnaturallanguagedescriptions? Astopsignisflyingin

blueskies

Apaleyellowschoolbusisflyinginblueskies

Aherdofelephantsisflyinginblueskies

Alargecommercialairplaneisflyinginblueskies

(Mansimov,Parisotto,Ba,Salakhutdinov,2015)

2

OverallModel

VariationalAutoecnoder

StochasticLayer

3

Motivation• Hinton,G.E.,Dayan,P.,Frey,B.J.andNeal,R.,Science1995

Inputdata

h3

h2

h1

v

W3

W2

W1

GenerativeProcessApproximate

Inference

• Kingma&Welling,2014

• Rezende,Mohamed,Daan,2014

• Mnih&Gregor,2014

• Bornschein&Bengio,2015

• Tang&Salakhutdinov,2013

4

VariationalAutoencoders(VAEs)• TheVAEdefinesagenerativeprocessintermsofancestralsamplingthroughacascadeofhiddenstochasticlayers:

h3

h2

h1

v

W3

W2

W1

Eachtermmaydenoteacomplicatednonlinearrelationship

•  Samplingandprobabilityevaluationistractableforeach.

GenerativeProcess

•  denotesparametersofVAE.

•  isthenumberofstochasticlayers.

Inputdata 5

VAE:Example• TheVAEdefinesagenerativeprocessintermsofancestralsamplingthroughacascadeofhiddenstochasticlayers:

Thistermdenotesaone-layerneuralnet.

DeterministicLayer

StochasticLayer

StochasticLayer

•  denotesparametersofVAE.

•  Samplingandprobabilityevaluationistractableforeach.

•  isthenumberofstochasticlayers.

6

RecognitionNetwork• Therecognitionmodelisdefinedintermsofananalogousfactorization:

Inputdata

h3

h2

h1

v

W3

W2

W1

GenerativeProcess

Eachtermmaydenoteacomplicatednonlinearrelationship

•  Theconditionals:

areGaussianswithdiagonalcovariances

ApproximateInference

•  Weassumethat

7

VariationalBound• TheVAEistrainedtomaximizethevariationallowerbound:

Inputdata

h3

h2

h1

v

W3

W2

W1

•  Hardtooptimizethevariationalboundwithrespecttotherecognitionnetwork(high-variance).

•  KeyideaofKingmaandWellingistousereparameterizationtrick.

•  Tradingoffthedatalog-likelihoodandtheKLdivergencefromthetrueposterior.

8

ReparameterizationTrick• AssumethattherecognitiondistributionisGaussian:

withmeanandcovariancecomputedfromthestateofthehiddenunitsatthepreviouslayer.

•  Alternatively,wecanexpressthisintermofauxiliaryvariable:

9

• AssumethattherecognitiondistributionisGaussian:

•  Or

DeterministicEncoder

•  Therecognitiondistributioncanbeexpressedintermsofadeterministicmapping:

Distributionofdoesnotdependon

ReparameterizationTrick

10

ComputingtheGradients•  Thegradientw.r.ttheparameters:bothrecognitionandgenerative:

Gradientscanbecomputedbybackprop

Themappinghisadeterministicneuralnetforfixed.

11

wherewedefinedunnormalizedimportanceweights:

•  VAEupdate:Lowvarianceasitusesthelog-likelihoodgradientswithrespecttothelatentvariables.

•  Thegradientw.r.ttheparameters:recognitionandgenerative:

•  Approximateexpectationbygeneratingksamplesfrom:

ComputingtheGradients

12

VAE:Assumptions•  Rememberthevariationalbound:

•  Thevariationalassumptionsmustbeapproximatelysatisfied.

•  Weshowthatwecanrelaxtheseassumptionsusingatighterlowerboundonmarginallog-likelihood.

•  Theposteriordistributionmustbeapproximatelyfactorial(commonpractice)andpredictablewithafeed-forwardnet.

13

ImportanceWeightedAutoencoders•  Considerthefollowingk-sampleimportanceweightingofthelog-likelihood:

wherearesampledfromtherecognitionnetwork.

Inputdata

h3

h2

h1

v

W3

W2

W1

unnormalizedimportanceweights

14

ImportanceWeightedAutoencoders•  Considerthefollowingk-sampleimportanceweightingofthelog-likelihood:

•  Thisisalowerboundonthemarginallog-likelihood:

•  SpecialCaseofk=1:SameasstandardVAEobjective.

•  UsingmoresamplesàImprovesthetightnessofthebound.15

TighterLowerBound

•  Forallk,thelowerboundssatisfy:

•  Usingmoresamplescanonlyimprovethetightnessofthebound.

•  Moreoverifisbounded,then:

16

ComputingtheGradients•  Wecanusetheunbiasedestimateofthegradientusingreparameterizationtrick:

wherewedefinenormalizedimportanceweights:

17

IWAEsvs.VAEs•  Drawk-samplesformtherecognitionnetwork-  ork-setsofauxiliaryvariables.

•  ObtainthefollowingMonteCarloestimateofthegradient:

•  ComparethistotheVAE’sestimateofthegradient:

18

Firstterm:- Decoder:encouragesthegenerativemodeltoassignhighprobabilitytoeach.

IWAE:Intuition•  Thegradientofthelogweightsdecomposes:

DeterministicEncoder

Deterministicdecoder

Inputdata

h3

h2

h1

v

W3

W2

W1

.-  Encoder:encouragestherecognitionnettoadjustitslatentstateshsothatthegenerativenetworkmakesbetterpredictions.

19

Secondterm:-  Encoder:encouragestherecognitionnetworktohaveaspread-outdistributionoverpredictions.

IWAE:Intuition•  Thegradientofthelogweightsdecomposes:

DeterministicEncoder

Deterministicdecoder

Inputdata

h3

h2

h1

v

W3

W2

W1

20

TwoArchitectures

•  FortheMNISTexperiments,weconsideredtwoarchitectures:

784

200

200

50

DeterministicLayers

1-stochasticlayer

784

200

200

100

100

100

50

2-stochasticlayers

StochasticLayers

DeterministicLayers

DeterministicLayers

21

MNISTResults

22

MNISTResults

23

LatentSpaceRepresentation•  BothVAEsandIWAEstendtolearnlatentrepresentationswitheffectivedimensionsfarbelowtheircapacity.

•  Measuretheactivityofthelatentdimensionuusingthestatistics:

•  Optimizationissue?

•  Thedistributionofconsistoftwoseparatedmodes.

•  Inactivedimensionsàunitsdyingout.

24

IWAEsvs.VAEs

25

IWAEsvs.VAEs

26

OMNIGLOTExperiments

27

ModelingImagePatchesBSDSDataset

•  Model8x8patches.

64

500

40

DeterministicLayer

1-stochasticlayerStochasticLayer

•  Reporttestlog-likelihoodson10^68x8patches,extractedfromBSDStestdataset.

•  EvaluationprotocolestablishedbyUria,MurrayandLarochelle):-  adduniformnoisebetween0and1,divideby256,-  subtractthemeananddiscardingthelastpixel 28

TestLog-probabilitiesModel nats Bits/pixel

RNADE6hiddenlayers(Uriaet.al.2013) 155.2nats 3.55bit/pixel

MoG,200full-covariancemixture(ZoranandWeiss,2012)

152.8nats 3.50bit/pixel

IWAE(k=500) 151.4nats 3.47bit/pixelVAE(k=500) 148.0nats 3.39bit/pixelGSM(GaussianScaleMixture) 142nats 3.25bit/pixel

ICA 111nats 2.54bit/pixelPCA 96nats 2.21bit/pixel

Burda2015 29

LearnedFilters

Burda2015 30

MotivatingExample• Canwegenerateimagesfromnaturallanguagedescriptions? Astopsignisflyingin

blueskies

Apaleyellowschoolbusisflyinginblueskies

Aherdofelephantsisflyinginblueskies

Alargecommercialairplaneisflyinginblueskies

(Mansimov,Parisotto,Ba,Salakhutdinov,2015)

31

OverallModel

VariationalAutoecnoder

StochasticLayer

32

Sequence-to-Sequence• Sequence-to-sequenceframework.(Sutskeveretal.2014;Choetal.2014;Srivastavaetal.2015)

• Caption(y)isrepresentedasasequenceofconsecutivewords.• Image(x)isrepresentedasasequenceofpatchesdrawnoncanvas.

• Attentionmechanismover:- Words:Whichwordstofocusonwhengeneratingapatch- ImageLocationWheretoplacethegeneratedpatchesonthecanvas

33

RepresentingCaptionsBidirectionalRNN

• ForwardRNNreadsthesentenceyfromlefttoright:

• BackwardRNNreadsthesentenceyfromrighttoleft:

• Thehiddenstatesarethenconcatenated:

34

• Ateachstepthemodelgeneratesapxppatch.

DRAWModelwriteoperator:

whosefilterlocationsandscalesarecomputedfrom:

• ItgetstransformedintowxhcanvasusingtwoarraysofGaussianfilterbanks

(Gregoret.al.2015)

35

OverallModel

• GenerativeModel:StochasticRecurrentNetwork,chainedsequenceofVariationalAutoencoders,withasinglestochasticlayer.

StochasticLayer

Gregoret.al.2015

(Mansimov,Parisotto,Ba,Salakhutdinov,2015)

Bidirectional LSTM

36

OverallModel

• GenerativeModel:StochasticRecurrentNetwork,chainedsequenceofVariationalAutoencoders,withasinglestochasticlayer.• RecognitionModel:DeterministicRecurrentNetwork.Gregoret.al.2015

(Mansimov,Parisotto,Ba,Salakhutdinov,2015)

Bidirectional LSTM

StochasticLayer

37

• Attention(alignment):Focusondifferentwordsatdifferenttimestepswhengeneratingpatchesandplacingthemonthecanvas.

Sentencerepresentation:dynamicallyweightedaverageofthehiddenstatesrepresentingwords.

Bahdanauet.al.2015

OverallModel

StochasticLayer

38

GeneratingImages

• Imageisrepresentedasasequenceofpatches(t=1,…T)drawnoncanvas:

39

GeneratingImages

• Imageisrepresentedasasequenceofpatches(t=1,…T)drawnoncanvas:

40

GeneratingImages

• Imageisrepresentedasasequenceofpatches(t=1,…T)drawnoncanvas:

• Inpractice,weusetheconditionalmean:.41

AlignmentModel

• Dynamicsentencerepresentationattimet:weightedaverageofthebi-directionalhiddenstates:

wherethealignmentprobabilitiesarecomputedas:

42

Learning

• Maximizethevariationallowerboundonthemarginallog-likelihoodofthecorrectimagexgiventhecaptiony:

43

Sharpening

• Additionalpostprocessingstep:useanadversarialnetworktrainedonresidualsofaLaplacianpyramidtosharpenthegeneratedimages(Dentonet.al.2015).

44

MSCOCODataset• Contains83Kimages.

Linet.al.2014

• Eachimagecontains5captions.

• Standardbenchmarkdatasetformanyoftherecentimagecaptioningsystems.

45

FlippingColorsAyellowschoolbusparkedintheparkinglot

Aredschoolbusparkedintheparkinglot

Agreenschoolbusparkedintheparkinglot

Ablueschoolbusparkedintheparkinglot

46

FlippingBackgroundsAverylargecommercialplaneflyinginclearskies.

Averylargecommercialplaneflyinginrainyskies.

Aherdofelephantswalkingacrossadrygrassfield.

Aherdofelephantswalkingacrossagreengrassfield.

47

FlippingObjectsThedecadentchocolatedesertisonthetable.

Abowlofbananasisonthetable..

Avintagephotoofacat. Avintagephotoofadog.

48

QualitativeComparisonAgroupofpeoplewalkonabeachwithsurfboardsOurModel LAPGAN(Dentonet.al.2015)

FullyConnectedVAEConv-DeconvVAE

49

VariationalLower-Bound• Wecanestimatethevariationallower-boundontheaveragetestlog-probabilities:

• Atleastwecanseethatwedonotoverfittothetrainingdata,unlikemanyotherapproaches.

Model Training Test

OurModel -1792,15 -1791,53Skipthought-Draw -1794,29 -1791,37noAlignDraw -1792,14 -1791,15

50

NovelSceneCompositionsAtoiletseatsitsopeninthebathroom

AskGoogle?

Atoiletseatsitsopeninthegrassfield

51

Recommended