26
Collapsed Variational Inference for Sum-Product Networks Han Zhao 1 , Tameem Adel 2 , Geoff Gordon 1 , Brandon Amos 1 Presented by: Han Zhao Carnegie Mellon University 1 , University of Amsterdam 2 June. 20th, 2016 1 / 26

Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

  • Upload
    others

  • View
    3

  • Download
    0

Embed Size (px)

Citation preview

Page 1: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Collapsed Variational Inference for Sum-ProductNetworks

Han Zhao1, Tameem Adel2, Geoff Gordon1, Brandon Amos1

Presented by: Han Zhao

Carnegie Mellon University1, University of Amsterdam2

June. 20th, 2016

1 / 26

Page 2: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Outline

BackgroundSum-Product NetworksVariational Inference

Collapsed Variational InferenceMotivations and ChallengesEfficient MarginalizationLogarithmic Transformation

Experiments

Summary

2 / 26

Page 3: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Sum-Product NetworksDefinition

A Sum-Product Network (SPN) is a

I Rooted directed acyclic graph of univariate distributions, sumnodes and product nodes.

I Value of a product node is the product of its children.

I Value of a sum node is the weighted sum of its children,where the weights are nonnegative.

I Value of the network is the value at the root.

+

× × ×

Ix1Ix1

Ix2Ix2

w1 w2w3

3 / 26

Page 4: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Sum-Product NetworksMixture of Trees

Each SPN can be decomposed as a mixture of trees:

+

× × ×

Ix1 Ix1 Ix2 Ix2

w1 w2w3

= w1

+

×

Ix1 Ix2

+w2

+

×

Ix1 Ix2

+w3

+

×

Ix1 Ix2

I Each tree is a product of univariate distributions.

I Number of mixture components is Ω(2Depth).

I Each network computes a positive polynomial (posynomial)function of model parameters:

Vroot(x | w) =

τS∑t=1

∏(k,j)∈TtE

wkj

n∏i=1

pt(Xi = xi )

4 / 26

Page 5: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Sum-Product NetworksBayesian Network

Alternatively, each SPN S is equivalent to a Bayesian network Bwith bipartite structure.

+

H

× × ×

+

H1

+

H2

+

H3

+

H4

× × × ×

X1 X3 X2 X4

⇔ HH3H1 H2 H4

X1 X2 X3 X4

I Number of sum nodes in S = Number of hidden variables inB = Θ(|S|). |B| = O(n|S|)

I Number of observable variables in B = Number of variablesmodeled by S.

I Typically number of hidden variables number of observablevariables.

5 / 26

Page 6: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Variational InferenceBrief Introduction

Bayesian Inference:

p(w | x)︸ ︷︷ ︸posterior

∝ p(w)︸ ︷︷ ︸prior

p(x | w)︸ ︷︷ ︸likelihood

Often intractable because of:

I No analytical solution.

I Expensive numerical integration.

General idea: find the best approximation in a tractable family ofdistributions Q:

minimizeq∈Q KL[q(w) || p(w | x)]

Typical choice of approximation families: Mean-field, structuredmean-field, etc.

6 / 26

Page 7: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Variational InferenceBrief Introduction

Variational method: Optimization-based, deterministic approachfor approximate Bayesian inference.

infq∈Q

KL[q(w) || p(w | x)]⇔ supq∈Q

Eq[log p(w, x)] + H[q]

Evidence Lower Bound L:

log p(x) ≥ supq∈Q

Eq[log p(w, x)] + H[q] =: L

7 / 26

Page 8: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Collapsed Variational InferenceMotivations and Challenges

Bayesian inference algorithms for SPNs:

I Flexible at incorporating prior knowledge about the structureof SPNs.

I More robust to overfitting.

8 / 26

Page 9: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Collapsed Variational InferenceMotivations and Challenges

W1 W2 W3 · · · Wm

H1 H2 H3 · · · Hm

X1 X2 X3 · · · Xn

D

I W – Model parameters,global hidden variables.

I H – Assignments of sumnodes, local hiddenvariables.

I X – Observable variables.

I D – Number of instances.

Challenges for standard VB:

I Large number of local hidden variables: number of localhidden variables = Number of sum nodes = Θ(|S|).

I Memory overhead: space complexity O(D|S|).

I Time complexity: O(nD|S|).

9 / 26

Page 10: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Collapsed Variational InferenceMotivations and Challenges

