MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKS
Abdullah-Al-Zubaer Imran and Demetri Terzopoulos
UCLA Computer Graphics & Vision Laboratory
• Limited training data
• Leveraging by unlabeled examples
• Unsupervised learning
• Semi-supervised learning
• Latent code helps understand the models
12/17/2019 Multi-Adversarial Variational Autoencoder Networks 2
Why Generative Modeling?
12/17/2019Multi-Adversarial Variational Autoencoder Networks
3
Deep Generative Models: VAE
[Kingma et al. 2013]
• Likelihood maximization
• Encoder: variational inference
• Decoder: sample generation
• Efficient variational inference
• Blurry samples
• Re-parameterization: z = μ + σ⊙ϵ, ϵ∼Normal(0,1)
• Losses• Reconstruction: 𝐸𝑞(𝑧|𝑥) log 𝑝 𝑥 𝑧
• Regularization: −𝐾𝐿(𝑞(𝑧|𝑥)||𝑝(𝑧))
12/17/2019Multi-Adversarial Variational Autoencoder Networks
4
Deep Generative Models: GAN
[Goodfellow et al. 2014, Radford et al. 2015]
• Mini-max game
• Generator maps latent variables to data samples
• Discriminator distinguishes generated and real samples
• Sharpest image generation
• Unstable and difficult to optimize
• Losses
𝐷𝑚𝑎𝑥𝑉 𝐷 = 𝐸𝑥~𝑝𝑑𝑎𝑡𝑎(𝑥) 𝑙𝑜𝑔𝐷 𝑥 + 𝐸𝑥~𝑝𝑧(𝑧) log(1 − 𝐷 𝐺(𝑧)
𝐺𝑚𝑖𝑛𝑉 𝐺 = 𝐸𝑥~𝑝𝑧(𝑧) log(1 − 𝐷 𝐺(𝑧)
12/17/2019Multi-Adversarial Variational Autoencoder Networks
5
Deep Generative Models: PixelRNN
[Oord et al. 2016]
• Autoregressive model
• Simple and stable training process
• Inefficient sampling
• Assign probability to every pixel in the image
• Softmax loss
𝑝 𝑥 =
𝑖=1
𝑛2
𝑝(𝑥𝑖|𝑥1, … , 𝑥𝑖−1)
𝑝 𝑥𝑖, 𝑅|𝑥<𝑖 𝑝(𝑥𝑖, 𝐺|𝑥<𝑖 , 𝑥𝑖,𝑅)𝑝(𝑥𝑖, 𝐵|𝑥<𝑖 , 𝑥𝑖,𝑅 , 𝑥𝑖,𝐺)
Aim
12/17/2019Multi-Adversarial Variational Autoencoder Networks
6
VAE-GAN
PixelGAN Autoencoder
• Improving the deep generative models
• Evaluation measures
[Larsen et al. 2016, Makhzani et al. 2017]
Primary Aim
(Efficient and stable generative modeling
for medical image analysis)
✓ Combining generative models
✓ High quality image generation
✓ Learning from limited labeled data
Proposed: MAVENs• Highlights
• Ensemble of multiple discriminators in VAE-GAN• Joint image generation and classification
• Motivation• Instability in generative models
• Mode collapsed generation
• Poor image quality in VAE
• Small labeled data
• Objective• Improve samples and semi-supervised classification
• Unified generative model
• Variational inference with adversarial learning
12/17/2019 7
Basic comparisons of MAVEN with GAN, VAE, and VAE-GAN
Multi-Adversarial Variational Autoencoder Networks
GAN-mode collapsed generation
MAVENs: Architecture
12/17/2019Multi-Adversarial Variational Autoencoder Networks
8
MAVENs: Objectives
12/17/2019 Multi-Adversarial Variational Autoencoder Networks 9
Discriminator Loss
Supervised:
𝐿𝐷𝑠𝑢𝑝𝑒𝑟𝑣𝑖𝑠𝑒𝑑 = − 𝝚𝑥,𝑦~𝑝𝑑𝑎𝑡𝑎log[𝑝 𝑦 = 𝑖 𝑥, 𝑖 < 𝑛 + 1])]
Unsupervised:
𝐿𝐷𝑟𝑒𝑎𝑙 = − 𝝚𝑥~𝑝𝑑𝑎𝑡𝑎 log[1 − 𝑝 𝑦 = 𝑛 + 1 𝑥)]
𝐿𝐷𝑓𝑎𝑘𝑒_𝐺 = − 𝝚 ො𝑥~𝐺 log[𝑝 𝑦 = 𝑛 + 1 ෝ𝑥)]
𝐿𝐷𝑓𝑎𝑘𝑒_𝐸 = − 𝝚 𝑥~𝐺 log[𝑝 𝑦 = 𝑛 + 1 𝑥)]
Generator Loss
𝐿𝐺𝑓𝑎𝑘𝑒_𝐺 = − 𝝚 ො𝑥~𝐺 log[1 − 𝑝 𝑦 = 𝑛 + 1 ෝ𝑥)]
𝐿𝐺𝑓𝑎𝑘𝑒_𝐸 = − 𝝚 𝑥~𝐺 log[1 − 𝑝 𝑦 = 𝑛 + 1 𝑥)]
𝐿𝐺𝑓𝑒𝑎𝑡𝑢𝑟𝑒 = 𝝚𝑥~𝑝𝑑𝑎𝑡𝑎𝑓 𝑥 − 𝝚 ො𝑥~𝐺𝑓(ො𝑥) 2
2
Encoder Loss
𝐿𝐸𝑓𝑒𝑎𝑡𝑢𝑟𝑒 = 𝝚𝑥~𝑝𝑑𝑎𝑡𝑎𝑓 𝑥 − 𝝚 𝑥~𝐺𝑓( 𝑥) 2
2
𝐿𝐸𝐾𝐿 = −𝝚𝑞𝞴 𝑧 𝑥 𝑙𝑜𝑔𝑝(𝑧)
𝑞𝞴(𝑧|𝑥)
MAVENs: Implementation Details
• Datasets• SVHN (32 x 32 x 3) [street view digits]
• CIFAR10 (32 x 32 x 3) [outdoor natural images]
• Chest X-ray (128 x 128 x 1) [normal, bacterial and virus-pneumonia]
• Baselines: DC-GAN and VAE-GAN
• MAVENs with 2, 3, and 5 discriminators• Feedback as mean or random selection
• Merely with 10% training data with their corresponding label information
12/17/2019 Multi-Adversarial Variational Autoencoder Networks 10
MAVENs: Evaluations
12/17/2019Multi-Adversarial Variational Autoencoder Networks
11
• Image quality• Fréchet Inception Distance (FID)
• Activation from pool3 of inception-v3 model
𝐹𝐼𝐷 = µ𝑑𝑎𝑡𝑎 − µ𝑓𝑎𝑘𝑒2+ 𝑇𝑟 𝞢𝑑𝑎𝑡𝑎 + 𝞢𝑓𝑎𝑘𝑒 − 2(𝞢𝑑𝑎𝑡𝑎𝞢𝑓𝑎𝑘𝑒)
1/2
• Descriptive Distribution Distance (DDD)
• Comparing first four moments of the two distributions
𝐷𝐷𝐷 =
𝑖=1
𝑖=4
−𝑙𝑜𝑔𝑤𝑖 µ𝑖 𝑑𝑎𝑡𝑎 − µ𝑖 𝑓𝑎𝑘𝑒
• Classification• Overall accuracy
• Class-wise F1 scoring
𝐹1 =2 ∗𝑝𝑟𝑒𝑐𝑖𝑠𝑖𝑜𝑛 ∗𝑟𝑒𝑐𝑎𝑙𝑙
𝑝𝑟𝑒𝑐𝑖𝑠𝑖𝑜𝑛+𝑟𝑒𝑐𝑎𝑙𝑙
Model FID Score DDD Score
DC-GAN 16.789±0.303 0.343
VAE-GAN 13.252±0.001 0.329
MAVEN-mean2D 11.675±0.001 0.309
MAVEN-mean3D 11.515±0.065 0.300
MAVEN-mean5D 10.909±0.001 0.294
MAVEN-rand2D 11.384±0.001 0.316
MAVEN-rand3D 10.791±0.029 0.357
MAVEN-rand5D 11.052±0.751 0.323
12/17/2019 Multi-Adversarial Variational Autoencoder Networks 12
SVHN Results: Generated Samples
MAVEN-mean2D MAVEN-mean3D MAVEN-mean5D
MAVEN-rand2D MAVEN-rand3D MAVEN-rand5D
SVHN Results: Classification
12/17/2019Multi-Adversarial Variational Autoencoder Networks
13
Model Acc F1 Scores0 1 2 3 4 5 6 7 8 9
DC-GAN 0.876 0.860 0.920 0.890 0.840 0.890 0.870 0.830 0.890 0.820 0.840
VAE-GAN 0.901 0.900 0.940 0.930 0.860 0.920 0.900 0.860 0.910 0.840 0.850
MAVEN-
mean2D
0.909 0.890 0.930 0.940 0.890 0.930 0.900 0.870 0.910 0.870 0.890
MAVEN-
mean3D
0.909 0.910 0.940 0.940 0.870 0.920 0.890 0.870 0.920 0.870 0.860
MAVEN-
mean5D
0.905 0.910 0.930 0.930 0.870 0.930 0.900 0.860 0.910 0.860 0.870
MAVEN-
rand2D
0.905 0.910 0.930 0.940 0.870 0.930 0.890 0.860 0.920 0.850 0.860
MAVEN-
rand3D
0.907 0.890 0.910 0.920 0.870 0.900 0.870 0.860 0.900 0.870 0.890
MAVEN-
rand5D
0.903 0.910 0.930 0.940 0.860 0.910 0.890 0.870 0.920 0.850 0.870
CIFAR10 Results: Generated Samples
12/17/2019Multi-Adversarial Variational Autoencoder Networks
14
Model FID Score DDD Score
DC-GAN 61.293±0.209 0.265
VAE-GAN 15.511±0.125 0.224
MAVEN-mean2D 12.743±0.242 0.223
MAVEN-mean3D 11.316±0.808 0.190
MAVEN-mean5D 12.123±0.140 0.207
MAVEN-rand2D 12.820±0.584 0.194
MAVEN-rand3D 12.620±0.001 0.202
MAVEN-rand5D 18.509±0.001 0.215
MAVEN-mean2D MAVEN-mean3D MAVEN-mean5D
MAVEN-rand2D MAVEN-rand3D MAVEN-rand5D
CIFAR10 Results: Classification
12/17/2019 Multi-Adversarial Variational Autoencoder Networks 15
Model Acc F1 Scoresairplan autom
o
bird cat deer dog frog horse ship truck
DC-GAN 0.713 0.,760 0.840 0.560 0.510 0.660 0.590 0.780 0.780 0.810 0.810
VAE-
GAN
0.743 0.770 0.850 0.640 0.560 0.690 0.620 0.820 0.770 0.860 0.830
MAVEN-
mean2D
0.761 0.800 0.860 0.650 0.590 0.750 0.680 0.810 0.780 0.850 0.850
MAVEN-
mean3D
0.759 0.770 0.860 0.670 0.580 0.700 0.690 0.800 0.810 0.870 0.830
MAVEN-
mean5D
0.771 0.800 0.860 0.650 0.610 0.710 0.640 0.810 0.790 0.880 0.820
MAVEN-
rand2D
0.757 0.780 0.860 0.650 0.530 0.720 0.650 0.810 0.800 0.870 0.860
MAVEN-
rand3D
0.756 0.780 0.860 0.640 0.580 0.720 0.650 0.830 0.800 0.870 0.830
MAVEN-
rand5D
0.762 0.810 0.850 0.680 0.600 0.720 0.660 0.840 0.800 0.850 0.820
Model Vs Real Distributions: Good Match
12/17/2019 Multi-Adversarial Variational Autoencoder Networks 16
SVHN CIFAR10
CXR Results: Generated Samples
12/17/2019Multi-Adversarial Variational Autoencoder Networks
17
Model FID Score DDD Score
DC-GAN 152.511±0.370 0.145
VAE-GAN 141.422±0.580 0.107
MAVEN-mean2D 141.339±0.420 0.138
MAVEN-mean3D 140.865±0.983 0.018
MAVEN-mean5D 147.316±1.169 0.100
MAVEN-rand2D 154.501±0.345 0.038
MAVEN-rand3D 158.749±0.297 0.179
MAVEN-rand5D 152.778±1.254 0.180
MAVEN-mean2D MAVEN-mean3D MAVEN-mean5D
MAVEN-rand2D MAVEN-rand3D MAVEN-rand5D
CXR Results: Classification
12/17/2019 Multi-Adversarial Variational Autoencoder Networks 18
Model Vs Real Distributions: Not-So-Good Match
12/17/2019 Multi-Adversarial Variational Autoencoder Networks 19
CXR
Conclusions & Future Work
Significance
New generative model
Improved image quality and classification
Evaluation measure for deep generative models
Limitation
Performance for medical image data
Execution time
What’s Next
Hyper-parameters for medical images
Constrained generation
Complex image analysis tasks
Generative multi-tasking
12/17/2019 Multi-Adversarial Variational Autoencoder Networks 20
MULTI-ADVERSARIAL VARIATIONAL AUTO-ENCODER NETWORKS
Abdullah-Al-Zubaer Imran and Demetri Terzopoulos
UCLA Computer Graphics & Vision Laboratory
Questions?