Upload
others
View
4
Download
0
Embed Size (px)
Citation preview
A Distributed Method for FittingLaplacian Regularized Stratified Models
Jonathan Tuck Shane Barratt Stephen Boyd
International Conference on Statistical Optimization and LearningBeijing Jiatong University, June 19 2019
Outline
Stratified models
Data models
Regularization graphs
Solution method
Examples
Conclusions
Stratified models 2
Data records
I Data records have form (z , x , y) ∈ Z × X × YI z ∈ Z = {1, . . . ,K} are categorical features (we’ll stratify over)I x ∈ X are the other featuresI y ∈ Y is the label/dependent variable
Stratified models 3
Stratified model
I Fit a model to data (z , x , y)
I Stratified model: fit different model for each value of z
I θk is model parameter for z = k
I θ = (θ1, . . . , θK ) ∈ Θ ⊆ RKn parameterizes the stratified modelI Old idea [Kernan 99]
I Example: stratified regression model ŷ = xT θz
Stratified models 4
Example
0.0 0.2 0.4 0.6 0.8 1.0
10
5
0
5
10z=1z=-1
ŷ = θ1 + θ2x + θ3z , z ∈ {−1, 1}
Stratified models 5
Example
0.0 0.2 0.4 0.6 0.8 1.0
10
5
0
5
10z=1z=-1
ŷ =
{θ−1,1 + θ−1,2x z = −1θ1,1 + θ1,2x z = 1
Stratified models 6
Stratified model
I Stratified model is simple function of x (e.g., linear), arbitrary function of zI Examples: separate models for
– z ∈ {Male, Female}– z ∈ {Monday, . . . , Sunday}
I If K is large, might not have enough training data to fit θkI Extreme case: no training data for some values of z
I We’ll add regularization so nearby θk ’s are close
Stratified models 7
Stratified model with Laplacian regularization
I Choose θ = (θ1, . . . , θK ) to minimize
N∑i=1
l(θzi , xi , yi ) +K∑
k=1
r(θk) +K∑
u,v=1
Wuv‖θu − θv‖2
I l is loss function, r is (local) regularization
I Last term is Laplacian regularization
I Wuv ≥ 0 are edge weights of graph with node set ZI Convex problem when l , r convex in θ
Stratified models 8
Stratified model with Laplacian regularization
I Graph encodes prior that nearby values of z should have similar models
I Can be used to capture periodicities, other structureI Examples:
– θmale and θfemale should be close– θjan should be close to θfeb and θdec
I Model for each value of z ‘borrows strength’ from its neighbors
I Works even when there’s no data for some values of z
I As Wuv → 0, get traditional (unregularized) stratified modelI As Wuv →∞, get common model (θ1 = · · · = θK )
Stratified models 9
Related work
I Network lasso [Hallac 15])
I Pliable lasso [Tibshirani 17]
I Varying-coefficient models [Hastie 93, Fan 08]
I Multi-task learning [Caruana 97]
Stratified models 10
Outline
Stratified models
Data models
Regularization graphs
Solution method
Examples
Conclusions
Data models 11
Point estimate: Predict y
I Regression: X = Rn, Y = R– l(θ, x , y) = p(xT θ − y), p is penalty function– ŷ = xT θz
I Classification: X = Rn, Y = {−1, 1}– l(θ, x , y) = p(yxT θ)– ŷ = sign(xT θz)
I M-class classification: X ∈ Rn, Y = {1, . . . ,M}– l(θ, x , y) = py (x
T θ), θ ∈ Rn×M , py is penalty function for class y– ŷ = argmax(xT θz)
Data models 12
Conditional distribution estimate: Predict prob(y | x , z)
I Multinomial logistic regression: X = Rn, Y = {1, . . . ,M}I Conditional distribution:
prob(y | x , z) = exp(xT θz)y∑M
j=1 exp(xT θz)j
, y = 1, . . . ,M
I Loss function (convex in θ):
l(θ, x , y) = log
M∑j=1
exp(xT θ)j
− (xT θ)yData models 13
Distribution estimate: Predict p(y | z)
I Gaussian distribution: Y = Rm
I Density:
p(y | z) = (2π)−m/2 det(Σ)−1/2 exp(−1/2(y − µ)TΣ−1(y − µ))
I Use natural parameter θ = (S , ν) = (Σ−1,Σ−1µ) (so Σ = S−1, µ = S−1ν)
I Loss function (convex in θ):
l(θ, y) = − log detS + yTSy − 2yTν + νTS−1ν
Data models 14
Outline
Stratified models
Data models
Regularization graphs
Solution method
Examples
Conclusions
Regularization graphs 15
Path graph
0 1 2 3 4 5
I Models time, distance, . . .
I Yields time-varying, distance-varying models
Regularization graphs 16
Cycle graph
M
TuW
Th
F
SaSu
I Yields diurnal, weekly, seasonally-varying models
Regularization graphs 17
Tree graph
CEO
CTO
SVP1
COO
SVP2
VP1 VP2
SVP3
VP3
I Yields hierarchical models
Regularization graphs 18
Grid graph
0,0
0,1
0,2
0,3
1,0
1,1
1,2
1,3
2,0
2,1
2,2
2,3
3,0
3,1
3,2
3,3
I Yields (2D) space-varying models
Regularization graphs 19
Products of graphs
M,1
F,1
M,2
F,2
· · ·
· · ·
M,99
F,99
M,100
F,100
I Z = {Male, Female} × {1, 2, . . . , 99, 100} (sex × age)I Yields sex, age-varying models
Regularization graphs 20
Outline
Stratified models
Data models
Regularization graphs
Solution method
Examples
Conclusions
Solution method 21
Fitting problem
I To fit stratified model, minimize `(θ) + r(θ) + L(θ)I `(θ) =
∑Kk=1 `k(θk) is loss, `k(θk) =
∑i :zi=k
l(θk , xi , yi )
I r(θ) =∑K
k=1 r(θk) is (local) regularization
I L(θ) =∑K
u,v=1Wuv‖θu − θv‖2 is Laplacian regularization
I `, r are separable in θkI L is quadratic, separable in components of θk
I We’ll use operator splitting method (ADMM)
Solution method 22
Reformulation
I Replicate variables:minimize `(θ) + r(θ̃) + L(θ̂)subject to θ = θ̃ = θ̂
I Augmented Lagrangian
L(θ, θ̃, θ̂, u, ũ) = `(θ) + r(θ̃) + L(θ̂) + 12t‖θ − θ̂ + u‖22 +
1
2t‖θ̃ − θ̂ + ũ‖22
I u, ũ dual variables for the two constraints, t > 0 is algorithm parameter
Solution method 23
ADMM
I Augmented Lagrangian
L(θ, θ̃, θ̂, u, ũ) = `(θ) + r(θ̃) + L(θ̂) + 12t‖θ − θ̂ + u‖22 +
1
2t‖θ̃ − θ̂ + ũ‖22
I ADMM: for i = 1, 2, . . .
θi+1, θ̃i+1 := argminθ,θ̃
L(θ, θ̃, θ̂i , ui , ũi )
θ̂i+1 := argminθ̂
L(θi , θ̃i , θ̂, ui , ũi )
ui+1 := ui + θi+1 − θ̂i+1
ũi+1 := ũi + θ̃i+1 − θ̂i+1
Solution method 24
ADMM
I First step can be expressed as
θi+1k = proxt`k (θ̂ik − uik), θ̃i+1k = proxtr (θ̂
ik − ũik), k = 1, . . . ,K
I proximal operator of tg is
proxtg (ν) = argminθ
(tg(θ) + (1/2)‖θ − ν‖22
)I Can evaluate these 2K proximal operators in parallel
Solution method 25
Loss proximal operators
I Often has closed form expression or efficient implementationI Square loss: `k(θk) = ‖Xkθk − yk‖22
– proxt`k (ν) = (I + tXTk Xk)
−1(ν − tXTk yk)– Cache factorization or warm-start CG
I Logistic loss: `k(θk) =∑
i log(1 + exp(−ykiθTk xki ))– Use L-BFGS to evaluate proxt`k (ν)– Warm-start
I Many others . . .
Solution method 26
Regularizer proximal operators
I Often has closed form expression or efficient implementation
Regularizer r(θk) proxtr (ν)
Sum of squares (`2) (λ/2)‖θk‖22 ν/(tλ+ 1)
`1 norm λ‖θk‖1 (ν − tλ)+ − (−ν − tλ)+Nonnegative I+(θk) (θk)+
I λ > 0 local regularization parameter
Solution method 27
Laplacian proximal operator
I Separable across each component of θkI To find each component (θk)i , need to solve a Laplacian system
I Many efficient ways to solve, e.g., diagonally preconditioned CG
I These n systems can be solved in parallel
Solution method 28
Software implementation
I Available at www.github.com/cvxgrp/strat_models
I numpy, scipy for matrix operations
I networkx for handling graphs and graph operations
I torch for L-BFGS and GPU computation
I multiprocessing for parallelism
I model.fit(X,Y,Z,G) (writes θk on graph nodes)
Solution method 29
www.github.com/cvxgrp/strat_modelsnumpyscipynetworkxtorchmultiprocessingmodel.fit(X,Y,Z,G)
Outline
Stratified models
Data models
Regularization graphs
Solution method
Examples
Conclusions
Examples 30
House price prediction
I N ≈ 22000 (z , x , y) from King County, WAI Split into train/test ≈ 16200/5400I z = (latitude bin, longitude bin)
I x ∈ R10 = features of house, y = log of house priceI Graph is 50× 50 grid with all edge weights same; K = 2500I Stratified ridge regression model with two hyperparameters
(one for local regularization, one for Laplacian regularization)
Examples 31
House price prediction: Results
I Compare stratifed, common, and random forest with 50 trees
Model Parameters RMS test error
Stratified 25000 0.181Common 10 0.316Random forest 985888 0.184
Examples 32
House price prediction: Parameters
bedrooms bathrooms sqft living sqft lot floors
waterfront condition grade yr built intercept
Examples 33
Chicago crime prediction
I N ≈ 7 000 000 (z , y) pairs for 2017, 2018I Train on 2017, test on 2018
I y = number of crimes
I z = (location bin,week of year, day of week, hour of day); K ≈ 3 500 000I Graph is Cartesian product of grid, three cycles; four graph edge weights
I Stratified Poisson model with four hyperparameters(one for each graph edge weight)
Examples 34
Chicago crime prediction
I Compare three models, average negative log likelihood on test data
Model Train Test
Separate 0.068 0.740Stratified 0.221 0.234Common 0.279 0.278
Examples 35
Chicago crime prediction
Crime rate as a function of latitude/longitude
-87.85 -87.77 -87.69 -87.61 -87.52
42.02
41.93
41.83
41.74
41.64
0.02
0.04
0.06
0.08
0.10
0.12
Examples 36
Chicago crime prediction
Crime rate as a function of week of year
Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec
0.036
0.038
0.040
0.042
0.044
Examples 37
Chicago crime prediction
Crime rate as a function of hour of week
M Tu W Th F Sa Su
0.020
0.025
0.030
0.035
0.040
0.045
0.050
0.055
Examples 38
Outline
Stratified models
Data models
Regularization graphs
Solution method
Examples
Conclusions
Conclusions 39
Conclusions
I Stratified models combine
– simple dependence on some features (x)– complex dependence on others (z)
I Often interpretable
I Laplacian regularization encodes prior on values of z , so models can borrowstrength from their neighbors
I Effective method to build time-varying, space-varying, seasonally-varying models
I Efficient, distributed ADMM-based implementation for large-scale data
Conclusions 40
Stratified modelsData modelsRegularization graphsSolution methodExamplesConclusions