W1 W2 W3 · · · Wm

H1 H2 H3 · · · Hm

X1 X2 X3 · · · Xn

D

I W – Model parameters,global hidden variables.

I H – Assignments of sumnodes, local hiddenvariables.

I X – Observable variables.

I D – Number of instances.

Challenges for standard VB:

I Large number of local hidden variables: number of localhidden variables = Number of sum nodes = Θ(|S|).

I Memory overhead: space complexity O(D|S|).

I Time complexity: O(nD|S|).

10 / 26

Page 11: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Collapsed Variational InferenceMotivations and Challenges

W1 W2 W3 · · · Wm

H1 H2 H3 · · · Hm

X1 X2 X3 · · · Xn

D

I W – Model parameters,global hidden variables.

I H – Assignments of sumnodes, local hiddenvariables.

I X – Observable variables.

I D – Number of instances.

Challenges for standard VB:

I Large number of local hidden variables: number of localhidden variables = Number of sum nodes = Θ(|S|).

I Memory overhead: space complexity O(D|S|).

I Time complexity: O(nD|S|).

11 / 26

Page 12: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Collapsed Variational InferenceContributions

Our contributions:

I We obtain better ELBO L to optimize than L, the oneobtained by mean-field.

I Reduced space complexity: O(D|S|)⇒ O(|S|), spacecomplexity is independent of training size.

I Reduced time complexity: O(nD|S|)⇒ O(D|S|), removingthe explicit dependency on the dimension.

12 / 26

Page 13: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Collapsed Variational InferenceEfficient Marginalization

Recall ELBO in standard VI:

L := Eq(w,h)[log p(w,h, x)] + H[q(w,h)]

Consider the new ELBO in Collapsed VI:

L :=Eq(w)[log p(w, x)] + H[q(w)]

=Eq(w)[log∑h

p(w,h, x)] + H[q(w)]

We can establish the following inequality:

log p(x) ≥ L ≥ L

The new ELBO in Collapsed VI leads to a better lower bound thanthe one used in standard VI!

13 / 26

Page 14: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Collapsed Variational InferenceComparisons

Standard Variational InferenceMean-field assumption: q(w,h) =

∏i q(wi )

∏j q(hj)

ELBO: L := Eq(w,h)[log p(w,h, x)] + H[q(w,h)]

Collapsed Variational Inference for LDA, HDP

Collapsed out global hidden variables: q(h) =∫w q(w,h) dw

ELBO: Lh := Eq(h)[log p(h, x)] + H[q(h)]

Better lower bound: Lh ≥ L

Collapsed Variational Inference for SPN

Collapsed out local hidden variables: q(w) =∑

h q(w,h)ELBO: Lw := Eq(w)[log p(w, x)] + H[q(w)]

Better lower bound: Lw ≥ L

14 / 26

Page 15: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Collapsed Variational InferenceEfficient Marginalization

Time complexity of the exact marginalization incurred incomputing

∑h p(w,h, x):

I Time complexity of marginalization in graphical model G:O(D · 2tw(G)).

I Exact marginalization in BN B with algebraic decisiondiagram as local factors: O(D|B|) = O(nD|S|).

I Exact marginalization in SPN S: O(D|S|).

Space complexity reduction:

I No posterior over h to approximate anymore.

I No variational variables over h needed: O(D|S|)⇒ O(|S|).

15 / 26

Page 16: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Collapsed Variational InferenceEfficient Marginalization

Time complexity of the exact marginalization incurred incomputing

∑h p(w,h, x):

I Time complexity of marginalization in graphical model G:O(D · 2tw(G)).

I Exact marginalization in BN B with algebraic decisiondiagram as local factors: O(D|B|) = O(nD|S|).

I Exact marginalization in SPN S: O(D|S|).

Space complexity reduction:

I No posterior over h to approximate anymore.

I No variational variables over h needed: O(D|S|)⇒ O(|S|).

16 / 26

Page 17: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Collapsed Variational InferenceEfficient Marginalization

Time complexity of the exact marginalization incurred incomputing

∑h p(w,h, x):

I Time complexity of marginalization in graphical model G:O(D · 2tw(G)).

I Exact marginalization in BN B with algebraic decisiondiagram as local factors: O(D|B|) = O(nD|S|).

I Exact marginalization in SPN S: O(D|S|).

Space complexity reduction:

I No posterior over h to approximate anymore.

I No variational variables over h needed: O(D|S|)⇒ O(|S|).

