View
222
Download
0
Embed Size (px)
Citation preview
Learning Structured Prediction Models:
A Large Margin Approach
Ben TaskarU.C. Berkeley
Vassil Chatalbashev Michael Collins Carlos Guestrin Dan Klein
Daphne Koller Chris Manning
“Don’t worry, Howard. The big questions are multiple choice.”
Handwriting recognition
brace
Sequential structure
x y
Object segmentation
Spatial structure
x y
Natural language parsing
The screen was a sea of red
Recursive structure
x y
Disulfide connectivity prediction
RSCCPCYWGGCPWGQNCYPEGCSGPKV
Combinatorial structure
x y
Outline Structured prediction models
Sequences (CRFs) Trees (CFGs) Associative Markov networks (Special MRFs) Matchings
Geometric View Structured model polytopes Linear programming inference
Structured large margin estimation Min-max formulation Application: 3D object segmentation Certificate formulation Application: disulfide connectivity prediction
Structured models
Mild assumption:
linear combination
space of feasible outputs
scoring function
Chain Markov Net (aka CRF*)
a-z
a-z
a-z
a-z
a-z
y
x
(xi,yi) = exp{ wf(xi,yi)}
(yi,yi+1) = exp{ wf (yi,yi+1)}
P(y|x) i (xi,yi) i (yi,yi+1)
f(y,y’) = I(y=‘z’,y’=‘a’)
*Lafferty et al. 01
f(x,y) = I(xp=1, y=‘z’)
Chain Markov Net (aka CRF*)
a-z
a-z
a-z
a-z
a-z
y
x
i (xi,yi) = exp{ w i f(xi,yi)}
i (yi,yi+1) = exp{ w i f (yi,yi+1)}
P(y|x) i (xi,yi) i (yi,yi+1)
*Lafferty et al. 01
= exp{wTf(x,y)}
f(x,y) = #(y=‘z’,y’=‘a’)
f(x,y) = #(xp=1, y=‘z’)
w = [… , w , … , w, …]
f(x,y) = [… , f(x,y) , … , f(x,y) ,
…]
Associative Markov Nets
Point featuresspin-images, point height
Edge featureslength of edge, edge orientation
yi
yj
ij
i
“associative” restriction
PCFG
#(NP DT NN)
…
#(PP IN NP)
…
#(NN ‘sea’)
Disulfide bonds: non-bipartite matching
1
2 3
4
6 5
RSCCPCYWGGCPWGQNCYPEGCSGPKV 1 2 3 4 5 6
61
2
4 53
Fariselli & Casadio `01, Baldi et al. ‘04
Scoring function
RSCCPCYWGGCPWGQNCYPEGCSGPKV 1 2 3 4 5 6
RSCCPCYWGGCPWGQNCYPEGCSGPKV 1 2 3 4 5 6
1
2 3
4
6 5
String features: residues, physical properties
Structured models
Mild assumption:
Another mild assumption:
linear programming
space of feasible outputs
scoring function
MAP inference linear program
LP inference for Chains Trees Associative Markov Nets Bipartite Matchings …
Markov Net Inference LP
0 0 0 0
0 0 0 0
0 1 0 0
0 0 0 0
0
0
1
0
0 1 0 0
Has integral solutions y for chains, treesGives upper bound for general networks
Associative MN Inference LP
For K=2, solutions are always integral (optimal) For K>2, within factor of 2 of optimal Constraint matrix A is linear in number of nodes and
edges, regardless of the tree-width
“associative” restriction
Other Inference LPs Context-free parsing
Dynamic programs
Bipartite matching
Network flow
Many other combinatorial problems
Outline Structured prediction models
Sequences (CRFs) Trees (CFGs) Associative Markov networks (Special MRFs) Matchings
Geometric View Structured model polytopes Linear programming inference
Structured large margin estimation Min-max formulation Application: 3D object segmentation Certificate formulation Application: disulfide connectivity prediction
Learning w
Training example (x, y*)
Probabilistic approach:
Maximize conditional likelihood
Problem: computing Zw(x) is #P-complete
Geometric Example
Training data:
Goal: Learn w s.t. wTf( , y*) points the “right” way
OCR Example We want:
argmaxword wT f( ,word) = “brace”
Equivalently:wT
f( ,“brace”) > wT f( ,“aaaaa”)
wT f( ,“brace”) > wT
f( ,“aaaab”)
…wT
f( ,“brace”) > wT f( ,“zzzzz”)
a lot!
Large margin estimation Given training example (x, y*), we want:
*Taskar et al. 03
Maximize margin Mistake weighted margin:
# of mistakes in y
Large margin estimation Brute force enumeration
Min-max formulation
‘Plug-in’ linear program for inference
Min-max formulation
LP inference
Assume linear loss (Hamming):
Inference
Min-max formulation
By strong LP duality
Minimize jointly over w, z
Min-max formulation
Formulation produces compact QP for Low-treewidth Markov networks Associative Markov networks Context free grammars Bipartite matchings Any problem with compact LP inference
3D Mapping
Laser Range Finder
GPS
IMU
Data provided by: Michael Montemerlo & Sebastian Thrun
Label: ground, building, tree, shrub Training: 30 thousand points Testing: 3 million points
Segmentation results
Hand labeled 180K test pointsModel
Accuracy
SVM 68%
V-SVM
73%
M3N 93%
Fly-through
Certificate formulation Non-bipartite matchings:
O(n3) combinatorial algorithm No polynomial-size LP known
Spanning trees No polynomial-size LP known Simple certificate of optimality
Intuition: Verifying optimality easier than optimizing
Compact optimality condition of y* wrt.
1
2 3
4
6 5
ijkl
Certificate for non-bipartite matching
Alternating cycle: Every other edge is in matching
Augmenting alternating cycle: Score of edges not in matching greater than edges in matching
Negate score of edges not in matching Augmenting alternating cycle = negative length alternating
cycle
Matching is optimal no negative alternating cycles
1
2 3
4
6 5
Edmonds ‘65
Certificate for non-bipartite matching
Pick any node r as root
= length of shortest alternating
path from r to j
Triangle inequality:
Theorem:
No negative length cycle distance function d exists
Can be expressed as linear constraints: O(n) distance variables, O(n2) constraints
1
2 3
4
6 5
Certificate formulation
Formulation produces compact QP for Spanning trees Non-bipartite matchings Any problem with compact optimality condition
Disulfide connectivity prediction Dataset
Swiss Prot protein database, release 39 Fariselli & Casadio 01, Baldi et al. 04
446 sequences (4-50 cysteines) Features: window profiles (size 9) around each pair Two modes: bonded state known/unknown
Comparison: SVM-trained weights (ignoring constraints during
learning) DAG Recursive Neural Network [Baldi et al. 04]
Our model: Max-margin matching using RBF kernel Training: off-the-shelf LP/QP solver CPLEX (~1 hour)
Known bonded state
Bonds
SVM DAG RNN[Baldi et al.,
04]
Max-margin matching
2 0.63 / 0.63 0.74 / 0.74 0.77 / 0.77
3 0.51 / 0.38 0.61 / 0.51 0.62 / 0.52
4 0.34 / 0.12 0.44 / 0.27 0.51 / 0.36
5 0.31 / 0.07 0.41 / 0.11 0.43 / 0.16
Precision / Accuracy
4-fold cross-validation
Unknown bonded state
Bonds
DAG RNN[Baldi et al., 04]
Max-margin matching
2 0.49 / 0.59 / 0.40
0.57 / 0.59 / 0.44
3 0.45 / 0.50 / 0.32
0.48 / 0.52 / 0.28
4 0.37 / 0.36 / 0.15
0.39 / 0.40 / 0.14
5 0.31 / 0.28 / 0.03
0.31 / 0.33 / 0.07
Precision / Recall / Accuracy
4-fold cross-validation
Formulation summary
Brute force enumeration
Min-max formulation ‘Plug-in’ convex program for inference
Certificate formulation Directly guarantee optimality of y*
Estimation
Generative
Discriminative
Local Global
P(x,y)
P(y|x)
P(z) = i P(zi|z) P(z) = 1/Z c (zc)
HMMsPCFGs
MEMMs CRFs
MRFs
Margin
Omissions Formulation details
Kernels Multiple examples Slacks for non-separable case
Approximate learning of intractable models General MRFs Learning to cluster
Structured generalization bounds
Scalable algorithms (no QP solver needed) Structured SMO (works for chains, trees) Structured EG (works for chains, trees) Structured PG (works for chains, matchings, AMNs, …)
Current Work Learning approximate energy functions
Protein folding Physical processes
Semi-supervised learning Hidden variables Mixing labeled and unlabeled data
Discriminative structure learning Using sparsifying priors
Conclusion Two general techniques for structured large-
margin estimation Exact, compact, convex formulations Allow efficient use of kernels Tractable when other estimation methods are
not Structured generalization bounds Efficient learning algorithms Empirical success on many domains
Papers at http://www.cs.berkeley.edu/~taskar
Duals and Kernels
Kernel trick works! Scoring functions (log-potentials) can use
kernels Same for certificate formulation
Handwriting Recognition
Length: ~8 charsLetter: 16x8 pixels 10-fold Train/Test5000/50000
letters600/6000 words
Models: Multiclass-SVMs* CRFs M3 nets
*Crammer & Singer 01
0
5
10
15
20
25
30
CRFsMC–SVMs M^3 nets
Te
st e
rro
r (a
vera
ge
pe
r-c
ha
ract
er) raw
pixelsquadratickernel
cubickernel
45% error reduction over linear CRFs33% error reduction over multiclass
SVMs
better
0
5
10
15
20
Tes
t Err
or
SVMs RMNS M^3Ns
Hypertext Classification WebKB dataset
Four CS department websites: 1300 pages/3500 links Classify each page: faculty, course, student, project, other Train on three universities/test on fourth
53% error reduction over SVMs
38% error reduction over RMNs
relaxed dual
*Taskar et al 02
better
loopy belief propagation
Projected Gradient
Projecting y’ onto constraints:
min-cost convex flow for Markov nets, matchingsConvergence: same as steepest gradientConjugate gradient also possible (two-metric proj.)
yk
yk+1
yk+2yk+3
yk+4
Min-Cost Flow for Markov Chains
Capacities = C Edge costs = For edges from node s, to node t, cost = 0
z
s t
a-z
a-z
a-z
a-z
a-z
a
z
a
z
a
z
a
z
a
Min-Cost Flow for Bipartite Matchings
st
Capacities = C Edge costs = For edges from node s, to node t, cost = 0
CFG Chart
CNF tree = set of two types of parts: Constituents (A, s, e) CF-rules (A B C, s, m, e)
CFG Inference LP
inside
outside
Has integral solutions y for trees