Direct Optimization · Kingma et al 2013. Standard (Gaussian) VAE Kingma et al 2013. Standard...

Preview:

Citation preview

Direct OptimizationCSC2547

Adamo Young, Dami Choi, Sepehr Abbasi Zadeh

Direct Optimization

● A way to obtain gradient estimates that directly optimizes a non-differentiable objective.

● It has first appeared in structured prediction problems.

Structured PredictionWhenever the goal state has inter-dependency

Image from Wikipedia Image from http://dbmsnotes-ritu.blogspot.com/

Structured Prediction

Scoring function , discrete

Structured Prediction

Inference:

Structured Prediction

Scoring function , discrete

Structured Prediction

Inference:

Training:

Structured Prediction

Scoring function , discrete

Gradient Estimator

Gradient Estimator

● Gradient descent on discrete :

Gradient Estimator

● Gradient descent on discrete :

● Option 1: continuous relaxation

Gradient Estimator

● Gradient descent on discrete :

● Option 1: continuous relaxation● Option 2: estimate

Loss Gradient Theorem (McAllester et al., 2010;Song et al,. 2016)

Loss Gradient Theorem (McAllester et al., 2010;Song et al,. 2016)

Loss Gradient Theorem (McAllester et al., 2010;Song et al,. 2016)

Inference:

Loss-augmented Inference:

Loss Gradient Theorem (McAllester et al., 2010;Song et al,. 2016)

Loss Gradient Theorem (McAllester et al., 2010;Song et al,. 2016)

“Away from worse” “Towards better”

Limitations

● Existence of ○ Bias/variance trade-off

● Solving argmax of loss-adjusted inference

Applications● Phoneme-to-speech alignment (McAllester et al. 2010)

● Maximizing average precision for ranking (Song et al. 2016)

● Discrete structured VAE (Lorberbom et al. 2018)

● RL with discrete action spaces (Lorberbom et al. 2019)

Applications● Phoneme-to-speech alignment (McAllester et al. 2010)

● Maximizing average precision for ranking (Song et al. 2016)

● Discrete structured VAE (Lorberbom et al. 2018)

● RL with discrete action spaces (Lorberbom et al. 2019)

Direct Optimization through arg max for Discrete

Variational Auto-EncoderGuy Lorberbom, Andreea Gane, Tommi Jaakola,

Tamir Hazan

Probability Background

● Gumbel Distribution● Various Sampling “Tricks”

○ Reparameterization○ Gumbel-Max○ Gumbel-Softmax

Gumbel Distribution

Intuitively: Distribution of extreme value of a number of normally distributed samples

x

p(x)

https://en.wikipedia.org/wiki/Gumbel_distribution

Dot = parameter nodeRectangle = deterministic node

Circle = stochastic nodeLine = functional dependency

Gradient Estimators for Stochastic Computation Graphs

Schulman et al 2016

Gradient Estimators for Stochastic Computation Graphs

Dot = parameter nodeRectangle = deterministic node

Circle = stochastic nodeLine = functional dependency

Red Line = gradient propagation

Reparameterization Trick

Kingma et al 2015

Reparameterization Trick

REINFORCE/REBAR/RELAX Reparam

Williams 1988Tucker et al 2016Grathwohl et al 2017

Gumbel-Max Trick

Gumbel-Max Trick

REINFORCE/REBAR/RELAX Direct Optimization

Gumbel-Softmax Trick

REINFORCE/REBAR/RELAX CONCRETE

Jang et al 2017Maddison et al 2017

Gumbel-Softmax Distribution

Jang et al 2017

Why discrete latent variables?

● Stronger inductive bias● Interpretability● Allow structural relations in encoder

Standard (Gaussian) VAE

Kingma et al 2013

Standard (Gaussian) VAE

Kingma et al 2013

Standard (Gaussian) VAE

Kingma et al 2013

Standard (Gaussian) VAE

Kingma et al 2013

Naive Categorical VAE

Naive Categorical VAE

Naive Categorical VAE

Naive Categorical VAE

We can apply standard gradient estimators (REINFORCE/REBAR/RELAX)

Gumbel-Max VAE

Gumbel-Max VAE + Direct Optimization

Gumbel-Max VAE + Direct Optimization

Gumbel-Max VAE + Direct Optimization

Algorithm:1) Sample from Gumbel2) Compute 3) Estimate gradient

Structured Encoder

No structure:

Structured Encoder

No structure:

Pairwise relationships:

Solve argmax with QIP/MaxFlow

Structured Encoder

No structure:

Pairwise relationships:

Solve with CPLEX/Max Flow

Not practical with Gumbel-Softmax: exponential number of terms to sum over in the denominator

Structured Encoder may help

Gradient Bias-Variance Tradeoff

Direct Gumbel-Max VAE (with associated epsilon)Gumbel-Softmax VAE (with associated tau)

Direct Gumbel-Max VAE trains fasterK = 10

VAE Comparison

Standard (Gaussian) Gumbel-Softmax Naive Categorical + standard gradient estimator

Gumbel-Max + Direct

+ Unbiased, low variance gradients

+ Discrete latent variables

+ Discrete latent variables

+ Unbiased gradients