17 / 26

Page 18: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Collapsed Variational InferenceEfficient Marginalization

Time complexity of the exact marginalization incurred incomputing

∑h p(w,h, x):

I Time complexity of marginalization in graphical model G:O(D · 2tw(G)).

I Exact marginalization in BN B with algebraic decisiondiagram as local factors: O(D|B|) = O(nD|S|).

I Exact marginalization in SPN S: O(D|S|).

Space complexity reduction:

I No posterior over h to approximate anymore.

I No variational variables over h needed: O(D|S|)⇒ O(|S|).

18 / 26

Page 19: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Collapsed Variational InferenceEfficient Marginalization

Time complexity of the exact marginalization incurred incomputing

∑h p(w,h, x):

I Time complexity of marginalization in graphical model G:O(D · 2tw(G)).

I Exact marginalization in BN B with algebraic decisiondiagram as local factors: O(D|B|) = O(nD|S|).

I Exact marginalization in SPN S: O(D|S|).

Space complexity reduction:

I No posterior over h to approximate anymore.

I No variational variables over h needed: O(D|S|)⇒ O(|S|).

19 / 26

Page 20: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Collapsed Variational InferenceLogarithmic Transformation

New optimization objective:

maximizeq∈Q Eq(w)[log∑h

p(w,h, x)] + H[q(w)]

which is equivalent to

minimizeq∈Q KL[q(w) || p(w)]− Eq(w)[log p(x | w)]

I p(w) – prior distribution over w, product of Dirichlets.

I q(w) – variational posterior over w, product of Dirichlets.

I p(x | w) – likelihood, not multinomial anymore aftermarginalization.

Non-conjugate q(w) and p(x | w), no analytical solution forEq(w)[log p(x | w)].

20 / 26

Page 21: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Collapsed Variational InferenceLogarithmic Transformation

Key observation:

p(x | w) = Vroot(x | w) =

τS∑t=1

∏(k,j)∈TtE

wkj

n∏i=1

pt(Xi = xi )

is a posynomial function of w.Make a bijective mapping (change of variable): w′ = log(w).

I Dates back to the literature of geometric programming.

I The new objective after transformation is convex in w′.

log p(x | w) = log

τS∑t=1

exp

ct +∑

(k,j)∈TtE

w ′kj

Jensen’s inequality to obtain further lower bound.

21 / 26

Page 22: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Collapsed Variational InferenceLogarithmic Transformation

Further lower bound:

Eq(w)[log p(x | w)] = Eq(w′)[log p(x | w′)] ≥ log p(x | Eq′(w′)[w′])

Relaxed objective:

minimizeq∈Q KL[q(w) || p(w)]︸ ︷︷ ︸Regularity

− log p(x | Eq′(w′)[w′])︸ ︷︷ ︸Data fitting

Roughly, log p(x | Eq′(w′)[w′]) corresponds the log-likelihood bysetting the weights of SPN as the posterior mean of q(w).

Optimized by projected GD.

22 / 26

Page 23: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Collapsed Variational InferenceAlgorithm

I Line 4 – 8 easily parallelizable, distributed version.

I Sample minibatch in Line 4 – 8, stochastic version.

23 / 26

Page 24: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Experiments

I Experiments on 20 data sets, report average log-likelihoods,Wilcoxon ranked test.

I Compared with (O)MLE-SPN and OBMM-SPN.

400 350 300 250 200 150 100 50 0MLE-Projected GD

400

350

300

250

200

150

100

50

0

CV

B-P

roje

cted G

D

MLE-SPN vs CVB-SPN, Avg. log-likelihoods

800 700 600 500 400 300 200 100 0OMLE-Projected GD

800

700

600

500

400

300

200

100

0

OC

VB

-Pro

ject

ed G

D

OMLE-SPN vs OCVB-SPN, Avg. log-likelihoods

24 / 26

Page 25: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Experiments

300 250 200 150 100 50 0OBMM-Moment Matching

300

250

200

150

100

50

0

OC

VB

-Pro

ject

ed

OBMM vs OCVB, Avg. log-likelihoods

25 / 26

Page 26: Collapsed Variational Inference for Sum-Product Networkshzhao1/papers/ICML2016/cvb_spn... · 2019-11-02 · Collapsed Variational Inference for Sum-Product Networks Han Zhao1, Tameem

Summary

Thanks

Q & A

26 / 26