+ Discrete latent variables

+ Allows structural relations

- Continuous latent variables

- Limited structural relations

- Biased gradients- Limited structural

relations- Extra parameter (tau)

- Limited structural relations

- Biased gradients- Extra parameter

(epsilon)- Optimization

subproblem to get gradients

Direct Policy Gradients: Direct Optimization of Policies in

Discrete Action SpacesGuy Lorberbom, Chris J. Maddison, Nicolas Heess,

Tamir Hazan, Daniel Tarlow

Reinforcement Learning

Agent

Environment

actionreward, state

Goal:Maximize cumulative reward

Policy Gradient Method

Goal:

Agent

Environment

actionreward, state

Policy Gradient Method

Want:

REINFORCE:

Policy Gradient Method

Want:

REINFORCE:

Direct Policy Gradient:

State Reward TreeTree of all possible trajectories (fix the seed of the environment)

Separate environment stochasticity and policy stochasticity

State Reward TreeGiven:

Can sample trajectories:

Reparameterize the PolicyInstead of sampling per-timestep

we sample per-trajectory.

Given action sequences ,

define:

Gumbel-max reparameterizationNow that we have

Let for each trajectory , and

Gumbel-max reparameterizationNow that we have

Let for each trajectory , and

Gumbel-max reparameterization

Let , and .

Then under this reparameterization,

Discrete configurations

Scoring function

Loss

Inference

Loss-augmentedInference

Structured Prediction RL

Discrete configurations

Scoring function

Loss

Inference

Loss-augmentedInference

Structured Prediction RL

Discrete configurations

Scoring function

Loss

Inference

Loss-augmentedInference

Structured Prediction RL

Discrete configurations

Scoring function

Loss

Inference

Loss-augmentedInference

Structured Prediction RL

Discrete configurations

Scoring function

Loss

Inference

Loss-augmentedInference

Structured Prediction RL

Discrete configurations

Scoring function

Loss

Inference

Loss-augmentedInference

Structured Prediction RL

Direct Policy Gradient (DirPG)

Direct Policy Gradient (DirPG)

Direct Policy Gradient (DirPG)

AlgorithmFor every training step:

1. Sample

2.

3. Compute gradients

ProblemFor every training step:

1. Sample

2. ⇐ How to obtain this?

3. Compute gradients

Solution: A* sampling (Maddison et al., 2014)

Use heuristic search to find trajectory with direct objective better than

Complete AlgorithmFor every training step:

1. Sample and compute

2. While budget not exceeded:

a. Obtain from heuristic search

b. End search if

3. Compute gradients

LimitationsFor every training step:

1. Sample and compute

2. While budget not exceeded:

a. Obtain from heuristic search

b. End search if

3. Compute gradients

Must be able to reset environment to previously visited states.

LimitationsFor every training step:

1. Sample and compute

2. While budget not exceeded:

a. Obtain from heuristic search

b. End search if

3. Compute gradients

Must be able to reset environment to previously visited states.

Termination on first improvement

Combinatorial banditsNumber of trajectories searched to find increases as training progresses for combinatorial bandits.

MiniGridComparisons between different heuristics for DirPG and REINFORCE on MiniGrid.

MiniGridEvidence of “pulling up” on MiniGrid.

Related Work● Gradient Estimators

○ REINFORCE (Williams 1988)○ REBAR (Tucker et al 2017)○ RELAX (Grathwohl et al 2018)○ Gumbel-Softmax (Jang et al 2017, Maddison et al 2017)

● Discrete Deep Generative Models○ VQ-VAE (Oord et al 2017)○ Discrete VAE (Rolfe 2017)○ Gumbel-Sinkhorn (Mena at al 2018)

● Reinforcement Learning

Top-Down sampling using A* Sampling

Non-starters● Compute for all possible trajectories

● Roll-out many trajectories and select best

Gumbel Process

Gumbel ProcessWe know:

Gumbel ProcessWe know:

Therefore:

Gumbel ProcessWe know:

Gumbel Process

A B

Gumbel Process

A B

Gumbel Process

A B

Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:

○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.

Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:

○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.

Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:

○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.

Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:

○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.

1.3

Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:

○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.

1.3

Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:

○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.

1.3

Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:

○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.

1.3

1.3

Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:

○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.

1.3

1.3

Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:

○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.

1.3

1.3 1.1

Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:

○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.

1.3

1.3 1.1

Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:

○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.

1.3

1.3 1.1

Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:

○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.

1.3

1.3 1.1

1.3

Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:

○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.

1.3

1.3 1.1

1.3

Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:

○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.

1.3

1.3 1.1

1.30.19

Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:

○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.

1.3

1.3 1.1

1.30.19

● Repeat until terminating state found.

● Yield trajectory and

Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:

○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.

1.3

1.3 1.1

1.30.19

● Repeat until terminating state found.

● Yield trajectory and

Recall, Goal:

How to prioritize ?

Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:

○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.

1.3

1.3 1.1

1.30.19

● Repeat until terminating state found.

● Yield trajectory and

Recall, Goal:

How to prioritize ?

Search for large using A* Sampling● Lower bound of accumulated reward (L)

● Upper bound of reward-to-go (U)

● In practice:

Recommended